pavanmutha commited on
Commit
1071fde
·
verified ·
1 Parent(s): 08666a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -13,6 +13,7 @@ from sklearn.ensemble import RandomForestClassifier
13
  from sklearn.model_selection import train_test_split, cross_val_score
14
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
  from sklearn.preprocessing import LabelEncoder
 
16
 
17
  # Authenticate with Hugging Face
18
  hf_token = os.getenv("HF_TOKEN")
@@ -53,6 +54,8 @@ agent = CodeAgent(
53
  additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"]
54
  )
55
 
 
 
56
  def run_agent(_):
57
  if df_global is None:
58
  return "Please upload a file first.", []
@@ -73,7 +76,7 @@ def run_agent(_):
73
  - At least 3 visualizations showing important trends.
74
  4. Derive at least 3 actionable real-world insights.
75
  5. Save all visualizations to ./figures/ directory.
76
- Return a dictionary with keys:
77
  - 'insights': clean bullet-point insights.
78
  - 'figures': list of file paths of generated visualizations.
79
  """
@@ -83,11 +86,22 @@ def run_agent(_):
83
  additional_args={"source_file": temp_file.name}
84
  )
85
 
86
- # Now, result is expected to be a dictionary
87
- insights = result.get("insights", "No insights generated.")
88
- image_paths = result.get("figures", [])
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- return insights, image_paths
91
 
92
 
93
 
 
13
  from sklearn.model_selection import train_test_split, cross_val_score
14
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
  from sklearn.preprocessing import LabelEncoder
16
+ from smolagent.types import AgentText
17
 
18
  # Authenticate with Hugging Face
19
  hf_token = os.getenv("HF_TOKEN")
 
54
  additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"]
55
  )
56
 
57
+
58
+
59
  def run_agent(_):
60
  if df_global is None:
61
  return "Please upload a file first.", []
 
76
  - At least 3 visualizations showing important trends.
77
  4. Derive at least 3 actionable real-world insights.
78
  5. Save all visualizations to ./figures/ directory.
79
+ Return a JSON object with keys:
80
  - 'insights': clean bullet-point insights.
81
  - 'figures': list of file paths of generated visualizations.
82
  """
 
86
  additional_args={"source_file": temp_file.name}
87
  )
88
 
89
+ # Convert AgentText object to string and parse it
90
+ if isinstance(result, AgentText):
91
+ import json
92
+ result_str = result.text.strip()
93
+
94
+ try:
95
+ result_dict = json.loads(result_str)
96
+ except json.JSONDecodeError:
97
+ return f"Error decoding agent response: {result_str}", []
98
+
99
+ insights = result_dict.get("insights", "No insights generated.")
100
+ image_paths = result_dict.get("figures", [])
101
+ return insights, image_paths
102
+ else:
103
+ return "Unexpected result type from agent", []
104
 
 
105
 
106
 
107