pavanmutha commited on
Commit
4949145
·
verified ·
1 Parent(s): e7399a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -59
app.py CHANGED
@@ -182,9 +182,6 @@ def compare_models():
182
  results_df = pd.DataFrame(results)
183
  return results_df
184
 
185
-
186
-
187
-
188
  def train_model(_):
189
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
190
  run_counter = 1
@@ -209,34 +206,29 @@ def train_model(_):
209
  common_errors = error_df[error_df["error"]].groupby(["actual", "predicted"]).size().reset_index(name='count')
210
 
211
  def generate_report(metrics_df, trials_df, common_errors_df):
212
- report = f"""
213
- # Model Training Report
214
-
215
- ## Metrics
216
- {metrics_df.to_markdown(index=False)}
217
-
218
- ## Top Trials
219
- {trials_df.to_markdown(index=False)}
220
 
221
- ## Common Errors
222
- {common_errors_df.to_markdown(index=False)}
223
 
224
- _Generated on {time.strftime('%Y-%m-%d %H:%M:%S')}_
225
- """
226
- with open("model_report.md", "w") as f:
227
- f.write(report)
228
- return "Report saved to model_report.md"
229
 
230
-
 
231
 
 
 
 
 
 
232
 
233
  fig, ax = plt.subplots(figsize=(6, 4))
234
  ConfusionMatrixDisplay.from_estimator(best_model, X_test, y_test, ax=ax)
235
  plt.savefig("confusion_matrix.png")
236
  wandb.log({"confusion_matrix": wandb.Image("confusion_matrix.png")})
237
 
238
-
239
-
240
  # Inside your layout:
241
  compare_button = gr.Button("Compare Models")
242
  compare_output = gr.Dataframe()
@@ -251,40 +243,38 @@ report_button.click(
251
  outputs=report_status
252
  )
253
 
254
-
255
  # Log common misclassifications to wandb
256
  wandb.log({"common_errors": wandb.Table(dataframe=common_errors)})
257
 
258
-
259
- def objective(trial):
260
- params = {
261
- "n_estimators": trial.suggest_int("n_estimators", 50, 200),
262
- "max_depth": trial.suggest_int("max_depth", 3, 10),
263
- }
264
- model = RandomForestClassifier(**params)
265
- score = cross_val_score(model, X_train, y_train, cv=3).mean()
266
- wandb.log(params | {"cv_score": score})
267
- return score
268
-
269
- study = optuna.create_study(direction="maximize")
270
- study.optimize(objective, n_trials=15)
271
-
272
- best_params = study.best_params
273
- model = RandomForestClassifier(**best_params)
274
- model.fit(X_train, y_train)
275
- y_pred = model.predict(X_test)
276
-
277
- metrics = {
278
- "accuracy": accuracy_score(y_test, y_pred),
279
- "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
280
- "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
281
- "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
282
  }
283
- wandb.log(metrics)
284
- wandb_run.finish()
285
-
286
- top_trials = pd.DataFrame(study.trials_dataframe().sort_values(by="value", ascending=False).head(7))
287
- return metrics, top_trials
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  def explainability(_):
290
  import warnings
@@ -361,9 +351,6 @@ def explainability(_):
361
 
362
  return shap_path, lime_path
363
 
364
-
365
-
366
-
367
  with gr.Blocks() as demo:
368
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
369
 
@@ -374,7 +361,6 @@ with gr.Blocks() as demo:
374
  file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
375
 
376
  with gr.Column():
377
-
378
  insights_output = gr.HTML(label="Insights from SmolAgent")
379
  visual_output = gr.Gallery(label="Visualizations (Auto-generated by Agent)", columns=2)
380
  agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
@@ -389,11 +375,8 @@ with gr.Blocks() as demo:
389
  shap_img = gr.Image(label="SHAP Summary Plot")
390
  lime_img = gr.Image(label="LIME Explanation")
391
 
392
-
393
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
394
  train_btn.click(fn=train_model, inputs=[], outputs=[metrics_output, trials_output])
395
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
396
 
397
-
398
-
399
- demo.launch(debug=True)
 
182
  results_df = pd.DataFrame(results)
183
  return results_df
184
 
 
 
 
185
  def train_model(_):
186
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
187
  run_counter = 1
 
206
  common_errors = error_df[error_df["error"]].groupby(["actual", "predicted"]).size().reset_index(name='count')
207
 
208
  def generate_report(metrics_df, trials_df, common_errors_df):
209
+ report = f"""
210
+ # Model Training Report
 
 
 
 
 
 
211
 
212
+ ## Metrics
213
+ {metrics_df.to_markdown(index=False)}
214
 
215
+ ## Top Trials
216
+ {trials_df.to_markdown(index=False)}
 
 
 
217
 
218
+ ## Common Errors
219
+ {common_errors_df.to_markdown(index=False)}
220
 
221
+ _Generated on {time.strftime('%Y-%m-%d %H:%M:%S')}_
222
+ """
223
+ with open("model_report.md", "w") as f:
224
+ f.write(report)
225
+ return "Report saved to model_report.md"
226
 
227
  fig, ax = plt.subplots(figsize=(6, 4))
228
  ConfusionMatrixDisplay.from_estimator(best_model, X_test, y_test, ax=ax)
229
  plt.savefig("confusion_matrix.png")
230
  wandb.log({"confusion_matrix": wandb.Image("confusion_matrix.png")})
231
 
 
 
232
  # Inside your layout:
233
  compare_button = gr.Button("Compare Models")
234
  compare_output = gr.Dataframe()
 
243
  outputs=report_status
244
  )
245
 
 
246
  # Log common misclassifications to wandb
247
  wandb.log({"common_errors": wandb.Table(dataframe=common_errors)})
248
 
249
+ def objective(trial):
250
+ params = {
251
+ "n_estimators": trial.suggest_int("n_estimators", 50, 200),
252
+ "max_depth": trial.suggest_int("max_depth", 3, 10),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  }
254
+ model = RandomForestClassifier(**params)
255
+ score = cross_val_score(model, X_train, y_train, cv=3).mean()
256
+ wandb.log(params | {"cv_score": score})
257
+ return score
258
+
259
+ study = optuna.create_study(direction="maximize")
260
+ study.optimize(objective, n_trials=15)
261
+
262
+ best_params = study.best_params
263
+ model = RandomForestClassifier(**best_params)
264
+ model.fit(X_train, y_train)
265
+ y_pred = model.predict(X_test)
266
+
267
+ metrics = {
268
+ "accuracy": accuracy_score(y_test, y_pred),
269
+ "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
270
+ "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
271
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
272
+ }
273
+ wandb.log(metrics)
274
+ wandb_run.finish()
275
+
276
+ top_trials = pd.DataFrame(study.trials_dataframe().sort_values(by="value", ascending=False).head(7))
277
+ return metrics, top_trials
278
 
279
  def explainability(_):
280
  import warnings
 
351
 
352
  return shap_path, lime_path
353
 
 
 
 
354
  with gr.Blocks() as demo:
355
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
356
 
 
361
  file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
362
 
363
  with gr.Column():
 
364
  insights_output = gr.HTML(label="Insights from SmolAgent")
365
  visual_output = gr.Gallery(label="Visualizations (Auto-generated by Agent)", columns=2)
366
  agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
 
375
  shap_img = gr.Image(label="SHAP Summary Plot")
376
  lime_img = gr.Image(label="LIME Explanation")
377
 
 
378
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
379
  train_btn.click(fn=train_model, inputs=[], outputs=[metrics_output, trials_output])
380
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
381
 
382
+ demo.launch(debug=True)