import os import sys from pathlib import Path import joblib import numpy as np import pandas as pd import streamlit as st from huggingface_hub import hf_hub_download ARTIFACT_REPO = "wasicse/mvppred-artifacts" ARTIFACT_FILES = [ "angle_bundle.joblib", "bite_bundle.joblib", "distance_capacity_bundle.joblib", "endurance_bundle.joblib", "jump_accel_bundle.joblib", "jump_distance_bundle.joblib", "jump_power_bundle.joblib", "jump_vel_bundle.joblib", "sprint_bundle.joblib", ] @st.cache_resource def ensure_artifacts(): outdir = Path("artifacts_inference") outdir.mkdir(exist_ok=True) for name in ARTIFACT_FILES: target = outdir / name if not target.exists(): downloaded = hf_hub_download( repo_id=ARTIFACT_REPO, filename=name, repo_type="model", ) target.write_bytes(Path(downloaded).read_bytes()) ensure_artifacts() # Make sure project root is on PYTHONPATH (so src/... imports work) ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from infer import predict_with_confidence st.set_page_config(page_title="Lizard Performance Predictor", layout="wide") st.title("MVPpred: Lizard Performance Predictor") st.caption("Enter phenotypic features manually (use -1 for missing) and view predictions + confidence.") # ------------------------- # Hardcoded config (remove widgets) # ------------------------- BUNDLE_DIR = "artifacts_inference" # <-- set once here INTERVAL = "q90" # <-- set once here ("q90" or "q95") # ------------------------- # Cache model bundles (huge speedup) # ------------------------- @st.cache_resource def load_bundle(path: str): return joblib.load(path) # ------------------------- # Load available targets # ------------------------- if not os.path.isdir(BUNDLE_DIR): st.error(f"Bundle directory not found: {BUNDLE_DIR}") st.stop() bundle_files = sorted([f for f in os.listdir(BUNDLE_DIR) if f.endswith("_bundle.joblib")]) if not bundle_files: st.error("No *_bundle.joblib files found in bundle directory.") st.stop() targets = [f.replace("_bundle.joblib", "") for f in bundle_files] selected_targets = st.multiselect("Targets to predict", targets, default=targets) st.divider() # ------------------------- # Default example sample (your provided row) # ------------------------- default_sample = { "taxon": 69, "genus": 22, "species": 68, "sex_num": 0, # 0/1 coding in your table "mass": 3.04, "svl": 52.32, "hl": 12.905, "hw": -1.0, "hh": -1.0, "femur": 10.675, "tibia": 8.8325, "metat": 4.23, "hindtoe": 11.37, "humerus": 5.365, "radius": 6.31, "metac": 2.3175, "foretoe": 5.8125, "tail": 37.265, } st.subheader("Enter one sample") with st.form("manual_input_form"): # Optional taxonomy fields (kept for display; model may ignore them) c0, c1, c2, c3 = st.columns(4) with c0: taxon = st.number_input("taxon", value=int(default_sample["taxon"])) with c1: genus = st.number_input("genus", value=int(default_sample["genus"])) with c2: species = st.number_input("species", value=int(default_sample["species"])) with c3: # Keep your original m/f, but prefill from sex_num (0 -> m, 1 -> f) default_sex = "m" if int(default_sample["sex_num"]) == 0 else "f" sex = st.selectbox("sex (m/f)", ["m", "f"], index=0 if default_sex == "m" else 1) col1, col2, col3 = st.columns(3) with col1: mass = st.number_input("mass", value=float(default_sample["mass"])) svl = st.number_input("svl", value=float(default_sample["svl"])) hl = st.number_input("hl", value=float(default_sample["hl"])) hw = st.number_input("hw", value=float(default_sample["hw"])) hh = st.number_input("hh", value=float(default_sample["hh"])) with col2: femur = st.number_input("femur", value=float(default_sample["femur"])) tibia = st.number_input("tibia", value=float(default_sample["tibia"])) metat = st.number_input("metat", value=float(default_sample["metat"])) hindtoe = st.number_input("hindtoe", value=float(default_sample["hindtoe"])) with col3: humerus = st.number_input("humerus", value=float(default_sample["humerus"])) radius = st.number_input("radius", value=float(default_sample["radius"])) metac = st.number_input("metac", value=float(default_sample["metac"])) foretoe = st.number_input("foretoe", value=float(default_sample["foretoe"])) tail = st.number_input("tail", value=float(default_sample["tail"])) run_btn = st.form_submit_button("Run predictions") # ------------------------- # Run predictions (with progress) # ------------------------- if run_btn: if not selected_targets: st.warning("Please select at least one target.") st.stop() # Build 1-row dataframe for the model (ONLY include columns used in training) input_row = { "sex": sex, # your pipeline expects m/f "mass": mass, "svl": svl, "hl": hl, "hw": hw, "hh": hh, "femur": femur, "tibia": tibia, "metat": metat, "hindtoe": hindtoe, "humerus": humerus, "radius": radius, "metac": metac, "foretoe": foretoe, "tail": tail, } df = pd.DataFrame([input_row]) progress = st.progress(0) status = st.empty() all_outputs = [] n = len(selected_targets) for i, t in enumerate(selected_targets, start=1): status.write(f"Running {t} ({i}/{n}) …") path = os.path.join(BUNDLE_DIR, f"{t}_bundle.joblib") bundle = load_bundle(path) # cached out = predict_with_confidence(bundle, df, interval=INTERVAL) out.insert(0, "target", t) all_outputs.append(out.reset_index(drop=True)) progress.progress(i / n) status.write("Prediction Complete.") result = pd.concat(all_outputs, axis=0, ignore_index=True) st.subheader("Predictions (with confidence)") st.dataframe(result, use_container_width=True) # st.subheader("Confidence summary") # st.write(result["confidence_label"].value_counts(dropna=False)) # # Make per-target view optional (faster UI) # show_cards = st.checkbox("Show per-target view", value=False) # if show_cards: # st.subheader("Per-target view") # for _, row in result.iterrows(): # with st.expander(f"{row['target']} — {row['confidence_label']} (score={row['confidence_score']:.2f})"): # st.write( # { # "prediction": float(row["prediction"]), # "lower": float(row["lower"]) if np.isfinite(row["lower"]) else None, # "upper": float(row["upper"]) if np.isfinite(row["upper"]) else None, # "confidence_score": float(row["confidence_score"]), # "confidence_label": row["confidence_label"], # "note": row.get("note", ""), # } # ) csv_out = result.to_csv(index=False).encode("utf-8") st.download_button( "Download results CSV", csv_out, file_name="predictions_with_confidence.csv", mime="text/csv", )