pavanmutha commited on
Commit
6b33a99
·
verified ·
1 Parent(s): 201e59d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -7,12 +7,14 @@ import shap
7
  import lime.lime_tabular
8
  import optuna
9
  import wandb
 
10
  from smolagents import HfApiModel, CodeAgent
11
  from huggingface_hub import login
12
  from sklearn.ensemble import RandomForestClassifier
13
  from sklearn.model_selection import train_test_split, cross_val_score
14
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
  from sklearn.preprocessing import LabelEncoder
 
16
 
17
  # Authenticate with Hugging Face
18
  hf_token = os.getenv("HF_TOKEN")
@@ -79,6 +81,7 @@ def run_agent(_):
79
  - At least 3 visualizations showing important trends.
80
  4. Derive at least 3 actionable real-world insights.
81
  5. Save all visualizations to ./figures/ directory.
 
82
  Return a JSON object with keys:
83
  - 'insights': clean bullet-point insights.
84
  - 'figures': list of file paths of generated visualizations.
@@ -89,17 +92,29 @@ def run_agent(_):
89
  additional_args={"source_file": temp_file.name}
90
  )
91
 
92
- # Process result as a dictionary if possible
 
 
 
 
 
 
93
  if isinstance(result, dict):
94
  insights = result.get("insights", "No insights generated.")
95
  image_paths = result.get("figures", [])
96
  else:
97
- insights = "Error: The result is not in the expected format."
98
  image_paths = []
99
 
100
- return insights, image_paths
101
-
 
 
 
 
 
102
 
 
103
  def train_model(_):
104
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
105
  run_counter = 1
@@ -247,12 +262,6 @@ with gr.Blocks() as demo:
247
  shap_img = gr.Image(label="SHAP Summary Plot")
248
  lime_img = gr.Image(label="LIME Explanation")
249
 
250
- with gr.Row():
251
- agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
252
- insights_output = gr.Textbox(label="Insights from SmolAgent", lines=15)
253
- #visual_output = gr.Gallery(label="Generated Visualizations").style(grid=3, height="auto")
254
- visual_output = gr.Gallery(label="Generated Visualizations", columns=[3], height=400)
255
-
256
 
257
  #agent_btn.click(fn=run_agent, inputs=df_output, outputs=insights_output)
258
  agent_btn.click(fn=run_agent, inputs=df_output, outputs=[insights_output, visual_output])
 
7
  import lime.lime_tabular
8
  import optuna
9
  import wandb
10
+ import json
11
  from smolagents import HfApiModel, CodeAgent
12
  from huggingface_hub import login
13
  from sklearn.ensemble import RandomForestClassifier
14
  from sklearn.model_selection import train_test_split, cross_val_score
15
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
16
  from sklearn.preprocessing import LabelEncoder
17
+ from PIL import Image
18
 
19
  # Authenticate with Hugging Face
20
  hf_token = os.getenv("HF_TOKEN")
 
81
  - At least 3 visualizations showing important trends.
82
  4. Derive at least 3 actionable real-world insights.
83
  5. Save all visualizations to ./figures/ directory.
84
+ 6. "Ensure that all visualizations have figure size at least 8x6 and saved at 150+ dpi."
85
  Return a JSON object with keys:
86
  - 'insights': clean bullet-point insights.
87
  - 'figures': list of file paths of generated visualizations.
 
92
  additional_args={"source_file": temp_file.name}
93
  )
94
 
95
+ if isinstance(result, str):
96
+ try:
97
+ result = json.loads(result)
98
+ except json.JSONDecodeError:
99
+ insights = "Failed to parse result from agent."
100
+ return insights, []
101
+
102
  if isinstance(result, dict):
103
  insights = result.get("insights", "No insights generated.")
104
  image_paths = result.get("figures", [])
105
  else:
106
+ insights = "Error: Unexpected result format from agent."
107
  image_paths = []
108
 
109
+ images = []
110
+ for path in image_paths:
111
+ try:
112
+ images.append(Image.open(path))
113
+ except Exception as e:
114
+ print(f"Error loading image {path}: {e}")
115
+ return insights, images
116
 
117
+
118
  def train_model(_):
119
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
120
  run_counter = 1
 
262
  shap_img = gr.Image(label="SHAP Summary Plot")
263
  lime_img = gr.Image(label="LIME Explanation")
264
 
 
 
 
 
 
 
265
 
266
  #agent_btn.click(fn=run_agent, inputs=df_output, outputs=insights_output)
267
  agent_btn.click(fn=run_agent, inputs=df_output, outputs=[insights_output, visual_output])