pavanmutha commited on
Commit
37b56c1
·
verified ·
1 Parent(s): 182f873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +431 -526
app.py CHANGED
@@ -1,580 +1,485 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Gradio App for AI Data Analysis, Model Comparison, and Explainability
4
- Requires: HF_TOKEN and WANDB_API_KEY environment variables.
5
- """
6
-
7
  import os
8
  import gradio as gr
9
- import pandas as pd # Make sure pandas is imported
10
  import numpy as np
11
  import matplotlib.pyplot as plt
12
  import shap
13
  import lime
14
  import lime.lime_tabular
15
- import optuna
16
  import wandb
17
  import json
18
  import time
19
  import psutil
20
  import shutil
21
  import ast
22
- from smolagents import HfApiModel, CodeAgent # Assuming smolagents is installed
23
  from huggingface_hub import login
24
- from sklearn.model_selection import train_test_split, cross_val_score
25
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
26
- from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
 
 
27
  from sklearn.linear_model import LogisticRegression
28
- from sklearn.svm import SVC # Kept import in case you add it later
29
- from sklearn.preprocessing import LabelEncoder, StandardScaler
30
- from sklearn.pipeline import Pipeline
31
  from datetime import datetime
32
- # from PIL import Image
33
- import warnings
34
- import joblib # For saving models
35
- from typing import List, Tuple, Optional # Keep these
36
-
37
- # Suppress common warnings
38
- warnings.filterwarnings("ignore")
39
 
40
- # --- Authentication and Setup ---
41
- # (Keep Authentication and Setup block as before)
42
- print("--- Initializing Setup ---")
43
  hf_token = os.getenv("HF_TOKEN")
44
- wandb_api_key = os.getenv("WANDB_API_KEY")
45
- wandb_run = None
46
- if not hf_token: print("Warning: HF_TOKEN environment variable not set.")
47
- else:
48
- try: login(token=hf_token); print("Hugging Face login successful.")
49
- except Exception as e: print(f"Hugging Face login failed: {e}")
50
- if not wandb_api_key:
51
- print("Warning: WANDB_API_KEY environment variable not set. WandB logging will be disabled.")
52
- if wandb.run is None:
53
- try: wandb.init(mode="disabled"); print("WandB initialized in disabled mode.")
54
- except Exception as e: print(f"Failed to initialize WandB in disabled mode: {e}")
55
  else:
56
- try: wandb.login(key=wandb_api_key); print("WandB login successful.")
 
 
 
 
 
 
57
  except Exception as e:
58
- print(f"WandB login failed: {e}. Disabling WandB.")
59
- if wandb.run is None:
60
- try: wandb.init(mode="disabled"); print("WandB initialized in disabled mode due to login failure.")
61
- except Exception as e_init: print(f"Failed to initialize WandB in disabled mode: {e_init}")
62
- agent = None
 
 
63
  try:
64
- print("Initializing SmolAgent...")
65
  model_api = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
66
- agent = CodeAgent(tools=[], model=model_api, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json", "os"])
67
- print("SmolAgent initialized successfully.")
68
- except Exception as e: print(f"Error initializing SmolAgent: {e}. AI Agent features might fail.")
 
 
69
  df_global = None
70
- split_data_global = None
71
- comparison_results_global = None
72
- best_model_details_global = None
73
- print("Global variables initialized.")
74
- print("--- Setup Complete ---")
75
-
76
-
77
- # --- Data Handling ---
78
-
79
- def clean_data(df: pd.DataFrame) -> pd.DataFrame: # Added type hints
80
- """Cleans the input DataFrame."""
81
- print("Starting data cleaning...")
82
- df_cleaned = df.copy()
83
- df_cleaned = df_cleaned.dropna(how='all', axis=1).dropna(how='all', axis=0)
84
- print(f"Shape after dropping fully empty rows/cols: {df_cleaned.shape}")
85
- object_cols = df_cleaned.select_dtypes(include='object').columns
86
- if not object_cols.empty:
87
- print(f"Encoding object columns: {list(object_cols)}")
88
- for col in object_cols:
89
- df_cleaned[col] = df_cleaned[col].astype(str)
90
- df_cleaned[col] = LabelEncoder().fit_transform(df_cleaned[col])
91
- numeric_cols = df_cleaned.select_dtypes(include=np.number).columns
92
  if not numeric_cols.empty:
93
- cols_with_na = df_cleaned[numeric_cols].isnull().sum()
94
- cols_to_impute = cols_with_na[cols_with_na > 0].index
95
- if not cols_to_impute.empty:
96
- print(f"Imputing NaNs with mean in columns: {list(cols_to_impute)}")
97
- for col in cols_to_impute:
98
- mean_val = df_cleaned[col].mean()
99
- df_cleaned[col] = df_cleaned[col].fillna(mean_val)
100
- print("Data cleaning finished.")
101
- return df_cleaned
102
-
103
- # ADDED TYPE HINT HERE
104
- def upload_file(file) -> pd.DataFrame:
105
- """Handles file upload, cleaning, and global state update."""
106
- global df_global, split_data_global, comparison_results_global, best_model_details_global
107
  df_global = None
108
  split_data_global = None
109
- comparison_results_global = None
110
- best_model_details_global = None
111
- print("Reset global data states on file change.")
112
-
113
  if file is None:
114
- # Return an empty DataFrame or one with a status message, matching hint
115
- return pd.DataFrame({"Status": ["No file uploaded or file removed."]})
116
-
117
- print(f"Uploading file: {file.name}")
118
  try:
119
- ext = os.path.splitext(file.name)[-1].lower()
120
- if ext == ".csv":
121
- df = pd.read_csv(file.name)
122
- elif ext in [".xls", ".xlsx"]:
123
- df = pd.read_excel(file.name)
124
- else:
125
- # Return DataFrame matching hint
126
- return pd.DataFrame({"Error": [f"Unsupported file type: {ext}"]})
127
-
128
- print(f"Original data shape: {df.shape}")
129
- df_cleaned = clean_data(df)
130
- print(f"Cleaned data shape: {df_cleaned.shape}")
131
- df_global = df_cleaned
132
- print("Global DataFrame updated with cleaned data.")
133
- # Return DataFrame matching hint
134
- return df_global.head()
135
  except Exception as e:
136
- print(f"Error processing file {file.name}: {e}")
137
- df_global = None; split_data_global = None; comparison_results_global = None; best_model_details_global = None
138
- # Return DataFrame matching hint
139
  return pd.DataFrame({"Error": [f"Failed to process file: {e}"]})
140
 
141
-
142
- # --- AI Agent Analysis ---
143
- # (Keep format_observations, format_insights, format_analysis_report as before)
144
- def format_observations(observations):
145
- """Formats the observations dictionary into HTML list items."""
146
- if not isinstance(observations, dict): return f"<p style='color: orange;'>Observations data is not a dictionary: {type(observations)}</p>"
147
- items_html = ""
148
- for key, value in observations.items():
149
- formatted_key = key.replace('_', ' ').title()
150
- if isinstance(value, (dict, list)):
151
- formatted_value = json.dumps(value, indent=2); value_html = f"<pre style='margin: 0; padding: 8px; background: #ffffff; border: 1px solid #ccc; border-radius: 4px; font-size: 0.9em; white-space: pre-wrap; word-wrap: break-word;'>{formatted_value}</pre>"
152
- else: formatted_value = str(value); value_html = f"<p style='margin: 0; padding: 8px; background: #ffffff; border: 1px solid #ccc; border-radius: 4px; font-size: 0.9em;'>{formatted_value}</p>"
153
- items_html += f"""<div style="margin-bottom: 12px; padding: 10px; background: #fdfefe; border-radius: 4px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);"><h4 style="margin: 0 0 8px 0; color: #34495e;">{formatted_key}</h4>{value_html}</div>"""
154
- return items_html if items_html else "<p>No observations found.</p>"
155
-
156
- def format_insights(insights, visuals):
157
- """Formats insights and embeds corresponding visuals."""
158
- if not isinstance(insights, dict): return f"<p style='color: orange;'>Insights data is not a dictionary: {type(insights)}</p>"
159
- items_html = ""; visual_idx = 0; insight_keys = list(insights.keys())
160
- for i, key in enumerate(insight_keys):
161
- insight_text = str(insights[key]); formatted_key = key.replace('_', ' ').title()
162
- items_html += f"""<div style="margin: 20px 0; padding: 15px; background: #ffffff; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);"><h4 style='margin-top: 0; margin-bottom: 10px; color: #16a085;'>Insight {i+1}: {formatted_key}</h4><p style="margin-bottom: 15px;">{insight_text}</p>"""
163
- if visual_idx < len(visuals):
164
- img_path = visuals[visual_idx]
165
- items_html += f'<img src="/file={img_path}" alt="Visualization for {formatted_key}" style="max-width: 95%; height: auto; display: block; margin-top: 10px; border-radius: 6px; border: 1px solid #eee; box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
166
- visual_idx += 1
167
- items_html += "</div>"
168
- if visual_idx < len(visuals):
169
- items_html += "<h4 style='margin-top: 25px; color: #2980b9;'>Additional Visualizations:</h4>"
170
- for i in range(visual_idx, len(visuals)):
171
- img_path = visuals[i]
172
- items_html += f"""<div style="margin: 20px 0; padding: 15px; background: #ffffff; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);"><img src="/file={img_path}" alt="Additional Visualization {i+1}" style="max-width: 95%; height: auto; display: block; margin: auto; border-radius: 6px; border: 1px solid #eee; box-shadow: 0 1px 3px rgba(0,0,0,0.1);"></div>"""
173
- return items_html if (items_html or visuals) else "<p>No insights or visuals generated/found.</p>"
174
-
175
- def format_analysis_report(raw_output, visuals):
176
- """Formats the AI agent's output into readable HTML."""
177
- print("Formatting AI analysis report...")
178
- report_html = ""; analysis_dict = {}; parsing_error = None
179
  try:
180
- if isinstance(raw_output, str):
181
- cleaned_output = raw_output.strip().removeprefix("```python").removeprefix("```json").removesuffix("```").strip()
182
- dict_start_index = cleaned_output.find('{')
183
- if dict_start_index != -1:
184
- try: analysis_dict = ast.literal_eval(cleaned_output[dict_start_index:])
185
- except (SyntaxError, ValueError, TypeError) as e: parsing_error = f"Error parsing agent output: {e}\nRaw output:\n{raw_output}"; print(parsing_error)
186
- else: parsing_error = f"Could not find dictionary start '{{' in agent output.\nRaw output:\n{raw_output}"; print(parsing_error)
187
- elif isinstance(raw_output, dict): analysis_dict = raw_output
188
- else: parsing_error = f"Output is not a string or dictionary, type: {type(raw_output)}.\nRaw output:\n{str(raw_output)}"; print(parsing_error)
189
- report_html = """<div style="font-family: Arial, sans-serif; line-height: 1.6; color: #333; padding: 15px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;"><h1 style="color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; margin-top: 0;">📊 AI Data Analysis Report</h1>"""
190
- if parsing_error: report_html += f"<div style='background-color: #f8d7da; color: #721c24; border: 1px solid #f5c6cb; padding: 10px; border-radius: 5px; margin-bottom: 15px;'><pre>{parsing_error}</pre></div>"
191
- observations = analysis_dict.get('observations', {}); report_html += """<div style="margin-top: 20px; background: #ecf0f1; padding: 15px; border-radius: 5px;"><h2 style="color: #2980b9; margin-top: 0;">🔍 Key Observations</h2>"""
192
- report_html += format_observations(observations) if observations else "<p>No 'observations' found or parsed.</p>"; report_html += "</div>"
193
- insights = analysis_dict.get('insights', {}); report_html += """<div style="margin-top: 25px;"><h2 style="color: #2980b9;">💡 Insights & Visualizations</h2>"""
194
- report_html += format_insights(insights, visuals) if (insights or visuals) else "<p>No 'insights' or visuals found or parsed.</p>"; report_html += "</div>"
195
- report_html += "</div>"; print("Report formatting complete.")
196
- return report_html, visuals
197
- except Exception as e:
198
- print(f"Critical error in format_analysis_report: {e}"); error_message = f"<p style='color: red; font-weight: bold;'>Error generating report:</p><pre>{str(e)}</pre>"; raw_display = f"<p style='font-weight: bold;'>Raw Agent Output:</p><pre>{str(raw_output)}</pre>"; return error_message + raw_display, visuals
199
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # ADDED TYPE HINT HERE
202
- def analyze_data(csv_file, additional_notes="") -> Tuple[str, List[str]]:
203
- """Runs the SmolAgent for data analysis and visualization."""
204
- global df_global, agent
205
- # Default return values matching the type hint
206
- default_error_html = "<p style='color:red;'>An error occurred.</p>"
207
- default_visuals = []
208
 
209
- if df_global is None: return "<p style='color:red;'>Please upload a file first.</p>", default_visuals
210
- if agent is None: return "<p style='color:red;'>AI Agent is not available (initialization failed).</p>", default_visuals
211
- if csv_file is None: return "<p style='color:red;'>File object missing, please re-upload.</p>", default_visuals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- print("--- Starting AI Agent Analysis ---")
214
- start_time = time.time(); process = psutil.Process(os.getpid()); initial_memory = process.memory_info().rss / 1024 ** 2
215
- figures_dir = './figures'
216
- try:
217
- if os.path.exists(figures_dir): shutil.rmtree(figures_dir); print(f"Cleaned existing directory: {figures_dir}")
218
- os.makedirs(figures_dir); print(f"Created directory: {figures_dir}")
219
- except Exception as e: print(f"Error managing figures directory: {e}"); return f"<p style='color:red;'>Error setting up visualization directory: {e}</p>", default_visuals
220
 
221
- wandb_run_agent = None; run_name = f"AgentAnalysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
222
- # ... (WandB init logic as before) ...
223
- if wandb.run is None or wandb.run.mode != "disabled":
224
- try:
225
- if wandb.run and wandb.run.id: wandb.finish()
226
- wandb_run_agent = wandb.init(project="ai-data-analysis-gradio", name=run_name, config={...}, reinit=True) # Simplified config display
227
- print(f"WandB run '{run_name}' initialized for Agent Analysis.")
228
- except Exception as e: print(f"Error initializing WandB run for Agent Analysis: {e}"); wandb_run_agent = None
229
- else: print("WandB disabled, skipping Agent run logging.")
230
-
231
- analysis_result = None; visuals = []
232
- try:
233
- # ... (prompt definition as before) ...
234
- prompt = f"""
235
- Analyze `df_global`. Tasks: 3 observations, 5 insights, 5 visualizations saved to './figures/'.
236
- Output Format: Python dictionary {{'observations':{{...}}, 'insights':{{...}}}}. Context: {additional_notes}
237
- Use `df_global`. Save plots with plt.savefig('./figures/unique_name.png', bbox_inches='tight') and plt.clf(). No plt.show().
238
- """ # Simplified prompt display
239
- print("Running AI agent..."); analysis_result = agent.run(prompt, additional_args={"df_global": df_global.copy()}); print("AI agent finished.")
240
- if os.path.exists(figures_dir):
241
- visuals = [os.path.join(figures_dir, f) for f in os.listdir(figures_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
242
- visuals = [v for v in visuals if os.path.isfile(v)]; print(f"Found {len(visuals)} visualizations in {figures_dir}.")
243
- else: print(f"Warning: Figures directory '{figures_dir}' not found after agent run.")
244
- except Exception as e:
245
- print(f"Error during AI agent execution: {e}"); import traceback; traceback.print_exc()
246
- if wandb_run_agent: wandb_run_agent.finish(exit_code=1)
247
- # Return values matching type hint on error
248
- return f"<p style='color:red;'>Error running AI agent: {e}</p>", default_visuals
249
-
250
- execution_time = time.time() - start_time; final_memory = process.memory_info().rss / 1024 ** 2; memory_usage = final_memory - initial_memory
251
- print(f"Agent execution time: {execution_time:.2f}s, Memory usage: {memory_usage:.2f} MB")
252
- # ... (WandB logging logic as before) ...
253
- if wandb_run_agent:
254
- try:
255
- wandb.log({"agent_execution_time_sec": execution_time, "agent_memory_usage_mb": memory_usage, "visualizations_generated": len(visuals)}, commit=False)
256
- # Log visualizations, etc.
257
- for viz_path in visuals:
258
- if os.path.exists(viz_path):
259
- try: wandb.log({f"agent_visualization_{os.path.basename(viz_path)}": wandb.Image(viz_path)}, commit=False)
260
- except Exception as log_e: print(f"Wandb img log error: {log_e}")
261
- if analysis_result: wandb.log({"agent_raw_output": str(analysis_result)[:10000]}, commit=True)
262
- print("Logged agent results to WandB.")
263
- except Exception as e: print(f"Error logging agent results to WandB: {e}")
264
- finally: wandb_run_agent.finish(); print(f"WandB run '{run_name}' finished.")
265
-
266
- # Ensure return matches type hint
267
- formatted_html, _ = format_analysis_report(analysis_result, visuals) # Get formatted HTML
268
- return formatted_html, visuals
269
-
270
-
271
- # --- Model Training and Comparison ---
272
-
273
- def prepare_data(df, target_column=None): # -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series] - Internal use, hint optional
274
- """Prepares data for modeling (selects target, splits, handles encoding)."""
275
- global split_data_global
276
- print("--- Preparing Data for Modeling ---")
277
- if df is None or df.empty: raise ValueError("Cannot prepare data: DataFrame is empty.")
278
- # ... (logic for selecting target, dropping, encoding 'y' as before) ...
279
- if target_column is None: target_column = df.columns[-1]; print(f"Target column automatically selected: '{target_column}'")
280
- elif target_column not in df.columns: raise ValueError(f"Target column '{target_column}' not found.")
281
- else: print(f"Using specified target column: '{target_column}'")
282
- X = df.drop(columns=[target_column]); y = df[target_column].copy(); le = None
283
- if y.dtype == 'object' or pd.api.types.is_categorical_dtype(y): le = LabelEncoder(); y = le.fit_transform(y); print(f"Encoded target. Classes: {le.classes_}")
284
- non_numeric_cols = X.select_dtypes(exclude=np.number).columns
285
- if not non_numeric_cols.empty: print(f"Warning: Non-numeric columns found in features: {list(non_numeric_cols)}. Dropping them."); X = X.drop(columns=non_numeric_cols)
286
- if X.empty: raise ValueError("No features remaining to train the model.")
287
- if y.nunique() < 2: raise ValueError("Target column must have at least two unique classes for classification.")
288
- try: X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
289
- except ValueError as split_e: print(f"Stratified split failed ({split_e}). Trying non-stratified split."); X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
290
- split_data_global = (X_train, X_test, y_train, y_test, le); print("Data prepared and split stored globally.")
291
- return X_train, X_test, y_train, y_test
292
-
293
- # --- Model Training and Comparison ---
294
-
295
- # ADDED TYPE HINT HERE
296
- def train_and_compare_models(tune_rf=True, tune_gb=True, n_trials_optuna=10) -> pd.DataFrame:
297
- """Trains, (optionally) tunes, evaluates multiple models, and logs comparison."""
298
- global df_global, split_data_global, comparison_results_global, best_model_details_global, wandb_run
299
- # Default return DataFrame matching hint
300
- default_error_df = pd.DataFrame({"Error": ["Comparison failed."]})
301
-
302
- if df_global is None: print("Error: No data loaded for training/comparison."); return pd.DataFrame({"Error": ["Please upload data first."]})
303
-
304
- print("--- Starting Model Training and Comparison ---")
305
- run_name = f"CompareModels_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
306
  models_to_compare = {
307
- "LogisticRegression": Pipeline([('scaler', StandardScaler()), ('logreg', LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced'))]),
308
- "RandomForest": RandomForestClassifier(random_state=42, class_weight='balanced'),
309
- "GradientBoosting": GradientBoostingClassifier(random_state=42)
310
- }
311
- config = {
312
- "task": "Model Comparison", "models": list(models_to_compare.keys()), "tune_rf": tune_rf,
313
- "tune_gb": tune_gb, "optuna_trials": n_trials_optuna if (tune_rf or tune_gb) else 0,
314
- "data_shape": df_global.shape if df_global is not None else "N/A", "test_size": 0.3, "stratify": True
315
  }
316
 
317
- # --- WandB Setup ---
318
- # (WandB init logic...)
 
319
  if wandb.run is None or wandb.run.mode != "disabled":
320
  try:
321
- if wandb.run and wandb.run.id: wandb.finish()
322
- wandb_run = wandb.init(project="ai-data-analysis-gradio", name=run_name, config=config, reinit=True)
323
- print(f"WandB run '{run_name}' initialized for Model Comparison.")
324
- except Exception as e: print(f"Error initializing WandB run for Comparison: {e}"); wandb_run = None
325
- else: print("WandB disabled, skipping Comparison run logging."); wandb_run = None
326
-
327
- results = []; best_f1 = -1; best_model_obj = None; best_model_name = None; best_model_params = None
328
- try:
329
- if split_data_global: print("Using previously split data."); X_train, X_test, y_train, y_test, _ = split_data_global
330
- else: print("Preparing data for comparison..."); X_train, X_test, y_train, y_test = prepare_data(df_global)
331
-
332
- # --- Optuna Objective Functions ---
333
- def objective_rf(trial):
334
- # --- CORRECTED PARAMETER DEFINITION ---
335
- params = {
336
- "n_estimators": trial.suggest_int("n_estimators", 50, 250, step=50),
337
- "max_depth": trial.suggest_int("max_depth", 5, 20, log=True),
338
- "min_samples_split": trial.suggest_int("min_samples_split", 2, 16),
339
- "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 16),
340
- "criterion": trial.suggest_categorical("criterion", ["gini", "entropy"]),
341
- "class_weight": trial.suggest_categorical("class_weight", ["balanced", "balanced_subsample", None]),
342
- "random_state": 42
343
- }
344
- # --- END CORRECTION ---
345
- try:
346
- model = RandomForestClassifier(**params)
347
- score = cross_val_score(model, X_train, y_train, cv=3, scoring="f1_weighted", n_jobs=-1).mean()
348
- if wandb_run: wandb.log({"optuna_rf_trial": trial.number, "optuna_rf_cv_f1w": score, **params}, commit=False)
349
- return score
350
- except ValueError as e: print(f"Optuna RF trial error (params {params}): {e}"); return -1
351
-
352
- def objective_gb(trial):
353
- # --- CORRECTED PARAMETER DEFINITION ---
354
- params = {
355
- "n_estimators": trial.suggest_int("n_estimators", 50, 250, step=50),
356
- "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
357
- "max_depth": trial.suggest_int("max_depth", 3, 10),
358
- "min_samples_split": trial.suggest_int("min_samples_split", 2, 16),
359
- "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 16),
360
- "subsample": trial.suggest_float("subsample", 0.6, 1.0),
361
- "random_state": 42
362
  }
363
- # --- END CORRECTION ---
364
- try:
365
- model = GradientBoostingClassifier(**params)
366
- score = cross_val_score(model, X_train, y_train, cv=3, scoring="f1_weighted", n_jobs=-1).mean()
367
- if wandb_run: wandb.log({"optuna_gb_trial": trial.number, "optuna_gb_cv_f1w": score, **params}, commit=False)
368
- return score
369
- except ValueError as e: print(f"Optuna GB trial error (params {params}): {e}"); return -1
370
-
371
- # --- Model Training Loop ---
372
- for name, model_pipeline in models_to_compare.items():
373
- print(f"--- Training and Evaluating: {name} ---")
374
- start_time = time.time(); current_params = model_pipeline.get_params(); final_model = model_pipeline
375
- try:
376
- # Tuning logic as before...
377
- if name == "RandomForest" and tune_rf:
378
- study_rf = optuna.create_study(direction="maximize"); study_rf.optimize(objective_rf, n_trials=n_trials_optuna, timeout=300)
379
- final_model = RandomForestClassifier(**study_rf.best_params, random_state=42); current_params = final_model.get_params()
380
- print(f"Best RF params (CV F1w: {study_rf.best_value:.4f})")
381
- if wandb_run: wandb.log({f"{name}_best_cv_f1w": study_rf.best_value, f"{name}_best_params": study_rf.best_params}, commit=False)
382
- elif name == "GradientBoosting" and tune_gb:
383
- study_gb = optuna.create_study(direction="maximize"); study_gb.optimize(objective_gb, n_trials=n_trials_optuna, timeout=300)
384
- final_model = GradientBoostingClassifier(**study_gb.best_params, random_state=42); current_params = final_model.get_params()
385
- print(f"Best GB params (CV F1w: {study_gb.best_value:.4f})")
386
- if wandb_run: wandb.log({f"{name}_best_cv_f1w": study_gb.best_value, f"{name}_best_params": study_gb.best_params}, commit=False)
387
-
388
- # Train/Eval logic as before...
389
- final_model.fit(X_train, y_train)
390
- y_pred = final_model.predict(X_test)
391
- accuracy=accuracy_score(y_test, y_pred); precision=precision_score(y_test, y_pred, average="weighted", zero_division=0); recall=recall_score(y_test, y_pred, average="weighted", zero_division=0); f1=f1_score(y_test, y_pred, average="weighted", zero_division=0); duration = time.time() - start_time
392
- print(f"{name} Test - F1(w): {f1:.4f}, Acc: {accuracy:.4f}, Time: {duration:.2f}s")
393
- metrics = { "Model": name, "Test Accuracy": accuracy, "Test Precision (Weighted)": precision, "Test Recall (Weighted)": recall, "Test F1 Score (Weighted)": f1, "Training Time (s)": duration, "Tuned": (name == "RandomForest" and tune_rf) or (name == "GradientBoosting" and tune_gb) }
394
- results.append(metrics)
395
- # WandB logging logic...
396
- if wandb_run:
397
- log_metrics = {f"{name}_{k.lower().replace(' (weighted)','_w').replace(' ','_')}": v for k, v in metrics.items() if k not in ["Model", "Tuned"]}; log_metrics[f"{name}_tuned_flag"] = metrics["Tuned"]
398
- wandb.log(log_metrics, commit=False)
399
- # Update best model logic...
400
- if f1 > best_f1: best_f1 = f1; best_model_name = name; best_model_obj = final_model; best_model_params = current_params; print(f"*** New best model: {name} ***")
401
- except Exception as train_e: print(f"ERROR training/evaluating {name}: {train_e}"); results.append({"Model": name, "Error": str(train_e)}); import traceback; traceback.print_exc()
402
-
403
- # --- Finalize Comparison ---
404
- # (Logic as before: create DataFrame, sort, store globals, save artifact, log table)
405
- if not results: print("No models trained."); return pd.DataFrame({"Status": ["Model training failed."]})
406
- comparison_df = pd.DataFrame(results)
407
- if "Test F1 Score (Weighted)" in comparison_df.columns: comparison_df = comparison_df.sort_values(by="Test F1 Score (Weighted)", ascending=False).reset_index(drop=True)
408
- comparison_results_global = comparison_df
409
- print("\n--- Model Comparison Summary ---"); print(comparison_df.to_string())
410
- if best_model_obj is not None: best_model_details_global = {'name': best_model_name, 'model': best_model_obj, 'params': best_model_params, 'f1_score': best_f1}; print(f"Stored best model: {best_model_name}")
411
- # Save artifact logic...
412
- # Log table logic...
413
-
414
- return comparison_df
415
-
416
- except Exception as e:
417
- print(f"Error during model comparison: {e}"); import traceback; traceback.print_exc()
418
- if wandb_run: wandb_run.finish(exit_code=1)
419
- return pd.DataFrame({"Error": [f"Comparison failed: {e}"]})
420
- finally:
421
- # Finish WandB run...
422
- if wandb_run and wandb.run: wandb_run.finish(); print(f"WandB run '{run_name}' finished.")
423
- wandb_run = None
424
-
425
-
426
- # --- Model Explainability ---
427
-
428
- # TYPE HINT ALREADY ADDED HERE
429
- def explainability(_=None) -> Tuple[List[str], Optional[str], str]:
430
- """Generates SHAP and LIME explanations for the best performing model."""
431
- global split_data_global, best_model_details_global, wandb_run
432
- # Default returns match hint
433
- default_shap_paths = []
434
- default_lime_path = None
435
- default_status = "Error: Explainability could not run."
436
-
437
- if split_data_global is None: return default_shap_paths, default_lime_path, "Error: Data not prepared. Run 'Train & Compare' first."
438
- if best_model_details_global is None: return default_shap_paths, default_lime_path, "Error: Best model not identified. Run 'Train & Compare' first."
439
-
440
- # --- Get data and model ---
441
- X_train, X_test, y_train, y_test, label_encoder = split_data_global
442
- best_model_name = best_model_details_global['name']
443
- best_model = best_model_details_global['model']
444
-
445
- print(f"--- Generating explanations for the best model: {best_model_name} ---")
446
- # ... (Setup output dir, define paths, WandB init logic as before) ...
447
- output_dir = "./explainability_plots";
448
- if os.path.exists(output_dir): shutil.rmtree(output_dir)
449
- os.makedirs(output_dir)
450
- shap_summary_path = os.path.join(output_dir, f"shap_summary_{best_model_name}.png")
451
- lime_path = os.path.join(output_dir, f"lime_instance_{best_model_name}.png")
452
- all_shap_paths = []; status_message = f"Explaining best model: {best_model_name}" # Initialize gallery list
453
- # WandB Init...
454
- run_name = f"Explain_{best_model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
455
- wandb_run_explain = None
456
- if wandb.run is None or wandb.run.mode != "disabled":
457
- try:
458
- if wandb.run and wandb.run.id: wandb.finish()
459
- wandb_run_explain = wandb.init(project="ai-data-analysis-gradio", name=run_name, config={...}, reinit=True)
460
- print(f"WandB run '{run_name}' initialized for Explainability.")
461
- except Exception as e: print(f"Error initializing Wandb run for Explainability: {e}"); wandb_run_explain = None
462
- else: print("WandB disabled, skipping Explainability run logging."); wandb_run_explain = None
463
 
464
 
465
- try:
466
- # --- SHAP Explanation (logic as before) ---
467
- print("Calculating SHAP values...")
468
- shap_values = None; explainer = None; X_test_for_shap = X_test
469
- # ... (logic to determine explainer and calculate shap_values based on model type/pipeline) ...
470
- # Simplified example logic:
471
- if isinstance(best_model, Pipeline):
472
- final_estimator = best_model.steps[-1][1]
473
- if isinstance(final_estimator, (RandomForestClassifier, GradientBoostingClassifier)):
474
- print("Using TreeExplainer for Pipeline")
475
- # ... (transform data, init explainer, get shap_values) ...
476
- elif isinstance(final_estimator, LogisticRegression):
477
- print("Using KernelExplainer for Pipeline")
478
- # ... (init explainer, get shap_values for subset) ...
479
- else: print("SHAP not implemented for this pipeline step.")
480
- elif isinstance(best_model, (RandomForestClassifier, GradientBoostingClassifier)):
481
- print("Using TreeExplainer for standalone model")
482
- explainer = shap.TreeExplainer(best_model); shap_values = explainer.shap_values(X_test_for_shap)
483
- else: print("SHAP not implemented for this model type.")
484
-
485
-
486
- # --- Generate SHAP Plots (logic as before) ---
487
- if shap_values is not None:
488
- # ... (Generate summary plot, calculate global importance, generate dependence plots) ...
489
- # Important: Ensure generated paths are added to `all_shap_paths`
490
- # Example:
491
- # if summary plot saved: all_shap_paths.append(shap_summary_path)
492
- # if dep plot saved: all_shap_paths.append(dep_path)
493
- print("Generating SHAP plots...") # Placeholder print
494
- # ... (SHAP plot generation, saving, and appending to all_shap_paths) ...
495
- # Example: If summary plot is generated and saved:
496
- if os.path.exists(shap_summary_path):
497
- all_shap_paths.append(shap_summary_path)
498
- # Example: If dependence plots are generated and saved:
499
- # for dep_path in shap_dep_paths:
500
- # if os.path.exists(dep_path):
501
- # all_shap_paths.append(dep_path)
502
-
503
-
504
- # --- LIME Explanation (logic as before) ---
505
- print("Generating LIME explanation...")
506
  try:
507
- # ... (LIME explainer setup, explain_instance, plot saving logic) ...
508
- print("LIME explanation generated.") # Placeholder print
509
- except Exception as lime_e:
510
- print(f"Error generating LIME explanation: {lime_e}")
511
- lime_path = None # Indicate failure
512
-
513
- # --- Finalize and Return ---
514
- status_message = f"Explanations finished for {best_model_name}." # Update status
515
- valid_lime_path = lime_path if lime_path and os.path.exists(lime_path) else None
516
- valid_shap_paths = [p for p in all_shap_paths if p and os.path.exists(p)]
517
- print(f"Returning {len(valid_shap_paths)} SHAP paths, LIME path: {valid_lime_path}")
518
- # Ensure return matches type hint
519
- return valid_shap_paths, valid_lime_path, status_message
 
 
 
 
520
 
521
- except Exception as e:
522
- print(f"An error occurred during explainability: {e}"); import traceback; traceback.print_exc()
523
- status_message = f"Error during explanation: {e}"
524
- if wandb_run_explain: wandb_run_explain.finish(exit_code=1)
525
- # Ensure return matches type hint
526
- return [], None, status_message
527
- finally:
528
- plt.close('all')
529
- # Finish WandB run logic...
530
- if wandb_run_explain and wandb.run and wandb.run.id == wandb_run_explain.id: wandb_run_explain.finish(); print(f"WandB run '{run_name}' finished.")
531
- wandb_run_explain = None
532
-
533
-
534
- # --- Gradio Interface ---
535
- # (Keep Gradio UI layout and connections exactly as before)
536
- print("--- Setting up Gradio Interface ---")
537
- with gr.Blocks(theme=gr.themes.Soft(), title="AI Data Analysis & Model Comparison") as demo:
538
- gr.Markdown( ... ) # Title markdown
539
- # Row 1: File Upload ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  with gr.Row():
541
- with gr.Column(scale=1): file_input = gr.File(...)
542
- with gr.Column(scale=2): df_output = gr.DataFrame(...)
543
- # Row 2: AI Agent ...
544
- with gr.Accordion("🤖 Step 2 (Optional): Run AI Agent for Insights & Visuals", open=False):
 
 
 
545
  with gr.Row():
546
- with gr.Column(scale=1): agent_notes = gr.Textbox(...); agent_btn = gr.Button(...)
547
- with gr.Column(scale=2): insights_output = gr.HTML(...)
548
- with gr.Row(): visual_output = gr.Gallery(...)
549
- # Row 3: Model Training ...
550
- with gr.Accordion("⚙️ Step 3: Train & Compare Models", open=True):
551
  with gr.Row():
552
- with gr.Column(scale=1): tune_rf_checkbox = gr.Checkbox(...); tune_gb_checkbox = gr.Checkbox(...); optuna_trials_slider = gr.Slider(...); compare_btn = gr.Button(...)
553
- with gr.Column(scale=2): comparison_output = gr.DataFrame(...)
554
- # Row 4: Explainability ...
555
- with gr.Accordion("💡 Step 4: Explain Best Model (SHAP & LIME)", open=False):
556
- with gr.Row(): explain_btn = gr.Button(...); explain_status = gr.Textbox(...)
557
- with gr.Row(): shap_gallery = gr.Gallery(...); lime_img = gr.Image(...)
 
 
 
 
 
 
 
 
 
558
 
559
  # --- Connect Components ---
560
- print("Connecting Gradio components...")
561
  file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
 
 
562
  agent_btn.click(fn=analyze_data, inputs=[file_input, agent_notes], outputs=[insights_output, visual_output])
563
- compare_btn.click(fn=train_and_compare_models, inputs=[tune_rf_checkbox, tune_gb_checkbox, optuna_trials_slider], outputs=[comparison_output])
564
- explain_btn.click(fn=explainability, inputs=[], outputs=[shap_gallery, lime_img, explain_status])
565
- print("Gradio components connected.")
566
-
567
-
568
- # --- Launch the App ---
569
- if __name__ == "__main__":
570
- print("--- Cleaning up temporary directories/files ---")
571
- # (Cleanup logic as before)
572
- temp_dirs = ['./figures', './explainability_plots', './saved_models', './__pycache__']
573
- for d in temp_dirs:
574
- if os.path.exists(d):
575
- try: shutil.rmtree(d); print(f"Cleaned up directory: {d}")
576
- except Exception as e: print(f"Warning: Could not clean up directory {d}: {e}")
577
-
578
- print("--- Launching Gradio App ---")
579
- demo.launch(debug=False, share=False)
580
- print("--- Gradio App Closed ---")
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
+ import pandas as pd
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import shap
7
  import lime
8
  import lime.lime_tabular
9
+ # import optuna # Removing Optuna for this simplified approach
10
  import wandb
11
  import json
12
  import time
13
  import psutil
14
  import shutil
15
  import ast
16
+ from smolagents import HfApiModel, CodeAgent
17
  from huggingface_hub import login
18
+ from sklearn.model_selection import train_test_split, cross_val_score # Keep cross_val_score if needed elsewhere, but not primary for comparison here
19
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
20
+ # from sklearn.metrics import ConfusionMatrixDisplay # Not used currently
21
+ from sklearn.ensemble import RandomForestClassifier # Keep RF
22
+ # from sklearn.ensemble import GradientBoostingClassifier # Remove GB for simplicity now
23
  from sklearn.linear_model import LogisticRegression
24
+ from sklearn.preprocessing import LabelEncoder, StandardScaler # Added StandardScaler
25
+ from sklearn.pipeline import Pipeline # Added Pipeline
 
26
  from datetime import datetime
27
+ # from PIL import Image # Likely not needed directly
 
 
 
 
 
 
28
 
29
+ # --- Authentication and Setup (Keep as is) ---
 
 
30
  hf_token = os.getenv("HF_TOKEN")
31
+ wandb_api_key = os.getenv("WANDB_API_KEY") # Get WandB key
32
+
33
+ # Authenticate with Hugging Face
34
+ if hf_token:
35
+ try:
36
+ login(token=hf_token)
37
+ print("HF Login successful.")
38
+ except Exception as e:
39
+ print(f"HF login failed: {e}")
 
 
40
  else:
41
+ print("HF_TOKEN not found.")
42
+
43
+ # Login to WandB if key exists
44
+ if wandb_api_key:
45
+ try:
46
+ wandb.login(key=wandb_api_key)
47
+ print("WandB login successful.")
48
  except Exception as e:
49
+ print(f"WandB login failed: {e}. Logging will be disabled.")
50
+ wandb.init(mode="disabled") # Disable if login fails
51
+ else:
52
+ print("WANDB_API_KEY not found. WandB logging disabled.")
53
+ wandb.init(mode="disabled") # Disable if no key
54
+
55
+ # SmolAgent initialization (Keep as is)
56
  try:
 
57
  model_api = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
58
+ print("SmolAgent API Model initialized.")
59
+ except Exception as e:
60
+ print(f"SmolAgent initialization failed: {e}")
61
+ model_api = None # Set to None if failed
62
+
63
  df_global = None
64
+ # --- NEW: Global variable for split data ---
65
+ split_data_global = None # Will store (X_train, X_test, y_train, y_test, label_encoder)
66
+
67
+ # --- clean_data, upload_file, AI Agent functions (Keep as is from your original code) ---
68
+ def clean_data(df):
69
+ # Your original clean_data implementation
70
+ df = df.copy() # Work on copy
71
+ df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
72
+ for col in df.select_dtypes(include='object').columns:
73
+ df[col] = df[col].astype(str)
74
+ df[col] = LabelEncoder().fit_transform(df[col])
75
+ # Impute only if numeric columns exist
76
+ numeric_cols = df.select_dtypes(include=np.number).columns
 
 
 
 
 
 
 
 
 
77
  if not numeric_cols.empty:
78
+ df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].mean())
79
+ return df
80
+
81
+ def upload_file(file):
82
+ global df_global, split_data_global # Reset split data on new upload
 
 
 
 
 
 
 
 
 
83
  df_global = None
84
  split_data_global = None
 
 
 
 
85
  if file is None:
86
+ return pd.DataFrame({"Error": ["No file uploaded."]})
 
 
 
87
  try:
88
+ ext = os.path.splitext(file.name)[-1].lower() # Use lower()
89
+ df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
90
+ df = clean_data(df)
91
+ df_global = df
92
+ print("File uploaded and cleaned.")
93
+ return df.head()
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
+ print(f"Error in upload_file: {e}")
 
 
96
  return pd.DataFrame({"Error": [f"Failed to process file: {e}"]})
97
 
98
+ # --- AI Agent functions (Keep your original format_*, analyze_data) ---
99
+ # Placeholder for brevity - use your original functions
100
+ def format_analysis_report(raw_output, visuals): return f"<h2>AI Report</h2><pre>{str(raw_output)}</pre>", visuals
101
+ def format_observations(observations): return f"<pre>{str(observations)}</pre>"
102
+ def format_insights(insights, visuals): return f"<pre>{str(insights)}</pre>"
103
+ def analyze_data(csv_file, additional_notes=""):
104
+ print("Running AI Agent (stub)...")
105
+ # Your original analyze_data logic here
106
+ # Ensure it uses wandb.init(reinit=True) if called multiple times
107
+ # and finishes the run: wandb.finish()
108
+ if not model_api: return "AI Agent not initialized.", []
109
+ # Dummy result
110
+ analysis_result = {"observations": {"data": "desc"}, "insights": {"insight1": "text"}}
111
+ visuals = [] # Agent should save plots to './figures/'
112
+ return format_analysis_report(analysis_result, visuals)
113
+
114
+ # --- MODIFIED: prepare_data ---
115
+ def prepare_data(df, target_column=None) -> bool:
116
+ """Splits data and stores it globally. Returns True on success, False on failure."""
117
+ global split_data_global
118
+ print("Preparing data split...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  try:
120
+ if df is None or df.empty:
121
+ print("Error: DataFrame is empty in prepare_data.")
122
+ split_data_global = None
123
+ return False
124
+
125
+ # --- Target Column Logic ---
126
+ if target_column is None:
127
+ # Prioritize object columns if they exist and are not all unique
128
+ object_cols = df.select_dtypes(include=['object', 'category']).columns
129
+ potential_targets = [col for col in object_cols if df[col].nunique() < len(df)]
130
+ if potential_targets:
131
+ target_column = potential_targets[0] # Take the first suitable object col
132
+ print(f"Target column auto-selected (object): '{target_column}'")
133
+ else:
134
+ target_column = df.columns[-1] # Fallback to last column
135
+ print(f"Target column auto-selected (last): '{target_column}'")
136
+ elif target_column not in df.columns:
137
+ print(f"Error: Specified target column '{target_column}' not found.")
138
+ split_data_global = None
139
+ return False
140
+
141
+ X = df.drop(columns=[target_column])
142
+ y = df[target_column].copy()
143
+
144
+ # --- Feature Check (ensure numeric) ---
145
+ # (Should be handled by clean_data, but double-check)
146
+ non_numeric_features = X.select_dtypes(exclude=np.number).columns
147
+ if not non_numeric_features.empty:
148
+ print(f"Warning: Dropping non-numeric feature columns: {list(non_numeric_features)}")
149
+ X = X.drop(columns=non_numeric_features)
150
+ if X.empty:
151
+ print("Error: No numeric features left after dropping non-numeric ones.")
152
+ split_data_global = None
153
+ return False
154
+
155
+ # --- Target Encoding ---
156
+ label_encoder = None
157
+ if not pd.api.types.is_numeric_dtype(y):
158
+ print(f"Encoding target column '{target_column}' with LabelEncoder.")
159
+ label_encoder = LabelEncoder()
160
+ y = label_encoder.fit_transform(y)
161
+ else:
162
+ # Check if float target should be treated as classification (e.g., integers represented as float)
163
+ if pd.api.types.is_float_dtype(y) and np.all(y == y.astype(int)):
164
+ print(f"Target '{target_column}' is float but looks like integer. Converting to int.")
165
+ y = y.astype(int)
166
+
167
+ # --- Check for sufficient classes ---
168
+ if y.nunique() < 2:
169
+ print(f"Error: Target column '{target_column}' has less than 2 unique values after processing.")
170
+ split_data_global = None
171
+ return False
172
+
173
+ # --- Perform Split ---
174
+ try:
175
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
176
+ print("Performed stratified split.")
177
+ except ValueError: # Handle cases where stratification is not possible
178
+ print("Stratified split failed, using non-stratified split.")
179
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
180
 
181
+ split_data_global = (X_train, X_test, y_train, y_test, label_encoder)
182
+ print(f"Data split successfully: Train {X_train.shape}, Test {X_test.shape}")
183
+ return True
 
 
 
 
184
 
185
+ except Exception as e:
186
+ print(f"Error during data preparation: {e}")
187
+ import traceback
188
+ traceback.print_exc()
189
+ split_data_global = None
190
+ return False
191
+
192
+ # --- NEW: run_comparison_and_explainability ---
193
+ def run_comparison_and_explainability():
194
+ """Compares models, explains the best one, and logs to WandB."""
195
+ global df_global, split_data_global
196
+
197
+ # --- 1. Check Prerequisites ---
198
+ if df_global is None:
199
+ return pd.DataFrame({"Error": ["No data uploaded."]}), None, None, "Error: Upload data first."
200
+ if split_data_global is None:
201
+ # Attempt to prepare data if not already done
202
+ print("Split data not found globally, attempting to prepare now...")
203
+ if not prepare_data(df_global):
204
+ return pd.DataFrame({"Error": ["Data preparation failed."]}), None, None, "Error: Failed to prepare data for comparison."
205
+ # If prepare_data succeeded, split_data_global is now populated
206
+
207
+ # Unpack the globally stored split data
208
+ X_train, X_test, y_train, y_test, label_encoder = split_data_global
209
+ class_names = getattr(label_encoder, 'classes_', [str(c) for c in np.unique(y_train)]) if label_encoder else [str(c) for c in np.unique(y_train)]
210
+ class_names = [str(c) for c in class_names] # Ensure strings
211
 
212
+ print("--- Starting Model Comparison & Explainability ---")
 
 
 
 
 
 
213
 
214
+ # --- 2. Define Models ---
215
+ # Using Pipelines where scaling is beneficial
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  models_to_compare = {
217
+ "LogisticRegression": Pipeline([
218
+ ('scaler', StandardScaler()),
219
+ ('logreg', LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced'))
220
+ ]),
221
+ "RandomForest": RandomForestClassifier(random_state=42, class_weight='balanced')
222
+ # Add more models here if desired, e.g., GradientBoostingClassifier
223
+ # "GradientBoosting": GradientBoostingClassifier(random_state=42)
 
224
  }
225
 
226
+ # --- 3. Initialize WandB Run ---
227
+ run_name = f"CompareExplain_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
228
+ wandb_run = None
229
  if wandb.run is None or wandb.run.mode != "disabled":
230
  try:
231
+ # Finish any existing run
232
+ if wandb.run: wandb.finish()
233
+ wandb_run = wandb.init(
234
+ project="huggingface-data-analysis", # Or your preferred project
235
+ name=run_name,
236
+ config={
237
+ "task": "Comparison & Explainability",
238
+ "models": list(models_to_compare.keys()),
239
+ "data_shape_train": X_train.shape,
240
+ "data_shape_test": X_test.shape,
241
+ },
242
+ reinit=True
243
+ )
244
+ print(f"WandB Run '{run_name}' started.")
245
+ except Exception as e:
246
+ print(f"WandB init failed: {e}")
247
+ wandb_run = None # Ensure it's None if failed
248
+
249
+ # --- 4. Train and Evaluate Models ---
250
+ results = []
251
+ trained_models = {} # Store fitted models
252
+ print("Comparing models...")
253
+ for name, model in models_to_compare.items():
254
+ print(f" Training {name}...")
255
+ start_time = time.time()
256
+ try:
257
+ model.fit(X_train, y_train)
258
+ y_pred = model.predict(X_test)
259
+ duration = time.time() - start_time
260
+
261
+ # Calculate metrics
262
+ metrics = {
263
+ "Model": name,
264
+ "Accuracy": accuracy_score(y_test, y_pred),
265
+ "Precision (Weighted)": precision_score(y_test, y_pred, average="weighted", zero_division=0),
266
+ "Recall (Weighted)": recall_score(y_test, y_pred, average="weighted", zero_division=0),
267
+ "F1 Score (Weighted)": f1_score(y_test, y_pred, average="weighted", zero_division=0),
268
+ "Time (s)": duration
 
 
 
269
  }
270
+ results.append(metrics)
271
+ trained_models[name] = model # Store the fitted model
272
+ print(f" {name} - F1: {metrics['F1 Score (Weighted)']:.4f}, Time: {duration:.2f}s")
273
+
274
+ # Log individual model metrics to WandB
275
+ if wandb_run:
276
+ wandb.log({f"{name}_{k.lower().replace(' (weighted)','_w').replace(' ','_')}": v
277
+ for k, v in metrics.items() if k != "Model"}, commit=False)
278
+
279
+ except Exception as e:
280
+ print(f" ERROR training/evaluating {name}: {e}")
281
+ results.append({"Model": name, "Error": str(e)})
282
+ if wandb_run: wandb.log({f"{name}_error": str(e)}, commit=False)
283
+
284
+ # --- 5. Process Comparison Results ---
285
+ if not results:
286
+ if wandb_run: wandb.finish()
287
+ return pd.DataFrame({"Error": ["No models trained successfully."]}), None, None, "Error: Model training failed."
288
+
289
+ comparison_df = pd.DataFrame(results)
290
+ # Sort by F1, handle potential errors where F1 might be NaN
291
+ if "F1 Score (Weighted)" in comparison_df.columns:
292
+ comparison_df = comparison_df.sort_values(by="F1 Score (Weighted)", ascending=False, na_position='last').reset_index(drop=True)
293
+ best_model_row = comparison_df.iloc[0]
294
+ best_model_name = best_model_row['Model']
295
+ if pd.notna(best_model_row["F1 Score (Weighted)"]) and best_model_name in trained_models:
296
+ best_model = trained_models[best_model_name]
297
+ print(f"Best model determined: {best_model_name} (F1: {best_model_row['F1 Score (Weighted)']:.4f})")
298
+ else:
299
+ best_model = None # Best row had an error or NaN F1
300
+ best_model_name = "N/A (Error or No Valid Model)"
301
+ print("Warning: Could not determine a valid best model from results.")
302
+ else:
303
+ print("Warning: F1 Score column missing, cannot determine best model.")
304
+ best_model = None
305
+ best_model_name = "N/A (F1 Missing)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
 
308
+ # Log comparison table to WandB
309
+ if wandb_run and not comparison_df.empty:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  try:
311
+ wandb.log({"model_comparison": wandb.Table(dataframe=comparison_df)}, commit=False)
312
+ print("Logged comparison table to WandB.")
313
+ except Exception as e:
314
+ print(f"Error logging comparison table: {e}")
315
+
316
+ # --- 6. Explain Best Model (if found) ---
317
+ shap_plot_path = None
318
+ lime_plot_path = None
319
+ explain_status = f"Compared {len(trained_models)} models. Best: {best_model_name}."
320
+
321
+ if best_model:
322
+ print(f"Generating explanations for {best_model_name}...")
323
+ explain_dir = "./explain_plots"
324
+ if os.path.exists(explain_dir): shutil.rmtree(explain_dir)
325
+ os.makedirs(explain_dir)
326
+ shap_plot_path = os.path.join(explain_dir, f"shap_{best_model_name}.png")
327
+ lime_plot_path = os.path.join(explain_dir, f"lime_{best_model_name}.png")
328
 
329
+ try:
330
+ # --- SHAP ---
331
+ explainer = None
332
+ shap_values = None
333
+ X_test_for_shap = X_test # Default
334
+
335
+ if isinstance(best_model, Pipeline):
336
+ final_estimator = best_model.steps[-1][1]
337
+ if isinstance(final_estimator, (RandomForestClassifier, LogisticRegression)): # Add others if needed
338
+ # Use KernelExplainer for pipelines generally (safer)
339
+ print(" Using SHAP KernelExplainer for Pipeline...")
340
+ predict_proba_pipe = lambda x_np: best_model.predict_proba(pd.DataFrame(x_np, columns=X_test.columns))
341
+ # Smaller background for KernelExplainer
342
+ X_train_summary = shap.kmeans(X_train.values, min(50, X_train.shape[0]), random_state=42)
343
+ explainer = shap.KernelExplainer(predict_proba_pipe, X_train_summary)
344
+ X_test_sample = X_test.sample(min(50, X_test.shape[0]), random_state=42) # Sample for faster explanation
345
+ shap_values = explainer.shap_values(X_test_sample.values, nsamples='auto')
346
+ X_test_for_shap = X_test_sample # Use the same sample for plotting
347
+ print(" SHAP values calculated (Kernel).")
348
+ else:
349
+ print(f" SHAP not configured for pipeline step: {type(final_estimator)}")
350
+ elif isinstance(best_model, RandomForestClassifier):
351
+ print(" Using SHAP TreeExplainer...")
352
+ explainer = shap.TreeExplainer(best_model)
353
+ shap_values = explainer.shap_values(X_test)
354
+ print(" SHAP values calculated (Tree).")
355
+ else:
356
+ print(f" SHAP not configured for model type: {type(best_model)}")
357
+
358
+ if shap_values is not None:
359
+ plt.figure()
360
+ # Use class index 1 for binary, or average for multiclass summary
361
+ plot_values = shap_values
362
+ shap_title = f"SHAP Summary ({best_model_name})"
363
+ if isinstance(shap_values, list) and len(class_names) == 2:
364
+ plot_values = shap_values[1] # Positive class
365
+ shap_title += f" - Class: {class_names[1]}"
366
+
367
+ shap.summary_plot(plot_values, X_test_for_shap, plot_type="dot", show=False, class_names=class_names)
368
+ plt.title(shap_title)
369
+ plt.tight_layout()
370
+ plt.savefig(shap_plot_path, bbox_inches='tight')
371
+ plt.clf()
372
+ print(f" SHAP plot saved: {shap_plot_path}")
373
+ if wandb_run: wandb.log({"shap_summary_best": wandb.Image(shap_plot_path)}, commit=False)
374
+ else:
375
+ shap_plot_path = None # No plot generated
376
+
377
+ # --- LIME ---
378
+ print(" Generating LIME explanation...")
379
+ # LIME needs predict_proba function
380
+ predict_fn_lime = None
381
+ if hasattr(best_model, 'predict_proba'):
382
+ predict_fn_lime = lambda x_np: best_model.predict_proba(pd.DataFrame(x_np, columns=X_train.columns))
383
+ else: # Fallback if no predict_proba
384
+ num_classes_lime = len(class_names)
385
+ predict_fn_lime = lambda x: np.ones((len(x), num_classes_lime)) / num_classes_lime
386
+
387
+ lime_explainer = lime.lime_tabular.LimeTabularExplainer(
388
+ training_data=X_train.values,
389
+ feature_names=X_train.columns.tolist(),
390
+ class_names=class_names,
391
+ mode='classification' # Assume classification
392
+ )
393
+ instance_idx = 0 # Explain the first test instance
394
+ lime_exp = lime_explainer.explain_instance(
395
+ data_row=X_test.iloc[instance_idx].values,
396
+ predict_fn=predict_fn_lime,
397
+ num_features=10
398
+ )
399
+ lime_fig = lime_exp.as_pyplot_figure()
400
+ lime_fig.suptitle(f"LIME Explanation ({best_model_name} - Instance {instance_idx})")
401
+ lime_fig.savefig(lime_plot_path, bbox_inches='tight')
402
+ plt.clf()
403
+ print(f" LIME plot saved: {lime_plot_path}")
404
+ if wandb_run: wandb.log({"lime_explanation_best": wandb.Image(lime_plot_path)}, commit=False)
405
+
406
+ explain_status += f" Explanations generated for {best_model_name}."
407
+
408
+ except Exception as e:
409
+ print(f" ERROR during explanation: {e}")
410
+ import traceback
411
+ traceback.print_exc()
412
+ explain_status += f" Explanation failed for {best_model_name}: {e}"
413
+ # Keep paths as None if error occurred
414
+ if not os.path.exists(shap_plot_path): shap_plot_path = None
415
+ if not os.path.exists(lime_plot_path): lime_plot_path = None
416
+
417
+ else:
418
+ explain_status += " No best model found to explain."
419
+
420
+
421
+ # --- 7. Finish WandB Run and Return ---
422
+ if wandb_run:
423
+ wandb.log({}, commit=True) # Ensure final commit
424
+ wandb.finish()
425
+ print(f"WandB Run '{run_name}' finished.")
426
+
427
+ # Filter out non-existent plot paths before returning
428
+ valid_shap_path = shap_plot_path if shap_plot_path and os.path.exists(shap_plot_path) else None
429
+ valid_lime_path = lime_plot_path if lime_plot_path and os.path.exists(lime_plot_path) else None
430
+
431
+ return comparison_df, valid_shap_path, valid_lime_path, explain_status
432
+
433
+
434
+ # --- Gradio UI ---
435
+ with gr.Blocks() as demo:
436
+ gr.Markdown("## 📊 AI Data Analysis, Model Comparison & Explainability")
437
+
438
+ # --- Row 1: Upload ---
439
  with gr.Row():
440
+ with gr.Column(scale=1):
441
+ file_input = gr.File(label="1. Upload CSV or Excel", type="filepath", file_types=[".csv", ".xls", ".xlsx"])
442
+ with gr.Column(scale=2):
443
+ df_output = gr.DataFrame(label="Cleaned Data Preview", interactive=False)
444
+
445
+ # --- Row 2: AI Agent (Optional) ---
446
+ with gr.Accordion("🤖 Step 2 (Optional): Run AI Agent Insights", open=False):
447
  with gr.Row():
448
+ with gr.Column(scale=1):
449
+ agent_notes = gr.Textbox(label="Optional requests for Agent", placeholder="e.g., 'Focus on column X'")
450
+ agent_btn = gr.Button("Run AI Analysis", interactive=(model_api is not None))
451
+ with gr.Column(scale=2):
452
+ insights_output = gr.HTML(label="AI Agent Report")
453
  with gr.Row():
454
+ visual_output = gr.Gallery(label="AI Agent Visualizations", height=350, object_fit="contain", columns=3, preview=True)
455
+
456
+ # --- Row 3: Compare & Explain ---
457
+ with gr.Accordion("⚙️💡 Step 3: Compare Models & Explain Best", open=True):
458
+ with gr.Row():
459
+ compare_explain_btn = gr.Button("Run Comparison & Explain Best Model", variant="primary")
460
+ with gr.Row():
461
+ comparison_output = gr.DataFrame(label="Model Comparison Results", interactive=False)
462
+ with gr.Row():
463
+ explain_status_output = gr.Textbox(label="Status", interactive=False)
464
+ with gr.Row():
465
+ # Only one SHAP plot expected now (summary)
466
+ shap_img_output = gr.Image(label="SHAP Summary (Best Model)", type="filepath", interactive=False)
467
+ lime_img_output = gr.Image(label="LIME Explanation (Best Model - Instance 0)", type="filepath", interactive=False)
468
+
469
 
470
  # --- Connect Components ---
 
471
  file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
472
+
473
+ # AI Agent connection (Keep as is)
474
  agent_btn.click(fn=analyze_data, inputs=[file_input, agent_notes], outputs=[insights_output, visual_output])
475
+
476
+ # NEW: Connection for combined comparison and explainability
477
+ compare_explain_btn.click(
478
+ fn=run_comparison_and_explainability,
479
+ inputs=[], # Takes data from global df_global
480
+ outputs=[comparison_output, shap_img_output, lime_img_output, explain_status_output]
481
+ )
482
+
483
+ # --- Launch ---
484
+ print("Launching Gradio App...")
485
+ demo.launch(debug=True) # Use debug=True for more detailed errors during development