pavanmutha commited on
Commit
c8f222f
·
verified ·
1 Parent(s): a6f26d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +913 -143
app.py CHANGED
@@ -1,10 +1,16 @@
1
- # Required imports (ensure all are present from previous code)
 
 
 
 
 
2
  import os
3
  import gradio as gr
4
  import pandas as pd
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
  import shap
 
8
  import lime.lime_tabular
9
  import optuna
10
  import wandb
@@ -13,42 +19,814 @@ import time
13
  import psutil
14
  import shutil
15
  import ast
16
- from smolagents import HfApiModel, CodeAgent
17
  from huggingface_hub import login
18
  from sklearn.model_selection import train_test_split, cross_val_score
19
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
20
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
21
  from sklearn.linear_model import LogisticRegression
22
- from sklearn.svm import SVC
23
  from sklearn.preprocessing import LabelEncoder, StandardScaler
24
  from sklearn.pipeline import Pipeline
25
  from datetime import datetime
26
- from PIL import Image
27
  import warnings
28
- import joblib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # (Keep all previous setup, functions like clean_data, upload_file, AI Agent, prepare_data, train_and_compare_models)
31
- # ... (paste the previous code here up to the explainability function) ...
 
 
32
 
33
- # --- Model Explainability (REVISED) ---
 
34
 
35
- def explainability(_=None): # Add dummy input for button click signature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  """Generates SHAP and LIME explanations for the best performing model."""
37
  global split_data_global, best_model_details_global, wandb_run
38
  if split_data_global is None:
39
- print("Error: Data not split. Please run comparison first.")
40
- return [], None, "Error: Data not prepared. Run 'Train & Compare' first." # Return empty list for gallery
41
  if best_model_details_global is None:
42
- print("Error: Best model details not found. Please run comparison first.")
43
- return [], None, "Error: Best model not identified. Run 'Train & Compare' first." # Return empty list
44
 
45
- X_train, X_test, y_train, y_test = split_data_global
46
  best_model_name = best_model_details_global['name']
47
- best_model = best_model_details_global['model'] # Use the stored, already fitted best model
48
 
49
  print(f"--- Generating explanations for the best model: {best_model_name} ---")
50
 
51
- # Define paths dynamically
52
  output_dir = "./explainability_plots"
53
  if os.path.exists(output_dir): shutil.rmtree(output_dir)
54
  os.makedirs(output_dir)
@@ -59,16 +837,15 @@ def explainability(_=None): # Add dummy input for button click signature
59
  status_message = f"Explaining best model: {best_model_name}"
60
  all_shap_paths = [] # Initialize empty list for gallery output
61
 
 
62
  run_name = f"Explain_{best_model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
63
  config = {"task": "Explainability", "best_model": best_model_name, "explainers": ["SHAP", "LIME"]}
64
-
65
- # Init separate wandb run for explainability
66
  wandb_run_explain = None
67
  if wandb.run is None or wandb.run.mode != "disabled":
68
  try:
69
- # Ensure previous run is finished if still active from comparison
70
  if wandb.run and wandb.run.id:
71
- print(f"Finishing potentially active WandB run ({wandb.run.id}) before starting explanation run.")
72
  wandb.finish()
73
  wandb_run_explain = wandb.init(project="ai-data-analysis-gradio", name=run_name, config=config, reinit=True)
74
  print(f"WandB run '{run_name}' initialized for Explainability.")
@@ -76,6 +853,7 @@ def explainability(_=None): # Add dummy input for button click signature
76
  print(f"Error initializing Wandb run for Explainability: {e}")
77
  wandb_run_explain = None # Ensure it's None if init fails
78
  else:
 
79
  wandb_run_explain = None
80
 
81
  try:
@@ -83,146 +861,132 @@ def explainability(_=None): # Add dummy input for button click signature
83
  print("Calculating SHAP values...")
84
  shap_values = None
85
  explainer = None
86
- model_for_shap = None # Define variable to hold the model actually used by SHAP
87
- X_test_for_plot = X_test # Default unless subsetting happens
88
 
89
  # Determine explainer type and get SHAP values
90
  if isinstance(best_model, (RandomForestClassifier, GradientBoostingClassifier)):
91
- model_for_shap = best_model # Use directly
92
- explainer = shap.TreeExplainer(model_for_shap)
93
- shap_values = explainer.shap_values(X_test)
 
94
  elif isinstance(best_model, Pipeline):
95
- # Get the final estimator from the pipeline
96
  final_estimator_name, final_estimator = best_model.steps[-1]
97
  print(f"Handling Pipeline. Final estimator: {final_estimator_name} ({type(final_estimator)})")
 
98
 
99
  if isinstance(final_estimator, (RandomForestClassifier, GradientBoostingClassifier)):
100
- # Retrain the tree model part outside the pipeline for SHAP TreeExplainer
101
- print("Note: Retraining tree model component without pipeline for SHAP TreeExplainer.")
102
- model_for_shap = type(final_estimator)(**final_estimator.get_params())
103
- # Apply preprocessing steps if any before fitting
104
- # Simple case: assume only scaling before tree
105
- pipeline_transforms = Pipeline(best_model.steps[:-1])
106
- X_train_transformed = pipeline_transforms.fit_transform(X_train)
107
- X_test_transformed = pipeline_transforms.transform(X_test) # Use transformed data for SHAP
108
- model_for_shap.fit(X_train_transformed, y_train)
109
- explainer = shap.TreeExplainer(model_for_shap)
110
- shap_values = explainer.shap_values(X_test_transformed) # Explain on transformed test data
111
- X_test_for_plot = pd.DataFrame(X_test_transformed, columns=X_test.columns) # Use transformed data for plotting too
 
112
 
113
  elif isinstance(final_estimator, LogisticRegression):
114
  print("Using SHAP KernelExplainer for Logistic Regression in Pipeline (can be slow)...")
115
- # Create a predict_proba function for the *entire* pipeline
116
  predict_proba_pipeline = lambda x: best_model.predict_proba(pd.DataFrame(x, columns=X_test.columns))
117
- # Use a background dataset (summary) - kmeans is common
118
  print("Summarizing training data for KernelExplainer background...")
119
- X_train_summary = shap.kmeans(X_train, min(100, X_train.shape[0])) # Use min to avoid errors on small data
 
120
  explainer = shap.KernelExplainer(predict_proba_pipeline, X_train_summary)
121
- # Use a smaller subset of X_test for KernelExplainer speed
122
  subset_size = min(50, X_test.shape[0])
123
  print(f"Calculating SHAP values for {subset_size} test instances...")
124
- X_test_subset_np = X_test.sample(subset_size, random_state=42).values # Sample before converting to numpy
125
- # X_test_subset_np = shap.sample(X_test.values, subset_size) # Sample numpy array directly
126
- shap_values = explainer.shap_values(X_test_subset_np)
127
  # Create DataFrame from subset for plotting consistency
128
- X_test_for_plot = pd.DataFrame(X_test_subset_np, columns=X_test.columns)
129
  print("SHAP values calculated using KernelExplainer.")
130
-
131
  else:
132
- print(f"Warning: SHAP explainer not implemented for final pipeline step {type(final_estimator)}. Skipping SHAP.")
133
  else:
134
- print(f"Warning: SHAP explainer not explicitly handled for model type {type(best_model)}. Trying KernelExplainer as fallback.")
135
- # Fallback to KernelExplainer (might be slow)
136
- try:
137
- predict_proba_fallback = lambda x: best_model.predict_proba(pd.DataFrame(x, columns=X_test.columns))
138
- X_train_summary = shap.kmeans(X_train, min(100, X_train.shape[0]))
139
- explainer = shap.KernelExplainer(predict_proba_fallback, X_train_summary)
140
- subset_size = min(50, X_test.shape[0])
141
- X_test_subset_np = X_test.sample(subset_size, random_state=42).values
142
- shap_values = explainer.shap_values(X_test_subset_np)
143
- X_test_for_plot = pd.DataFrame(X_test_subset_np, columns=X_test.columns)
144
- print("Used KernelExplainer fallback.")
145
- except Exception as kernel_fallback_e:
146
- print(f"KernelExplainer fallback failed: {kernel_fallback_e}. Skipping SHAP.")
147
 
148
 
149
  # --- Generate SHAP Plots (if shap_values exist) ---
150
  if shap_values is not None:
151
- num_classes = len(np.unique(y_train)) # Get number of classes from original y_train
152
 
153
  # SHAP Summary Plot
154
  print("Generating SHAP summary plot...")
155
- plt.figure(figsize=(10, 6)) # Ensure figure context
156
  try:
157
- plot_shap_values = shap_values # Default
158
  title_suffix = f"({best_model_name})"
 
 
159
  if num_classes == 2 and isinstance(shap_values, list) and len(shap_values) == 2:
160
- plot_shap_values = shap_values[1] # Plot for class 1 in binary case
161
- title_suffix = f"({best_model_name} - Positive Class)"
162
- print("Plotting SHAP summary for positive class (binary)")
 
 
163
  elif num_classes > 2 and isinstance(shap_values, list):
164
  title_suffix = f"({best_model_name} - Multiclass Avg Impact)"
165
  print("Plotting SHAP summary for multiclass (average impact)")
166
 
167
- shap.summary_plot(plot_shap_values, X_test_for_plot, show=False, plot_type="dot")
168
  plt.title(f"SHAP Feature Importance Summary {title_suffix}")
169
  plt.tight_layout()
170
  plt.savefig(shap_summary_path, bbox_inches='tight')
171
  plt.clf()
172
  print(f"SHAP summary plot saved to {shap_summary_path}")
173
- all_shap_paths.append(shap_summary_path) # Add to list for gallery
174
  if wandb_run_explain: wandb.log({"shap_summary": wandb.Image(shap_summary_path)}, commit=False)
175
  except Exception as summary_e:
176
  print(f"Error generating SHAP summary plot: {summary_e}")
177
- plt.clf() # Clear figure even on error
178
 
179
  # SHAP Dependence Plots
180
  print("Calculating global feature importance for dependence plots...")
181
- global_shap_values = None
182
  try:
183
- if isinstance(shap_values, list): # Multi-class or binary list format
184
- # Handle case where list elements might have different shapes (e.g., KernelExplainer?)
185
- abs_shap_arrays = [np.abs(sv) for sv in shap_values if isinstance(sv, np.ndarray) and sv.ndim == 2]
186
- if abs_shap_arrays:
187
- # Ensure all arrays have the same number of features before stacking
188
- min_features = min(arr.shape[1] for arr in abs_shap_arrays)
189
- consistent_shap_arrays = [arr[:, :min_features] for arr in abs_shap_arrays]
190
- stacked_shap = np.stack(consistent_shap_arrays, axis=0) # Shape (n_classes, n_instances, n_features)
191
- global_shap_values = stacked_shap.mean(axis=(0, 1)) # Mean over classes and instances -> (n_features,)
192
- print(f"Calculated global SHAP values (list input), shape: {global_shap_values.shape}")
193
- else:
194
- print("Warning: Could not extract valid 2D arrays from SHAP values list.")
195
-
196
- elif isinstance(shap_values, np.ndarray) and shap_values.ndim == 2: # Regression or binary array format
197
- global_shap_values = np.abs(shap_values).mean(axis=0) # Mean over instances -> (n_features,)
198
- print(f"Calculated global SHAP values (array input), shape: {global_shap_values.shape}")
199
  else:
200
- print(f"Warning: Unexpected SHAP values type/shape for global importance: {type(shap_values)}. Skipping dependence plots.")
201
-
202
  except Exception as gsi_e:
203
  print(f"Error calculating global SHAP importance: {gsi_e}")
204
 
205
- # Generate plots if importance calculated successfully
206
- if global_shap_values is not None and len(global_shap_values) > 0 :
207
  try:
208
- feature_indices = np.argsort(global_shap_values)[::-1] # Indices sorted by importance
209
-
210
- num_features_to_plot = min(2, len(global_shap_values), len(X_test_for_plot.columns)) # Plot top 2 or fewer
211
  if num_features_to_plot > 0:
212
  top_feature_indices = feature_indices[:num_features_to_plot]
213
- top_features = X_test_for_plot.columns[top_feature_indices]
 
 
214
 
215
  print(f"Generating SHAP dependence plots for top features: {list(top_features)}")
216
- for feature_idx, feature_name in zip(top_feature_indices, top_features):
217
  plt.figure(figsize=(8, 5))
218
- # Select appropriate SHAP values for dependence plot
219
- # For binary list, use class 1; for multiclass list, use default (class 0 usually) or specify; for array, use array.
220
  shap_values_for_dep = shap_values
 
221
  if isinstance(shap_values, list):
222
- shap_values_for_dep = shap_values[1] if num_classes == 2 and len(shap_values)==2 else shap_values[0] # Default to class 0 for multi or if binary isn't len 2
 
 
 
 
223
 
224
- shap.dependence_plot(feature_idx, shap_values_for_dep, X_test_for_plot, interaction_index='auto', show=False)
225
- plt.title(f"SHAP Dependence Plot: {feature_name} ({best_model_name})")
 
 
226
  plt.tight_layout()
227
  dep_path = os.path.join(output_dir, f"shap_dependence_{best_model_name}_{feature_name}.png")
228
  plt.savefig(dep_path, bbox_inches='tight')
@@ -232,13 +996,14 @@ def explainability(_=None): # Add dummy input for button click signature
232
  print(f"Saved dependence plot: {dep_path}")
233
  if wandb_run_explain: wandb.log({f"shap_dependence_{feature_name}": wandb.Image(dep_path)}, commit=False)
234
  else:
235
- print("Skipping dependence plots: Not enough features.")
236
  except Exception as dep_e:
237
  print(f"Could not generate SHAP dependence plots: {dep_e}")
238
- plt.clf() # Ensure figure is cleared
 
 
239
  else:
240
  print("Skipping dependence plots due to issue calculating global SHAP values.")
241
-
242
  else:
243
  print("Skipping SHAP plots as SHAP values were not generated.")
244
 
@@ -247,51 +1012,47 @@ def explainability(_=None): # Add dummy input for button click signature
247
  print("Generating LIME explanation for the first test instance...")
248
  try:
249
  predict_fn_lime = None
250
- # Create predict_proba function needed by LIME
251
  if hasattr(best_model, 'predict_proba'):
252
- # Handle numpy vs pandas input for pipeline/model
253
  def predict_proba_wrapper(x_np):
254
- # Convert numpy array back to DataFrame for pipeline/model consistency
255
  x_df = pd.DataFrame(x_np, columns=X_train.columns)
256
  return best_model.predict_proba(x_df)
257
  predict_fn_lime = predict_proba_wrapper
258
  else:
259
- print("Warning: Model does not have predict_proba. LIME might not work as expected.")
260
- # Dummy fallback returning equal probabilities
261
  num_classes_lime = len(np.unique(y_train))
262
  predict_fn_lime = lambda x: np.ones((len(x), num_classes_lime)) / num_classes_lime
263
 
264
- # Get class names
265
- if hasattr(best_model, 'classes_'):
266
- class_names_str = [str(c) for c in best_model.classes_]
267
- else: # Infer from y_train if no classes_ attribute
268
- class_names_str = [str(c) for c in sorted(np.unique(y_train))]
269
 
270
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
271
- training_data=X_train.values, # LIME needs numpy array for background
272
  feature_names=X_train.columns.tolist(),
273
- class_names=class_names_str,
274
- mode='classification' if len(class_names_str) > 1 else 'regression'
275
  )
276
 
277
  instance_idx = 0
278
- instance_to_explain = X_test.iloc[instance_idx].values # Explain first instance
279
- true_class = y_test[instance_idx] if isinstance(y_test, (np.ndarray, list)) else y_test.iloc[instance_idx] # Get true class safely
 
 
280
 
281
  lime_exp = lime_explainer.explain_instance(
282
  data_row=instance_to_explain,
283
- predict_fn=predict_fn_lime, # Use the wrapper
284
  num_features=10,
285
- num_samples=1000 # Adjust as needed for speed/accuracy
286
  )
287
  print(f"LIME explanation generated for instance {instance_idx}.")
288
 
289
  lime_fig = lime_exp.as_pyplot_figure()
290
- # Attempt to get predicted class label for title
291
- predicted_class_idx = lime_exp.available_labels()[0] # Often the predicted class is first
292
- predicted_class_label = class_names_str[predicted_class_idx]
293
- lime_fig.suptitle(f"LIME Exp (Inst {instance_idx}, True: {true_class}, Pred: {predicted_class_label}, Model: {best_model_name})", y=1.03, fontsize=10)
294
- lime_fig.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout
295
  lime_fig.savefig(lime_path, bbox_inches='tight')
296
  plt.clf()
297
  print(f"LIME plot saved to {lime_path}")
@@ -299,30 +1060,33 @@ def explainability(_=None): # Add dummy input for button click signature
299
 
300
  except Exception as lime_e:
301
  print(f"Error generating LIME explanation: {lime_e}")
 
 
302
  if wandb_run_explain: wandb.log({"lime_error": str(lime_e)}, commit=False)
303
  lime_path = None # Indicate failure
304
 
305
  # Final status message
306
- status_message = f"Explanations generated for {best_model_name}. Check plots."
307
- if not all_shap_paths: status_message += " (SHAP failed/skipped)."
308
  if not lime_path: status_message += " (LIME failed/skipped)."
309
 
310
  # Return paths to the plots and status
311
- # Ensure lime_path is valid before returning, otherwise None
312
  valid_lime_path = lime_path if lime_path and os.path.exists(lime_path) else None
313
- return all_shap_paths, valid_lime_path, status_message
 
314
 
315
  except Exception as e:
316
  print(f"An error occurred during explainability: {e}")
317
  import traceback
318
- traceback.print_exc() # Print full traceback for debugging
319
  status_message = f"Error during explanation: {e}"
320
  if wandb_run_explain: wandb_run_explain.finish(exit_code=1)
321
  return [], None, status_message # Return empty list/None for paths on error
322
  finally:
323
  plt.close('all') # Close all matplotlib figures
324
- if wandb_run_explain and wandb.run and wandb.run.id == wandb_run_explain.id: # Check if it's the correct run
325
  try:
 
326
  wandb_run_explain.finish()
327
  print(f"WandB run '{run_name}' finished.")
328
  except Exception as finish_e:
@@ -330,9 +1094,8 @@ def explainability(_=None): # Add dummy input for button click signature
330
  wandb_run_explain = None # Reset
331
 
332
 
333
- # --- Gradio Interface (Keep the same as the previous version) ---
334
- # ... (paste the Gradio Blocks UI code here) ...
335
-
336
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Data Analysis & Model Comparison") as demo:
337
  gr.Markdown(
338
  """
@@ -354,7 +1117,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Data Analysis & Model Compariso
354
  with gr.Row():
355
  with gr.Column(scale=1):
356
  agent_notes = gr.Textbox(label="Optional: Specific requests for the AI Agent", placeholder="e.g., 'Focus on correlations with column X'")
357
- agent_btn = gr.Button("Run AI Analysis", variant="secondary")
358
  with gr.Column(scale=2):
359
  insights_output = gr.HTML(label="AI Agent Analysis Report")
360
  with gr.Row():
@@ -369,7 +1132,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Data Analysis & Model Compariso
369
  optuna_trials_slider = gr.Slider(minimum=5, maximum=50, value=10, step=5, label="Optuna Trials per Model")
370
  compare_btn = gr.Button("Train & Compare Models", variant="primary")
371
  with gr.Column(scale=2):
372
- comparison_output = gr.DataFrame(label="Model Comparison Results (Sorted by F1 Score)", interactive=False)
373
 
374
  # --- Row 4: Model Explainability ---
375
  with gr.Accordion("💡 Step 4: Explain Best Model (SHAP & LIME)", open=False):
@@ -378,41 +1141,47 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Data Analysis & Model Compariso
378
  explain_status = gr.Textbox(label="Explanation Status", interactive=False)
379
  with gr.Row():
380
  # Use Gallery for SHAP as there can be multiple plots
381
- shap_gallery = gr.Gallery(label="SHAP Plots (Summary + Top Feature Dependence)", height=400, object_fit="contain", columns=2, preview=True)
382
  lime_img = gr.Image(label="LIME Explanation (for first test instance)", type="filepath", interactive=False)
383
 
384
 
385
  # --- Connect Components ---
 
 
386
  file_input.change(
387
  fn=upload_file,
388
  inputs=file_input,
389
  outputs=df_output
390
  )
391
 
 
392
  agent_btn.click(
393
  fn=analyze_data,
394
  inputs=[file_input, agent_notes],
395
  outputs=[insights_output, visual_output]
396
  )
397
 
 
398
  compare_btn.click(
399
  fn=train_and_compare_models,
400
  inputs=[tune_rf_checkbox, tune_gb_checkbox, optuna_trials_slider],
401
  outputs=[comparison_output]
402
  )
403
 
 
404
  explain_btn.click(
405
  fn=explainability,
406
  inputs=[], # Uses global best model details
407
  outputs=[shap_gallery, lime_img, explain_status] # Output list of SHAP plots, one LIME plot, and status
408
  )
409
-
410
 
411
  # --- Launch the App ---
412
  if __name__ == "__main__":
 
413
  # Clean up temporary files/dirs from previous runs before launching
414
- temp_dirs = ['./figures', './explainability_plots', './__pycache__'] # Add explainability dir
415
- temp_files = [f for f in os.listdir('.') if f.lower().endswith('.joblib')] # Only remove joblib now
416
 
417
  for d in temp_dirs:
418
  if os.path.exists(d):
@@ -429,5 +1198,6 @@ if __name__ == "__main__":
429
  except Exception as e:
430
  print(f"Warning: Could not clean up file {f}: {e}")
431
 
432
-
433
- demo.launch(debug=False) # Turn debug=True if you need Gradio's traceback
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Gradio App for AI Data Analysis, Model Comparison, and Explainability
4
+ Requires: HF_TOKEN and WANDB_API_KEY environment variables.
5
+ """
6
+
7
  import os
8
  import gradio as gr
9
  import pandas as pd
10
  import numpy as np
11
  import matplotlib.pyplot as plt
12
  import shap
13
+ import lime
14
  import lime.lime_tabular
15
  import optuna
16
  import wandb
 
19
  import psutil
20
  import shutil
21
  import ast
22
+ from smolagents import HfApiModel, CodeAgent # Assuming smolagents is installed
23
  from huggingface_hub import login
24
  from sklearn.model_selection import train_test_split, cross_val_score
25
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
26
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
27
  from sklearn.linear_model import LogisticRegression
28
+ from sklearn.svm import SVC # Kept import in case you add it later
29
  from sklearn.preprocessing import LabelEncoder, StandardScaler
30
  from sklearn.pipeline import Pipeline
31
  from datetime import datetime
32
+ # from PIL import Image # PIL often implicitly used by matplotlib/wandb, explicit import usually not needed unless manipulating images directly
33
  import warnings
34
+ import joblib # For saving models
35
+
36
+ # Suppress common warnings (especially from SHAP/LIME/Sklearn)
37
+ warnings.filterwarnings("ignore")
38
+
39
+ # --- Authentication and Setup ---
40
+ print("--- Initializing Setup ---")
41
+ hf_token = os.getenv("HF_TOKEN")
42
+ wandb_api_key = os.getenv("WANDB_API_KEY")
43
+
44
+ # Initialize wandb run variable globally, helps manage state across functions
45
+ wandb_run = None # Tracks the *current* active run (e.g., comparison or explainability)
46
+
47
+ if not hf_token:
48
+ print("Warning: HF_TOKEN environment variable not set. Hugging Face Hub features may fail.")
49
+ else:
50
+ try:
51
+ login(token=hf_token)
52
+ print("Hugging Face login successful.")
53
+ except Exception as e:
54
+ print(f"Hugging Face login failed: {e}")
55
+
56
+ if not wandb_api_key:
57
+ print("Warning: WANDB_API_KEY environment variable not set. WandB logging will be disabled.")
58
+ # Initialize wandb in disabled mode if no key and not already initialized
59
+ if wandb.run is None:
60
+ try:
61
+ wandb.init(mode="disabled")
62
+ print("WandB initialized in disabled mode.")
63
+ except Exception as e:
64
+ print(f"Failed to initialize WandB in disabled mode: {e}")
65
+ else:
66
+ try:
67
+ wandb.login(key=wandb_api_key)
68
+ print("WandB login successful.")
69
+ except Exception as e:
70
+ print(f"WandB login failed: {e}. Disabling WandB.")
71
+ if wandb.run is None:
72
+ try:
73
+ wandb.init(mode="disabled")
74
+ print("WandB initialized in disabled mode due to login failure.")
75
+ except Exception as e_init:
76
+ print(f"Failed to initialize WandB in disabled mode: {e_init}")
77
+
78
+
79
+ # SmolAgent initialization
80
+ agent = None # Initialize agent to None
81
+ try:
82
+ print("Initializing SmolAgent...")
83
+ model_api = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
84
+ agent = CodeAgent(tools=[], model=model_api, additional_authorized_imports=[
85
+ "numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json", "os"
86
+ ])
87
+ print("SmolAgent initialized successfully.")
88
+ except Exception as e:
89
+ print(f"Error initializing SmolAgent: {e}. AI Agent features might fail.")
90
+
91
+
92
+ # Global variables
93
+ df_global = None
94
+ split_data_global = None # To store (X_train, X_test, y_train, y_test)
95
+ comparison_results_global = None # To store comparison DataFrame
96
+ best_model_details_global = None # To store {'name': best_name, 'model': best_model_obj, 'params': best_params}
97
+ print("Global variables initialized.")
98
+ print("--- Setup Complete ---")
99
+
100
+ # --- Data Handling ---
101
+
102
+ def clean_data(df):
103
+ """Cleans the input DataFrame."""
104
+ print("Starting data cleaning...")
105
+ df_cleaned = df.copy() # Work on a copy
106
+ # Drop columns/rows that are entirely empty
107
+ df_cleaned = df_cleaned.dropna(how='all', axis=1).dropna(how='all', axis=0)
108
+ print(f"Shape after dropping fully empty rows/cols: {df_cleaned.shape}")
109
+
110
+ # Encode object columns
111
+ object_cols = df_cleaned.select_dtypes(include='object').columns
112
+ if not object_cols.empty:
113
+ print(f"Encoding object columns: {list(object_cols)}")
114
+ for col in object_cols:
115
+ # Convert to string first to handle mixed types or NaN representations
116
+ df_cleaned[col] = df_cleaned[col].astype(str)
117
+ df_cleaned[col] = LabelEncoder().fit_transform(df_cleaned[col])
118
+ else:
119
+ print("No object columns found to encode.")
120
+
121
+ # Impute missing values in numeric columns with the mean
122
+ numeric_cols = df_cleaned.select_dtypes(include=np.number).columns
123
+ if not numeric_cols.empty:
124
+ cols_with_na = df_cleaned[numeric_cols].isnull().sum()
125
+ cols_to_impute = cols_with_na[cols_with_na > 0].index
126
+ if not cols_to_impute.empty:
127
+ print(f"Imputing NaNs with mean in columns: {list(cols_to_impute)}")
128
+ for col in cols_to_impute: # Iterate through columns needing imputation
129
+ mean_val = df_cleaned[col].mean()
130
+ df_cleaned[col] = df_cleaned[col].fillna(mean_val)
131
+ else:
132
+ print("No NaNs found in numeric columns to impute.")
133
+ else:
134
+ print("No numeric columns found for NaN imputation.")
135
+
136
+ print("Data cleaning finished.")
137
+ return df_cleaned
138
+
139
+ def upload_file(file):
140
+ """Handles file upload, cleaning, and global state update."""
141
+ global df_global, split_data_global, comparison_results_global, best_model_details_global
142
+ # Reset all global states when a new file is uploaded or file is cleared
143
+ df_global = None
144
+ split_data_global = None
145
+ comparison_results_global = None
146
+ best_model_details_global = None
147
+ print("Reset global data states on file change.")
148
+
149
+ if file is None:
150
+ # No file uploaded or file removed by user
151
+ return pd.DataFrame({"Status": ["No file uploaded or file removed."]})
152
+
153
+ print(f"Uploading file: {file.name}")
154
+ try:
155
+ ext = os.path.splitext(file.name)[-1].lower()
156
+ if ext == ".csv":
157
+ df = pd.read_csv(file.name)
158
+ elif ext in [".xls", ".xlsx"]:
159
+ df = pd.read_excel(file.name)
160
+ else:
161
+ return pd.DataFrame({"Error": [f"Unsupported file type: {ext}"]})
162
+
163
+ print(f"Original data shape: {df.shape}")
164
+ df_cleaned = clean_data(df)
165
+ print(f"Cleaned data shape: {df_cleaned.shape}")
166
+ df_global = df_cleaned # Store the cleaned data
167
+ # dependent globals (split_data_global, etc.) remain None until downstream functions are called
168
+ print("Global DataFrame updated with cleaned data.")
169
+ return df_global.head() # Return head of CLEANED data for preview
170
+ except Exception as e:
171
+ print(f"Error processing file {file.name}: {e}")
172
+ # Ensure globals are None on error
173
+ df_global = None
174
+ split_data_global = None
175
+ comparison_results_global = None
176
+ best_model_details_global = None
177
+ return pd.DataFrame({"Error": [f"Failed to process file: {e}"]})
178
+
179
+
180
+ # --- AI Agent Analysis ---
181
+
182
+ def format_observations(observations):
183
+ """Formats the observations dictionary into HTML list items."""
184
+ if not isinstance(observations, dict):
185
+ return f"<p style='color: orange;'>Observations data is not a dictionary: {type(observations)}</p>"
186
+ items_html = ""
187
+ for key, value in observations.items():
188
+ formatted_key = key.replace('_', ' ').title()
189
+ if isinstance(value, (dict, list)):
190
+ formatted_value = json.dumps(value, indent=2)
191
+ value_html = f"<pre style='margin: 0; padding: 8px; background: #ffffff; border: 1px solid #ccc; border-radius: 4px; font-size: 0.9em; white-space: pre-wrap; word-wrap: break-word;'>{formatted_value}</pre>"
192
+ else:
193
+ formatted_value = str(value)
194
+ value_html = f"<p style='margin: 0; padding: 8px; background: #ffffff; border: 1px solid #ccc; border-radius: 4px; font-size: 0.9em;'>{formatted_value}</p>"
195
+
196
+ items_html += f"""
197
+ <div style="margin-bottom: 12px; padding: 10px; background: #fdfefe; border-radius: 4px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);">
198
+ <h4 style="margin: 0 0 8px 0; color: #34495e;">{formatted_key}</h4>
199
+ {value_html}
200
+ </div>
201
+ """
202
+ return items_html if items_html else "<p>No observations found.</p>"
203
+
204
+ def format_insights(insights, visuals):
205
+ """Formats insights and embeds corresponding visuals."""
206
+ if not isinstance(insights, dict):
207
+ return f"<p style='color: orange;'>Insights data is not a dictionary: {type(insights)}</p>"
208
+ items_html = ""
209
+ visual_idx = 0
210
+ insight_keys = list(insights.keys())
211
+
212
+ for i, key in enumerate(insight_keys):
213
+ insight_text = str(insights[key])
214
+ formatted_key = key.replace('_', ' ').title()
215
+
216
+ items_html += f"""
217
+ <div style="margin: 20px 0; padding: 15px; background: #ffffff; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
218
+ <h4 style='margin-top: 0; margin-bottom: 10px; color: #16a085;'>Insight {i+1}: {formatted_key}</h4>
219
+ <p style="margin-bottom: 15px;">{insight_text}</p>
220
+ """
221
+ # Embed visual if available for this insight index
222
+ if visual_idx < len(visuals):
223
+ img_path = visuals[visual_idx]
224
+ # Gradio uses /file= syntax for temporary files
225
+ # Ensure path is correctly formatted for Gradio's file serving
226
+ items_html += f'<img src="/file={img_path}" alt="Visualization for {formatted_key}" style="max-width: 95%; height: auto; display: block; margin-top: 10px; border-radius: 6px; border: 1px solid #eee; box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
227
+ visual_idx += 1
228
+ items_html += "</div>"
229
+
230
+ # Add any remaining visuals
231
+ if visual_idx < len(visuals):
232
+ items_html += "<h4 style='margin-top: 25px; color: #2980b9;'>Additional Visualizations:</h4>"
233
+ for i in range(visual_idx, len(visuals)):
234
+ img_path = visuals[i]
235
+ items_html += f"""
236
+ <div style="margin: 20px 0; padding: 15px; background: #ffffff; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
237
+ <img src="/file={img_path}" alt="Additional Visualization {i+1}" style="max-width: 95%; height: auto; display: block; margin: auto; border-radius: 6px; border: 1px solid #eee; box-shadow: 0 1px 3px rgba(0,0,0,0.1);">
238
+ </div>
239
+ """
240
+ return items_html if (items_html or visuals) else "<p>No insights or visuals generated/found.</p>" # Show message only if both empty
241
+
242
+ def format_analysis_report(raw_output, visuals):
243
+ """Formats the AI agent's output into readable HTML."""
244
+ print("Formatting AI analysis report...")
245
+ report_html = ""
246
+ analysis_dict = {}
247
+ parsing_error = None
248
+
249
+ try:
250
+ # Attempt to parse the output string into a dictionary
251
+ if isinstance(raw_output, str):
252
+ cleaned_output = raw_output.strip()
253
+ if cleaned_output.startswith("```python"):
254
+ cleaned_output = cleaned_output[len("```python"):].strip()
255
+ elif cleaned_output.startswith("```json"):
256
+ cleaned_output = cleaned_output[len("```json"):].strip()
257
+ if cleaned_output.endswith("```"):
258
+ cleaned_output = cleaned_output[:-len("```")].strip()
259
+
260
+ dict_start_index = cleaned_output.find('{')
261
+ if dict_start_index != -1:
262
+ # Try parsing from the first brace
263
+ try:
264
+ analysis_dict = ast.literal_eval(cleaned_output[dict_start_index:])
265
+ except (SyntaxError, ValueError, TypeError) as e:
266
+ parsing_error = f"Error parsing agent output: {e}\nRaw output:\n{raw_output}"
267
+ print(parsing_error)
268
+ else:
269
+ parsing_error = f"Could not find dictionary start '{{' in agent output.\nRaw output:\n{raw_output}"
270
+ print(parsing_error)
271
+
272
+ elif isinstance(raw_output, dict):
273
+ analysis_dict = raw_output # Already a dict
274
+ else:
275
+ parsing_error = f"Output is not a string or dictionary, type: {type(raw_output)}.\nRaw output:\n{str(raw_output)}"
276
+ print(parsing_error)
277
+
278
+ # --- Build HTML Report ---
279
+ report_html = """
280
+ <div style="font-family: Arial, sans-serif; line-height: 1.6; color: #333; padding: 15px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;">
281
+ <h1 style="color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; margin-top: 0;">📊 AI Data Analysis Report</h1>
282
+ """
283
+
284
+ # Display parsing error if any
285
+ if parsing_error:
286
+ report_html += f"<div style='background-color: #f8d7da; color: #721c24; border: 1px solid #f5c6cb; padding: 10px; border-radius: 5px; margin-bottom: 15px;'><pre>{parsing_error}</pre></div>"
287
+
288
+ # Observations Section
289
+ observations = analysis_dict.get('observations', {})
290
+ report_html += """
291
+ <div style="margin-top: 20px; background: #ecf0f1; padding: 15px; border-radius: 5px;">
292
+ <h2 style="color: #2980b9; margin-top: 0;">🔍 Key Observations</h2>
293
+ """
294
+ report_html += format_observations(observations) if observations else "<p>No 'observations' found or parsed.</p>"
295
+ report_html += "</div>"
296
+
297
+ # Insights Section
298
+ insights = analysis_dict.get('insights', {})
299
+ report_html += """
300
+ <div style="margin-top: 25px;">
301
+ <h2 style="color: #2980b9;">💡 Insights & Visualizations</h2>
302
+ """
303
+ # Format insights and add visuals
304
+ report_html += format_insights(insights, visuals) if (insights or visuals) else "<p>No 'insights' or visuals found or parsed.</p>"
305
+ report_html += "</div>"
306
+
307
+ report_html += "</div>" # Close main container
308
+ print("Report formatting complete.")
309
+ return report_html, visuals
310
+
311
+ except Exception as e:
312
+ print(f"Critical error in format_analysis_report: {e}")
313
+ error_message = f"<p style='color: red; font-weight: bold;'>Error generating report:</p><pre>{str(e)}</pre>"
314
+ raw_display = f"<p style='font-weight: bold;'>Raw Agent Output:</p><pre>{str(raw_output)}</pre>"
315
+ return error_message + raw_display, visuals
316
+
317
+ def analyze_data(csv_file, additional_notes=""):
318
+ """Runs the SmolAgent for data analysis and visualization."""
319
+ global df_global, agent # Need agent globally
320
+ if df_global is None:
321
+ return "<p style='color:red;'>Please upload a file first.</p>", []
322
+ if agent is None:
323
+ return "<p style='color:red;'>AI Agent is not available (initialization failed).</p>", []
324
+ if csv_file is None: # Check if file object exists
325
+ return "<p style='color:red;'>File object missing, please re-upload.</p>", []
326
+
327
+ print("--- Starting AI Agent Analysis ---")
328
+ start_time = time.time()
329
+ process = psutil.Process(os.getpid())
330
+ initial_memory = process.memory_info().rss / 1024 ** 2
331
+
332
+ # Ensure figures directory exists and is empty
333
+ figures_dir = './figures'
334
+ try:
335
+ if os.path.exists(figures_dir):
336
+ shutil.rmtree(figures_dir)
337
+ print(f"Cleaned existing directory: {figures_dir}")
338
+ os.makedirs(figures_dir)
339
+ print(f"Created directory: {figures_dir}")
340
+ except Exception as e:
341
+ print(f"Error managing figures directory: {e}")
342
+ return f"<p style='color:red;'>Error setting up visualization directory: {e}</p>", []
343
+
344
+ # --- WandB Setup for Agent Run ---
345
+ wandb_run_agent = None
346
+ run_name = f"AgentAnalysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
347
+ if wandb.run is None or wandb.run.mode != "disabled":
348
+ try:
349
+ # Finish any potentially lingering run first
350
+ if wandb.run and wandb.run.id:
351
+ print(f"Finishing potentially active WandB run ({wandb.run.id}) before Agent run.")
352
+ wandb.finish()
353
+
354
+ wandb_run_agent = wandb.init(
355
+ project="ai-data-analysis-gradio",
356
+ name=run_name,
357
+ config={
358
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
359
+ "agent_type": "CodeAgent",
360
+ "task": "EDA and Visualization",
361
+ "additional_notes": additional_notes,
362
+ "source_file": os.path.basename(csv_file.name) if csv_file else "N/A",
363
+ "data_shape": df_global.shape
364
+ },
365
+ reinit=True
366
+ )
367
+ print(f"WandB run '{run_name}' initialized for Agent Analysis.")
368
+ except Exception as e:
369
+ print(f"Error initializing WandB run for Agent Analysis: {e}")
370
+ wandb_run_agent = None # Ensure run is None if init fails
371
+ else:
372
+ print("WandB disabled, skipping Agent run logging.")
373
+
374
+ # --- Run Agent ---
375
+ analysis_result = None
376
+ visuals = []
377
+ try:
378
+ prompt = f"""
379
+ Analyze the provided dataset (available as `df_global`).
380
+ Focus on generating clear insights and high-quality visualizations.
381
+
382
+ **Tasks:**
383
+ 1. **Load Data:** The data is already loaded into the `df_global` pandas DataFrame. Use this DataFrame directly.
384
+ 2. **Understand Data:** Briefly describe the data (shape, columns, basic types). Put in 'observations'.
385
+ 3. **Generate Observations:** Provide at least 3 key statistical observations (e.g., correlations, value distributions, unique counts). Structure this under an 'observations' key in your output dictionary.
386
+ 4. **Generate Insights:** Extract at least 5 meaningful insights from the data. These should be understandable conclusions or patterns discovered. Structure this under an 'insights' key in your output dictionary.
387
+ 5. **Create Visualizations:** Generate exactly 5 publication-quality visualizations (e.g., histograms, scatter plots, heatmaps, bar charts) that support your insights.
388
+ * **Save EACH plot** to the './figures/' directory with a unique name (e.g., './figures/plot_1.png', './figures/plot_2.png'). Use `plt.savefig('./figures/unique_name.png', bbox_inches='tight')` and `plt.clf()` after each plot.
389
+ * Make sure plots have titles and clear labels.
390
+ * **DO NOT** use `plt.show()`.
391
+
392
+ **Output Format:**
393
+ Return a Python dictionary strictly following this structure:
394
+ {{
395
+ 'observations': {{
396
+ 'data_description': 'Brief description...',
397
+ 'observation_1_key': 'Description of observation 1.',
398
+ # ... more observations
399
+ }},
400
+ 'insights': {{
401
+ 'insight_1_key': 'Description of insight 1.',
402
+ # ... more insights
403
+ }}
404
+ }}
405
+
406
+ **Additional Context/Requests:** {additional_notes}
407
+ Ensure all code is executable. Access the data using the variable `df_global`.
408
+ """
409
+ print("Running AI agent...")
410
+ # Pass the DataFrame in additional_args
411
+ # IMPORTANT: Ensure your SmolAgent setup correctly handles passing DataFrames.
412
+ # If not, you might need to save df_global to a temporary CSV and pass the path.
413
+ analysis_result = agent.run(prompt, additional_args={"df_global": df_global.copy()}) # Pass a copy
414
+ print("AI agent finished.")
415
+ # print(f"Raw Agent Output:\n{analysis_result}") # For debugging
416
+
417
+ # Check for generated figures
418
+ if os.path.exists(figures_dir):
419
+ visuals = [os.path.join(figures_dir, f) for f in os.listdir(figures_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
420
+ print(f"Found {len(visuals)} visualizations in {figures_dir}.")
421
+ # Filter out invalid paths (e.g., if agent created non-image files)
422
+ visuals = [v for v in visuals if os.path.isfile(v)]
423
+ else:
424
+ print(f"Warning: Figures directory '{figures_dir}' not found after agent run.")
425
+
426
+ except Exception as e:
427
+ print(f"Error during AI agent execution: {e}")
428
+ import traceback
429
+ traceback.print_exc()
430
+ if wandb_run_agent: wandb_run_agent.finish(exit_code=1)
431
+ return f"<p style='color:red;'>Error running AI agent: {e}</p>", []
432
+
433
+ # --- Logging and Cleanup ---
434
+ execution_time = time.time() - start_time
435
+ final_memory = process.memory_info().rss / 1024 ** 2
436
+ memory_usage = final_memory - initial_memory
437
+ print(f"Agent execution time: {execution_time:.2f}s")
438
+ print(f"Memory usage during agent execution: {memory_usage:.2f} MB")
439
+
440
+ if wandb_run_agent:
441
+ try:
442
+ wandb.log({
443
+ "agent_execution_time_sec": execution_time,
444
+ "agent_memory_usage_mb": memory_usage,
445
+ "visualizations_generated": len(visuals)
446
+ }, commit=False) # Commit later or let finish handle it
447
+
448
+ # Log visualizations
449
+ logged_images_count = 0
450
+ for viz_path in visuals:
451
+ if os.path.exists(viz_path):
452
+ try:
453
+ img_name = os.path.basename(viz_path)
454
+ wandb.log({f"agent_visualization_{img_name}": wandb.Image(viz_path)}, commit=False)
455
+ logged_images_count += 1
456
+ except Exception as log_e:
457
+ print(f"Warning: Could not log image {viz_path} to WandB: {log_e}")
458
+ else:
459
+ print(f"Warning: Visualization path not found for logging: {viz_path}")
460
+ print(f"Attempted to log {logged_images_count}/{len(visuals)} visualizations to WandB.")
461
 
462
+ # Log the raw analysis result text
463
+ log_data = {}
464
+ if analysis_result:
465
+ log_data["agent_raw_output"] = str(analysis_result)[:10000] # Truncate long output
466
 
467
+ wandb.log(log_data, commit=True) # Commit logs here
468
+ print("Logged agent results to WandB.")
469
 
470
+ except Exception as e:
471
+ print(f"Error logging agent results to WandB: {e}")
472
+ finally:
473
+ wandb_run_agent.finish()
474
+ print(f"WandB run '{run_name}' finished.")
475
+
476
+ # Format report
477
+ return format_analysis_report(analysis_result, visuals)
478
+
479
+
480
+ # --- Model Training and Comparison ---
481
+
482
+ def prepare_data(df, target_column=None):
483
+ """Prepares data for modeling (selects target, splits, handles encoding)."""
484
+ global split_data_global # Allow modification of global state
485
+ print("--- Preparing Data for Modeling ---")
486
+
487
+ if df is None or df.empty:
488
+ print("Error: DataFrame is None or empty in prepare_data.")
489
+ raise ValueError("Cannot prepare data: DataFrame is empty.")
490
+
491
+ # Determine target column if not specified
492
+ if target_column is None:
493
+ target_column = df.columns[-1] # Default to last column
494
+ print(f"Target column automatically selected: '{target_column}'")
495
+ elif target_column not in df.columns:
496
+ print(f"Error: Specified target column '{target_column}' not found.")
497
+ raise ValueError(f"Target column '{target_column}' not found.")
498
+ else:
499
+ print(f"Using specified target column: '{target_column}'")
500
+
501
+ X = df.drop(columns=[target_column])
502
+ y = df[target_column].copy() # Use copy to avoid SettingWithCopyWarning
503
+
504
+ # Ensure target `y` is numeric for classification/regression models
505
+ le = None # LabelEncoder object
506
+ if y.dtype == 'object' or pd.api.types.is_categorical_dtype(y):
507
+ print(f"Encoding target column '{target_column}' with LabelEncoder.")
508
+ le = LabelEncoder()
509
+ y = le.fit_transform(y) # Overwrite y with encoded values
510
+ print(f"Target classes found: {le.classes_}") # Useful for interpretation later
511
+
512
+ # Check for non-numeric features (should be handled by clean_data, but as safeguard)
513
+ non_numeric_cols = X.select_dtypes(exclude=np.number).columns
514
+ if not non_numeric_cols.empty:
515
+ print(f"Warning: Non-numeric columns found in features after cleaning: {list(non_numeric_cols)}. Dropping them.")
516
+ X = X.drop(columns=non_numeric_cols)
517
+
518
+ if X.empty:
519
+ print("Error: No features remaining after dropping non-numeric columns.")
520
+ raise ValueError("No features remaining to train the model.")
521
+
522
+ # Check if target has only one class after potential encoding/filtering
523
+ if y.nunique() < 2:
524
+ print(f"Error: Target column '{target_column}' has fewer than 2 unique values after processing. Cannot stratify or train meaningful classifier.")
525
+ raise ValueError("Target column must have at least two unique classes for classification.")
526
+
527
+ # Split data
528
+ try:
529
+ X_train, X_test, y_train, y_test = train_test_split(
530
+ X, y, test_size=0.3, random_state=42, stratify=y # Always try to stratify for classification
531
+ )
532
+ print(f"Data split: X_train {X_train.shape}, X_test {X_test.shape}, y_train {y_train.shape}, y_test {y_test.shape}")
533
+ except ValueError as split_e:
534
+ # This can happen if a class has too few members (e.g., only 1) for stratification
535
+ print(f"Stratified split failed ({split_e}). Trying non-stratified split.")
536
+ X_train, X_test, y_train, y_test = train_test_split(
537
+ X, y, test_size=0.3, random_state=42
538
+ )
539
+ print(f"Data split (non-stratified): X_train {X_train.shape}, X_test {X_test.shape}, y_train {y_train.shape}, y_test {y_test.shape}")
540
+
541
+
542
+ # Store the split data AND the label encoder if used
543
+ split_data_global = (X_train, X_test, y_train, y_test, le) # Add le to the tuple
544
+ print("Data prepared and split stored globally.")
545
+ return X_train, X_test, y_train, y_test
546
+
547
+
548
+ def train_and_compare_models(tune_rf=True, tune_gb=True, n_trials_optuna=10):
549
+ """Trains, (optionally) tunes, evaluates multiple models, and logs comparison."""
550
+ global df_global, split_data_global, comparison_results_global, best_model_details_global, wandb_run
551
+ if df_global is None:
552
+ print("Error: No data loaded for training/comparison.")
553
+ return pd.DataFrame({"Error": ["Please upload data first."]})
554
+
555
+ print("--- Starting Model Training and Comparison ---")
556
+ run_name = f"CompareModels_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
557
+ models_to_compare = {
558
+ # Use Pipelines for models benefiting from scaling
559
+ "LogisticRegression": Pipeline([('scaler', StandardScaler()), ('logreg', LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced'))]),
560
+ "RandomForest": RandomForestClassifier(random_state=42, class_weight='balanced'), # Add class_weight
561
+ "GradientBoosting": GradientBoostingClassifier(random_state=42) # GB doesn't have class_weight in quite the same way
562
+ }
563
+ # Filter models based on input type (add later if needed)
564
+ # is_classification = True # Assume classification for now
565
+
566
+ config = {
567
+ "task": "Model Comparison",
568
+ "models": list(models_to_compare.keys()),
569
+ "tune_rf": tune_rf,
570
+ "tune_gb": tune_gb,
571
+ "optuna_trials": n_trials_optuna if (tune_rf or tune_gb) else 0,
572
+ "data_shape": df_global.shape,
573
+ "test_size": 0.3,
574
+ "stratify": True # Assuming stratification was attempted
575
+ }
576
+
577
+ # --- WandB Setup for Comparison Run ---
578
+ if wandb.run is None or wandb.run.mode != "disabled":
579
+ try:
580
+ # Finish any potentially lingering run first
581
+ if wandb.run and wandb.run.id:
582
+ print(f"Finishing potentially active WandB run ({wandb.run.id}) before Comparison run.")
583
+ wandb.finish()
584
+
585
+ wandb_run = wandb.init(project="ai-data-analysis-gradio", name=run_name, config=config, reinit=True)
586
+ print(f"WandB run '{run_name}' initialized for Model Comparison.")
587
+ except Exception as e:
588
+ print(f"Error initializing WandB run for Comparison: {e}")
589
+ wandb_run = None # Ensure it's None if init fails
590
+ else:
591
+ print("WandB disabled, skipping Comparison run logging.")
592
+ wandb_run = None # Explicitly set to None if disabled
593
+
594
+ results = []
595
+ best_f1 = -1 # Initialize best F1 score
596
+ best_model_obj = None
597
+ best_model_name = None
598
+ best_model_params = None
599
+
600
+ try:
601
+ # Prepare data if not already split
602
+ if split_data_global:
603
+ print("Using previously split data.")
604
+ X_train, X_test, y_train, y_test, _ = split_data_global # Unpack (ignore label encoder here)
605
+ else:
606
+ print("Preparing data for comparison...")
607
+ # Use default target (last column) if prepare_data hasn't been run
608
+ X_train, X_test, y_train, y_test = prepare_data(df_global)
609
+
610
+ # --- Optuna Objective Functions ---
611
+ # (Ensure X_train, y_train are accessible within objectives)
612
+ def objective_rf(trial):
613
+ params = {
614
+ "n_estimators": trial.suggest_int("n_estimators", 50, 250, step=50),
615
+ "max_depth": trial.suggest_int("max_depth", 5, 20, log=True), # Log scale for depth
616
+ "min_samples_split": trial.suggest_int("min_samples_split", 2, 16),
617
+ "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 16),
618
+ "criterion": trial.suggest_categorical("criterion", ["gini", "entropy"]),
619
+ "class_weight": trial.suggest_categorical("class_weight", ["balanced", "balanced_subsample", None]),
620
+ "random_state": 42
621
+ }
622
+ # Note: class_weight='balanced' might require specific sklearn version or handling
623
+ try:
624
+ model = RandomForestClassifier(**params)
625
+ # Use smaller CV during tuning for speed
626
+ score = cross_val_score(model, X_train, y_train, cv=3, scoring="f1_weighted", n_jobs=-1).mean() # Tune based on F1 weighted
627
+ if wandb_run: wandb.log({"optuna_rf_trial": trial.number, "optuna_rf_cv_f1w": score, **params}, commit=False)
628
+ return score
629
+ except ValueError as e:
630
+ print(f"Optuna RF trial error (params {params}): {e}")
631
+ return -1 # Return poor score on error
632
+
633
+ def objective_gb(trial):
634
+ params = {
635
+ "n_estimators": trial.suggest_int("n_estimators", 50, 250, step=50),
636
+ "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
637
+ "max_depth": trial.suggest_int("max_depth", 3, 10),
638
+ "min_samples_split": trial.suggest_int("min_samples_split", 2, 16),
639
+ "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 16),
640
+ "subsample": trial.suggest_float("subsample", 0.6, 1.0),
641
+ "random_state": 42
642
+ }
643
+ try:
644
+ model = GradientBoostingClassifier(**params)
645
+ score = cross_val_score(model, X_train, y_train, cv=3, scoring="f1_weighted", n_jobs=-1).mean() # Tune based on F1 weighted
646
+ if wandb_run: wandb.log({"optuna_gb_trial": trial.number, "optuna_gb_cv_f1w": score, **params}, commit=False)
647
+ return score
648
+ except ValueError as e:
649
+ print(f"Optuna GB trial error (params {params}): {e}")
650
+ return -1
651
+
652
+ # --- Model Training Loop ---
653
+ for name, model_pipeline in models_to_compare.items(): # Use model_pipeline for clarity
654
+ print(f"--- Training and Evaluating: {name} ---")
655
+ start_time = time.time()
656
+ current_params = model_pipeline.get_params() # Start with default/pipeline params
657
+ final_model = model_pipeline # The model/pipeline to be trained
658
+
659
+ try:
660
+ # Optional Tuning with Optuna
661
+ if name == "RandomForest" and tune_rf:
662
+ print(f"Tuning {name} with Optuna ({n_trials_optuna} trials)...")
663
+ study_rf = optuna.create_study(direction="maximize", study_name=f"{name}_tune_{run_name}")
664
+ study_rf.optimize(objective_rf, n_trials=n_trials_optuna, timeout=300)
665
+ best_params_rf = study_rf.best_params
666
+ # Important: Re-initialize the model *without* pipeline for RF tuning
667
+ final_model = RandomForestClassifier(**best_params_rf, random_state=42)
668
+ current_params = final_model.get_params() # Update params to tuned ones
669
+ print(f"Best RF params (CV F1w: {study_rf.best_value:.4f}): {best_params_rf}")
670
+ if wandb_run: wandb.log({f"{name}_best_cv_f1w": study_rf.best_value, f"{name}_best_params": best_params_rf}, commit=False)
671
+
672
+ elif name == "GradientBoosting" and tune_gb:
673
+ print(f"Tuning {name} with Optuna ({n_trials_optuna} trials)...")
674
+ study_gb = optuna.create_study(direction="maximize", study_name=f"{name}_tune_{run_name}")
675
+ study_gb.optimize(objective_gb, n_trials=n_trials_optuna, timeout=300)
676
+ best_params_gb = study_gb.best_params
677
+ final_model = GradientBoostingClassifier(**best_params_gb, random_state=42)
678
+ current_params = final_model.get_params()
679
+ print(f"Best GB params (CV F1w: {study_gb.best_value:.4f}): {best_params_gb}")
680
+ if wandb_run: wandb.log({f"{name}_best_cv_f1w": study_gb.best_value, f"{name}_best_params": best_params_gb}, commit=False)
681
+
682
+ # Train the final model (tuned or default/pipeline)
683
+ final_model.fit(X_train, y_train)
684
+
685
+ # Evaluate on the test set
686
+ y_pred = final_model.predict(X_test)
687
+ accuracy = accuracy_score(y_test, y_pred)
688
+ # Use weighted avg for metrics suitable for multi-class and imbalanced datasets
689
+ precision = precision_score(y_test, y_pred, average="weighted", zero_division=0)
690
+ recall = recall_score(y_test, y_pred, average="weighted", zero_division=0)
691
+ f1 = f1_score(y_test, y_pred, average="weighted", zero_division=0)
692
+ duration = time.time() - start_time
693
+
694
+ print(f"{name} Test Set - Accuracy: {accuracy:.4f}, F1 (Weighted): {f1:.4f}, Time: {duration:.2f}s")
695
+
696
+ metrics = {
697
+ "Model": name,
698
+ "Test Accuracy": accuracy, # Changed key for clarity
699
+ "Test Precision (Weighted)": precision,
700
+ "Test Recall (Weighted)": recall,
701
+ "Test F1 Score (Weighted)": f1,
702
+ "Training Time (s)": duration,
703
+ "Tuned": (name == "RandomForest" and tune_rf) or (name == "GradientBoosting" and tune_gb)
704
+ }
705
+ results.append(metrics)
706
+
707
+ # Log individual model metrics to WandB
708
+ if wandb_run:
709
+ # Create a flat dictionary for logging
710
+ log_metrics = {f"{name}_{k.lower().replace(' (weighted)','_w').replace(' ','_')}": v
711
+ for k, v in metrics.items() if k not in ["Model", "Tuned"]}
712
+ log_metrics[f"{name}_tuned_flag"] = metrics["Tuned"]
713
+ wandb.log(log_metrics, commit=False)
714
+
715
+ # Check if this is the best model so far based on F1 score
716
+ if f1 > best_f1:
717
+ print(f"*** New best model found: {name} (F1: {f1:.4f}) ***")
718
+ best_f1 = f1
719
+ best_model_name = name
720
+ best_model_obj = final_model # Store the fitted model/pipeline object
721
+ best_model_params = current_params # Store its parameters
722
+
723
+ except Exception as train_e:
724
+ print(f"ERROR training/evaluating {name}: {train_e}")
725
+ import traceback
726
+ traceback.print_exc()
727
+ results.append({"Model": name, "Error": str(train_e)})
728
+ if wandb_run: wandb.log({f"{name}_error": str(train_e)}, commit=False)
729
+
730
+
731
+ # --- Finalize Comparison ---
732
+ if not results:
733
+ print("No models were successfully trained.")
734
+ return pd.DataFrame({"Status": ["Model training failed for all candidates."]})
735
+
736
+ comparison_df = pd.DataFrame(results)
737
+ # Handle cases where F1 score might be missing due to errors
738
+ if "Test F1 Score (Weighted)" in comparison_df.columns:
739
+ comparison_df = comparison_df.sort_values(by="Test F1 Score (Weighted)", ascending=False).reset_index(drop=True)
740
+ else:
741
+ print("Warning: F1 score column missing, cannot sort results.")
742
+
743
+ comparison_results_global = comparison_df # Store globally
744
+ print("\n--- Model Comparison Summary ---")
745
+ print(comparison_df.to_string())
746
+
747
+ # Store best model details globally if found
748
+ if best_model_obj is not None:
749
+ best_model_details_global = {
750
+ 'name': best_model_name,
751
+ 'model': best_model_obj, # Store the actual fitted model/pipeline
752
+ 'params': best_model_params, # Store params used (tuned or default)
753
+ 'f1_score': best_f1
754
+ }
755
+ print(f"Stored details for best model: {best_model_name}")
756
+
757
+ # Optional: Save the best model artifact locally and log to WandB
758
+ output_dir_models = "./saved_models"
759
+ os.makedirs(output_dir_models, exist_ok=True)
760
+ model_filename = os.path.join(output_dir_models, f"best_model_{best_model_name.lower().replace(' ','_')}.joblib")
761
+ try:
762
+ joblib.dump(best_model_obj, model_filename)
763
+ print(f"Best model saved locally to {model_filename}")
764
+ if wandb_run:
765
+ # Log artifact to WandB
766
+ # Clean params dict for artifact metadata (remove complex objects if any)
767
+ clean_params_meta = {k: str(v) for k, v in best_model_params.items() if isinstance(v, (str, int, float, bool, list))}
768
+
769
+ artifact = wandb.Artifact(f'best_model-{wandb_run.id}', type='model',
770
+ metadata={'model_type': best_model_name, 'test_f1_score': best_f1, **clean_params_meta})
771
+ artifact.add_file(model_filename)
772
+ wandb_run.log_artifact(artifact)
773
+ print("Logged best model artifact to WandB.")
774
+ except Exception as save_e:
775
+ print(f"Error saving/logging best model artifact: {save_e}")
776
+
777
+ # Log comparison table to WandB
778
+ if wandb_run and not comparison_df.empty:
779
+ try:
780
+ # Filter out potential error rows before creating table
781
+ valid_comparison_df = comparison_df.dropna(subset=[col for col in comparison_df.columns if col != 'Error'])
782
+ if not valid_comparison_df.empty:
783
+ wandb_comparison_table = wandb.Table(dataframe=valid_comparison_df)
784
+ wandb_run.log({"model_comparison_summary": wandb_comparison_table}, commit=True) # Commit final logs
785
+ print("Logged comparison summary table to WandB.")
786
+ else:
787
+ print("No valid results to log to WandB table.")
788
+ except Exception as log_e:
789
+ print(f"Error logging comparison table to WandB: {log_e}")
790
+
791
+ return comparison_df
792
+
793
+ except Exception as e:
794
+ print(f"An error occurred during model comparison: {e}")
795
+ import traceback
796
+ traceback.print_exc()
797
+ if wandb_run: wandb_run.finish(exit_code=1) # Mark run as failed
798
+ return pd.DataFrame({"Error": [f"Comparison failed: {e}"]})
799
+ finally:
800
+ if wandb_run and wandb.run: # Check if wandb_run was initialized and is still active
801
+ # Ensure logs are committed before finishing
802
+ try:
803
+ wandb.log({}, commit=True)
804
+ except Exception:
805
+ pass # Ignore if already finished or other issue
806
+ wandb_run.finish()
807
+ print(f"WandB run '{run_name}' finished.")
808
+ wandb_run = None # Reset global run variable
809
+
810
+
811
+ # --- Model Explainability ---
812
+
813
+ def explainability(_=None):
814
  """Generates SHAP and LIME explanations for the best performing model."""
815
  global split_data_global, best_model_details_global, wandb_run
816
  if split_data_global is None:
817
+ print("Error: Data not split. Please run 'Train & Compare' first.")
818
+ return [], None, "Error: Data not prepared. Run 'Train & Compare' first."
819
  if best_model_details_global is None:
820
+ print("Error: Best model details not found. Please run 'Train & Compare' first.")
821
+ return [], None, "Error: Best model not identified. Run 'Train & Compare' first."
822
 
823
+ X_train, X_test, y_train, y_test, label_encoder = split_data_global # Unpack label encoder
824
  best_model_name = best_model_details_global['name']
825
+ best_model = best_model_details_global['model'] # Use the stored, already fitted best model/pipeline
826
 
827
  print(f"--- Generating explanations for the best model: {best_model_name} ---")
828
 
829
+ # Define paths dynamically within a dedicated directory
830
  output_dir = "./explainability_plots"
831
  if os.path.exists(output_dir): shutil.rmtree(output_dir)
832
  os.makedirs(output_dir)
 
837
  status_message = f"Explaining best model: {best_model_name}"
838
  all_shap_paths = [] # Initialize empty list for gallery output
839
 
840
+ # --- WandB Setup for Explainability Run ---
841
  run_name = f"Explain_{best_model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
842
  config = {"task": "Explainability", "best_model": best_model_name, "explainers": ["SHAP", "LIME"]}
 
 
843
  wandb_run_explain = None
844
  if wandb.run is None or wandb.run.mode != "disabled":
845
  try:
846
+ # Finish any potentially lingering run first
847
  if wandb.run and wandb.run.id:
848
+ print(f"Finishing potentially active WandB run ({wandb.run.id}) before Explainability run.")
849
  wandb.finish()
850
  wandb_run_explain = wandb.init(project="ai-data-analysis-gradio", name=run_name, config=config, reinit=True)
851
  print(f"WandB run '{run_name}' initialized for Explainability.")
 
853
  print(f"Error initializing Wandb run for Explainability: {e}")
854
  wandb_run_explain = None # Ensure it's None if init fails
855
  else:
856
+ print("WandB disabled, skipping Explainability run logging.")
857
  wandb_run_explain = None
858
 
859
  try:
 
861
  print("Calculating SHAP values...")
862
  shap_values = None
863
  explainer = None
864
+ X_test_for_shap = X_test # Data to pass to SHAP (might be transformed)
 
865
 
866
  # Determine explainer type and get SHAP values
867
  if isinstance(best_model, (RandomForestClassifier, GradientBoostingClassifier)):
868
+ print("Using SHAP TreeExplainer for standalone tree model.")
869
+ explainer = shap.TreeExplainer(best_model)
870
+ shap_values = explainer.shap_values(X_test) # Use original X_test
871
+
872
  elif isinstance(best_model, Pipeline):
 
873
  final_estimator_name, final_estimator = best_model.steps[-1]
874
  print(f"Handling Pipeline. Final estimator: {final_estimator_name} ({type(final_estimator)})")
875
+ pipeline_transforms = Pipeline(best_model.steps[:-1]) # Get all steps EXCEPT the last one
876
 
877
  if isinstance(final_estimator, (RandomForestClassifier, GradientBoostingClassifier)):
878
+ print("Using SHAP TreeExplainer for tree model within Pipeline.")
879
+ # Need to transform data using the pipeline's transform steps first
880
+ try:
881
+ print("Transforming data using pipeline steps before TreeExplainer...")
882
+ X_train_transformed = pipeline_transforms.fit_transform(X_train) # Fit on train
883
+ X_test_transformed = pipeline_transforms.transform(X_test) # Transform test
884
+ X_test_for_shap = pd.DataFrame(X_test_transformed, columns=X_test.columns, index=X_test.index) # Keep index/columns
885
+ print("Data transformed.")
886
+ # Explain the final estimator using the *transformed* data
887
+ explainer = shap.TreeExplainer(final_estimator)
888
+ shap_values = explainer.shap_values(X_test_for_shap)
889
+ except Exception as transform_e:
890
+ print(f"ERROR transforming data for TreeExplainer in Pipeline: {transform_e}. Skipping SHAP.")
891
 
892
  elif isinstance(final_estimator, LogisticRegression):
893
  print("Using SHAP KernelExplainer for Logistic Regression in Pipeline (can be slow)...")
 
894
  predict_proba_pipeline = lambda x: best_model.predict_proba(pd.DataFrame(x, columns=X_test.columns))
 
895
  print("Summarizing training data for KernelExplainer background...")
896
+ # Use original X_train for background summary as KernelExplainer handles model internally
897
+ X_train_summary = shap.kmeans(X_train, min(100, X_train.shape[0]), random_state=42)
898
  explainer = shap.KernelExplainer(predict_proba_pipeline, X_train_summary)
 
899
  subset_size = min(50, X_test.shape[0])
900
  print(f"Calculating SHAP values for {subset_size} test instances...")
901
+ X_test_subset_np = X_test.sample(subset_size, random_state=42).values
902
+ shap_values = explainer.shap_values(X_test_subset_np, nsamples='auto') # Let SHAP choose nsamples
 
903
  # Create DataFrame from subset for plotting consistency
904
+ X_test_for_shap = pd.DataFrame(X_test_subset_np, columns=X_test.columns) # Use subset for plotting
905
  print("SHAP values calculated using KernelExplainer.")
 
906
  else:
907
+ print(f"Warning: SHAP not implemented for final pipeline step {type(final_estimator)}. Skipping SHAP.")
908
  else:
909
+ print(f"Warning: SHAP not explicitly handled for model type {type(best_model)}. Skipping SHAP.")
 
 
 
 
 
 
 
 
 
 
 
 
910
 
911
 
912
  # --- Generate SHAP Plots (if shap_values exist) ---
913
  if shap_values is not None:
914
+ num_classes = len(np.unique(y_train))
915
 
916
  # SHAP Summary Plot
917
  print("Generating SHAP summary plot...")
918
+ plt.figure(figsize=(10, 6))
919
  try:
920
+ plot_shap_values = shap_values
921
  title_suffix = f"({best_model_name})"
922
+ class_names = getattr(label_encoder, 'classes_', [f'Class {i}' for i in range(num_classes)]) if label_encoder else [f'Class {i}' for i in range(num_classes)]
923
+
924
  if num_classes == 2 and isinstance(shap_values, list) and len(shap_values) == 2:
925
+ # Typically, shap_values[1] corresponds to the positive class
926
+ plot_shap_values = shap_values[1]
927
+ positive_class_name = class_names[1] if len(class_names) > 1 else "Class 1"
928
+ title_suffix = f"({best_model_name} - Class: {positive_class_name})"
929
+ print(f"Plotting SHAP summary for positive class ({positive_class_name})")
930
  elif num_classes > 2 and isinstance(shap_values, list):
931
  title_suffix = f"({best_model_name} - Multiclass Avg Impact)"
932
  print("Plotting SHAP summary for multiclass (average impact)")
933
 
934
+ shap.summary_plot(plot_shap_values, X_test_for_shap, show=False, plot_type="dot", class_names=class_names)
935
  plt.title(f"SHAP Feature Importance Summary {title_suffix}")
936
  plt.tight_layout()
937
  plt.savefig(shap_summary_path, bbox_inches='tight')
938
  plt.clf()
939
  print(f"SHAP summary plot saved to {shap_summary_path}")
940
+ all_shap_paths.append(shap_summary_path)
941
  if wandb_run_explain: wandb.log({"shap_summary": wandb.Image(shap_summary_path)}, commit=False)
942
  except Exception as summary_e:
943
  print(f"Error generating SHAP summary plot: {summary_e}")
944
+ plt.clf()
945
 
946
  # SHAP Dependence Plots
947
  print("Calculating global feature importance for dependence plots...")
948
+ global_shap_values_mean = None
949
  try:
950
+ if isinstance(shap_values, list): # Multi-class or binary list
951
+ abs_shap_arrays = [np.abs(sv) for sv in shap_values if isinstance(sv, np.ndarray) and sv.ndim == 2 and sv.shape[1] == X_test_for_shap.shape[1]]
952
+ if abs_shap_arrays:
953
+ stacked_shap = np.stack(abs_shap_arrays, axis=0)
954
+ global_shap_values_mean = stacked_shap.mean(axis=(0, 1))
955
+ print(f"Calculated global SHAP values (list input), shape: {global_shap_values_mean.shape}")
956
+ elif isinstance(shap_values, np.ndarray) and shap_values.ndim == 2 and shap_values.shape[1] == X_test_for_shap.shape[1]: # Regression or binary array
957
+ global_shap_values_mean = np.abs(shap_values).mean(axis=0)
958
+ print(f"Calculated global SHAP values (array input), shape: {global_shap_values_mean.shape}")
 
 
 
 
 
 
 
959
  else:
960
+ print(f"Warning: SHAP values structure not suitable for global importance calculation. Shape: {getattr(shap_values, 'shape', 'N/A')}, Type: {type(shap_values)}")
 
961
  except Exception as gsi_e:
962
  print(f"Error calculating global SHAP importance: {gsi_e}")
963
 
964
+ if global_shap_values_mean is not None and len(global_shap_values_mean) > 0:
 
965
  try:
966
+ feature_indices = np.argsort(global_shap_values_mean)[::-1]
967
+ num_features_to_plot = min(2, len(global_shap_values_mean))
 
968
  if num_features_to_plot > 0:
969
  top_feature_indices = feature_indices[:num_features_to_plot]
970
+ # Ensure indices are within bounds of columns
971
+ valid_indices = [idx for idx in top_feature_indices if idx < len(X_test_for_shap.columns)]
972
+ top_features = X_test_for_shap.columns[valid_indices]
973
 
974
  print(f"Generating SHAP dependence plots for top features: {list(top_features)}")
975
+ for feature_idx, feature_name in zip(valid_indices, top_features):
976
  plt.figure(figsize=(8, 5))
 
 
977
  shap_values_for_dep = shap_values
978
+ class_idx_dep = 0 # Default to first class for multiclass
979
  if isinstance(shap_values, list):
980
+ if num_classes == 2 and len(shap_values) == 2:
981
+ shap_values_for_dep = shap_values[1] # Use positive class for binary
982
+ class_idx_dep = 1
983
+ elif len(shap_values) > 0:
984
+ shap_values_for_dep = shap_values[0] # Default class 0 for multiclass
985
 
986
+ shap.dependence_plot(feature_idx, shap_values_for_dep, X_test_for_shap, interaction_index='auto', show=False)
987
+ dep_title = f"SHAP Dependence: {feature_name} ({best_model_name})"
988
+ if isinstance(shap_values, list): dep_title += f" (Class Index {class_idx_dep})"
989
+ plt.title(dep_title)
990
  plt.tight_layout()
991
  dep_path = os.path.join(output_dir, f"shap_dependence_{best_model_name}_{feature_name}.png")
992
  plt.savefig(dep_path, bbox_inches='tight')
 
996
  print(f"Saved dependence plot: {dep_path}")
997
  if wandb_run_explain: wandb.log({f"shap_dependence_{feature_name}": wandb.Image(dep_path)}, commit=False)
998
  else:
999
+ print("Skipping dependence plots: Not enough features.")
1000
  except Exception as dep_e:
1001
  print(f"Could not generate SHAP dependence plots: {dep_e}")
1002
+ import traceback
1003
+ traceback.print_exc()
1004
+ plt.clf()
1005
  else:
1006
  print("Skipping dependence plots due to issue calculating global SHAP values.")
 
1007
  else:
1008
  print("Skipping SHAP plots as SHAP values were not generated.")
1009
 
 
1012
  print("Generating LIME explanation for the first test instance...")
1013
  try:
1014
  predict_fn_lime = None
 
1015
  if hasattr(best_model, 'predict_proba'):
 
1016
  def predict_proba_wrapper(x_np):
 
1017
  x_df = pd.DataFrame(x_np, columns=X_train.columns)
1018
  return best_model.predict_proba(x_df)
1019
  predict_fn_lime = predict_proba_wrapper
1020
  else:
1021
+ print("Warning: Model lacks predict_proba. LIME results might be unreliable.")
 
1022
  num_classes_lime = len(np.unique(y_train))
1023
  predict_fn_lime = lambda x: np.ones((len(x), num_classes_lime)) / num_classes_lime
1024
 
1025
+ # Use class names from label encoder if available
1026
+ class_names_lime = getattr(label_encoder, 'classes_', [str(c) for c in sorted(np.unique(y_train))]) if label_encoder else [str(c) for c in sorted(np.unique(y_train))]
1027
+ # Ensure class names are strings
1028
+ class_names_lime = [str(cn) for cn in class_names_lime]
 
1029
 
1030
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
1031
+ training_data=X_train.values,
1032
  feature_names=X_train.columns.tolist(),
1033
+ class_names=class_names_lime,
1034
+ mode='classification' if len(class_names_lime) > 1 else 'regression'
1035
  )
1036
 
1037
  instance_idx = 0
1038
+ instance_to_explain = X_test.iloc[instance_idx].values
1039
+ true_class_encoded = y_test[instance_idx] if isinstance(y_test, np.ndarray) else y_test.iloc[instance_idx]
1040
+ # Decode true class if label encoder exists
1041
+ true_class_label = class_names_lime[true_class_encoded] if label_encoder and true_class_encoded < len(class_names_lime) else str(true_class_encoded)
1042
 
1043
  lime_exp = lime_explainer.explain_instance(
1044
  data_row=instance_to_explain,
1045
+ predict_fn=predict_fn_lime,
1046
  num_features=10,
1047
+ num_samples=1000
1048
  )
1049
  print(f"LIME explanation generated for instance {instance_idx}.")
1050
 
1051
  lime_fig = lime_exp.as_pyplot_figure()
1052
+ predicted_class_idx = lime_exp.available_labels()[0]
1053
+ predicted_class_label = class_names_lime[predicted_class_idx] if predicted_class_idx < len(class_names_lime) else f"Index {predicted_class_idx}"
1054
+ lime_fig.suptitle(f"LIME Exp (Inst {instance_idx}, True: {true_class_label}, Pred: {predicted_class_label}, Model: {best_model_name})", y=1.03, fontsize=10)
1055
+ lime_fig.tight_layout(rect=[0, 0, 1, 0.98])
 
1056
  lime_fig.savefig(lime_path, bbox_inches='tight')
1057
  plt.clf()
1058
  print(f"LIME plot saved to {lime_path}")
 
1060
 
1061
  except Exception as lime_e:
1062
  print(f"Error generating LIME explanation: {lime_e}")
1063
+ import traceback
1064
+ traceback.print_exc()
1065
  if wandb_run_explain: wandb.log({"lime_error": str(lime_e)}, commit=False)
1066
  lime_path = None # Indicate failure
1067
 
1068
  # Final status message
1069
+ status_message = f"Explanations generated for {best_model_name}."
1070
+ if not all_shap_paths: status_message += " (SHAP failed/skipped or generated no plots)."
1071
  if not lime_path: status_message += " (LIME failed/skipped)."
1072
 
1073
  # Return paths to the plots and status
 
1074
  valid_lime_path = lime_path if lime_path and os.path.exists(lime_path) else None
1075
+ valid_shap_paths = [p for p in all_shap_paths if p and os.path.exists(p)] # Filter out non-existent paths
1076
+ return valid_shap_paths, valid_lime_path, status_message
1077
 
1078
  except Exception as e:
1079
  print(f"An error occurred during explainability: {e}")
1080
  import traceback
1081
+ traceback.print_exc()
1082
  status_message = f"Error during explanation: {e}"
1083
  if wandb_run_explain: wandb_run_explain.finish(exit_code=1)
1084
  return [], None, status_message # Return empty list/None for paths on error
1085
  finally:
1086
  plt.close('all') # Close all matplotlib figures
1087
+ if wandb_run_explain and wandb.run and wandb.run.id == wandb_run_explain.id:
1088
  try:
1089
+ wandb.log({}, commit=True) # Commit final logs for explain run
1090
  wandb_run_explain.finish()
1091
  print(f"WandB run '{run_name}' finished.")
1092
  except Exception as finish_e:
 
1094
  wandb_run_explain = None # Reset
1095
 
1096
 
1097
+ # --- Gradio Interface ---
1098
+ print("--- Setting up Gradio Interface ---")
 
1099
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Data Analysis & Model Comparison") as demo:
1100
  gr.Markdown(
1101
  """
 
1117
  with gr.Row():
1118
  with gr.Column(scale=1):
1119
  agent_notes = gr.Textbox(label="Optional: Specific requests for the AI Agent", placeholder="e.g., 'Focus on correlations with column X'")
1120
+ agent_btn = gr.Button("Run AI Analysis", variant="secondary", interactive=(agent is not None)) # Disable if agent failed init
1121
  with gr.Column(scale=2):
1122
  insights_output = gr.HTML(label="AI Agent Analysis Report")
1123
  with gr.Row():
 
1132
  optuna_trials_slider = gr.Slider(minimum=5, maximum=50, value=10, step=5, label="Optuna Trials per Model")
1133
  compare_btn = gr.Button("Train & Compare Models", variant="primary")
1134
  with gr.Column(scale=2):
1135
+ comparison_output = gr.DataFrame(label="Model Comparison Results (Sorted by Test F1 Score)", interactive=False)
1136
 
1137
  # --- Row 4: Model Explainability ---
1138
  with gr.Accordion("💡 Step 4: Explain Best Model (SHAP & LIME)", open=False):
 
1141
  explain_status = gr.Textbox(label="Explanation Status", interactive=False)
1142
  with gr.Row():
1143
  # Use Gallery for SHAP as there can be multiple plots
1144
+ shap_gallery = gr.Gallery(label="SHAP Plots (Summary + Top Feature Dependence)", elem_id="shap-gallery", height=450, object_fit="contain", columns=1, preview=True) # Better display
1145
  lime_img = gr.Image(label="LIME Explanation (for first test instance)", type="filepath", interactive=False)
1146
 
1147
 
1148
  # --- Connect Components ---
1149
+ print("Connecting Gradio components...")
1150
+ # Link file upload to function
1151
  file_input.change(
1152
  fn=upload_file,
1153
  inputs=file_input,
1154
  outputs=df_output
1155
  )
1156
 
1157
+ # Link AI agent button
1158
  agent_btn.click(
1159
  fn=analyze_data,
1160
  inputs=[file_input, agent_notes],
1161
  outputs=[insights_output, visual_output]
1162
  )
1163
 
1164
+ # Link model comparison button
1165
  compare_btn.click(
1166
  fn=train_and_compare_models,
1167
  inputs=[tune_rf_checkbox, tune_gb_checkbox, optuna_trials_slider],
1168
  outputs=[comparison_output]
1169
  )
1170
 
1171
+ # Link explain button
1172
  explain_btn.click(
1173
  fn=explainability,
1174
  inputs=[], # Uses global best model details
1175
  outputs=[shap_gallery, lime_img, explain_status] # Output list of SHAP plots, one LIME plot, and status
1176
  )
1177
+ print("Gradio components connected.")
1178
 
1179
  # --- Launch the App ---
1180
  if __name__ == "__main__":
1181
+ print("--- Cleaning up temporary directories/files ---")
1182
  # Clean up temporary files/dirs from previous runs before launching
1183
+ temp_dirs = ['./figures', './explainability_plots', './saved_models', './__pycache__']
1184
+ temp_files = [] # Don't delete pngs automatically if they are inside the cleaned dirs
1185
 
1186
  for d in temp_dirs:
1187
  if os.path.exists(d):
 
1198
  except Exception as e:
1199
  print(f"Warning: Could not clean up file {f}: {e}")
1200
 
1201
+ print("--- Launching Gradio App ---")
1202
+ demo.launch(debug=False, share=False) # Set debug=True for detailed Gradio errors if needed
1203
+ print("--- Gradio App Closed ---")