pavanmutha commited on
Commit
3eedbb8
·
verified ·
1 Parent(s): 4ce62b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -78
app.py CHANGED
@@ -198,75 +198,59 @@ def prepare_data(df, target_column=None):
198
  return train_test_split(X, y, test_size=0.3, random_state=42)
199
 
200
 
201
- def train_model(file, ab_choice="A"):
202
- df = pd.read_csv(file.name)
203
- df = preprocess_data(df)
204
- X = df.drop("target", axis=1)
205
- y = df["target"]
206
-
207
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
208
-
209
- # SmolAgent
210
- tuner = SmolAgent(model="gpt-4")
211
- model = tuner.fit(X_train, y_train)
212
-
213
- # Evaluate
214
- y_pred = model.predict(X_test)
215
- y_proba = model.predict_proba(X_test) if hasattr(model, "predict_proba") else None
216
-
217
- metrics = {
218
- "accuracy": accuracy_score(y_test, y_pred),
219
- "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
220
- "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
221
- "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
222
- }
223
 
224
- # Optional AUC
225
- if y_proba is not None:
226
- try:
227
- if y_proba.shape[1] == 2:
228
- metrics["roc_auc"] = roc_auc_score(y_test, y_proba[:, 1])
229
- else:
230
- metrics["roc_auc_ovr"] = roc_auc_score(y_test, y_proba, multi_class="ovr")
231
- except Exception:
232
- pass
233
-
234
- # Classification report
235
- metrics["classification_report"] = classification_report(y_test, y_pred, output_dict=True)
236
-
237
- # Confusion Matrix Plot
238
- fig, ax = plt.subplots(figsize=(6, 4))
239
- ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=ax)
240
- plt.title("Confusion Matrix")
241
- conf_matrix_path = "./conf_matrix.png"
242
- plt.savefig(conf_matrix_path)
243
- plt.close()
244
-
245
- # Bar Plot of Metrics
246
- def plot_metrics(metrics):
247
- plt.figure(figsize=(6, 3))
248
- keys = [k for k in metrics if isinstance(metrics[k], (int, float))]
249
- values = [metrics[k] for k in keys]
250
- plt.barh(keys, values, color="skyblue")
251
- plt.xlabel("Score")
252
- plt.title("Model Performance Metrics")
253
- path = "./metrics_plot.png"
254
- plt.tight_layout()
255
- plt.savefig(path)
256
- plt.close()
257
- return path
258
-
259
- metrics_plot_path = plot_metrics(metrics)
260
-
261
- # Log to WandB
262
- wandb.init(project="ab-test", name=f"variant_{ab_choice}", reinit=True)
263
- wandb.log({**metrics, "confusion_matrix": wandb.Image(conf_matrix_path),
264
- "metrics_plot": wandb.Image(metrics_plot_path),
265
- "ab_variant": ab_choice})
266
- wandb.finish()
267
-
268
- # Return everything to Gradio
269
- return metrics, pd.DataFrame.from_dict(metrics["classification_report"]).T, conf_matrix_path, metrics_plot_path
270
 
271
 
272
  def explainability(_):
@@ -359,25 +343,17 @@ with gr.Blocks() as demo:
359
  agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
360
 
361
  with gr.Row():
362
- ab_dropdown = gr.Dropdown(choices=["A", "B"], label="Choose Model Variant", value="A")
363
  train_btn = gr.Button("Train Model with Optuna + WandB")
364
-
365
- with gr.Row():
366
  metrics_output = gr.JSON(label="Performance Metrics")
367
  trials_output = gr.DataFrame(label="Top 7 Hyperparameter Trials")
368
 
369
- with gr.Row():
370
- conf_matrix_img = gr.Image(label="Confusion Matrix")
371
- metric_plot_img = gr.Image(label="Metric Bar Plot")
372
-
373
  with gr.Row():
374
  explain_btn = gr.Button("SHAP + LIME Explainability")
375
  shap_img = gr.Image(label="SHAP Summary Plot")
376
  lime_img = gr.Image(label="LIME Explanation")
377
 
378
- # Button logic
379
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
380
- train_btn.click(fn=train_model, inputs=[file_input, ab_dropdown], outputs=[metrics_output, trials_output, conf_matrix_img, metric_plot_img])
381
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
382
 
383
- demo.launch(debug=True)
 
198
  return train_test_split(X, y, test_size=0.3, random_state=42)
199
 
200
 
201
+ def train_model(_):
202
+ try:
203
+ wandb.login(key=os.environ.get("WANDB_API_KEY"))
204
+ wandb_run = wandb.init(
205
+ project="huggingface-data-analysis",
206
+ name=f"Optuna_Run_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
207
+ reinit=True
208
+ )
209
+
210
+ X_train, X_test, y_train, y_test = prepare_data(df_global)
211
+
212
+ def objective(trial):
213
+ params = {
214
+ "n_estimators": trial.suggest_int("n_estimators", 50, 200),
215
+ "max_depth": trial.suggest_int("max_depth", 3, 10),
216
+ }
217
+ model = RandomForestClassifier()
218
+ score = cross_val_score(model, X_train, y_train, cv=3).mean()
219
+ wandb.log({**params, "cv_score": score})
220
+ return score
 
 
221
 
222
+ study = optuna.create_study(direction="maximize")
223
+ study.optimize(objective, n_trials=15)
224
+
225
+ best_params = study.best_params
226
+ model = RandomForestClassifier()
227
+ model.fit(X_train, y_train)
228
+ y_pred = model.predict(X_test)
229
+
230
+
231
+ metrics = {
232
+ "accuracy": accuracy_score(y_test, y_pred),
233
+ "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
234
+ "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
235
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
236
+ }
237
+ wandb.log(metrics)
238
+ wandb_run.finish()
239
+
240
+ # Top 7 trials
241
+ top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7]
242
+ trial_rows = []
243
+ for t in top_trials:
244
+ row = t.params.copy()
245
+ row["score"] = t.value
246
+ trial_rows.append(row)
247
+ trials_df = pd.DataFrame(trial_rows)
248
+
249
+ return metrics, trials_df
250
+
251
+ except Exception as e:
252
+ print(f"Training Error: {e}")
253
+ return {}, pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  def explainability(_):
 
343
  agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
344
 
345
  with gr.Row():
 
346
  train_btn = gr.Button("Train Model with Optuna + WandB")
 
 
347
  metrics_output = gr.JSON(label="Performance Metrics")
348
  trials_output = gr.DataFrame(label="Top 7 Hyperparameter Trials")
349
 
 
 
 
 
350
  with gr.Row():
351
  explain_btn = gr.Button("SHAP + LIME Explainability")
352
  shap_img = gr.Image(label="SHAP Summary Plot")
353
  lime_img = gr.Image(label="LIME Explanation")
354
 
 
355
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
356
+ train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
357
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
358
 
359
+ demo.launch(debug=True)