mvppred / scripts /app.py
Md Wasi Ul Kabir
Initial commit
8bb21fb
import os
import sys
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
import streamlit as st
from huggingface_hub import hf_hub_download
ARTIFACT_REPO = "wasicse/mvppred-artifacts"
ARTIFACT_FILES = [
"angle_bundle.joblib",
"bite_bundle.joblib",
"distance_capacity_bundle.joblib",
"endurance_bundle.joblib",
"jump_accel_bundle.joblib",
"jump_distance_bundle.joblib",
"jump_power_bundle.joblib",
"jump_vel_bundle.joblib",
"sprint_bundle.joblib",
]
@st.cache_resource
def ensure_artifacts():
outdir = Path("artifacts_inference")
outdir.mkdir(exist_ok=True)
for name in ARTIFACT_FILES:
target = outdir / name
if not target.exists():
downloaded = hf_hub_download(
repo_id=ARTIFACT_REPO,
filename=name,
repo_type="model",
)
target.write_bytes(Path(downloaded).read_bytes())
ensure_artifacts()
# Make sure project root is on PYTHONPATH (so src/... imports work)
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from infer import predict_with_confidence
st.set_page_config(page_title="Lizard Performance Predictor", layout="wide")
st.title("MVPpred: Lizard Performance Predictor")
st.caption("Enter phenotypic features manually (use -1 for missing) and view predictions + confidence.")
# -------------------------
# Hardcoded config (remove widgets)
# -------------------------
BUNDLE_DIR = "artifacts_inference" # <-- set once here
INTERVAL = "q90" # <-- set once here ("q90" or "q95")
# -------------------------
# Cache model bundles (huge speedup)
# -------------------------
@st.cache_resource
def load_bundle(path: str):
return joblib.load(path)
# -------------------------
# Load available targets
# -------------------------
if not os.path.isdir(BUNDLE_DIR):
st.error(f"Bundle directory not found: {BUNDLE_DIR}")
st.stop()
bundle_files = sorted([f for f in os.listdir(BUNDLE_DIR) if f.endswith("_bundle.joblib")])
if not bundle_files:
st.error("No *_bundle.joblib files found in bundle directory.")
st.stop()
targets = [f.replace("_bundle.joblib", "") for f in bundle_files]
selected_targets = st.multiselect("Targets to predict", targets, default=targets)
st.divider()
# -------------------------
# Default example sample (your provided row)
# -------------------------
default_sample = {
"taxon": 69,
"genus": 22,
"species": 68,
"sex_num": 0, # 0/1 coding in your table
"mass": 3.04,
"svl": 52.32,
"hl": 12.905,
"hw": -1.0,
"hh": -1.0,
"femur": 10.675,
"tibia": 8.8325,
"metat": 4.23,
"hindtoe": 11.37,
"humerus": 5.365,
"radius": 6.31,
"metac": 2.3175,
"foretoe": 5.8125,
"tail": 37.265,
}
st.subheader("Enter one sample")
with st.form("manual_input_form"):
# Optional taxonomy fields (kept for display; model may ignore them)
c0, c1, c2, c3 = st.columns(4)
with c0:
taxon = st.number_input("taxon", value=int(default_sample["taxon"]))
with c1:
genus = st.number_input("genus", value=int(default_sample["genus"]))
with c2:
species = st.number_input("species", value=int(default_sample["species"]))
with c3:
# Keep your original m/f, but prefill from sex_num (0 -> m, 1 -> f)
default_sex = "m" if int(default_sample["sex_num"]) == 0 else "f"
sex = st.selectbox("sex (m/f)", ["m", "f"], index=0 if default_sex == "m" else 1)
col1, col2, col3 = st.columns(3)
with col1:
mass = st.number_input("mass", value=float(default_sample["mass"]))
svl = st.number_input("svl", value=float(default_sample["svl"]))
hl = st.number_input("hl", value=float(default_sample["hl"]))
hw = st.number_input("hw", value=float(default_sample["hw"]))
hh = st.number_input("hh", value=float(default_sample["hh"]))
with col2:
femur = st.number_input("femur", value=float(default_sample["femur"]))
tibia = st.number_input("tibia", value=float(default_sample["tibia"]))
metat = st.number_input("metat", value=float(default_sample["metat"]))
hindtoe = st.number_input("hindtoe", value=float(default_sample["hindtoe"]))
with col3:
humerus = st.number_input("humerus", value=float(default_sample["humerus"]))
radius = st.number_input("radius", value=float(default_sample["radius"]))
metac = st.number_input("metac", value=float(default_sample["metac"]))
foretoe = st.number_input("foretoe", value=float(default_sample["foretoe"]))
tail = st.number_input("tail", value=float(default_sample["tail"]))
run_btn = st.form_submit_button("Run predictions")
# -------------------------
# Run predictions (with progress)
# -------------------------
if run_btn:
if not selected_targets:
st.warning("Please select at least one target.")
st.stop()
# Build 1-row dataframe for the model (ONLY include columns used in training)
input_row = {
"sex": sex, # your pipeline expects m/f
"mass": mass,
"svl": svl,
"hl": hl,
"hw": hw,
"hh": hh,
"femur": femur,
"tibia": tibia,
"metat": metat,
"hindtoe": hindtoe,
"humerus": humerus,
"radius": radius,
"metac": metac,
"foretoe": foretoe,
"tail": tail,
}
df = pd.DataFrame([input_row])
progress = st.progress(0)
status = st.empty()
all_outputs = []
n = len(selected_targets)
for i, t in enumerate(selected_targets, start=1):
status.write(f"Running {t} ({i}/{n}) …")
path = os.path.join(BUNDLE_DIR, f"{t}_bundle.joblib")
bundle = load_bundle(path) # cached
out = predict_with_confidence(bundle, df, interval=INTERVAL)
out.insert(0, "target", t)
all_outputs.append(out.reset_index(drop=True))
progress.progress(i / n)
status.write("Prediction Complete.")
result = pd.concat(all_outputs, axis=0, ignore_index=True)
st.subheader("Predictions (with confidence)")
st.dataframe(result, use_container_width=True)
# st.subheader("Confidence summary")
# st.write(result["confidence_label"].value_counts(dropna=False))
# # Make per-target view optional (faster UI)
# show_cards = st.checkbox("Show per-target view", value=False)
# if show_cards:
# st.subheader("Per-target view")
# for _, row in result.iterrows():
# with st.expander(f"{row['target']} — {row['confidence_label']} (score={row['confidence_score']:.2f})"):
# st.write(
# {
# "prediction": float(row["prediction"]),
# "lower": float(row["lower"]) if np.isfinite(row["lower"]) else None,
# "upper": float(row["upper"]) if np.isfinite(row["upper"]) else None,
# "confidence_score": float(row["confidence_score"]),
# "confidence_label": row["confidence_label"],
# "note": row.get("note", ""),
# }
# )
csv_out = result.to_csv(index=False).encode("utf-8")
st.download_button(
"Download results CSV",
csv_out,
file_name="predictions_with_confidence.csv",
mime="text/csv",
)