Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,6 @@ from sklearn.base import BaseEstimator
|
|
| 21 |
if not hasattr(BaseEstimator, "sklearn_tags"):
|
| 22 |
# scikit-learn < 1.6 only has get_tags(); provide sklearn_tags() alias
|
| 23 |
def _sklearn_tags(self):
|
| 24 |
-
# mimic 1.6 behavior by delegating to get_tags()
|
| 25 |
return self.get_tags()
|
| 26 |
BaseEstimator.sklearn_tags = _sklearn_tags
|
| 27 |
|
|
@@ -252,9 +251,7 @@ def update_suggestion_panel(target, species):
|
|
| 252 |
# Load and normalize real data (for allowed pairs + KNN imputer)
|
| 253 |
# -----------------------------
|
| 254 |
if not RAW_PATH.exists():
|
| 255 |
-
raise FileNotFoundError(
|
| 256 |
-
"Missing 'ai_al.csv'. Please upload it to the Space root (same folder as app.py)."
|
| 257 |
-
)
|
| 258 |
|
| 259 |
df_raw = pd.read_csv(RAW_PATH)
|
| 260 |
df_raw.columns = (
|
|
@@ -634,7 +631,6 @@ def _encode_df_for_bundle(bundle: EnsembleBundle, df_like: pd.DataFrame) -> pd.D
|
|
| 634 |
def _norm(x):
|
| 635 |
return "nan" if pd.isna(x) else str(x).strip().lower()
|
| 636 |
|
| 637 |
-
# Ensure all columns in the right order
|
| 638 |
X = pd.DataFrame({c: df_like[c] if c in df_like.columns else np.nan for c in bundle.feature_order})
|
| 639 |
|
| 640 |
# encode categoricals
|
|
@@ -660,19 +656,14 @@ def _encode_df_for_bundle(bundle: EnsembleBundle, df_like: pd.DataFrame) -> pd.D
|
|
| 660 |
if c in X.columns:
|
| 661 |
X[c] = X[c].apply(_extract_first_float)
|
| 662 |
|
| 663 |
-
# impute to exact training numeric space
|
| 664 |
X_imp = pd.DataFrame(bundle.imputer.transform(X[bundle.feature_order]), columns=bundle.feature_order)
|
| 665 |
return X_imp
|
| 666 |
|
| 667 |
def predict_stack_batch(target: str, df_raw_rows: pd.DataFrame) -> tuple[np.ndarray, dict]:
|
| 668 |
-
"""
|
| 669 |
-
Vectorized stacked prediction for multiple rows.
|
| 670 |
-
Returns (stack_preds, base_pred_dict) where base_pred_dict has arrays for each base.
|
| 671 |
-
"""
|
| 672 |
b = _load_ensemble(target)
|
| 673 |
X_imp = _encode_df_for_bundle(b, df_raw_rows)
|
| 674 |
|
| 675 |
-
# Base preds
|
| 676 |
pred_xgb = b.xgb.predict(X_imp)
|
| 677 |
if b.lgb_booster is not None:
|
| 678 |
pred_lgb = b.lgb_booster.predict(X_imp)
|
|
@@ -683,7 +674,6 @@ def predict_stack_batch(target: str, df_raw_rows: pd.DataFrame) -> tuple[np.ndar
|
|
| 683 |
X_mlp = b.scaler.transform(X_imp) if b.scaler is not None else X_imp
|
| 684 |
pred_mlp = b.mlp.predict(X_mlp, verbose=0).reshape(-1)
|
| 685 |
|
| 686 |
-
# STACK
|
| 687 |
meta_in = np.vstack([pred_xgb, pred_lgb, pred_cat, pred_mlp]).T
|
| 688 |
pred_stack = b.meta.predict(meta_in)
|
| 689 |
|
|
@@ -693,22 +683,59 @@ def predict_stack_batch(target: str, df_raw_rows: pd.DataFrame) -> tuple[np.ndar
|
|
| 693 |
def predict_with_ensemble_one(target: str, raw_row: dict) -> dict:
|
| 694 |
df = pd.DataFrame([raw_row])
|
| 695 |
stack, bases = predict_stack_batch(target, df)
|
| 696 |
-
return {
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
# -----------------------------
|
| 705 |
# Predict + Uncertainty + Plot (with bounds clamping)
|
| 706 |
# -----------------------------
|
| 707 |
def predict_and_plot_ui(
|
| 708 |
-
target, species, media, light, expo_day, expo_night, temp_c, ph, days, plot_var
|
| 709 |
):
|
| 710 |
try:
|
| 711 |
-
# 0) raw row for ensemble
|
| 712 |
raw_row = {
|
| 713 |
"species": species, "media": media, "light": light,
|
| 714 |
"expo_day": expo_day, "expo_night": expo_night,
|
|
@@ -718,9 +745,16 @@ def predict_and_plot_ui(
|
|
| 718 |
# 1) KNN point for uncertainty
|
| 719 |
X_one = preprocess_row(species, media, light, expo_day, expo_night, temp_c, ph, days)
|
| 720 |
|
| 721 |
-
# 2)
|
| 722 |
-
|
| 723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
|
| 725 |
# 3) local uncertainty
|
| 726 |
qlo, qhi = _local_interval(target, X_one.values)
|
|
@@ -734,7 +768,7 @@ def predict_and_plot_ui(
|
|
| 734 |
lo_pt, _ = _clamp_scalar(lo_raw, b_lo, b_hi)
|
| 735 |
hi_pt, _ = _clamp_scalar(hi_raw, b_lo, b_hi)
|
| 736 |
|
| 737 |
-
# 5) response curve vs selected variable (
|
| 738 |
plot_var = (plot_var or "light").strip().lower()
|
| 739 |
if plot_var not in FEATURES: plot_var = "light"
|
| 740 |
j = FEATURES.index(plot_var)
|
|
@@ -743,18 +777,15 @@ def predict_and_plot_ui(
|
|
| 743 |
p05, p95 = _PERC[target][plot_var]
|
| 744 |
xs = np.linspace(p05, p95, 60)
|
| 745 |
|
| 746 |
-
# Build raw grid rows by sweeping only plot_var
|
| 747 |
grid_rows = []
|
| 748 |
-
x0_vals = X_one.values[0] # imputed numeric point (for reference only)
|
| 749 |
for xv in xs:
|
| 750 |
row = dict(raw_row)
|
| 751 |
-
# replace swept variable with numeric value
|
| 752 |
if plot_var in ["light","expo_day","expo_night","_c","ph","days"]:
|
| 753 |
row[plot_var] = float(xv)
|
| 754 |
grid_rows.append(row)
|
| 755 |
raw_grid_df = pd.DataFrame(grid_rows)
|
| 756 |
|
| 757 |
-
y_grid_raw
|
| 758 |
|
| 759 |
# KNN uncertainty band along the grid (independent of model)
|
| 760 |
X_grid = np.repeat(X_one.values, len(xs), axis=0)
|
|
@@ -766,11 +797,11 @@ def predict_and_plot_ui(
|
|
| 766 |
qlo_g, _ = _clamp_array(qlo_g_raw, b_lo, b_hi)
|
| 767 |
qhi_g, _ = _clamp_array(qhi_g_raw, b_lo, b_hi)
|
| 768 |
|
| 769 |
-
# 6) plot
|
| 770 |
fig, ax = plt.subplots(figsize=(7.0, 4.2))
|
| 771 |
if b_lo is not None and b_hi is not None:
|
| 772 |
ax.axhspan(b_lo, b_hi, alpha=0.10, label="Allowed range")
|
| 773 |
-
ax.plot(xs, y_grid, label="
|
| 774 |
ax.fill_between(xs, qlo_g, qhi_g, alpha=0.25, label=f"Local {int((Q_HI-Q_LO)*100)}% band")
|
| 775 |
|
| 776 |
x0 = float(X_one.values[0, j])
|
|
@@ -788,17 +819,20 @@ def predict_and_plot_ui(
|
|
| 788 |
|
| 789 |
clamp_note = " _(clamped to literature range)_" if clamped_point else ""
|
| 790 |
md = (
|
| 791 |
-
f"### Prediction (
|
| 792 |
f"**{target}** = **{yhat:.3f}**{clamp_note} \n"
|
| 793 |
-
f"Local {int((Q_HI-Q_LO)*100)}% interval: **[{lo_pt:.3f}, {hi_pt:.3f}]**
|
| 794 |
-
f"<details><summary>Base models</summary>\n"
|
| 795 |
-
f"XGB: {preds_point['XGB']:.4f} | "
|
| 796 |
-
f"LGBM: {preds_point['LGBM']:.4f} | "
|
| 797 |
-
f"CAT: {preds_point['CAT']:.4f} | "
|
| 798 |
-
f"MLP: {preds_point['MLP']:.4f}\n"
|
| 799 |
-
f"</details>"
|
| 800 |
+ ("" if not clamped_curve else "\n\n*Response curve clipped to species×medium range.*")
|
| 801 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
return md, fig
|
| 803 |
|
| 804 |
except Exception as e:
|
|
@@ -827,6 +861,15 @@ def update_media(species):
|
|
| 827 |
value = choices[0] if choices else None
|
| 828 |
return gr.update(choices=choices, value=value)
|
| 829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
allowed_species = allowed_species_choices()
|
| 831 |
first_species = allowed_species[0] if allowed_species else None
|
| 832 |
first_media_choices = allowed_media_for(first_species) if first_species else []
|
|
@@ -835,8 +878,9 @@ first_media = first_media_choices[0] if first_media_choices else None
|
|
| 835 |
with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
|
| 836 |
gr.Markdown(
|
| 837 |
f"<h1>Algae Yield Predictor</h1>"
|
| 838 |
-
f"<div class='small'>Predict <b>biomass / lipid / protein / carbohydrate</b> with
|
| 839 |
-
f"<b>XGB
|
|
|
|
| 840 |
f"{'' if DOI_READY else ' <em>(DOI file missing or lacks a doi column.)</em>'}"
|
| 841 |
f"</div>",
|
| 842 |
elem_classes=["card"]
|
|
@@ -847,6 +891,7 @@ with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
|
|
| 847 |
with gr.Group(elem_classes=["card"]):
|
| 848 |
gr.Markdown("### Inputs")
|
| 849 |
target_dd = gr.Dropdown(choices=TARGETS, value="biomass", label="Target", info="Choose outcome to predict")
|
|
|
|
| 850 |
with gr.Row():
|
| 851 |
species_dd = gr.Dropdown(choices=allowed_species, value=first_species, label="Species", info="Only curated species")
|
| 852 |
media_dd = gr.Dropdown(choices=first_media_choices, value=first_media, label="Medium", info="Restricted by species")
|
|
@@ -875,11 +920,24 @@ with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
|
|
| 875 |
gr.Markdown("### Suggested Conditions")
|
| 876 |
suggest_md = gr.Markdown(value=_format_suggestion_md(first_species or "", "biomass"))
|
| 877 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
with gr.Column(scale=6):
|
| 879 |
with gr.Group(elem_classes=["card"]):
|
| 880 |
pred_md = gr.Markdown("Click **Predict + Plot** to run.")
|
| 881 |
with gr.Group(elem_classes=["card"]):
|
| 882 |
-
gr.Markdown("###
|
| 883 |
plot_out = gr.Plot()
|
| 884 |
with gr.Group(elem_classes=["card"]):
|
| 885 |
gr.Markdown("### Literature (DOI) Matches")
|
|
@@ -889,15 +947,16 @@ with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
|
|
| 889 |
species_dd.change(fn=update_media, inputs=species_dd, outputs=media_dd)
|
| 890 |
target_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
|
| 891 |
species_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
|
|
|
|
| 892 |
|
| 893 |
go.click(
|
| 894 |
fn=predict_and_plot_ui,
|
| 895 |
-
inputs=[target_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_num, ph_num, days_sl, plot_var_dd],
|
| 896 |
outputs=[pred_md, plot_out]
|
| 897 |
)
|
| 898 |
doi_btn.click(
|
| 899 |
fn=doi_matches_ui,
|
| 900 |
-
inputs=[target_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_num, ph_num, days_sl],
|
| 901 |
outputs=doi_md
|
| 902 |
)
|
| 903 |
|
|
|
|
| 21 |
if not hasattr(BaseEstimator, "sklearn_tags"):
|
| 22 |
# scikit-learn < 1.6 only has get_tags(); provide sklearn_tags() alias
|
| 23 |
def _sklearn_tags(self):
|
|
|
|
| 24 |
return self.get_tags()
|
| 25 |
BaseEstimator.sklearn_tags = _sklearn_tags
|
| 26 |
|
|
|
|
| 251 |
# Load and normalize real data (for allowed pairs + KNN imputer)
|
| 252 |
# -----------------------------
|
| 253 |
if not RAW_PATH.exists():
|
| 254 |
+
raise FileNotFoundError("Missing 'ai_al.csv'. Please upload it to the Space root (same folder as app.py).")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
df_raw = pd.read_csv(RAW_PATH)
|
| 257 |
df_raw.columns = (
|
|
|
|
| 631 |
def _norm(x):
|
| 632 |
return "nan" if pd.isna(x) else str(x).strip().lower()
|
| 633 |
|
|
|
|
| 634 |
X = pd.DataFrame({c: df_like[c] if c in df_like.columns else np.nan for c in bundle.feature_order})
|
| 635 |
|
| 636 |
# encode categoricals
|
|
|
|
| 656 |
if c in X.columns:
|
| 657 |
X[c] = X[c].apply(_extract_first_float)
|
| 658 |
|
|
|
|
| 659 |
X_imp = pd.DataFrame(bundle.imputer.transform(X[bundle.feature_order]), columns=bundle.feature_order)
|
| 660 |
return X_imp
|
| 661 |
|
| 662 |
def predict_stack_batch(target: str, df_raw_rows: pd.DataFrame) -> tuple[np.ndarray, dict]:
|
| 663 |
+
"""Vectorized stacked prediction for multiple rows. Returns (stack_preds, base_pred_dict)."""
|
|
|
|
|
|
|
|
|
|
| 664 |
b = _load_ensemble(target)
|
| 665 |
X_imp = _encode_df_for_bundle(b, df_raw_rows)
|
| 666 |
|
|
|
|
| 667 |
pred_xgb = b.xgb.predict(X_imp)
|
| 668 |
if b.lgb_booster is not None:
|
| 669 |
pred_lgb = b.lgb_booster.predict(X_imp)
|
|
|
|
| 674 |
X_mlp = b.scaler.transform(X_imp) if b.scaler is not None else X_imp
|
| 675 |
pred_mlp = b.mlp.predict(X_mlp, verbose=0).reshape(-1)
|
| 676 |
|
|
|
|
| 677 |
meta_in = np.vstack([pred_xgb, pred_lgb, pred_cat, pred_mlp]).T
|
| 678 |
pred_stack = b.meta.predict(meta_in)
|
| 679 |
|
|
|
|
| 683 |
def predict_with_ensemble_one(target: str, raw_row: dict) -> dict:
|
| 684 |
df = pd.DataFrame([raw_row])
|
| 685 |
stack, bases = predict_stack_batch(target, df)
|
| 686 |
+
return {"STACK": float(stack[0]), "XGB": float(bases["XGB"][0]), "LGBM": float(bases["LGBM"][0]),
|
| 687 |
+
"CAT": float(bases["CAT"][0]), "MLP": float(bases["MLP"][0])}
|
| 688 |
+
|
| 689 |
+
# ---- New: model chooser support ----
|
| 690 |
+
MODEL_NAMES = ["STACK", "XGB", "LGBM", "CAT", "MLP"]
|
| 691 |
+
|
| 692 |
+
def _available_models_for_target(target: str) -> list[str]:
|
| 693 |
+
base = MODEL_DIR / target
|
| 694 |
+
avail = []
|
| 695 |
+
if (base / "meta.joblib").exists(): avail.append("STACK")
|
| 696 |
+
if (base / "xgb.json").exists(): avail.append("XGB")
|
| 697 |
+
if (base / "lgb.txt").exists() or (base / "lgb.joblib").exists(): avail.append("LGBM")
|
| 698 |
+
if (base / "cat.cbm").exists(): avail.append("CAT")
|
| 699 |
+
if (base / "mlp.keras").exists() or (base / "mlp_savedmodel").exists(): avail.append("MLP")
|
| 700 |
+
# keep order as MODEL_NAMES
|
| 701 |
+
return [m for m in MODEL_NAMES if m in avail]
|
| 702 |
+
|
| 703 |
+
def _predict_with_model_choice(target: str, model_choice: str, df_rows: pd.DataFrame) -> np.ndarray:
|
| 704 |
+
"""Predict with a specific model name. Falls back to first available if missing."""
|
| 705 |
+
avail = _available_models_for_target(target)
|
| 706 |
+
if not avail:
|
| 707 |
+
raise FileNotFoundError(f"No saved models found under models/{target}")
|
| 708 |
+
chosen = model_choice if model_choice in avail else avail[0]
|
| 709 |
+
|
| 710 |
+
if chosen == "STACK":
|
| 711 |
+
y, _ = predict_stack_batch(target, df_rows)
|
| 712 |
+
return y
|
| 713 |
+
|
| 714 |
+
# base models via bundle
|
| 715 |
+
b = _load_ensemble(target)
|
| 716 |
+
X_imp = _encode_df_for_bundle(b, df_rows)
|
| 717 |
+
if chosen == "XGB":
|
| 718 |
+
return np.asarray(b.xgb.predict(X_imp), dtype=float)
|
| 719 |
+
if chosen == "LGBM":
|
| 720 |
+
if b.lgb_booster is not None:
|
| 721 |
+
return np.asarray(b.lgb_booster.predict(X_imp), dtype=float)
|
| 722 |
+
return np.asarray(b.lgb_model.predict(X_imp), dtype=float)
|
| 723 |
+
if chosen == "CAT":
|
| 724 |
+
return np.asarray(b.cat.predict(X_imp), dtype=float)
|
| 725 |
+
if chosen == "MLP":
|
| 726 |
+
Xm = b.scaler.transform(X_imp) if b.scaler is not None else X_imp
|
| 727 |
+
return b.mlp.predict(Xm, verbose=0).reshape(-1).astype(float)
|
| 728 |
+
|
| 729 |
+
raise ValueError(f"Unknown model choice: {model_choice}")
|
| 730 |
|
| 731 |
# -----------------------------
|
| 732 |
# Predict + Uncertainty + Plot (with bounds clamping)
|
| 733 |
# -----------------------------
|
| 734 |
def predict_and_plot_ui(
|
| 735 |
+
target, model_choice, species, media, light, expo_day, expo_night, temp_c, ph, days, plot_var
|
| 736 |
):
|
| 737 |
try:
|
| 738 |
+
# 0) raw row for ensemble/base models
|
| 739 |
raw_row = {
|
| 740 |
"species": species, "media": media, "light": light,
|
| 741 |
"expo_day": expo_day, "expo_night": expo_night,
|
|
|
|
| 745 |
# 1) KNN point for uncertainty
|
| 746 |
X_one = preprocess_row(species, media, light, expo_day, expo_night, temp_c, ph, days)
|
| 747 |
|
| 748 |
+
# 2) Model point prediction (selected model) + also compute bases (for info)
|
| 749 |
+
df_one = pd.DataFrame([raw_row])
|
| 750 |
+
avail = _available_models_for_target(target)
|
| 751 |
+
chosen = model_choice if model_choice in avail else (avail[0] if avail else "STACK")
|
| 752 |
+
|
| 753 |
+
y_point = _predict_with_model_choice(target, chosen, df_one)
|
| 754 |
+
yhat_raw = float(y_point[0])
|
| 755 |
+
|
| 756 |
+
# (Optional) show base outputs in details
|
| 757 |
+
preds_point = predict_with_ensemble_one(target, raw_row) if "STACK" in avail else {}
|
| 758 |
|
| 759 |
# 3) local uncertainty
|
| 760 |
qlo, qhi = _local_interval(target, X_one.values)
|
|
|
|
| 768 |
lo_pt, _ = _clamp_scalar(lo_raw, b_lo, b_hi)
|
| 769 |
hi_pt, _ = _clamp_scalar(hi_raw, b_lo, b_hi)
|
| 770 |
|
| 771 |
+
# 5) response curve vs selected variable (same chosen model)
|
| 772 |
plot_var = (plot_var or "light").strip().lower()
|
| 773 |
if plot_var not in FEATURES: plot_var = "light"
|
| 774 |
j = FEATURES.index(plot_var)
|
|
|
|
| 777 |
p05, p95 = _PERC[target][plot_var]
|
| 778 |
xs = np.linspace(p05, p95, 60)
|
| 779 |
|
|
|
|
| 780 |
grid_rows = []
|
|
|
|
| 781 |
for xv in xs:
|
| 782 |
row = dict(raw_row)
|
|
|
|
| 783 |
if plot_var in ["light","expo_day","expo_night","_c","ph","days"]:
|
| 784 |
row[plot_var] = float(xv)
|
| 785 |
grid_rows.append(row)
|
| 786 |
raw_grid_df = pd.DataFrame(grid_rows)
|
| 787 |
|
| 788 |
+
y_grid_raw = _predict_with_model_choice(target, chosen, raw_grid_df)
|
| 789 |
|
| 790 |
# KNN uncertainty band along the grid (independent of model)
|
| 791 |
X_grid = np.repeat(X_one.values, len(xs), axis=0)
|
|
|
|
| 797 |
qlo_g, _ = _clamp_array(qlo_g_raw, b_lo, b_hi)
|
| 798 |
qhi_g, _ = _clamp_array(qhi_g_raw, b_lo, b_hi)
|
| 799 |
|
| 800 |
+
# 6) plot
|
| 801 |
fig, ax = plt.subplots(figsize=(7.0, 4.2))
|
| 802 |
if b_lo is not None and b_hi is not None:
|
| 803 |
ax.axhspan(b_lo, b_hi, alpha=0.10, label="Allowed range")
|
| 804 |
+
ax.plot(xs, y_grid, label=f"{chosen} (predicted mean)")
|
| 805 |
ax.fill_between(xs, qlo_g, qhi_g, alpha=0.25, label=f"Local {int((Q_HI-Q_LO)*100)}% band")
|
| 806 |
|
| 807 |
x0 = float(X_one.values[0, j])
|
|
|
|
| 819 |
|
| 820 |
clamp_note = " _(clamped to literature range)_" if clamped_point else ""
|
| 821 |
md = (
|
| 822 |
+
f"### Prediction ({chosen})\n"
|
| 823 |
f"**{target}** = **{yhat:.3f}**{clamp_note} \n"
|
| 824 |
+
f"Local {int((Q_HI-Q_LO)*100)}% interval: **[{lo_pt:.3f}, {hi_pt:.3f}]**"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
+ ("" if not clamped_curve else "\n\n*Response curve clipped to species×medium range.*")
|
| 826 |
)
|
| 827 |
+
if preds_point:
|
| 828 |
+
md += (
|
| 829 |
+
"\n\n<details><summary>Base models</summary>\n"
|
| 830 |
+
f"XGB: {preds_point['XGB']:.4f} | "
|
| 831 |
+
f"LGBM: {preds_point['LGBM']:.4f} | "
|
| 832 |
+
f"CAT: {preds_point['CAT']:.4f} | "
|
| 833 |
+
f"MLP: {preds_point['MLP']:.4f}\n"
|
| 834 |
+
"</details>"
|
| 835 |
+
)
|
| 836 |
return md, fig
|
| 837 |
|
| 838 |
except Exception as e:
|
|
|
|
| 861 |
value = choices[0] if choices else None
|
| 862 |
return gr.update(choices=choices, value=value)
|
| 863 |
|
| 864 |
+
# ---- New: restrict model choices per target ----
|
| 865 |
+
def update_model_choices(target):
|
| 866 |
+
avail = _available_models_for_target(target)
|
| 867 |
+
if not avail:
|
| 868 |
+
avail = ["STACK"] # hard fallback (shouldn't happen if models exist)
|
| 869 |
+
# Prefer STACK if available, else first
|
| 870 |
+
value = "STACK" if "STACK" in avail else avail[0]
|
| 871 |
+
return gr.update(choices=avail, value=value)
|
| 872 |
+
|
| 873 |
allowed_species = allowed_species_choices()
|
| 874 |
first_species = allowed_species[0] if allowed_species else None
|
| 875 |
first_media_choices = allowed_media_for(first_species) if first_species else []
|
|
|
|
| 878 |
with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
|
| 879 |
gr.Markdown(
|
| 880 |
f"<h1>Algae Yield Predictor</h1>"
|
| 881 |
+
f"<div class='small'>Predict <b>biomass / lipid / protein / carbohydrate</b> with "
|
| 882 |
+
f"a selectable model (<b>STACK / XGB / LGBM / CAT / MLP</b>), local uncertainty bands, "
|
| 883 |
+
f"and species×medium literature-range clamping."
|
| 884 |
f"{'' if DOI_READY else ' <em>(DOI file missing or lacks a doi column.)</em>'}"
|
| 885 |
f"</div>",
|
| 886 |
elem_classes=["card"]
|
|
|
|
| 891 |
with gr.Group(elem_classes=["card"]):
|
| 892 |
gr.Markdown("### Inputs")
|
| 893 |
target_dd = gr.Dropdown(choices=TARGETS, value="biomass", label="Target", info="Choose outcome to predict")
|
| 894 |
+
model_dd = gr.Dropdown(choices=MODEL_NAMES, value="STACK", label="Model", info="Choose which trained model to use")
|
| 895 |
with gr.Row():
|
| 896 |
species_dd = gr.Dropdown(choices=allowed_species, value=first_species, label="Species", info="Only curated species")
|
| 897 |
media_dd = gr.Dropdown(choices=first_media_choices, value=first_media, label="Medium", info="Restricted by species")
|
|
|
|
| 920 |
gr.Markdown("### Suggested Conditions")
|
| 921 |
suggest_md = gr.Markdown(value=_format_suggestion_md(first_species or "", "biomass"))
|
| 922 |
|
| 923 |
+
# ---- New: Model tips card ----
|
| 924 |
+
with gr.Group(elem_classes=["card"]):
|
| 925 |
+
gr.Markdown("### Model Tips")
|
| 926 |
+
model_tips_md = gr.Markdown("""\
|
| 927 |
+
**Recommendations**
|
| 928 |
+
- **STACK (Ensemble)** — best overall accuracy (offline metrics ~R² 0.89 / MAE ~0.66).
|
| 929 |
+
- **XGB / LGBM** — fast, strong single models (R² ~0.69).
|
| 930 |
+
- **CAT** — robust to categorical quirks (R² ~0.62).
|
| 931 |
+
- **MLP** — requires scaler; slower cold start (R² ~0.55 here).
|
| 932 |
+
|
| 933 |
+
**Pick**: Use **STACK** by default. Choose **XGB**/**LGBM** for speed or to sanity-check disagreement across models.
|
| 934 |
+
""")
|
| 935 |
+
|
| 936 |
with gr.Column(scale=6):
|
| 937 |
with gr.Group(elem_classes=["card"]):
|
| 938 |
pred_md = gr.Markdown("Click **Predict + Plot** to run.")
|
| 939 |
with gr.Group(elem_classes=["card"]):
|
| 940 |
+
gr.Markdown("### Response Plot")
|
| 941 |
plot_out = gr.Plot()
|
| 942 |
with gr.Group(elem_classes=["card"]):
|
| 943 |
gr.Markdown("### Literature (DOI) Matches")
|
|
|
|
| 947 |
species_dd.change(fn=update_media, inputs=species_dd, outputs=media_dd)
|
| 948 |
target_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
|
| 949 |
species_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
|
| 950 |
+
target_dd.change(fn=update_model_choices, inputs=target_dd, outputs=model_dd)
|
| 951 |
|
| 952 |
go.click(
|
| 953 |
fn=predict_and_plot_ui,
|
| 954 |
+
inputs=[target_dd, model_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_num, ph_num, days_sl, plot_var_dd],
|
| 955 |
outputs=[pred_md, plot_out]
|
| 956 |
)
|
| 957 |
doi_btn.click(
|
| 958 |
fn=doi_matches_ui,
|
| 959 |
+
inputs=[target_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_c := temp_num, ph_num, days_sl],
|
| 960 |
outputs=doi_md
|
| 961 |
)
|
| 962 |
|