pavanmutha commited on
Commit
f2d52e9
·
verified ·
1 Parent(s): 8115491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -19
app.py CHANGED
@@ -31,8 +31,10 @@ login(token=hf_token)
31
  # SmolAgent initialization
32
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
33
 
 
34
  df_global = None
35
  target_column_global = None
 
36
 
37
  def clean_data(df):
38
  df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
@@ -42,24 +44,44 @@ def clean_data(df):
42
  df = df.fillna(df.mean(numeric_only=True))
43
  return df
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def upload_file(file):
46
- global df_global
47
  if file is None:
48
  return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
 
49
  ext = os.path.splitext(file.name)[-1]
50
  df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
51
  df = clean_data(df)
52
  df_global = df
 
53
  return df.head(), gr.update(choices=df.columns.tolist())
54
 
55
-
56
-
57
  def set_target_column(col_name):
58
  global target_column_global
59
  target_column_global = col_name
60
  return f"✅ Target column set to: {col_name}"
61
 
62
 
 
63
  def format_analysis_report(raw_output, visuals):
64
  import json
65
 
@@ -195,7 +217,7 @@ def extract_json_from_codeagent_output(raw_output):
195
 
196
 
197
 
198
- def analyze_data(csv_file, additional_notes=""):
199
  try:
200
  start_time = time.time()
201
  process = psutil.Process(os.getpid())
@@ -211,7 +233,7 @@ def analyze_data(csv_file, additional_notes=""):
211
  run = wandb.init(project="huggingface-data-analysis", config={
212
  "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
213
  "additional_notes": additional_notes,
214
- "source_file": csv_file.name if csv_file else None
215
  })
216
 
217
  # Initialize Code Agent
@@ -221,8 +243,23 @@ def analyze_data(csv_file, additional_notes=""):
221
  additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"]
222
  )
223
 
224
- # Enhanced prompt with strict formatting requirements
225
- prompt = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  You are a helpful data analysis agent. Please follow these very strict instructions and formatting:
227
 
228
  1. Load the data from the provided `source_file`.
@@ -259,10 +296,11 @@ Be concise and avoid any narrative outside this final dictionary.
259
  Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allowed)
260
  """
261
 
262
- # Run the agent
263
- analysis_result = agent.run(prompt, additional_args={
264
  "additional_notes": additional_notes,
265
- "source_file": csv_file.name if csv_file else None
 
266
  })
267
 
268
  # Performance metrics
@@ -281,25 +319,19 @@ Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allo
281
  if f.endswith(('.png', '.jpg', '.jpeg'))
282
  ])
283
 
284
- # Log visuals to WandB
285
  for viz in visuals:
286
  wandb.log({os.path.basename(viz): wandb.Image(viz)})
287
 
288
  run.finish()
289
-
290
  print("DEBUG - Raw agent output:", analysis_result[:500] + "...")
291
- print("Columns in data:", df_global.columns.tolist())
292
- print("Data types:", df_global.dtypes)
293
  with open("agent_output.txt", "w") as f:
294
  f.write(str(analysis_result))
295
- # Parse the agent output
296
- parsed_result = extract_json_from_codeagent_output(analysis_result)
297
- print(f"DEBUG - Parsed result: {parsed_result}") # Debug output
298
 
 
299
  if parsed_result:
300
  return format_analysis_report(parsed_result, visuals)
301
  else:
302
- # Fallback to showing raw output if parsing fails
303
  error_msg = f"Failed to parse agent output. Showing raw response:\n{str(analysis_result)[:2000]}"
304
  print(error_msg)
305
  return f"<pre>{error_msg}</pre>", visuals
@@ -309,7 +341,7 @@ Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allo
309
  print(error_msg)
310
  return f"<pre>{error_msg}</pre>", []
311
 
312
-
313
  def compare_models():
314
  import seaborn as sns
315
  from sklearn.model_selection import cross_val_predict
 
31
  # SmolAgent initialization
32
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
33
 
34
+ # Globals
35
  df_global = None
36
  target_column_global = None
37
+ data_summary_global = None # ⬅️ Added for summarized data
38
 
39
  def clean_data(df):
40
  df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
 
44
  df = df.fillna(df.mean(numeric_only=True))
45
  return df
46
 
47
+ def summarize_data(df: pd.DataFrame, max_cols: int = 10, max_rows: int = 5) -> str:
48
+ summary = []
49
+ summary.append(f"Dataset shape: {df.shape}")
50
+ summary.append("\nColumn types:\n" + str(df.dtypes))
51
+
52
+ num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
53
+ cat_cols = df.select_dtypes(exclude=[np.number]).columns.tolist()
54
+
55
+ summary.append("\nMissing values:\n" + str(df.isnull().sum()))
56
+
57
+ if num_cols:
58
+ summary.append("\nNumerical summary:\n" + str(df[num_cols].describe().T.head(max_rows)))
59
+ if cat_cols:
60
+ summary.append("\nCategorical value counts (top categories):")
61
+ for col in cat_cols[:max_cols]:
62
+ summary.append(f"\nColumn: {col}\n{df[col].value_counts().head(max_rows)}")
63
+
64
+ return "\n".join(summary)
65
+
66
  def upload_file(file):
67
+ global df_global, data_summary_global
68
  if file is None:
69
  return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
70
+
71
  ext = os.path.splitext(file.name)[-1]
72
  df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
73
  df = clean_data(df)
74
  df_global = df
75
+ data_summary_global = summarize_data(df) # ⬅️ Summarize here
76
  return df.head(), gr.update(choices=df.columns.tolist())
77
 
 
 
78
  def set_target_column(col_name):
79
  global target_column_global
80
  target_column_global = col_name
81
  return f"✅ Target column set to: {col_name}"
82
 
83
 
84
+
85
  def format_analysis_report(raw_output, visuals):
86
  import json
87
 
 
217
 
218
 
219
 
220
+ def analyze_data(csv_file=None, additional_notes="", use_summary=True):
221
  try:
222
  start_time = time.time()
223
  process = psutil.Process(os.getpid())
 
233
  run = wandb.init(project="huggingface-data-analysis", config={
234
  "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
235
  "additional_notes": additional_notes,
236
+ "source_file": csv_file.name if csv_file else "summarized_input"
237
  })
238
 
239
  # Initialize Code Agent
 
243
  additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"]
244
  )
245
 
246
+ # Choose prompt content
247
+ if use_summary and data_summary_global:
248
+ input_data = data_summary_global
249
+ data_instruction = """
250
+ You are analyzing summarized dataset information from a CSV file. Your job is to:
251
+
252
+ 1. Interpret the summary content as if it was produced from a real dataset.
253
+ 2. Derive at least 5 high-level insights based on column types, distributions, missing values, etc.
254
+ 3. Imagine or mock visualizations and describe what they would show. Use synthetic data simulation with numpy/pandas if needed.
255
+ 4. Save plots to './figures/' using matplotlib or seaborn.
256
+
257
+ Always respond in the structured dictionary format below.
258
+ """
259
+ else:
260
+ # Fall back to full file input
261
+ input_data = None # You load file within the agent
262
+ data_instruction = """
263
  You are a helpful data analysis agent. Please follow these very strict instructions and formatting:
264
 
265
  1. Load the data from the provided `source_file`.
 
296
  Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allowed)
297
  """
298
 
299
+ # Run agent with either summarized content or CSV
300
+ analysis_result = agent.run(data_instruction, additional_args={
301
  "additional_notes": additional_notes,
302
+ "source_file": csv_file.name if csv_file and not use_summary else None,
303
+ "data_summary": input_data if use_summary else None
304
  })
305
 
306
  # Performance metrics
 
319
  if f.endswith(('.png', '.jpg', '.jpeg'))
320
  ])
321
 
 
322
  for viz in visuals:
323
  wandb.log({os.path.basename(viz): wandb.Image(viz)})
324
 
325
  run.finish()
326
+
327
  print("DEBUG - Raw agent output:", analysis_result[:500] + "...")
 
 
328
  with open("agent_output.txt", "w") as f:
329
  f.write(str(analysis_result))
 
 
 
330
 
331
+ parsed_result = extract_json_from_codeagent_output(analysis_result)
332
  if parsed_result:
333
  return format_analysis_report(parsed_result, visuals)
334
  else:
 
335
  error_msg = f"Failed to parse agent output. Showing raw response:\n{str(analysis_result)[:2000]}"
336
  print(error_msg)
337
  return f"<pre>{error_msg}</pre>", visuals
 
341
  print(error_msg)
342
  return f"<pre>{error_msg}</pre>", []
343
 
344
+
345
  def compare_models():
346
  import seaborn as sns
347
  from sklearn.model_selection import cross_val_predict