Update src/streamlit_app.py
Browse files- src/streamlit_app.py +26 -5
src/streamlit_app.py
CHANGED
|
@@ -555,18 +555,39 @@ with tabs[4]:
|
|
| 555 |
max_depth=trial.suggest_int("max_depth", 4, 30),
|
| 556 |
random_state=42, n_jobs=-1)
|
| 557 |
else:
|
|
|
|
| 558 |
m = RandomForestRegressor(random_state=42)
|
|
|
|
| 559 |
try:
|
| 560 |
return np.mean(cross_val_score(m, X_local, y_local, cv=3, scoring="r2"))
|
| 561 |
except Exception:
|
| 562 |
return -999.0
|
| 563 |
-
|
|
|
|
| 564 |
study = optuna.create_study(direction="maximize")
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
if fam == "RandomForest":
|
| 568 |
-
model = RandomForestRegressor(**
|
| 569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
# --- Run button ---
|
| 572 |
if st.button("Run AutoML + SHAP"):
|
|
|
|
| 555 |
max_depth=trial.suggest_int("max_depth", 4, 30),
|
| 556 |
random_state=42, n_jobs=-1)
|
| 557 |
else:
|
| 558 |
+
# fallback
|
| 559 |
m = RandomForestRegressor(random_state=42)
|
| 560 |
+
|
| 561 |
try:
|
| 562 |
return np.mean(cross_val_score(m, X_local, y_local, cv=3, scoring="r2"))
|
| 563 |
except Exception:
|
| 564 |
return -999.0
|
| 565 |
+
|
| 566 |
+
# --- Run Optuna optimization ---
|
| 567 |
study = optuna.create_study(direction="maximize")
|
| 568 |
+
try:
|
| 569 |
+
study.optimize(obj, n_trials=n_trials, show_progress_bar=False)
|
| 570 |
+
params = study.best_trial.params if study.trials else {}
|
| 571 |
+
best_score = study.best_value if study.trials else -999.0
|
| 572 |
+
except Exception as e:
|
| 573 |
+
st.warning(f"Optuna failed for {fam}: {e}")
|
| 574 |
+
params, best_score = {}, -999.0
|
| 575 |
+
|
| 576 |
+
# --- Always safely initialize a model, even if trials failed ---
|
| 577 |
if fam == "RandomForest":
|
| 578 |
+
model = RandomForestRegressor(**params, random_state=42, n_jobs=-1)
|
| 579 |
+
elif fam == "ExtraTrees":
|
| 580 |
+
model = ExtraTreesRegressor(**params, random_state=42, n_jobs=-1)
|
| 581 |
+
else:
|
| 582 |
+
model = RandomForestRegressor(random_state=42, n_jobs=-1)
|
| 583 |
+
|
| 584 |
+
return {
|
| 585 |
+
"family": fam,
|
| 586 |
+
"model_obj": model,
|
| 587 |
+
"best_params": params,
|
| 588 |
+
"cv_score": best_score
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
|
| 592 |
# --- Run button ---
|
| 593 |
if st.button("Run AutoML + SHAP"):
|