pavanmutha commited on
Commit
cb1d344
·
verified ·
1 Parent(s): f2d52e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -203
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import re
3
  import gradio as gr
@@ -24,45 +25,22 @@ from sklearn.preprocessing import LabelEncoder
24
  from datetime import datetime
25
  from PIL import Image
26
 
 
27
  # Authenticate with Hugging Face
28
  hf_token = os.getenv("HF_TOKEN")
29
  login(token=hf_token)
30
 
 
31
  # SmolAgent initialization
32
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
33
 
 
34
  # Globals
35
  df_global = None
36
  target_column_global = None
37
- data_summary_global = None # ⬅️ Added for summarized data
38
 
39
- def clean_data(df):
40
- df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
41
- for col in df.select_dtypes(include='object').columns:
42
- df[col] = df[col].astype(str)
43
- df[col] = LabelEncoder().fit_transform(df[col])
44
- df = df.fillna(df.mean(numeric_only=True))
45
- return df
46
-
47
- def summarize_data(df: pd.DataFrame, max_cols: int = 10, max_rows: int = 5) -> str:
48
- summary = []
49
- summary.append(f"Dataset shape: {df.shape}")
50
- summary.append("\nColumn types:\n" + str(df.dtypes))
51
-
52
- num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
53
- cat_cols = df.select_dtypes(exclude=[np.number]).columns.tolist()
54
-
55
- summary.append("\nMissing values:\n" + str(df.isnull().sum()))
56
-
57
- if num_cols:
58
- summary.append("\nNumerical summary:\n" + str(df[num_cols].describe().T.head(max_rows)))
59
- if cat_cols:
60
- summary.append("\nCategorical value counts (top categories):")
61
- for col in cat_cols[:max_cols]:
62
- summary.append(f"\nColumn: {col}\n{df[col].value_counts().head(max_rows)}")
63
-
64
- return "\n".join(summary)
65
 
 
66
  def upload_file(file):
67
  global df_global, data_summary_global
68
  if file is None:
@@ -80,6 +58,124 @@ def set_target_column(col_name):
80
  target_column_global = col_name
81
  return f"✅ Target column set to: {col_name}"
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  def format_analysis_report(raw_output, visuals):
@@ -165,182 +261,6 @@ def format_insights(insights, visuals):
165
 
166
 
167
 
168
- ### ✅ 2. Add a pre-check fallback for non-compliant agent outputs
169
-
170
- def extract_json_from_codeagent_output(raw_output):
171
- import re, json, ast
172
-
173
- try:
174
- # Extract code blocks from ```python ... ```
175
- code_blocks = re.findall(r"```(?:py|python)?\n(.*?)```", raw_output, re.DOTALL)
176
- for block in code_blocks:
177
- # Try extracting from print(json.dumps({...}))
178
- json_match = re.search(
179
- r"print\(\s*json\.dumps\(\s*(\{[\s\S]*?\})\s*\)\s*\)",
180
- block,
181
- re.DOTALL
182
- ) or re.search(
183
- r"json\.dumps\(\s*(\{[\s\S]*?\})\s*\)",
184
- block,
185
- re.DOTALL
186
- )
187
- if json_match:
188
- return json.loads(json_match.group(1))
189
-
190
- # Try extracting from: result = {...}
191
- result_match = re.search(
192
- r"result\s*=\s*(\{[\s\S]*?\})",
193
- block,
194
- re.DOTALL
195
- )
196
- if result_match:
197
- raw_dict = result_match.group(1)
198
- try:
199
- return json.loads(raw_dict) # Try strict JSON
200
- except json.JSONDecodeError:
201
- return ast.literal_eval(raw_dict) # Try Python dict parsing
202
-
203
- # Final fallback: look for any dict-like thing in entire output
204
- fallback_match = re.search(r"\{[\s\S]+\}", raw_output)
205
- if fallback_match:
206
- raw_dict = fallback_match.group(0)
207
- try:
208
- return json.loads(raw_dict)
209
- except json.JSONDecodeError:
210
- return ast.literal_eval(raw_dict)
211
-
212
- except Exception as e:
213
- print(f"extract_json_from_codeagent_output() failed: {e}")
214
- return None
215
-
216
-
217
-
218
-
219
-
220
- def analyze_data(csv_file=None, additional_notes="", use_summary=True):
221
- try:
222
- start_time = time.time()
223
- process = psutil.Process(os.getpid())
224
- initial_memory = process.memory_info().rss / 1024 ** 2
225
-
226
- # Clean up and prepare directories
227
- if os.path.exists('./figures'):
228
- shutil.rmtree('./figures')
229
- os.makedirs('./figures', exist_ok=True)
230
-
231
- # Initialize WandB
232
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
233
- run = wandb.init(project="huggingface-data-analysis", config={
234
- "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
235
- "additional_notes": additional_notes,
236
- "source_file": csv_file.name if csv_file else "summarized_input"
237
- })
238
-
239
- # Initialize Code Agent
240
- agent = CodeAgent(
241
- tools=[],
242
- model=model,
243
- additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"]
244
- )
245
-
246
- # Choose prompt content
247
- if use_summary and data_summary_global:
248
- input_data = data_summary_global
249
- data_instruction = """
250
- You are analyzing summarized dataset information from a CSV file. Your job is to:
251
-
252
- 1. Interpret the summary content as if it was produced from a real dataset.
253
- 2. Derive at least 5 high-level insights based on column types, distributions, missing values, etc.
254
- 3. Imagine or mock visualizations and describe what they would show. Use synthetic data simulation with numpy/pandas if needed.
255
- 4. Save plots to './figures/' using matplotlib or seaborn.
256
-
257
- Always respond in the structured dictionary format below.
258
- """
259
- else:
260
- # Fall back to full file input
261
- input_data = None # You load file within the agent
262
- data_instruction = """
263
- You are a helpful data analysis agent. Please follow these very strict instructions and formatting:
264
-
265
- 1. Load the data from the provided `source_file`.
266
- 2. FIRST analyze the data structure (column names and types)
267
- 3. THEN generate visualizations using EXISTING columns with least 5 visualizations and 5 insights.
268
- 4. Save all plots to `./figures/` as PNGs using matplotlib or seaborn.
269
- 5. DO NOT use open() or print() statements.
270
- 6. Use only authorized imports: `pandas`, `numpy`, `matplotlib.pyplot`, `seaborn`, `json`.
271
- 7. DO NOT return any explanations, thoughts, or narration outside the final output block.
272
- 8. DO NOT use `...` in any dictionary values, arrays, or code blocks.
273
- 9. Use empty lists like [] or strings like "N/A" instead.
274
- 10.Respond only with a JSON-serializable dictionary in Python syntax. Do not include any thoughts, comments, or explanation.
275
- 11. Any logging or warnings must be disabled or redirected; the only stdout must be the single print(json.dumps(result)) call.
276
- 12. FINALLY return ONLY this exact format:
277
-
278
- ```python
279
- import json
280
-
281
- result = {
282
- "observations": {
283
- "numeric_columns": [...],
284
- "categorical_columns": [...],
285
- "data_issues": "..."
286
- },
287
- "insights": [
288
- {"category": "Insight A", "insight": "Description of insight A"},
289
- {"category": "Insight B", "insight": "Description of insight B"}
290
- ]
291
- }
292
-
293
- print(json.dumps(result))
294
- ```<end_code>
295
- Be concise and avoid any narrative outside this final dictionary.
296
- Never use unauthorized imports (only pandas, numpy, matplotlib, seaborn are allowed)
297
- """
298
-
299
- # Run agent with either summarized content or CSV
300
- analysis_result = agent.run(data_instruction, additional_args={
301
- "additional_notes": additional_notes,
302
- "source_file": csv_file.name if csv_file and not use_summary else None,
303
- "data_summary": input_data if use_summary else None
304
- })
305
-
306
- # Performance metrics
307
- execution_time = time.time() - start_time
308
- final_memory = process.memory_info().rss / 1024 ** 2
309
- memory_usage = final_memory - initial_memory
310
-
311
- wandb.log({
312
- "execution_time_sec": execution_time,
313
- "memory_usage_mb": memory_usage
314
- })
315
-
316
- # Collect visualizations
317
- visuals = sorted([
318
- os.path.join('./figures', f) for f in os.listdir('./figures')
319
- if f.endswith(('.png', '.jpg', '.jpeg'))
320
- ])
321
-
322
- for viz in visuals:
323
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
324
-
325
- run.finish()
326
-
327
- print("DEBUG - Raw agent output:", analysis_result[:500] + "...")
328
- with open("agent_output.txt", "w") as f:
329
- f.write(str(analysis_result))
330
-
331
- parsed_result = extract_json_from_codeagent_output(analysis_result)
332
- if parsed_result:
333
- return format_analysis_report(parsed_result, visuals)
334
- else:
335
- error_msg = f"Failed to parse agent output. Showing raw response:\n{str(analysis_result)[:2000]}"
336
- print(error_msg)
337
- return f"<pre>{error_msg}</pre>", visuals
338
-
339
- except Exception as e:
340
- error_msg = f"Analysis failed with error: {str(e)}"
341
- print(error_msg)
342
- return f"<pre>{error_msg}</pre>", []
343
-
344
 
345
  def compare_models():
346
  import seaborn as sns
 
1
+ # Initialization and Imports
2
  import os
3
  import re
4
  import gradio as gr
 
25
  from datetime import datetime
26
  from PIL import Image
27
 
28
+
29
  # Authenticate with Hugging Face
30
  hf_token = os.getenv("HF_TOKEN")
31
  login(token=hf_token)
32
 
33
+
34
  # SmolAgent initialization
35
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
36
 
37
+
38
  # Globals
39
  df_global = None
40
  target_column_global = None
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ #File Upload and Cleanup
44
  def upload_file(file):
45
  global df_global, data_summary_global
46
  if file is None:
 
58
  target_column_global = col_name
59
  return f"✅ Target column set to: {col_name}"
60
 
61
+ def clean_data(df):
62
+ df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
63
+ for col in df.select_dtypes(include='object').columns:
64
+ df[col] = df[col].astype(str)
65
+ df[col] = LabelEncoder().fit_transform(df[col])
66
+ df = df.fillna(df.mean(numeric_only=True))
67
+ return df
68
+
69
+
70
+ # Add a extraction of JSON if CodeAgent Output is not in format
71
+
72
+ def extract_json_from_codeagent_output(raw_output):
73
+ import re, json, ast
74
+
75
+ try:
76
+ # Extract code blocks from ```python ... ```
77
+ code_blocks = re.findall(r"```(?:py|python)?\n(.*?)```", raw_output, re.DOTALL)
78
+ for block in code_blocks:
79
+ # Try extracting from print(json.dumps({...}))
80
+ json_match = re.search(
81
+ r"print\(\s*json\.dumps\(\s*(\{[\s\S]*?\})\s*\)\s*\)",
82
+ block,
83
+ re.DOTALL
84
+ ) or re.search(
85
+ r"json\.dumps\(\s*(\{[\s\S]*?\})\s*\)",
86
+ block,
87
+ re.DOTALL
88
+ )
89
+ if json_match:
90
+ return json.loads(json_match.group(1))
91
+
92
+ # Try extracting from: result = {...}
93
+ result_match = re.search(
94
+ r"result\s*=\s*(\{[\s\S]*?\})",
95
+ block,
96
+ re.DOTALL
97
+ )
98
+ if result_match:
99
+ raw_dict = result_match.group(1)
100
+ try:
101
+ return json.loads(raw_dict) # Try strict JSON
102
+ except json.JSONDecodeError:
103
+ return ast.literal_eval(raw_dict) # Try Python dict parsing
104
+
105
+ # Final fallback: look for any dict-like thing in entire output
106
+ fallback_match = re.search(r"\{[\s\S]+\}", raw_output)
107
+ if fallback_match:
108
+ raw_dict = fallback_match.group(0)
109
+ try:
110
+ return json.loads(raw_dict)
111
+ except json.JSONDecodeError:
112
+ return ast.literal_eval(raw_dict)
113
+
114
+ except Exception as e:
115
+ print(f"extract_json_from_codeagent_output() failed: {e}")
116
+ return None
117
+
118
+
119
+
120
+
121
+ # Data Analysis Function with CodeAgent
122
+ def analyze_data(csv_file, additional_notes=""):
123
+ start_time = time.time()
124
+ process = psutil.Process(os.getpid())
125
+ initial_memory = process.memory_info().rss / 1024 ** 2
126
+
127
+ if os.path.exists('./figures'):
128
+ shutil.rmtree('./figures')
129
+ os.makedirs('./figures', exist_ok=True)
130
+
131
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
132
+ run = wandb.init(project="huggingface-data-analysis", config={
133
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
134
+ "additional_notes": additional_notes,
135
+ "source_file": csv_file.name if csv_file else None
136
+ })
137
+
138
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"])
139
+ analysis_result = agent.run("""
140
+ You are a helpful data analysis agent. Just return insight information and visualization.
141
+ Load the data that is passed.do not create your own.
142
+ Automatically detect numeric columns and names.
143
+ 2. 5 data visualizations
144
+ 3. at least 5 insights from data
145
+ 5. Generate publication-quality visualizations and save to './figures/'.
146
+ Do not use 'open()' or write to files. Just return variables and plots.
147
+ The dictionary should have the following structure:
148
+ {
149
+ 'observations': {
150
+ 'observation_1_key': 'observation_1_value',
151
+ 'observation_2_key': 'observation_2_value',
152
+ ...
153
+ },
154
+ 'insights': {
155
+ 'insight_1_key': 'insight_1_value',
156
+ 'insight_2_key': 'insight_2_value',
157
+ ...
158
+ }
159
+ }
160
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
161
+
162
+ execution_time = time.time() - start_time
163
+ final_memory = process.memory_info().rss / 1024 ** 2
164
+ memory_usage = final_memory - initial_memory
165
+ wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
166
+
167
+ visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
168
+ for viz in visuals:
169
+ wandb.log({os.path.basename(viz): wandb.Image(viz)})
170
+
171
+ run.finish()
172
+ return format_analysis_report(analysis_result, visuals)
173
+
174
+
175
+
176
+
177
+
178
+
179
 
180
 
181
  def format_analysis_report(raw_output, visuals):
 
261
 
262
 
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  def compare_models():
266
  import seaborn as sns