Update app.py
Browse files
app.py
CHANGED
|
@@ -155,9 +155,12 @@ def analyze_data(csv_file, additional_notes=""):
|
|
| 155 |
return format_analysis_report(analysis_result, visuals)
|
| 156 |
|
| 157 |
def compare_models():
|
|
|
|
|
|
|
|
|
|
| 158 |
if df_global is None:
|
| 159 |
-
return "Please upload and preprocess a dataset first."
|
| 160 |
-
|
| 161 |
target = df_global.columns[-1]
|
| 162 |
X = df_global.drop(target, axis=1)
|
| 163 |
y = df_global[target]
|
|
@@ -168,21 +171,43 @@ def compare_models():
|
|
| 168 |
models = {
|
| 169 |
"RandomForest": RandomForestClassifier(),
|
| 170 |
"LogisticRegression": LogisticRegression(max_iter=1000),
|
| 171 |
-
"
|
| 172 |
}
|
| 173 |
|
| 174 |
results = []
|
| 175 |
for name, model in models.items():
|
|
|
|
| 176 |
scores = cross_val_score(model, X, y, cv=5)
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
"Model": name,
|
| 179 |
"CV Mean Accuracy": np.mean(scores),
|
| 180 |
-
"CV Std Dev": np.std(scores)
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
results_df = pd.DataFrame(results)
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# 1. prepare_data should come first
|
| 188 |
def prepare_data(df, target_column=None):
|
|
@@ -352,8 +377,15 @@ with gr.Blocks() as demo:
|
|
| 352 |
shap_img = gr.Image(label="SHAP Summary Plot")
|
| 353 |
lime_img = gr.Image(label="LIME Explanation")
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
|
| 356 |
train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
|
| 357 |
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
|
|
|
|
|
|
|
| 358 |
|
| 359 |
demo.launch(debug=True)
|
|
|
|
| 155 |
return format_analysis_report(analysis_result, visuals)
|
| 156 |
|
| 157 |
def compare_models():
|
| 158 |
+
import seaborn as sns
|
| 159 |
+
from sklearn.model_selection import cross_val_predict
|
| 160 |
+
|
| 161 |
if df_global is None:
|
| 162 |
+
return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
|
| 163 |
+
|
| 164 |
target = df_global.columns[-1]
|
| 165 |
X = df_global.drop(target, axis=1)
|
| 166 |
y = df_global[target]
|
|
|
|
| 171 |
models = {
|
| 172 |
"RandomForest": RandomForestClassifier(),
|
| 173 |
"LogisticRegression": LogisticRegression(max_iter=1000),
|
| 174 |
+
"GradientBoosting": GradientBoostingClassifier()
|
| 175 |
}
|
| 176 |
|
| 177 |
results = []
|
| 178 |
for name, model in models.items():
|
| 179 |
+
# Cross-validation scores
|
| 180 |
scores = cross_val_score(model, X, y, cv=5)
|
| 181 |
+
|
| 182 |
+
# Cross-validated predictions for metrics
|
| 183 |
+
y_pred = cross_val_predict(model, X, y, cv=5)
|
| 184 |
+
|
| 185 |
+
metrics = {
|
| 186 |
"Model": name,
|
| 187 |
"CV Mean Accuracy": np.mean(scores),
|
| 188 |
+
"CV Std Dev": np.std(scores),
|
| 189 |
+
"F1 Score": f1_score(y, y_pred, average="weighted", zero_division=0),
|
| 190 |
+
"Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
|
| 191 |
+
"Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
|
| 192 |
+
}
|
| 193 |
+
wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
|
| 194 |
+
results.append(metrics)
|
| 195 |
|
| 196 |
results_df = pd.DataFrame(results)
|
| 197 |
+
|
| 198 |
+
# Plotting
|
| 199 |
+
plt.figure(figsize=(8, 5))
|
| 200 |
+
sns.barplot(data=results_df, x="Model", y="CV Mean Accuracy", palette="Blues_d")
|
| 201 |
+
plt.title("Model Comparison (CV Mean Accuracy)")
|
| 202 |
+
plt.ylim(0, 1)
|
| 203 |
+
plt.tight_layout()
|
| 204 |
+
|
| 205 |
+
plot_path = "./model_comparison.png"
|
| 206 |
+
plt.savefig(plot_path)
|
| 207 |
+
plt.close()
|
| 208 |
+
|
| 209 |
+
return results_df, plot_path
|
| 210 |
+
|
| 211 |
|
| 212 |
# 1. prepare_data should come first
|
| 213 |
def prepare_data(df, target_column=None):
|
|
|
|
| 377 |
shap_img = gr.Image(label="SHAP Summary Plot")
|
| 378 |
lime_img = gr.Image(label="LIME Explanation")
|
| 379 |
|
| 380 |
+
with gr.Row():
|
| 381 |
+
compare_btn = gr.Button("Compare Models (A/B Testing)")
|
| 382 |
+
compare_output = gr.DataFrame(label="Model Comparison (CV + Metrics)")
|
| 383 |
+
compare_img = gr.Image(label="Model Accuracy Plot")
|
| 384 |
+
|
| 385 |
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
|
| 386 |
train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
|
| 387 |
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
|
| 388 |
+
compare_btn.click(fn=compare_models, inputs=[], outputs=[compare_output, compare_img])
|
| 389 |
+
|
| 390 |
|
| 391 |
demo.launch(debug=True)
|