Update app.py
Browse files
app.py
CHANGED
|
@@ -182,9 +182,6 @@ def compare_models():
|
|
| 182 |
results_df = pd.DataFrame(results)
|
| 183 |
return results_df
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
def train_model(_):
|
| 189 |
wandb.login(key=os.environ.get("WANDB_API_KEY"))
|
| 190 |
run_counter = 1
|
|
@@ -209,34 +206,29 @@ def train_model(_):
|
|
| 209 |
common_errors = error_df[error_df["error"]].groupby(["actual", "predicted"]).size().reset_index(name='count')
|
| 210 |
|
| 211 |
def generate_report(metrics_df, trials_df, common_errors_df):
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
## Metrics
|
| 216 |
-
{metrics_df.to_markdown(index=False)}
|
| 217 |
-
|
| 218 |
-
## Top Trials
|
| 219 |
-
{trials_df.to_markdown(index=False)}
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
with open("model_report.md", "w") as f:
|
| 227 |
-
f.write(report)
|
| 228 |
-
return "Report saved to model_report.md"
|
| 229 |
|
| 230 |
-
|
|
|
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 234 |
ConfusionMatrixDisplay.from_estimator(best_model, X_test, y_test, ax=ax)
|
| 235 |
plt.savefig("confusion_matrix.png")
|
| 236 |
wandb.log({"confusion_matrix": wandb.Image("confusion_matrix.png")})
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
# Inside your layout:
|
| 241 |
compare_button = gr.Button("Compare Models")
|
| 242 |
compare_output = gr.Dataframe()
|
|
@@ -251,40 +243,38 @@ report_button.click(
|
|
| 251 |
outputs=report_status
|
| 252 |
)
|
| 253 |
|
| 254 |
-
|
| 255 |
# Log common misclassifications to wandb
|
| 256 |
wandb.log({"common_errors": wandb.Table(dataframe=common_errors)})
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
| 263 |
-
}
|
| 264 |
-
model = RandomForestClassifier(**params)
|
| 265 |
-
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
| 266 |
-
wandb.log(params | {"cv_score": score})
|
| 267 |
-
return score
|
| 268 |
-
|
| 269 |
-
study = optuna.create_study(direction="maximize")
|
| 270 |
-
study.optimize(objective, n_trials=15)
|
| 271 |
-
|
| 272 |
-
best_params = study.best_params
|
| 273 |
-
model = RandomForestClassifier(**best_params)
|
| 274 |
-
model.fit(X_train, y_train)
|
| 275 |
-
y_pred = model.predict(X_test)
|
| 276 |
-
|
| 277 |
-
metrics = {
|
| 278 |
-
"accuracy": accuracy_score(y_test, y_pred),
|
| 279 |
-
"precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 280 |
-
"recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 281 |
-
"f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 282 |
}
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
def explainability(_):
|
| 290 |
import warnings
|
|
@@ -361,9 +351,6 @@ def explainability(_):
|
|
| 361 |
|
| 362 |
return shap_path, lime_path
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
with gr.Blocks() as demo:
|
| 368 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
| 369 |
|
|
@@ -374,7 +361,6 @@ with gr.Blocks() as demo:
|
|
| 374 |
file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
|
| 375 |
|
| 376 |
with gr.Column():
|
| 377 |
-
|
| 378 |
insights_output = gr.HTML(label="Insights from SmolAgent")
|
| 379 |
visual_output = gr.Gallery(label="Visualizations (Auto-generated by Agent)", columns=2)
|
| 380 |
agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
|
|
@@ -389,11 +375,8 @@ with gr.Blocks() as demo:
|
|
| 389 |
shap_img = gr.Image(label="SHAP Summary Plot")
|
| 390 |
lime_img = gr.Image(label="LIME Explanation")
|
| 391 |
|
| 392 |
-
|
| 393 |
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
|
| 394 |
train_btn.click(fn=train_model, inputs=[], outputs=[metrics_output, trials_output])
|
| 395 |
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
demo.launch(debug=True)
|
|
|
|
| 182 |
results_df = pd.DataFrame(results)
|
| 183 |
return results_df
|
| 184 |
|
|
|
|
|
|
|
|
|
|
| 185 |
def train_model(_):
|
| 186 |
wandb.login(key=os.environ.get("WANDB_API_KEY"))
|
| 187 |
run_counter = 1
|
|
|
|
| 206 |
common_errors = error_df[error_df["error"]].groupby(["actual", "predicted"]).size().reset_index(name='count')
|
| 207 |
|
| 208 |
def generate_report(metrics_df, trials_df, common_errors_df):
|
| 209 |
+
report = f"""
|
| 210 |
+
# Model Training Report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
## Metrics
|
| 213 |
+
{metrics_df.to_markdown(index=False)}
|
| 214 |
|
| 215 |
+
## Top Trials
|
| 216 |
+
{trials_df.to_markdown(index=False)}
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
+
## Common Errors
|
| 219 |
+
{common_errors_df.to_markdown(index=False)}
|
| 220 |
|
| 221 |
+
_Generated on {time.strftime('%Y-%m-%d %H:%M:%S')}_
|
| 222 |
+
"""
|
| 223 |
+
with open("model_report.md", "w") as f:
|
| 224 |
+
f.write(report)
|
| 225 |
+
return "Report saved to model_report.md"
|
| 226 |
|
| 227 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 228 |
ConfusionMatrixDisplay.from_estimator(best_model, X_test, y_test, ax=ax)
|
| 229 |
plt.savefig("confusion_matrix.png")
|
| 230 |
wandb.log({"confusion_matrix": wandb.Image("confusion_matrix.png")})
|
| 231 |
|
|
|
|
|
|
|
| 232 |
# Inside your layout:
|
| 233 |
compare_button = gr.Button("Compare Models")
|
| 234 |
compare_output = gr.Dataframe()
|
|
|
|
| 243 |
outputs=report_status
|
| 244 |
)
|
| 245 |
|
|
|
|
| 246 |
# Log common misclassifications to wandb
|
| 247 |
wandb.log({"common_errors": wandb.Table(dataframe=common_errors)})
|
| 248 |
|
| 249 |
+
def objective(trial):
|
| 250 |
+
params = {
|
| 251 |
+
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
|
| 252 |
+
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
}
|
| 254 |
+
model = RandomForestClassifier(**params)
|
| 255 |
+
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
| 256 |
+
wandb.log(params | {"cv_score": score})
|
| 257 |
+
return score
|
| 258 |
+
|
| 259 |
+
study = optuna.create_study(direction="maximize")
|
| 260 |
+
study.optimize(objective, n_trials=15)
|
| 261 |
+
|
| 262 |
+
best_params = study.best_params
|
| 263 |
+
model = RandomForestClassifier(**best_params)
|
| 264 |
+
model.fit(X_train, y_train)
|
| 265 |
+
y_pred = model.predict(X_test)
|
| 266 |
+
|
| 267 |
+
metrics = {
|
| 268 |
+
"accuracy": accuracy_score(y_test, y_pred),
|
| 269 |
+
"precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 270 |
+
"recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 271 |
+
"f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 272 |
+
}
|
| 273 |
+
wandb.log(metrics)
|
| 274 |
+
wandb_run.finish()
|
| 275 |
+
|
| 276 |
+
top_trials = pd.DataFrame(study.trials_dataframe().sort_values(by="value", ascending=False).head(7))
|
| 277 |
+
return metrics, top_trials
|
| 278 |
|
| 279 |
def explainability(_):
|
| 280 |
import warnings
|
|
|
|
| 351 |
|
| 352 |
return shap_path, lime_path
|
| 353 |
|
|
|
|
|
|
|
|
|
|
| 354 |
with gr.Blocks() as demo:
|
| 355 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
| 356 |
|
|
|
|
| 361 |
file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
|
| 362 |
|
| 363 |
with gr.Column():
|
|
|
|
| 364 |
insights_output = gr.HTML(label="Insights from SmolAgent")
|
| 365 |
visual_output = gr.Gallery(label="Visualizations (Auto-generated by Agent)", columns=2)
|
| 366 |
agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
|
|
|
|
| 375 |
shap_img = gr.Image(label="SHAP Summary Plot")
|
| 376 |
lime_img = gr.Image(label="LIME Explanation")
|
| 377 |
|
|
|
|
| 378 |
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
|
| 379 |
train_btn.click(fn=train_model, inputs=[], outputs=[metrics_output, trials_output])
|
| 380 |
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
|
| 381 |
|
| 382 |
+
demo.launch(debug=True)
|
|
|
|
|
|