pavanmutha commited on
Commit
68f9dc5
·
verified ·
1 Parent(s): 48f788e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -53
app.py CHANGED
@@ -198,59 +198,75 @@ def prepare_data(df, target_column=None):
198
  return train_test_split(X, y, test_size=0.3, random_state=42)
199
 
200
 
201
- def train_model(_):
202
- try:
203
- wandb.login(key=os.environ.get("WANDB_API_KEY"))
204
- wandb_run = wandb.init(
205
- project="huggingface-data-analysis",
206
- name=f"Optuna_Run_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
207
- reinit=True
208
- )
209
-
210
- X_train, X_test, y_train, y_test = prepare_data(df_global)
211
-
212
- def objective(trial):
213
- params = {
214
- "n_estimators": trial.suggest_int("n_estimators", 50, 200),
215
- "max_depth": trial.suggest_int("max_depth", 3, 10),
216
- }
217
- model = RandomForestClassifier()
218
- score = cross_val_score(model, X_train, y_train, cv=3).mean()
219
- wandb.log({**params, "cv_score": score})
220
- return score
221
-
222
- study = optuna.create_study(direction="maximize")
223
- study.optimize(objective, n_trials=15)
224
-
225
- best_params = study.best_params
226
- model = RandomForestClassifier()
227
- model.fit(X_train, y_train)
228
- y_pred = model.predict(X_test)
229
-
230
-
231
- metrics = {
232
- "accuracy": accuracy_score(y_test, y_pred),
233
- "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
234
- "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
235
- "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
236
- }
237
- wandb.log(metrics)
238
- wandb_run.finish()
239
-
240
- # Top 7 trials
241
- top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7]
242
- trial_rows = []
243
- for t in top_trials:
244
- row = t.params.copy()
245
- row["score"] = t.value
246
- trial_rows.append(row)
247
- trials_df = pd.DataFrame(trial_rows)
248
-
249
- return metrics, trials_df
250
 
251
- except Exception as e:
252
- print(f"Training Error: {e}")
253
- return {}, pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  def explainability(_):
@@ -351,9 +367,13 @@ with gr.Blocks() as demo:
351
  explain_btn = gr.Button("SHAP + LIME Explainability")
352
  shap_img = gr.Image(label="SHAP Summary Plot")
353
  lime_img = gr.Image(label="LIME Explanation")
 
 
 
 
354
 
355
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
356
- train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
357
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
358
 
359
  demo.launch(debug=True)
 
198
  return train_test_split(X, y, test_size=0.3, random_state=42)
199
 
200
 
201
+ def train_model(file, ab_choice="A"):
202
+ df = pd.read_csv(file.name)
203
+ df = preprocess_data(df)
204
+ X = df.drop("target", axis=1)
205
+ y = df["target"]
206
+
207
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
208
+
209
+ # SmolAgent
210
+ tuner = SmolAgent(model="gpt-4")
211
+ model = tuner.fit(X_train, y_train)
212
+
213
+ # Evaluate
214
+ y_pred = model.predict(X_test)
215
+ y_proba = model.predict_proba(X_test) if hasattr(model, "predict_proba") else None
216
+
217
+ metrics = {
218
+ "accuracy": accuracy_score(y_test, y_pred),
219
+ "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
220
+ "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
221
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
222
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ # Optional AUC
225
+ if y_proba is not None:
226
+ try:
227
+ if y_proba.shape[1] == 2:
228
+ metrics["roc_auc"] = roc_auc_score(y_test, y_proba[:, 1])
229
+ else:
230
+ metrics["roc_auc_ovr"] = roc_auc_score(y_test, y_proba, multi_class="ovr")
231
+ except Exception:
232
+ pass
233
+
234
+ # Classification report
235
+ metrics["classification_report"] = classification_report(y_test, y_pred, output_dict=True)
236
+
237
+ # Confusion Matrix Plot
238
+ fig, ax = plt.subplots(figsize=(6, 4))
239
+ ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=ax)
240
+ plt.title("Confusion Matrix")
241
+ conf_matrix_path = "./conf_matrix.png"
242
+ plt.savefig(conf_matrix_path)
243
+ plt.close()
244
+
245
+ # Bar Plot of Metrics
246
+ def plot_metrics(metrics):
247
+ plt.figure(figsize=(6, 3))
248
+ keys = [k for k in metrics if isinstance(metrics[k], (int, float))]
249
+ values = [metrics[k] for k in keys]
250
+ plt.barh(keys, values, color="skyblue")
251
+ plt.xlabel("Score")
252
+ plt.title("Model Performance Metrics")
253
+ path = "./metrics_plot.png"
254
+ plt.tight_layout()
255
+ plt.savefig(path)
256
+ plt.close()
257
+ return path
258
+
259
+ metrics_plot_path = plot_metrics(metrics)
260
+
261
+ # Log to WandB
262
+ wandb.init(project="ab-test", name=f"variant_{ab_choice}", reinit=True)
263
+ wandb.log({**metrics, "confusion_matrix": wandb.Image(conf_matrix_path),
264
+ "metrics_plot": wandb.Image(metrics_plot_path),
265
+ "ab_variant": ab_choice})
266
+ wandb.finish()
267
+
268
+ # Return everything to Gradio
269
+ return metrics, pd.DataFrame.from_dict(metrics["classification_report"]).T, conf_matrix_path, metrics_plot_path
270
 
271
 
272
  def explainability(_):
 
367
  explain_btn = gr.Button("SHAP + LIME Explainability")
368
  shap_img = gr.Image(label="SHAP Summary Plot")
369
  lime_img = gr.Image(label="LIME Explanation")
370
+ metrics_output = gr.JSON(label="Evaluation Metrics")
371
+ trials_output = gr.Dataframe(label="Classification Report")
372
+ conf_matrix_img = gr.Image(label="Confusion Matrix")
373
+ metric_plot_img = gr.Image(label="Metric Bar Plot")
374
 
375
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
376
+ train_btn.click(fn=train_model, inputs=[file_input, ab_dropdown], outputs=[metrics_output, trials_output, conf_matrix_img, metric_plot_img])
377
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
378
 
379
  demo.launch(debug=True)