pavanmutha commited on
Commit
1fdfe63
·
verified ·
1 Parent(s): 3eedbb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -155,9 +155,12 @@ def analyze_data(csv_file, additional_notes=""):
155
  return format_analysis_report(analysis_result, visuals)
156
 
157
  def compare_models():
 
 
 
158
  if df_global is None:
159
- return "Please upload and preprocess a dataset first."
160
-
161
  target = df_global.columns[-1]
162
  X = df_global.drop(target, axis=1)
163
  y = df_global[target]
@@ -168,21 +171,43 @@ def compare_models():
168
  models = {
169
  "RandomForest": RandomForestClassifier(),
170
  "LogisticRegression": LogisticRegression(max_iter=1000),
171
- "SVC": SVC()
172
  }
173
 
174
  results = []
175
  for name, model in models.items():
 
176
  scores = cross_val_score(model, X, y, cv=5)
177
- results.append({
 
 
 
 
178
  "Model": name,
179
  "CV Mean Accuracy": np.mean(scores),
180
- "CV Std Dev": np.std(scores)
181
- })
182
- wandb.log({f"{name}_cv_mean": np.mean(scores), f"{name}_cv_std": np.std(scores)})
 
 
 
 
183
 
184
  results_df = pd.DataFrame(results)
185
- return results_df
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  # 1. prepare_data should come first
188
  def prepare_data(df, target_column=None):
@@ -352,8 +377,15 @@ with gr.Blocks() as demo:
352
  shap_img = gr.Image(label="SHAP Summary Plot")
353
  lime_img = gr.Image(label="LIME Explanation")
354
 
 
 
 
 
 
355
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
356
  train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
357
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
 
 
358
 
359
  demo.launch(debug=True)
 
155
  return format_analysis_report(analysis_result, visuals)
156
 
157
  def compare_models():
158
+ import seaborn as sns
159
+ from sklearn.model_selection import cross_val_predict
160
+
161
  if df_global is None:
162
+ return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
163
+
164
  target = df_global.columns[-1]
165
  X = df_global.drop(target, axis=1)
166
  y = df_global[target]
 
171
  models = {
172
  "RandomForest": RandomForestClassifier(),
173
  "LogisticRegression": LogisticRegression(max_iter=1000),
174
+ "GradientBoosting": GradientBoostingClassifier()
175
  }
176
 
177
  results = []
178
  for name, model in models.items():
179
+ # Cross-validation scores
180
  scores = cross_val_score(model, X, y, cv=5)
181
+
182
+ # Cross-validated predictions for metrics
183
+ y_pred = cross_val_predict(model, X, y, cv=5)
184
+
185
+ metrics = {
186
  "Model": name,
187
  "CV Mean Accuracy": np.mean(scores),
188
+ "CV Std Dev": np.std(scores),
189
+ "F1 Score": f1_score(y, y_pred, average="weighted", zero_division=0),
190
+ "Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
191
+ "Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
192
+ }
193
+ wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
194
+ results.append(metrics)
195
 
196
  results_df = pd.DataFrame(results)
197
+
198
+ # Plotting
199
+ plt.figure(figsize=(8, 5))
200
+ sns.barplot(data=results_df, x="Model", y="CV Mean Accuracy", palette="Blues_d")
201
+ plt.title("Model Comparison (CV Mean Accuracy)")
202
+ plt.ylim(0, 1)
203
+ plt.tight_layout()
204
+
205
+ plot_path = "./model_comparison.png"
206
+ plt.savefig(plot_path)
207
+ plt.close()
208
+
209
+ return results_df, plot_path
210
+
211
 
212
  # 1. prepare_data should come first
213
  def prepare_data(df, target_column=None):
 
377
  shap_img = gr.Image(label="SHAP Summary Plot")
378
  lime_img = gr.Image(label="LIME Explanation")
379
 
380
+ with gr.Row():
381
+ compare_btn = gr.Button("Compare Models (A/B Testing)")
382
+ compare_output = gr.DataFrame(label="Model Comparison (CV + Metrics)")
383
+ compare_img = gr.Image(label="Model Accuracy Plot")
384
+
385
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
386
  train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
387
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
388
+ compare_btn.click(fn=compare_models, inputs=[], outputs=[compare_output, compare_img])
389
+
390
 
391
  demo.launch(debug=True)