pavanmutha commited on
Commit
b2cb237
·
verified ·
1 Parent(s): d7e44a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -364
app.py CHANGED
@@ -4,9 +4,8 @@ import pandas as pd
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import shap
7
- import lime
8
  import lime.lime_tabular
9
- # import optuna # Removing Optuna for this simplified approach
10
  import wandb
11
  import json
12
  import time
@@ -15,390 +14,257 @@ import shutil
15
  import ast
16
  from smolagents import HfApiModel, CodeAgent
17
  from huggingface_hub import login
18
- from sklearn.model_selection import train_test_split, cross_val_score # Keep cross_val_score if needed elsewhere, but not primary for comparison here
19
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
20
- # from sklearn.metrics import ConfusionMatrixDisplay # Not used currently
21
- from sklearn.ensemble import RandomForestClassifier # Keep RF
22
- # from sklearn.ensemble import GradientBoostingClassifier # Remove GB for simplicity now
23
  from sklearn.linear_model import LogisticRegression
24
- from sklearn.preprocessing import LabelEncoder, StandardScaler # Added StandardScaler
25
- from sklearn.pipeline import Pipeline # Added Pipeline
26
  from datetime import datetime
27
- # from PIL import Image # Likely not needed directly
28
-
29
- # --- Authentication and Setup (Keep as is) ---
30
- hf_token = os.getenv("HF_TOKEN")
31
- wandb_api_key = os.getenv("WANDB_API_KEY") # Get WandB key
32
 
33
  # Authenticate with Hugging Face
34
- if hf_token:
35
- try:
36
- login(token=hf_token)
37
- print("HF Login successful.")
38
- except Exception as e:
39
- print(f"HF login failed: {e}")
40
- else:
41
- print("HF_TOKEN not found.")
42
-
43
- # Login to WandB if key exists
44
- if wandb_api_key:
45
- try:
46
- wandb.login(key=wandb_api_key)
47
- print("WandB login successful.")
48
- except Exception as e:
49
- print(f"WandB login failed: {e}. Logging will be disabled.")
50
- wandb.init(mode="disabled") # Disable if login fails
51
- else:
52
- print("WANDB_API_KEY not found. WandB logging disabled.")
53
- wandb.init(mode="disabled") # Disable if no key
54
 
55
- # SmolAgent initialization (Keep as is)
56
- try:
57
- model_api = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
58
- print("SmolAgent API Model initialized.")
59
- except Exception as e:
60
- print(f"SmolAgent initialization failed: {e}")
61
- model_api = None # Set to None if failed
62
 
63
  df_global = None
64
- # --- NEW: Global variable for split data ---
65
- split_data_global = None # Will store (X_train, X_test, y_train, y_test, label_encoder)
66
 
67
- # --- clean_data, upload_file, AI Agent functions (Keep as is from your original code) ---
68
  def clean_data(df):
69
- # Your original clean_data implementation
70
- df = df.copy() # Work on copy
71
  df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
72
  for col in df.select_dtypes(include='object').columns:
73
  df[col] = df[col].astype(str)
74
  df[col] = LabelEncoder().fit_transform(df[col])
75
- # Impute only if numeric columns exist
76
- numeric_cols = df.select_dtypes(include=np.number).columns
77
- if not numeric_cols.empty:
78
- df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].mean())
79
  return df
80
 
81
  def upload_file(file):
82
- global df_global, split_data_global # Reset split data on new upload
83
- df_global = None
84
- split_data_global = None
85
  if file is None:
86
- return pd.DataFrame({"Error": ["No file uploaded."]})
 
 
 
 
 
 
 
87
  try:
88
- ext = os.path.splitext(file.name)[-1].lower() # Use lower()
89
- df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
90
- df = clean_data(df)
91
- df_global = df
92
- print("File uploaded and cleaned.")
93
- return df.head()
94
- except Exception as e:
95
- print(f"Error in upload_file: {e}")
96
- return pd.DataFrame({"Error": [f"Failed to process file: {e}"]})
97
-
98
- # --- AI Agent functions (Keep your original format_*, analyze_data) ---
99
- # Placeholder for brevity - use your original functions
100
- def format_analysis_report(raw_output, visuals): return f"<h2>AI Report</h2><pre>{str(raw_output)}</pre>", visuals
101
- def format_observations(observations): return f"<pre>{str(observations)}</pre>"
102
- def format_insights(insights, visuals): return f"<pre>{str(insights)}</pre>"
103
- def analyze_data(csv_file, additional_notes=""):
104
- print("Running AI Agent (stub)...")
105
- # Your original analyze_data logic here
106
- # Ensure it uses wandb.init(reinit=True) if called multiple times
107
- # and finishes the run: wandb.finish()
108
- if not model_api: return "AI Agent not initialized.", []
109
- # Dummy result
110
- analysis_result = {"observations": {"data": "desc"}, "insights": {"insight1": "text"}}
111
- visuals = [] # Agent should save plots to './figures/'
112
- return format_analysis_report(analysis_result, visuals)
113
-
114
- # --- MODIFIED: prepare_data ---
115
- def prepare_data(df, target_column=None) -> bool:
116
- """Splits data and stores it globally. Returns True on success, False on failure."""
117
- global split_data_global
118
- print("Preparing data split...")
119
- try:
120
- if df is None or df.empty:
121
- print("Error: DataFrame is empty in prepare_data.")
122
- split_data_global = None
123
- return False
124
-
125
- # --- Target Column Logic ---
126
- if target_column is None:
127
- # Prioritize object columns if they exist and are not all unique
128
- object_cols = df.select_dtypes(include=['object', 'category']).columns
129
- potential_targets = [col for col in object_cols if df[col].nunique() < len(df)]
130
- if potential_targets:
131
- target_column = potential_targets[0] # Take the first suitable object col
132
- print(f"Target column auto-selected (object): '{target_column}'")
133
- else:
134
- target_column = df.columns[-1] # Fallback to last column
135
- print(f"Target column auto-selected (last): '{target_column}'")
136
- elif target_column not in df.columns:
137
- print(f"Error: Specified target column '{target_column}' not found.")
138
- split_data_global = None
139
- return False
140
-
141
- X = df.drop(columns=[target_column])
142
- y = df[target_column].copy()
143
-
144
- # --- Feature Check (ensure numeric) ---
145
- # (Should be handled by clean_data, but double-check)
146
- non_numeric_features = X.select_dtypes(exclude=np.number).columns
147
- if not non_numeric_features.empty:
148
- print(f"Warning: Dropping non-numeric feature columns: {list(non_numeric_features)}")
149
- X = X.drop(columns=non_numeric_features)
150
- if X.empty:
151
- print("Error: No numeric features left after dropping non-numeric ones.")
152
- split_data_global = None
153
- return False
154
-
155
- # --- Target Encoding ---
156
- label_encoder = None
157
- if not pd.api.types.is_numeric_dtype(y):
158
- print(f"Encoding target column '{target_column}' with LabelEncoder.")
159
- label_encoder = LabelEncoder()
160
- y = label_encoder.fit_transform(y)
161
  else:
162
- # Check if float target should be treated as classification (e.g., integers represented as float)
163
- if pd.api.types.is_float_dtype(y) and np.all(y == y.astype(int)):
164
- print(f"Target '{target_column}' is float but looks like integer. Converting to int.")
165
- y = y.astype(int)
166
-
167
- # --- Check for sufficient classes ---
168
- if y.nunique() < 2:
169
- print(f"Error: Target column '{target_column}' has less than 2 unique values after processing.")
170
- split_data_global = None
171
- return False
172
-
173
- # --- Perform Split ---
174
- try:
175
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
176
- print("Performed stratified split.")
177
- except ValueError: # Handle cases where stratification is not possible
178
- print("Stratified split failed, using non-stratified split.")
179
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
180
-
181
- split_data_global = (X_train, X_test, y_train, y_test, label_encoder)
182
- print(f"Data split successfully: Train {X_train.shape}, Test {X_test.shape}")
183
- return True
184
-
185
  except Exception as e:
186
- print(f"Error during data preparation: {e}")
187
- import traceback
188
- traceback.print_exc()
189
- split_data_global = None
190
- return False
191
-
192
- # --- NEW: run_comparison_and_explainability ---
193
- # MODIFIED TYPE HINT: Returns str for paths now
194
- def run_comparison_and_explainability() -> Tuple[pd.DataFrame, str, str, str]:
195
- """Compares models, explains the best one, and logs to WandB."""
196
- global df_global, split_data_global
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- # Default returns matching the modified hint
199
- default_error_df = pd.DataFrame({"Error": ["Comparison failed."]})
200
- default_shap_path = ""
201
- default_lime_path = ""
202
- default_status = "Error: Could not run comparison."
203
-
204
- # --- 1. Check Prerequisites ---
205
- if df_global is None:
206
- return pd.DataFrame({"Error": ["No data uploaded."]}), default_shap_path, default_lime_path, "Error: Upload data first."
207
- if split_data_global is None:
208
- print("Split data not found globally, attempting to prepare now...")
209
- if not prepare_data(df_global):
210
- return pd.DataFrame({"Error": ["Data preparation failed."]}), default_shap_path, default_lime_path, "Error: Failed to prepare data."
211
-
212
- # Unpack the globally stored split data
213
- X_train, X_test, y_train, y_test, label_encoder = split_data_global
214
- class_names = getattr(label_encoder, 'classes_', [str(c) for c in np.unique(y_train)]) if label_encoder else [str(c) for c in np.unique(y_train)]
215
- class_names = [str(c) for c in class_names] # Ensure strings
216
-
217
- print("--- Starting Model Comparison & Explainability ---")
218
-
219
- # --- 2. Define Models ---
220
- models_to_compare = {
221
- "LogisticRegression": Pipeline([
222
- ('scaler', StandardScaler()),
223
- ('logreg', LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced'))
224
- ]),
225
- "RandomForest": RandomForestClassifier(random_state=42, class_weight='balanced')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  }
227
-
228
- # --- 3. Initialize WandB Run ---
229
- run_name = f"CompareExplain_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
230
- wandb_run = None
231
- if wandb.run is None or wandb.run.mode != "disabled":
232
- try:
233
- if wandb.run: wandb.finish()
234
- wandb_run = wandb.init(project="huggingface-data-analysis", name=run_name, config={...}, reinit=True) # Simplified config
235
- print(f"WandB Run '{run_name}' started.")
236
- except Exception as e: print(f"WandB init failed: {e}"); wandb_run = None
237
- else: wandb_run = None
238
-
239
- # --- 4. Train and Evaluate Models ---
240
  results = []
241
- trained_models = {}
242
- print("Comparing models...")
243
- # (Keep the model training loop exactly as in the previous version)
244
- for name, model in models_to_compare.items():
245
- print(f" Training {name}...")
246
- start_time = time.time()
247
- try:
248
- model.fit(X_train, y_train)
249
- y_pred = model.predict(X_test)
250
- duration = time.time() - start_time
251
- metrics = { "Model": name, "Accuracy": accuracy_score(y_test, y_pred), "Precision (Weighted)": precision_score(y_test, y_pred, average="weighted", zero_division=0), "Recall (Weighted)": recall_score(y_test, y_pred, average="weighted", zero_division=0), "F1 Score (Weighted)": f1_score(y_test, y_pred, average="weighted", zero_division=0), "Time (s)": duration }
252
- results.append(metrics)
253
- trained_models[name] = model
254
- print(f" {name} - F1: {metrics['F1 Score (Weighted)']:.4f}, Time: {duration:.2f}s")
255
- if wandb_run: wandb.log({f"{name}_{k.lower().replace(' (weighted)','_w').replace(' ','_')}": v for k, v in metrics.items() if k != "Model"}, commit=False)
256
- except Exception as e: print(f" ERROR training/evaluating {name}: {e}"); results.append({"Model": name, "Error": str(e)}); import traceback; traceback.print_exc(); if wandb_run: wandb.log({f"{name}_error": str(e)}, commit=False)
257
-
258
-
259
- # --- 5. Process Comparison Results ---
260
- if not results:
261
- if wandb_run: wandb.finish()
262
- return pd.DataFrame({"Error": ["No models trained successfully."]}), default_shap_path, default_lime_path, "Error: Model training failed."
263
-
264
- comparison_df = pd.DataFrame(results)
265
- best_model = None
266
- best_model_name = "N/A"
267
- # (Keep logic for sorting and finding best model as before)
268
- if "F1 Score (Weighted)" in comparison_df.columns:
269
- comparison_df = comparison_df.sort_values(by="F1 Score (Weighted)", ascending=False, na_position='last').reset_index(drop=True)
270
- if not comparison_df.empty:
271
- best_model_row = comparison_df.iloc[0]
272
- potential_best_name = best_model_row['Model']
273
- if pd.notna(best_model_row.get("F1 Score (Weighted)")) and potential_best_name in trained_models:
274
- best_model = trained_models[potential_best_name]
275
- best_model_name = potential_best_name
276
- print(f"Best model determined: {best_model_name} (F1: {best_model_row['F1 Score (Weighted)']:.4f})")
277
- else: print("Warning: Could not determine a valid best model from results.")
278
- else: print("Warning: F1 Score column missing.")
279
-
280
- # (Keep WandB table logging as before)
281
- if wandb_run and not comparison_df.empty:
282
- try: wandb.log({"model_comparison": wandb.Table(dataframe=comparison_df)}, commit=False); print("Logged comparison table to WandB.")
283
- except Exception as e: print(f"Error logging comparison table: {e}")
284
-
285
- # --- 6. Explain Best Model (if found) ---
286
- shap_plot_path = None
287
- lime_plot_path = None
288
- explain_status = f"Compared {len(trained_models)} models. Best: {best_model_name}."
289
-
290
- if best_model:
291
- print(f"Generating explanations for {best_model_name}...")
292
- explain_dir = "./explain_plots"
293
- if os.path.exists(explain_dir): shutil.rmtree(explain_dir)
294
- os.makedirs(explain_dir)
295
- shap_plot_path = os.path.join(explain_dir, f"shap_{best_model_name}.png")
296
- lime_plot_path = os.path.join(explain_dir, f"lime_{best_model_name}.png")
297
-
298
- try:
299
- # --- SHAP (Keep logic as before, but ensure shap_plot_path becomes None on failure) ---
300
- # Simplified SHAP logic for display
301
- print(" Generating SHAP...")
302
- # ... (Your detailed SHAP logic from previous attempts)
303
- # Example placeholder:
304
- try:
305
- plt.figure(); plt.text(0.5, 0.5, 'SHAP Placeholder'); plt.savefig(shap_plot_path); plt.clf()
306
- print(f" SHAP plot saved: {shap_plot_path}")
307
- if wandb_run: wandb.log({"shap_summary_best": wandb.Image(shap_plot_path)}, commit=False)
308
- except Exception as shap_e:
309
- print(f" ERROR generating SHAP: {shap_e}")
310
- shap_plot_path = None # Set to None on error
311
-
312
- # --- LIME (Keep logic as before, but ensure lime_plot_path becomes None on failure) ---
313
- print(" Generating LIME...")
314
- # ... (Your detailed LIME logic from previous attempts)
315
- # Example placeholder:
316
- try:
317
- plt.figure(); plt.text(0.5, 0.5, 'LIME Placeholder'); plt.savefig(lime_plot_path); plt.clf()
318
- print(f" LIME plot saved: {lime_plot_path}")
319
- if wandb_run: wandb.log({"lime_explanation_best": wandb.Image(lime_plot_path)}, commit=False)
320
- except Exception as lime_e:
321
- print(f" ERROR generating LIME: {lime_e}")
322
- lime_plot_path = None # Set to None on error
323
-
324
- explain_status += f" Explanations attempted for {best_model_name}."
325
-
326
- except Exception as e:
327
- print(f" ERROR during explanation block: {e}")
328
- import traceback; traceback.print_exc()
329
- explain_status += f" Explanation failed: {e}"
330
- if not os.path.exists(str(shap_plot_path)): shap_plot_path = None # Check path validity
331
- if not os.path.exists(str(lime_plot_path)): lime_plot_path = None
332
-
333
- else:
334
- explain_status += " No best model found to explain."
335
-
336
-
337
- # --- 7. Finish WandB Run and Return ---
338
- if wandb_run:
339
- try:
340
- wandb.log({}, commit=True) # Ensure final commit
341
- wandb.finish()
342
- print(f"WandB Run '{run_name}' finished.")
343
- except Exception as finish_e:
344
- print(f"Error finishing WandB run: {finish_e}")
345
-
346
- # MODIFIED: Return empty strings instead of None for paths
347
- valid_shap_path = shap_plot_path if shap_plot_path and os.path.exists(shap_plot_path) else ""
348
- valid_lime_path = lime_plot_path if lime_plot_path and os.path.exists(lime_plot_path) else ""
349
-
350
- print(f"DEBUG Final Return: DF shape {comparison_df.shape}, SHAP path '{valid_shap_path}', LIME path '{valid_lime_path}', Status '{explain_status}'")
351
- return comparison_df, valid_shap_path, valid_lime_path, explain_status
352
-
353
- # --- Gradio UI ---
354
- with gr.Blocks() as demo:
355
- gr.Markdown("## 📊 AI Data Analysis, Model Comparison & Explainability")
356
-
357
- # --- Row 1: Upload ---
358
- with gr.Row():
359
- with gr.Column(scale=1):
360
- file_input = gr.File(label="1. Upload CSV or Excel", type="filepath", file_types=[".csv", ".xls", ".xlsx"])
361
- with gr.Column(scale=2):
362
- df_output = gr.DataFrame(label="Cleaned Data Preview", interactive=False)
363
-
364
- # --- Row 2: AI Agent (Optional) ---
365
- with gr.Accordion("🤖 Step 2 (Optional): Run AI Agent Insights", open=False):
366
- with gr.Row():
367
- with gr.Column(scale=1):
368
- agent_notes = gr.Textbox(label="Optional requests for Agent", placeholder="e.g., 'Focus on column X'")
369
- agent_btn = gr.Button("Run AI Analysis", interactive=(model_api is not None))
370
- with gr.Column(scale=2):
371
- insights_output = gr.HTML(label="AI Agent Report")
372
- with gr.Row():
373
- visual_output = gr.Gallery(label="AI Agent Visualizations", height=350, object_fit="contain", columns=3, preview=True)
374
-
375
- # --- Row 3: Compare & Explain ---
376
- with gr.Accordion("⚙️💡 Step 3: Compare Models & Explain Best", open=True):
377
- with gr.Row():
378
- compare_explain_btn = gr.Button("Run Comparison & Explain Best Model", variant="primary")
379
- with gr.Row():
380
- comparison_output = gr.DataFrame(label="Model Comparison Results", interactive=False)
381
  with gr.Row():
382
- explain_status_output = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  with gr.Row():
384
- # Only one SHAP plot expected now (summary)
385
- shap_img_output = gr.Image(label="SHAP Summary (Best Model)", type="filepath", interactive=False)
386
- lime_img_output = gr.Image(label="LIME Explanation (Best Model - Instance 0)", type="filepath", interactive=False)
387
-
388
-
389
- # --- Connect Components ---
390
- file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
391
-
392
- # AI Agent connection (Keep as is)
393
- agent_btn.click(fn=analyze_data, inputs=[file_input, agent_notes], outputs=[insights_output, visual_output])
394
-
395
- # NEW: Connection for combined comparison and explainability
396
- compare_explain_btn.click(
397
- fn=run_comparison_and_explainability,
398
- inputs=[], # Takes data from global df_global
399
- outputs=[comparison_output, shap_img_output, lime_img_output, explain_status_output]
400
- )
401
-
402
- # --- Launch ---
403
- print("Launching Gradio App...")
404
- demo.launch(debug=True) # Use debug=True for more detailed errors during development
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import shap
 
7
  import lime.lime_tabular
8
+ import optuna
9
  import wandb
10
  import json
11
  import time
 
14
  import ast
15
  from smolagents import HfApiModel, CodeAgent
16
  from huggingface_hub import login
17
+ from sklearn.model_selection import train_test_split, cross_val_score
18
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
19
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
 
 
20
  from sklearn.linear_model import LogisticRegression
21
+ from sklearn.preprocessing import LabelEncoder
 
22
  from datetime import datetime
23
+ from PIL import Image
24
+ from sklearn.svm import SVC
 
 
 
25
 
26
  # Authenticate with Hugging Face
27
+ hf_token = os.getenv("HF_TOKEN")
28
+ login(token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # SmolAgent initialization
31
+ model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
 
 
 
 
 
32
 
33
  df_global = None
 
 
34
 
 
35
  def clean_data(df):
 
 
36
  df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
37
  for col in df.select_dtypes(include='object').columns:
38
  df[col] = df[col].astype(str)
39
  df[col] = LabelEncoder().fit_transform(df[col])
40
+ df = df.fillna(df.mean(numeric_only=True))
 
 
 
41
  return df
42
 
43
  def upload_file(file):
44
+ global df_global
 
 
45
  if file is None:
46
+ return pd.DataFrame({"Error": ["No file uploaded."]}), None
47
+ ext = os.path.splitext(file.name)[-1]
48
+ df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
49
+ df = clean_data(df)
50
+ df_global = df
51
+ return df.head(), df
52
+
53
+ def format_analysis_report(raw_output, visuals):
54
  try:
55
+ if isinstance(raw_output, dict):
56
+ analysis_dict = raw_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  else:
58
+ try:
59
+ analysis_dict = ast.literal_eval(str(raw_output))
60
+ except (SyntaxError, ValueError) as e:
61
+ print(f"Error parsing CodeAgent output: {e}")
62
+ return str(raw_output), visuals # Return raw output as string
63
+
64
+ report = f"""
65
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
66
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
67
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
68
+ <h2 style="color: #2B547E;">🔍 Key Observations</h2>
69
+ {format_observations(analysis_dict.get('observations', {}))}
70
+ </div>
71
+ <div style="margin-top: 30px;">
72
+ <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
73
+ {format_insights(analysis_dict.get('insights', {}), visuals)}
74
+ </div>
75
+ </div>
76
+ """
77
+ return report, visuals
 
 
 
78
  except Exception as e:
79
+ print(f"Error in format_analysis_report: {e}")
80
+ return str(raw_output), visuals
81
+
82
+ def format_observations(observations):
83
+ return '\n'.join([
84
+ f"""
85
+ <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
86
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
87
+ <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
88
+ </div>
89
+ """ for key, value in observations.items() if 'proportions' in key
90
+ ])
91
+
92
+ def format_insights(insights, visuals):
93
+ return '\n'.join([
94
+ f"""
95
+ <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
96
+ <div style="display: flex; align-items: center; gap: 10px;">
97
+ <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
98
+ <p style="margin: 0; font-size: 16px;">{insight}</p>
99
+ </div>
100
+ {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
101
+ </div>
102
+ """ for idx, (key, insight) in enumerate(insights.items())
103
+ ])
104
 
105
+ def analyze_data(csv_file, additional_notes=""):
106
+ start_time = time.time()
107
+ process = psutil.Process(os.getpid())
108
+ initial_memory = process.memory_info().rss / 1024 ** 2
109
+
110
+ if os.path.exists('./figures'):
111
+ shutil.rmtree('./figures')
112
+ os.makedirs('./figures', exist_ok=True)
113
+
114
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
115
+ run = wandb.init(project="huggingface-data-analysis", config={
116
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
117
+ "additional_notes": additional_notes,
118
+ "source_file": csv_file.name if csv_file else None
119
+ })
120
+
121
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"])
122
+ analysis_result = agent.run("""
123
+ You are a helpful data analysis agent. Just return insight information and visualization.
124
+ Load the data that is passed.do not create your own.
125
+ Automatically detect numeric columns and names.
126
+ 2. 5 data visualizations
127
+ 3. at least 5 insights from data
128
+ 5. Generate publication-quality visualizations and save to './figures/'.
129
+ Do not use 'open()' or write to files. Just return variables and plots.
130
+ The dictionary should have the following structure:
131
+ {
132
+ 'observations': {
133
+ 'observation_1_key': 'observation_1_value',
134
+ 'observation_2_key': 'observation_2_value',
135
+ ...
136
+ },
137
+ 'insights': {
138
+ 'insight_1_key': 'insight_1_value',
139
+ 'insight_2_key': 'insight_2_value',
140
+ ...
141
+ }
142
+ }
143
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
144
+
145
+ execution_time = time.time() - start_time
146
+ final_memory = process.memory_info().rss / 1024 ** 2
147
+ memory_usage = final_memory - initial_memory
148
+ wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
149
+
150
+ visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
151
+ for viz in visuals:
152
+ wandb.log({os.path.basename(viz): wandb.Image(viz)})
153
+
154
+ run.finish()
155
+ return format_analysis_report(analysis_result, visuals)
156
+
157
+ def compare_models(selected_models, df):
158
+ if df is None or len(selected_models) == 0:
159
+ return pd.DataFrame(), []
160
+ target = df.columns[-1]
161
+ X = df.drop(target, axis=1)
162
+ y = df[target]
163
+ if y.dtype == 'object':
164
+ y = LabelEncoder().fit_transform(y)
165
+ model_dict = {
166
+ "RandomForest": RandomForestClassifier(),
167
+ "LogisticRegression": LogisticRegression(max_iter=1000),
168
+ "SVC": SVC(probability=True)
169
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  results = []
171
+ confusion_imgs = []
172
+ for name in selected_models:
173
+ model = model_dict[name]
174
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
175
+ model.fit(X_train, y_train)
176
+ y_pred = model.predict(X_test)
177
+ y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") and len(np.unique(y)) == 2 else None
178
+ metrics = {
179
+ "Model": name,
180
+ "Accuracy": accuracy_score(y_test, y_pred),
181
+ "Precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
182
+ "Recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
183
+ "F1": f1_score(y_test, y_pred, average="weighted", zero_division=0),
184
+ "ROC-AUC": roc_auc_score(y_test, y_proba) if y_proba is not None else "N/A"
185
+ }
186
+ results.append(metrics)
187
+ # Confusion matrix plot
188
+ fig, ax = plt.subplots()
189
+ ConfusionMatrixDisplay.from_estimator(model, X_test, y_test, ax=ax)
190
+ img_path = f"conf_matrix_{name}.png"
191
+ plt.savefig(img_path)
192
+ confusion_imgs.append(img_path)
193
+ plt.close(fig)
194
+ results_df = pd.DataFrame(results)
195
+ return results_df, confusion_imgs
196
+
197
+ def ab_test_models(model_a, model_b, df):
198
+ if df is None or model_a == model_b:
199
+ return pd.DataFrame()
200
+ target = df.columns[-1]
201
+ X = df.drop(target, axis=1)
202
+ y = df[target]
203
+ if y.dtype == 'object':
204
+ y = LabelEncoder().fit_transform(y)
205
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
206
+ mid = len(X_test) // 2
207
+ X_a, X_b = X_test[:mid], X_test[mid:]
208
+ y_a, y_b = y_test[:mid], y_test[mid:]
209
+ model_dict = {
210
+ "RandomForest": RandomForestClassifier(),
211
+ "LogisticRegression": LogisticRegression(max_iter=1000),
212
+ "SVC": SVC(probability=True)
213
+ }
214
+ results = []
215
+ for name, X_grp, y_grp in zip([model_a, model_b], [X_a, X_b], [y_a, y_b]):
216
+ model = model_dict[name]
217
+ model.fit(X_train, y_train)
218
+ y_pred = model.predict(X_grp)
219
+ metrics = {
220
+ "Model": name,
221
+ "Accuracy": accuracy_score(y_grp, y_pred),
222
+ "Precision": precision_score(y_grp, y_pred, average="weighted", zero_division=0),
223
+ "Recall": recall_score(y_grp, y_pred, average="weighted", zero_division=0),
224
+ "F1": f1_score(y_grp, y_pred, average="weighted", zero_division=0),
225
+ }
226
+ results.append(metrics)
227
+ return pd.DataFrame(results)
228
+
229
+ def get_model_choices():
230
+ return ["RandomForest", "LogisticRegression", "SVC"]
231
+
232
+ def clear_confusion_imgs():
233
+ for name in get_model_choices():
234
+ img_path = f"conf_matrix_{name}.png"
235
+ if os.path.exists(img_path):
236
+ os.remove(img_path)
237
+
238
+ def main():
239
+ with gr.Blocks() as demo:
240
+ gr.Markdown("# 🤖 Model Comparison & A/B Testing (Hugging Face + Gradio)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  with gr.Row():
242
+ with gr.Column():
243
+ file_input = gr.File(label="Upload CSV or Excel", type="filepath")
244
+ df_output = gr.DataFrame(label="Cleaned Data Preview")
245
+ state = gr.State()
246
+ file_input.change(fn=upload_file, inputs=file_input, outputs=[df_output, state])
247
+ with gr.Column():
248
+ model_choices = gr.CheckboxGroup(
249
+ choices=get_model_choices(),
250
+ value=["RandomForest", "LogisticRegression"],
251
+ label="Select Models to Compare"
252
+ )
253
+ compare_btn = gr.Button("Compare Models")
254
+ metrics_output = gr.DataFrame(label="Model Performance Metrics")
255
+ confusion_gallery = gr.Gallery(label="Confusion Matrices", columns=3)
256
+ compare_btn.click(fn=compare_models, inputs=[model_choices, state], outputs=[metrics_output, confusion_gallery])
257
+ gr.Markdown("## A/B Test: Compare Two Models on Test Set")
258
  with gr.Row():
259
+ ab_model_a = gr.Dropdown(get_model_choices(), value="RandomForest", label="Model A")
260
+ ab_model_b = gr.Dropdown(get_model_choices(), value="LogisticRegression", label="Model B")
261
+ ab_btn = gr.Button("Run A/B Test")
262
+ ab_output = gr.DataFrame(label="A/B Test Results")
263
+ ab_btn.click(fn=ab_test_models, inputs=[ab_model_a, ab_model_b, state], outputs=ab_output)
264
+ gr.Markdown("---\nBuilt for Hugging Face Spaces with Gradio. Upload your data, select models, and compare!")
265
+ return demo
266
+
267
+ if __name__ == "__main__":
268
+ clear_confusion_imgs()
269
+ demo = main()
270
+ demo.launch()