gurma-dashboard / eval_tab.py
Emre Sarigöl
Deploy GURMA.ai Dashboard - 2026-03-01 20:49
ef90a4e
"""
GURMA.ai — Model Evaluation Tab
Displays benchmark results, test sample comparisons, and live re-inference.
Dual backend: MLX (local Apple Silicon) / HF Inference API (HF Spaces).
"""
import copy
import json
import os
import re
from collections import Counter
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
# ============================================================
# Environment & Paths
# ============================================================
IS_HF_SPACE = os.getenv("HF_SPACE") or Path("/app/research.py").exists()
if IS_HF_SPACE:
DATA_ROOT = Path("/app/data")
else:
DATA_ROOT = Path(__file__).resolve().parent.parent.parent / "data"
EXPERIMENTS_DIR = DATA_ROOT / "experiments"
TRAINING_DIR = DATA_ROOT / "training"
ADAPTERS_DIR = DATA_ROOT / "adapters"
# MLX model → HF Hub model for Inference API
MODEL_HF_MAP = {
"Qwen/Qwen3-8B-MLX-4bit": "Qwen/Qwen3-8B",
"Qwen/Qwen3-8B-MLX-8bit": "Qwen/Qwen3-8B",
"Qwen/Qwen3-30B-A3B-MLX-4bit": "Qwen/Qwen3-30B-A3B",
}
SYSTEM_PROMPT = (
"You are a rehabilitation robotics AI assistant for GURMA.ai. "
"You help clinicians and engineers with therapy parameters, "
"patient outcome interpretation, safety protocols, and session reporting. "
"Be precise, evidence-aware, and flag uncertainty explicitly."
)
# ============================================================
# Colour palette (GURMA brand)
# ============================================================
C_BASE = "#6c757d" # grey
C_ADAPTED = "#198754" # green
C_BG = "rgba(0,0,0,0)"
C_GRID = "rgba(200,200,200,0.3)"
# ============================================================
# Helpers
# ============================================================
def _resolve_adapter(bench_data: dict) -> str:
"""Get the effective adapter path from bench data.
Standard bench stores 'adapter_path'. Routed bench stores
'fallback_adapter' + per-category 'routing'. Return the
fallback (general) adapter for routed, or adapter_path for standard.
"""
ap = bench_data.get("adapter_path")
if ap:
return ap
return bench_data.get("fallback_adapter") or ""
def _is_routed(bench_data: dict) -> bool:
"""Whether this bench used per-category adapter routing."""
return bool(bench_data.get("routing"))
# ============================================================
# Data Loading
# ============================================================
@st.cache_data(ttl=120)
def _load_bench(path: str) -> dict | None:
try:
with open(path) as f:
return json.load(f)
except Exception:
return None
def _list_bench_files() -> list[Path]:
if not EXPERIMENTS_DIR.exists():
return []
# Sort by internal timestamp (newest first). File mtime is unreliable
# on HF Spaces where all files get the same mtime at deploy time.
files = list(EXPERIMENTS_DIR.glob("bench-*.json"))
timestamped = []
for f in files:
try:
with open(f) as fh:
ts = json.load(fh).get("timestamp", "")
except Exception:
ts = ""
timestamped.append((ts, f))
timestamped.sort(key=lambda x: x[0], reverse=True)
return [f for _, f in timestamped]
def _format_bench_label(stem: str, data: dict) -> str:
ts = data.get("timestamp", "")[:16].replace("T", " ")
n = data.get("test_examples", "?")
model_id = data.get("model", "")
model_short = model_id.split("/")[-1] if model_id else ""
if _is_routed(data):
n_routes = len(data.get("routing", {}))
adapter = f"routed ({n_routes} specialized)"
else:
ap = _resolve_adapter(data)
adapter = Path(ap).name if ap else "base only"
parts = [ts, "—"]
if model_short:
parts.append(model_short)
parts.append("—")
parts.append(adapter)
parts.append(f"({n} samples)")
return " ".join(parts)
# ============================================================
# Aggregate Recomputation for Routed Benchmarks
# ============================================================
def _recompute_specialized_aggregate(bench_data: dict) -> dict | None:
"""For routed benchmarks with a fallback adapter, recompute the adapted
aggregate using only the specialized categories so the headline metrics
reflect the dedicated adapters rather than being diluted by the general
fallback. Returns a patched copy of the aggregate dict, or None if
no recomputation is needed."""
routing = bench_data.get("routing", {})
if not routing or not bench_data.get("fallback_adapter"):
return None
examples = bench_data.get("per_example", [])
specialized_cats = set(routing.keys())
specialized = [ex for ex in examples if ex.get("category") in specialized_cats]
if not specialized:
return None
orig_agg = bench_data.get("aggregate", {})
if "adapted" not in orig_agg:
return None
def _mean(lst):
return round(sum(lst) / len(lst), 4) if lst else None
scores = {"rouge1_f1": [], "rouge2_f1": [], "rougeL_f1": [],
"bleu": [], "response_len": [],
"clinical_term_recall": [], "numeric_recall": [],
"structured_pct": [], "safety_pct": []}
pred_scores = {"fac_exact_match": [], "fac_error": [],
"fac_direction_match": [],
"speed_abs_error": [], "speed_direction_match": [],
"risk_count_match": []}
for r in specialized:
m = r.get("metrics_adapted")
if not m:
continue
scores["rouge1_f1"].append(m["rouge1"]["f1"])
scores["rouge2_f1"].append(m["rouge2"]["f1"])
scores["rougeL_f1"].append(m["rougeL"]["f1"])
scores["bleu"].append(m["bleu"]["score"])
scores["response_len"].append(m["response_len"])
d = m.get("domain", {})
if d.get("clinical_term_recall") is not None:
scores["clinical_term_recall"].append(d["clinical_term_recall"])
if d.get("numeric_recall") is not None:
scores["numeric_recall"].append(d["numeric_recall"])
scores["structured_pct"].append(1.0 if d.get("structured") else 0.0)
scores["safety_pct"].append(1.0 if d.get("safety_awareness") else 0.0)
pred = d.get("prediction")
if pred:
if pred.get("fac_exact_match") is not None:
pred_scores["fac_exact_match"].append(1.0 if pred["fac_exact_match"] else 0.0)
if pred.get("fac_error") is not None:
pred_scores["fac_error"].append(pred["fac_error"])
if pred.get("fac_direction_match") is not None:
pred_scores["fac_direction_match"].append(1.0 if pred["fac_direction_match"] else 0.0)
if pred.get("speed_abs_error") is not None:
pred_scores["speed_abs_error"].append(pred["speed_abs_error"])
if pred.get("speed_direction_match") is not None:
pred_scores["speed_direction_match"].append(1.0 if pred["speed_direction_match"] else 0.0)
if pred.get("risk_count_match") is not None:
pred_scores["risk_count_match"].append(1.0 if pred["risk_count_match"] else 0.0)
new_adapted = {
"rouge1_f1": _mean(scores["rouge1_f1"]),
"rouge2_f1": _mean(scores["rouge2_f1"]),
"rougeL_f1": _mean(scores["rougeL_f1"]),
"bleu": _mean(scores["bleu"]),
"avg_response_len": _mean(scores["response_len"]),
"clinical_term_recall": _mean(scores["clinical_term_recall"]),
"numeric_recall": _mean(scores["numeric_recall"]),
"structured_pct": _mean(scores["structured_pct"]),
"safety_awareness_pct": _mean(scores["safety_pct"]),
"prediction": {
"fac_exact_match": _mean(pred_scores["fac_exact_match"]),
"fac_mean_error": _mean(pred_scores["fac_error"]),
"fac_direction_accuracy": _mean(pred_scores["fac_direction_match"]),
"speed_mean_abs_error": _mean(pred_scores["speed_abs_error"]),
"speed_direction_accuracy": _mean(pred_scores["speed_direction_match"]),
"risk_count_accuracy": _mean(pred_scores["risk_count_match"]),
},
"by_category": orig_agg["adapted"].get("by_category", {}),
}
patched = copy.deepcopy(orig_agg)
patched["adapted"] = new_adapted
return patched
# ============================================================
# Inference Backends
# ============================================================
def _get_inference_backend():
"""Return (backend_name, run_fn) tuple.
run_fn(model_id, adapter_path, prompt, max_tokens) -> str
"""
# Try MLX first (local Apple Silicon)
try:
import mlx.core # noqa: F401
return "mlx", _infer_mlx
except ImportError:
pass
# Fall back to HF Inference API
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
if hf_token:
return "hf_api", _infer_hf_api
return "none", None
def _infer_mlx(model_id: str, adapter_path: str | None,
prompt: str, max_tokens: int = 512) -> str:
"""Run inference via MLX (local)."""
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
cache_key = f"mlx_{model_id}_{adapter_path}"
if cache_key not in st.session_state:
model, tokenizer = load(model_id, adapter_path=adapter_path)
st.session_state[cache_key] = (model, tokenizer)
model, tokenizer = st.session_state[cache_key]
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
if hasattr(tokenizer, "apply_chat_template"):
formatted = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
else:
formatted = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
return generate(model, tokenizer, prompt=formatted,
max_tokens=max_tokens,
sampler=make_sampler(temp=0.1), verbose=False)
def _infer_hf_api(model_id: str, adapter_path: str | None,
prompt: str, max_tokens: int = 512) -> str:
"""Run inference via HF Inference API (serverless)."""
from huggingface_hub import InferenceClient
hf_model = MODEL_HF_MAP.get(model_id, model_id)
token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
client = InferenceClient(token=token)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
resp = client.chat_completion(
model=hf_model, messages=messages,
max_tokens=max_tokens, temperature=0.1,
)
return resp.choices[0].message.content
# ============================================================
# Inline Metric Computation (for live re-inference)
# ============================================================
def _tokenize(text: str) -> list[str]:
return re.findall(r"\w+", text.lower())
def _ngrams(tokens: list[str], n: int) -> Counter:
return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
def _rouge_n_f1(ref_tokens: list[str], hyp_tokens: list[str], n: int) -> float:
if not ref_tokens or not hyp_tokens:
return 0.0
ref_ng = _ngrams(ref_tokens, n)
hyp_ng = _ngrams(hyp_tokens, n)
overlap = sum((ref_ng & hyp_ng).values())
p = overlap / sum(hyp_ng.values()) if hyp_ng else 0.0
r = overlap / sum(ref_ng.values()) if ref_ng else 0.0
return 2 * p * r / (p + r) if (p + r) > 0 else 0.0
def _compute_live_metrics(reference: str, response: str) -> dict:
ref_tok = _tokenize(reference)
hyp_tok = _tokenize(response)
return {
"rouge1_f1": round(_rouge_n_f1(ref_tok, hyp_tok, 1), 4),
"rouge2_f1": round(_rouge_n_f1(ref_tok, hyp_tok, 2), 4),
"len": len(hyp_tok),
}
# ============================================================
# Visualisations
# ============================================================
def _render_model_info(bench_data: dict, agg: dict, has_adapted: bool,
n_test: int):
"""Collapsible explainer section for non-ML stakeholders."""
model_id = bench_data.get("model", "Unknown")
adapter_path = _resolve_adapter(bench_data)
adapter_name = Path(adapter_path).name if adapter_path else None
# Try to load training metadata
meta = None
for name in ["rehab_public_v1"]:
meta_path = TRAINING_DIR / name / "metadata.json"
if meta_path.exists():
try:
with open(meta_path) as f:
meta = json.load(f)
except Exception:
pass
break
# Check for specialized adapters
spec_manifest_path = TRAINING_DIR / "specialized" / "manifest.json"
has_specialized = spec_manifest_path.exists()
n_adapters = 0
if has_specialized:
try:
with open(spec_manifest_path) as f:
spec_manifest = json.load(f)
n_adapters = len(spec_manifest.get("subtasks", {}))
except Exception:
has_specialized = False
with st.expander("About this model and evaluation", expanded=False):
st.markdown(f"""
**What is this?**
This page evaluates how well our AI model answers rehabilitation-specific
clinical questions — comparing a general-purpose base model against our
fine-tuned version.
""")
# --- Base Model ---
model_short = model_id.split("/")[-1]
is_moe = "A3B" in model_short or "MoE" in model_short
arch_desc = ("a Mixture-of-Experts model (30B total, 3B active per token)"
if is_moe else "a dense transformer model")
st.markdown(f"""
**Base model** — `{model_short}`
An open-source large language model — {arch_desc} — that has broad
medical knowledge but nothing specific about our domain: rehabilitation
robotics, gait analysis, or patient outcome prediction.
""")
# --- Adapter Architecture ---
if adapter_name or has_specialized:
st.markdown("---")
st.markdown("**Fine-tuning approach — LoRA adapters**")
st.markdown("""
We teach the base model our domain through **LoRA** (Low-Rank Adaptation)
— a technique that adjusts a small fraction of the model's weights
(~0.1%) rather than retraining the whole thing. Think of it as adding a
specialized lens on top of general medical knowledge.
Training setup: batch size 4, 16 adapted layers, prompt-masked loss
(the model only learns from the answer, not the question).
""")
if has_specialized:
st.markdown(f"""
**Multi-adapter architecture** — We train **{n_adapters} specialized
adapters**, each focused on one prediction sub-task. During evaluation,
each test question is automatically routed to the right adapter:
""")
adapter_info = {
"prediction_trajectory": ("Trajectory", "Forecasts overall recovery path from a single initial visit"),
"prediction_fac": ("FAC", "Predicts the Functional Ambulation Category score change"),
"prediction_speed": ("Speed", "Predicts gait speed trajectory over coming weeks"),
"prediction_risk": ("Risk", "Identifies recovery risk factors from baseline data"),
}
adapter_rows = []
for task, info in adapter_info.items():
adapter_rows.append({
"Adapter": info[0],
"Focus": info[1],
})
st.dataframe(
pd.DataFrame(adapter_rows),
hide_index=True,
width="stretch",
height=35 * (len(adapter_rows) + 1),
)
st.markdown("""
All other tasks (clinical interpretation, session reporting, progress
analysis) use a single **general adapter** trained on the full dataset.
This split lets each prediction adapter focus deeply on its task without
being diluted by unrelated training data.
""")
elif adapter_name:
st.markdown(f"**Current adapter:** `{adapter_name}`")
# --- Training Data ---
if meta:
cats = meta.get("categories", {})
st.markdown("---")
st.markdown(f"""
**Training data** — `{meta.get("total_pairs", "?")}` question-answer pairs
Built from a public stroke rehabilitation dataset
([Zenodo 10534055](https://zenodo.org/records/10534055)) — 10 patients
with longitudinal gait measurements across two therapy visits. We
converted the raw sensor data into structured clinical Q&A pairs across
{len(cats)} task types:
""")
task_labels = {
"progress_prediction": ("Progress Analysis", "Retrospective comparison of two visits"),
"prediction_trajectory": ("Trajectory Prediction", "Forecast recovery from initial assessment only"),
"prediction_fac": ("FAC Forecasting", "Predict functional ambulation category change"),
"prediction_speed": ("Speed Prediction", "Predict gait speed trajectory"),
"prediction_risk": ("Risk Assessment", "Identify recovery risks from baseline data"),
"automated_reporting": ("Clinical Reporting", "Generate therapy session reports"),
"clinical_interpretation": ("Parameter Interpretation", "Explain what gait measurements mean"),
}
rows = []
for cat_key, count in sorted(cats.items(), key=lambda x: -x[1]):
label, desc = task_labels.get(cat_key, (cat_key, ""))
rows.append({"Task": label, "Pairs": count, "Description": desc})
st.dataframe(
pd.DataFrame(rows),
hide_index=True,
width="stretch",
height=min(35 * (len(rows) + 1), 300),
)
# --- How to Read Scores ---
st.markdown("---")
st.markdown("""
**How to read the scores**
*Text overlap metrics:*
- **ROUGE / BLEU** — measure how closely the model's answer matches our
reference answer (1.0 = perfect match). Higher is better. These tell us
if the model produces the right *format and vocabulary*.
*Domain quality:*
- **Clinical Term Recall** — does the model mention the right medical
terms? (e.g., FAC, MCID, gait speed)
- **Numeric Recall** — does it use the correct numbers from the patient
data?
- **Safety Awareness** — does it flag risks, recommend monitoring, or
note limitations?
*Predictive accuracy (prediction tasks only):*
- **FAC Exact Match** — did the model predict the exact correct FAC score?
- **Speed / FAC Direction** — did it get the direction right (improving,
stable, declining)?
- **Speed Error** — how far off is the predicted gait speed from the
actual outcome, in m/s?
The **base model** gives generic textbook answers. The **adapted model**
produces structured, data-grounded responses in our specific clinical
format — the kind of output we need for automated reporting and decision
support in BAMA's rehabilitation workflow.
""")
def _render_metric_cards(agg: dict, has_adapted: bool):
"""Top-level KPI metric cards mixing text-quality and prediction metrics."""
b = agg.get("base", {})
a = agg.get("adapted", {})
bp = b.get("prediction", {})
ap = a.get("prediction", {}) if has_adapted else {}
# (label, key, source, fmt, higher_better, help)
# source: "top" = agg[section][key], "pred" = agg[section]["prediction"][key]
metrics = [
("ROUGE-1", "rouge1_f1", "top", ".4f", True,
"Token overlap F1 at unigram level (higher is better)."),
("Term Recall", "clinical_term_recall", "top", ".0%", True,
"Share of domain clinical terms recovered in output."),
("FAC Accuracy", "fac_exact_match", "pred", ".0%", True,
"Exact FAC score match rate against reference (higher is better)."),
("FAC Direction", "fac_direction_accuracy", "pred", ".0%", True,
"FAC trend direction match — improve/stable/decline (higher is better)."),
("Speed Error", "speed_mean_abs_error", "pred", ".3f", False,
"Mean absolute gait speed error in m/s (lower is better)."),
("Risk Accuracy", "risk_count_accuracy", "pred", ".0%", True,
"Exact match rate for extracted risk-factor count (higher is better)."),
]
cols = st.columns(len(metrics))
for col, (label, key, source, fmt, higher_better, help_text) in zip(cols, metrics):
b_src = bp if source == "pred" else b
a_src = ap if source == "pred" else a
bv = b_src.get(key)
av = a_src.get(key) if has_adapted else None
def _display(val):
if val is None:
return "—"
if fmt.endswith("%"):
return f"{val:{fmt}}"
return f"{val:{fmt}}"
def _delta_str(bv_, av_):
if bv_ is None or av_ is None:
return None
d = av_ - bv_
if fmt.endswith("%"):
return f"{d * 100:+.0f}pp"
return f"{d:+{fmt}}"
with col:
if has_adapted and av is not None:
d_str = _delta_str(bv, av)
delta_color = "normal"
if d_str is not None and bv is not None:
d_val = av - bv
if (higher_better and d_val < 0) or (not higher_better and d_val > 0):
delta_color = "inverse"
st.metric(label, _display(av), delta=d_str,
delta_color=delta_color, help=help_text)
else:
st.metric(label, _display(bv), help=help_text)
def _render_category_chart(agg: dict, has_adapted: bool):
"""Per-category ROUGE-1 chart for text-quality categories only.
Prediction categories are excluded — their quality is shown via
task-specific metrics (FAC accuracy, speed error, etc.) instead.
"""
b_cats = agg.get("base", {}).get("by_category", {})
a_cats = agg.get("adapted", {}).get("by_category", {}) if has_adapted else {}
all_cats = sorted(set(list(b_cats.keys()) + list(a_cats.keys())))
categories = [c for c in all_cats if not c.startswith("prediction_")]
if not categories:
return
fig = go.Figure()
fig.add_trace(go.Bar(
name="Base",
x=categories,
y=[b_cats.get(c, {}).get("rouge1_f1", 0) for c in categories],
marker_color=C_BASE,
text=[f"{b_cats.get(c, {}).get('rouge1_f1', 0):.2f}" for c in categories],
textposition="outside",
))
if has_adapted:
a_vals = [a_cats.get(c, {}).get("rouge1_f1") for c in categories]
fig.add_trace(go.Bar(
name="Adapted",
x=categories,
y=[v if v is not None else None for v in a_vals],
marker_color=C_ADAPTED,
text=[f"{v:.2f}" if v is not None else "" for v in a_vals],
textposition="outside",
))
fig.update_layout(
title="Text Quality by Category (ROUGE-1 F1)",
barmode="group",
yaxis_range=[0, 1.05],
yaxis_title="F1 Score",
height=340,
margin=dict(t=40, b=20, l=40, r=20),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
plot_bgcolor=C_BG,
)
fig.update_yaxes(gridcolor=C_GRID)
st.plotly_chart(fig, width="stretch")
def _render_prediction_accuracy(agg: dict, has_adapted: bool,
examples: list[dict]):
"""Show predictive accuracy metrics for prediction_* categories.
Only renders if prediction metrics exist in the aggregate data.
"""
# Check if there are any prediction examples
pred_examples = [ex for ex in examples
if ex.get("category", "").startswith("prediction_")]
if not pred_examples:
return
bp = agg.get("base", {}).get("prediction", {})
ap = agg.get("adapted", {}).get("prediction", {}) if has_adapted else {}
# Only render if we have at least some non-None values (in base OR adapted)
has_any_bp = any(v is not None for v in bp.values()) if bp else False
has_any_ap = any(v is not None for v in ap.values()) if ap else False
if not has_any_bp and not has_any_ap:
# Recompute from per-example data if aggregate is missing
# (e.g. older bench files without prediction aggregate)
bp, ap = _recompute_prediction_agg(examples, has_adapted)
has_any_bp = any(v is not None for v in bp.values()) if bp else False
has_any_ap = any(v is not None for v in ap.values()) if ap else False
if not has_any_bp and not has_any_ap:
return
st.divider()
st.subheader("Predictive Accuracy")
st.caption(
"Structured value extraction from model output — "
"compares predicted FAC, speed, and risk against actual outcomes"
)
# Build metric rows
metrics_def = [
("FAC Exact Match", "fac_exact_match", "%", True,
"Exact FAC score match rate against reference."),
("FAC Direction", "fac_direction_accuracy", "%", True,
"Whether FAC trend direction matches (improve/stable/decline)."),
("FAC Mean Error", "fac_mean_error", "levels", False,
"Average absolute FAC level difference (lower is better)."),
("Speed Direction", "speed_direction_accuracy", "%", True,
"Whether speed trend direction matches reference."),
("Speed Mean Error", "speed_mean_abs_error", "m/s", False,
"Mean absolute speed error in m/s (lower is better)."),
("Risk Count Match", "risk_count_accuracy", "%", True,
"Exact match rate for extracted risk-factor count."),
]
# Render as metric columns (2 rows of 3)
for row_start in range(0, len(metrics_def), 3):
row_items = metrics_def[row_start:row_start + 3]
cols = st.columns(len(row_items))
for col, (label, key, unit, higher_better, help_text) in zip(cols, row_items):
b_val = bp.get(key)
a_val = ap.get(key) if ap else None
with col:
if has_adapted and a_val is not None:
if unit == "%":
display = f"{a_val * 100:.0f}%"
else:
display = f"{a_val:.3f} {unit}"
if b_val is not None:
delta = a_val - b_val
if unit == "%":
delta_str = f"{delta * 100:+.0f}pp"
else:
delta_str = f"{delta:+.3f}"
delta_color = ("normal" if
(higher_better and delta >= 0) or
(not higher_better and delta <= 0)
else "inverse")
st.metric(label, display, delta_str,
delta_color=delta_color,
help=help_text)
else:
st.metric(label, display, help=help_text)
elif b_val is not None:
if unit == "%":
display = f"{b_val * 100:.0f}%"
else:
display = f"{b_val:.3f} {unit}"
st.metric(label, display, help=help_text)
else:
st.metric(label, "—", help=help_text)
# Per-example detail for prediction categories
with st.expander("Per-sample prediction extraction", expanded=False):
rows = []
for i, ex in enumerate(examples):
cat = ex.get("category", "")
if not cat.startswith("prediction_"):
continue
short_cat = cat.replace("prediction_", "")
metrics_key = "metrics_adapted" if has_adapted else "metrics_base"
m = ex.get(metrics_key, {})
pred = m.get("domain", {}).get("prediction", {})
if not pred:
continue
row = {"#": i + 1, "Task": short_cat}
if "ref_fac" in pred:
rf = pred["ref_fac"]
pf = pred["resp_fac"]
row["Ref FAC"] = (f"{rf['current']}{rf['predicted']}"
if rf["current"] is not None and
rf["predicted"] is not None else "—")
row["Pred FAC"] = (f"{pf['current']}{pf['predicted']}"
if pf["current"] is not None and
pf["predicted"] is not None else "—")
if "ref_speed" in pred:
rs = pred["ref_speed"]
ps = pred["resp_speed"]
row["Ref Speed"] = (f"{rs['current']}{rs['predicted']}"
if rs["current"] is not None and
rs["predicted"] is not None else "—")
row["Pred Speed"] = (f"{ps['current']}{ps['predicted']}"
if ps["current"] is not None and
ps["predicted"] is not None else "—")
if "ref_risk_count" in pred:
row["Ref Risks"] = pred.get("ref_risk_count", "—")
row["Pred Risks"] = pred.get("resp_risk_count", "—")
# Overall match indicators
checks = []
if pred.get("fac_exact_match") is True:
checks.append("FAC-exact")
elif pred.get("fac_direction_match") is True:
checks.append("FAC-dir")
if pred.get("speed_direction_match") is True:
checks.append("Speed-dir")
if pred.get("risk_count_match") is True:
checks.append("Risk-count")
row["Matches"] = ", ".join(checks) if checks else "—"
rows.append(row)
if rows:
st.dataframe(pd.DataFrame(rows), hide_index=True,
width="stretch")
else:
st.info("No extractable prediction values found in model outputs.")
def _recompute_prediction_agg(examples: list[dict],
has_adapted: bool) -> tuple[dict, dict]:
"""Recompute prediction aggregate from per-example metrics.
Used for older bench files that don't have prediction aggregate.
"""
bp = {"fac_exact_match": [], "fac_error": [], "fac_direction_match": [],
"speed_abs_error": [], "speed_direction_match": [],
"risk_count_match": []}
ap = {"fac_exact_match": [], "fac_error": [], "fac_direction_match": [],
"speed_abs_error": [], "speed_direction_match": [],
"risk_count_match": []}
for ex in examples:
for target, mkey in [(bp, "metrics_base"), (ap, "metrics_adapted")]:
if mkey == "metrics_adapted" and not has_adapted:
continue
m = ex.get(mkey, {})
pred = m.get("domain", {}).get("prediction")
if not pred:
continue
for k in target:
v = pred.get(k)
if v is not None:
if isinstance(v, bool):
target[k].append(1.0 if v else 0.0)
else:
target[k].append(v)
def _mean(lst):
return round(sum(lst) / len(lst), 4) if lst else None
bp_agg = {k: _mean(v) for k, v in bp.items()}
ap_agg = {k: _mean(v) for k, v in ap.items()} if has_adapted else {}
# Remap to aggregate keys
key_map = {"fac_error": "fac_mean_error",
"speed_abs_error": "speed_mean_abs_error",
"fac_direction_match": "fac_direction_accuracy",
"speed_direction_match": "speed_direction_accuracy",
"risk_count_match": "risk_count_accuracy"}
for old, new in key_map.items():
if old in bp_agg:
bp_agg[new] = bp_agg.pop(old)
if old in ap_agg:
ap_agg[new] = ap_agg.pop(old)
return bp_agg, ap_agg
def _render_sample_browser(examples: list[dict], bench_data: dict,
has_adapted: bool):
"""Individual test sample viewer with side-by-side comparison."""
st.subheader("Test Samples")
if not examples:
st.info("No test samples in this benchmark.")
return
# Filters
col_filter, col_cat = st.columns([3, 1])
categories = sorted(set(ex.get("category", "unknown") for ex in examples))
with col_cat:
cat_filter = st.selectbox(
"Category", ["All"] + categories, key="eval_cat_filter")
filtered = examples if cat_filter == "All" else [
ex for ex in examples if ex.get("category") == cat_filter]
with col_filter:
sample_labels = [
f"[{i+1}] {ex.get('category', '?')}{ex['prompt'][:70]}..."
for i, ex in enumerate(filtered)
]
if not sample_labels:
st.info("No samples in this category.")
return
# Default to sample with most timepoints (richest forecast chart)
default_idx = 0
max_tps = 0
for i, ex in enumerate(filtered):
n_tps = len(re.findall(r"^\s+T\d+:", ex.get("prompt", ""), re.MULTILINE))
if n_tps > max_tps:
max_tps = n_tps
default_idx = i
idx = st.selectbox("Sample", range(len(sample_labels)),
index=default_idx,
format_func=lambda i: sample_labels[i],
key="eval_sample_select")
ex = filtered[idx]
# Prompt
st.markdown("**Prompt:**")
st.text_area("prompt_display", ex["prompt"], height=100,
disabled=True, label_visibility="collapsed")
# Side-by-side responses
n_cols = 3 if has_adapted else 2
cols = st.columns(n_cols)
with cols[0]:
st.markdown(f"**Reference** &nbsp; `{len(ex['reference'].split())} words`")
st.text_area("ref_display", ex["reference"], height=300,
disabled=True, label_visibility="collapsed", key="ref_ta")
with cols[1]:
m_base = ex.get("metrics_base", {})
r1_b = m_base.get("rouge1", {}).get("f1", "?")
bl_b = m_base.get("bleu", {}).get("score", "?")
st.markdown(f"**Base Model** &nbsp; `R1:{r1_b}` `B4:{bl_b}`")
# Check for live re-inference result
live_key = f"live_base_{bench_data.get('id', '')}_{idx}"
display_text = st.session_state.get(live_key, ex.get("base_response", ""))
st.text_area("base_display", display_text, height=300,
disabled=True, label_visibility="collapsed", key="base_ta")
if has_adapted and n_cols == 3:
with cols[2]:
m_adpt = ex.get("metrics_adapted", {})
r1_a = m_adpt.get("rouge1", {}).get("f1", "?")
bl_a = m_adpt.get("bleu", {}).get("score", "?")
st.markdown(f"**Adapted** &nbsp; `R1:{r1_a}` `B4:{bl_a}`")
live_key_a = f"live_adapted_{bench_data.get('id', '')}_{idx}"
display_text_a = st.session_state.get(
live_key_a, ex.get("adapted_response", ""))
st.text_area("adapted_display", display_text_a, height=300,
disabled=True, label_visibility="collapsed", key="adpt_ta")
# --- Prediction visualizations (prediction_* categories only) ---
if ex.get("category", "").startswith("prediction_"):
if _is_timeseries_prompt(ex.get("prompt", "")):
# Longitudinal: forecast chart supersedes the bullet chart
_render_ts_forecast_chart(ex, has_adapted)
else:
# Single-visit: bullet chart for point predictions
_render_prediction_bullet(ex, has_adapted)
# --- Live Re-Inference ---
_render_reinference_controls(ex, bench_data, idx, has_adapted)
# ============================================================
# Time-Series Forecast Helpers
# ============================================================
def _is_timeseries_prompt(prompt: str) -> bool:
"""Check if prompt contains longitudinal multi-timepoint data."""
tp_matches = re.findall(r"^\s+T\d+:", prompt, re.MULTILINE)
return len(tp_matches) >= 2
def _parse_ts_history(prompt: str) -> list[dict]:
"""Parse timepoint history from a longitudinal prompt.
Returns list of dicts sorted by timepoint:
[{"tp": "T0", "fac": 2, "cadence": 50.4, "speed": 0.356,
"stride_time": 1.191, "regularity": 0.816}, ...]
"""
results = []
for m in re.finditer(
r"^\s+(T\d+):\s*FAC\s+(\d+)"
r"(?:,\s*cadence\s+([\d.]+))?"
r"(?:.*?stride time\s+([\d.]+)\s*s)?"
r"(?:.*?speed\s*\(est\.\)\s*([\d.]+))?"
r"(?:.*?step regularity\s+([\d.]+))?",
prompt, re.MULTILINE,
):
tp = {
"tp": m.group(1),
"fac": int(m.group(2)),
}
if m.group(3):
tp["cadence"] = float(m.group(3))
if m.group(4):
tp["stride_time"] = float(m.group(4))
if m.group(5):
tp["speed"] = float(m.group(5))
if m.group(6):
tp["regularity"] = float(m.group(6))
results.append(tp)
results.sort(key=lambda x: int(x["tp"][1:]))
return results
def _parse_ts_prediction(text: str) -> dict:
"""Extract predicted values from reference or model response text.
Strips markdown bold markers before matching.
"""
clean = text.replace("*", "")
pred: dict = {}
m_spd = re.search(r"Predicted Gait Speed[:\s]*([\d.]+)", clean)
if m_spd:
pred["speed"] = float(m_spd.group(1))
m_cad = re.search(r"Predicted Cadence[:\s]*([\d.]+)", clean)
if m_cad:
pred["cadence"] = float(m_cad.group(1))
m_fac = re.search(r"Predicted FAC(?: Score)?[:\s]*(\d+)", clean)
if m_fac:
pred["fac"] = int(m_fac.group(1))
# Also try speed patterns in speed-only responses
if "speed" not in pred:
m_spd2 = re.search(r"Gait Speed[:\s]*([\d.]+)\s*m/s", clean)
if m_spd2:
pred["speed"] = float(m_spd2.group(1))
return pred
# Colours for forecast chart
_C_HISTORY = "#6c757d" # grey (observed)
_C_ACTUAL = "#0d6efd" # blue (reference/ground truth)
_C_MODEL = "#198754" # green (model prediction)
def _render_ts_forecast_chart(ex: dict, has_adapted: bool):
"""Render a line chart showing longitudinal trajectory + forecast.
X-axis: timepoints (T0, T1, ..., T_predicted)
Traces: observed history, reference prediction, model prediction.
"""
prompt = ex.get("prompt", "")
history = _parse_ts_history(prompt)
if len(history) < 2:
return
ref_pred = _parse_ts_prediction(ex.get("reference", ""))
model_text = ex.get("adapted_response" if has_adapted else "base_response", "")
model_pred = _parse_ts_prediction(model_text) if model_text else {}
# Determine the forecast timepoint label
last_tp_num = int(history[-1]["tp"][1:])
forecast_tp = f"T{last_tp_num + 1}"
# Try to extract from reference or prompt
for src in [ex.get("reference", ""), ex.get("prompt", "")]:
m_tp = re.search(r"at (?:the next assessment \()?(T\d+)\)?", src)
if m_tp:
forecast_tp = m_tp.group(1)
break
# Available parameters: only those that have a forecast value
# (from reference or model) AND at least one history point
param_options = []
param_labels = {
"speed": "Gait Speed (m/s)",
"cadence": "Cadence (steps/min)",
"stride_time": "Stride Time (s)",
"fac": "FAC Score",
"regularity": "Step Regularity",
}
for key in ["speed", "cadence", "stride_time", "fac", "regularity"]:
has_history = any(tp.get(key) is not None for tp in history)
has_forecast = ref_pred.get(key) is not None or model_pred.get(key) is not None
if has_history and has_forecast:
param_options.append(key)
if not param_options:
return
st.markdown("---")
st.markdown("**Longitudinal Forecast**")
# Use prompt hash for stable key across re-renders
prompt_hash = hash(ex.get("prompt", "")) & 0xFFFFFFFF
selected_param = st.radio(
"Parameter",
param_options,
format_func=lambda k: param_labels.get(k, k),
horizontal=True,
key=f"ts_param_{prompt_hash}",
)
# Build data
tp_labels = [tp["tp"] for tp in history]
hist_values = [tp.get(selected_param) for tp in history]
fig = go.Figure()
# Observed history line
valid_x = [tp_labels[i] for i in range(len(hist_values)) if hist_values[i] is not None]
valid_y = [v for v in hist_values if v is not None]
if valid_x:
fig.add_trace(go.Scatter(
x=valid_x, y=valid_y,
mode="lines+markers",
line=dict(color=_C_HISTORY, width=2),
marker=dict(size=8, color=_C_HISTORY),
name="Observed",
))
# Reference (ground truth) at forecast timepoint
ref_val = ref_pred.get(selected_param)
if ref_val is not None:
# Dashed connector from last observed to reference
if valid_y:
fig.add_trace(go.Scatter(
x=[valid_x[-1], forecast_tp],
y=[valid_y[-1], ref_val],
mode="lines",
line=dict(color=_C_ACTUAL, width=1, dash="dot"),
showlegend=False,
))
fig.add_trace(go.Scatter(
x=[forecast_tp], y=[ref_val],
mode="markers+text",
marker=dict(size=12, color=_C_ACTUAL, symbol="circle",
line=dict(width=1, color="#fff")),
text=[f"{ref_val}"],
textposition="top center",
textfont=dict(size=10, color=_C_ACTUAL),
name="Actual",
))
# Model prediction at forecast timepoint
model_val = model_pred.get(selected_param)
if model_val is not None:
# Dashed connector from last observed to model prediction
if valid_y:
fig.add_trace(go.Scatter(
x=[valid_x[-1], forecast_tp],
y=[valid_y[-1], model_val],
mode="lines",
line=dict(color=_C_MODEL, width=1, dash="dot"),
showlegend=False,
))
fig.add_trace(go.Scatter(
x=[forecast_tp], y=[model_val],
mode="markers+text",
marker=dict(size=12, color=_C_MODEL, symbol="triangle-up",
line=dict(width=1, color="#fff")),
text=[f"{model_val}"],
textposition="bottom center",
textfont=dict(size=10, color=_C_MODEL),
name="Model",
))
# Layout
all_tp = tp_labels + [forecast_tp]
fig.update_layout(
height=280,
margin=dict(t=30, b=30, l=50, r=20),
xaxis=dict(
categoryorder="array",
categoryarray=all_tp,
title="Assessment",
showgrid=True, gridcolor=C_GRID,
),
yaxis=dict(
title=param_labels.get(selected_param, selected_param),
showgrid=True, gridcolor=C_GRID,
),
plot_bgcolor=C_BG,
legend=dict(orientation="h", yanchor="bottom", y=1.02,
xanchor="left", x=0, font=dict(size=10)),
)
# Add a vertical dashed line separating history from forecast
# Use add_shape instead of add_vline to avoid Plotly categorical axis bug
last_hist_idx = len(tp_labels) - 1
fig.add_shape(
type="line",
x0=last_hist_idx, x1=last_hist_idx,
y0=0, y1=1, yref="paper",
line=dict(dash="dash", color="#ccc", width=1),
)
fig.add_annotation(
x=last_hist_idx + 0.5, y=1.0, yref="paper",
text="forecast", showarrow=False,
font=dict(size=9, color="#aaa"),
)
st.plotly_chart(fig, width="stretch")
def _render_prediction_bullet(ex: dict, has_adapted: bool):
"""Compact bullet chart showing predicted vs actual values.
Renders a horizontal number line for FAC and/or speed predictions,
with markers for current value, reference prediction, and model
prediction. Only shown for prediction_* categories.
"""
category = ex.get("category", "")
metrics_key = "metrics_adapted" if has_adapted else "metrics_base"
pred = ex.get(metrics_key, {}).get("domain", {}).get("prediction", {})
if not pred:
return
fig = go.Figure()
y_pos = 0 # Track vertical position for stacked rows
y_labels = []
has_any = False
# --- FAC bullet ---
ref_fac = pred.get("ref_fac", {})
resp_fac = pred.get("resp_fac", {})
if ref_fac.get("current") is not None or ref_fac.get("predicted") is not None:
y_labels.append("FAC Score")
# Current FAC (diamond, grey)
if ref_fac.get("current") is not None:
fig.add_trace(go.Scatter(
x=[ref_fac["current"]], y=[y_pos],
mode="markers+text",
marker=dict(symbol="diamond", size=14, color=C_BASE,
line=dict(width=1, color="#fff")),
text=["current"], textposition="bottom center",
textfont=dict(size=9, color="#888"),
name="Current",
showlegend=(y_pos == 0),
legendgroup="current",
))
# Reference predicted FAC (circle, blue)
if ref_fac.get("predicted") is not None:
fig.add_trace(go.Scatter(
x=[ref_fac["predicted"]], y=[y_pos],
mode="markers+text",
marker=dict(symbol="circle", size=14, color="#0d6efd",
line=dict(width=1, color="#fff")),
text=["actual"], textposition="top center",
textfont=dict(size=9, color="#0d6efd"),
name="Actual outcome",
showlegend=(y_pos == 0),
legendgroup="actual",
))
# Model predicted FAC (triangle, color-coded)
if resp_fac.get("predicted") is not None:
exact = pred.get("fac_exact_match")
dir_match = pred.get("fac_direction_match")
if exact:
color, label = "#198754", "exact match"
elif dir_match:
color, label = "#fd7e14", "direction correct"
else:
color, label = "#dc3545", "missed"
fig.add_trace(go.Scatter(
x=[resp_fac["predicted"]], y=[y_pos],
mode="markers+text",
marker=dict(symbol="triangle-up", size=16, color=color,
line=dict(width=1, color="#fff")),
text=[f"model ({label})"],
textposition="bottom center",
textfont=dict(size=9, color=color),
name="Model prediction",
showlegend=(y_pos == 0),
legendgroup="model",
))
has_any = True
elif resp_fac.get("direction"):
has_any = True
y_pos += 1
# --- Speed bullet ---
ref_spd = pred.get("ref_speed", {})
resp_spd = pred.get("resp_speed", {})
if ref_spd.get("current") is not None or ref_spd.get("predicted") is not None:
y_labels.append("Gait Speed (m/s)")
# Current speed (diamond, grey)
if ref_spd.get("current") is not None:
fig.add_trace(go.Scatter(
x=[ref_spd["current"]], y=[y_pos],
mode="markers+text",
marker=dict(symbol="diamond", size=14, color=C_BASE,
line=dict(width=1, color="#fff")),
text=["current"], textposition="bottom center",
textfont=dict(size=9, color="#888"),
name="Current",
showlegend=(y_pos == 0),
legendgroup="current",
))
# Reference predicted speed (circle, blue)
if ref_spd.get("predicted") is not None:
fig.add_trace(go.Scatter(
x=[ref_spd["predicted"]], y=[y_pos],
mode="markers+text",
marker=dict(symbol="circle", size=14, color="#0d6efd",
line=dict(width=1, color="#fff")),
text=["actual"], textposition="top center",
textfont=dict(size=9, color="#0d6efd"),
name="Actual outcome",
showlegend=(y_pos == 0),
legendgroup="actual",
))
# Model predicted speed (triangle, color-coded)
if resp_spd.get("predicted") is not None:
dir_match = pred.get("speed_direction_match")
abs_err = pred.get("speed_abs_error")
if abs_err is not None and abs_err < 0.05:
color, label = "#198754", f"close ({abs_err:.2f} m/s off)"
elif dir_match:
color, label = "#fd7e14", f"direction ok ({abs_err:.2f} off)" if abs_err else "direction ok"
elif dir_match is False:
color, label = "#dc3545", f"wrong direction ({abs_err:.2f} off)" if abs_err else "wrong direction"
else:
color, label = C_BASE, "extracted"
fig.add_trace(go.Scatter(
x=[resp_spd["predicted"]], y=[y_pos],
mode="markers+text",
marker=dict(symbol="triangle-up", size=16, color=color,
line=dict(width=1, color="#fff")),
text=[f"model ({label})"],
textposition="bottom center",
textfont=dict(size=9, color=color),
name="Model prediction",
showlegend=(y_pos == 0),
legendgroup="model",
))
has_any = True
y_pos += 1
# --- Risk count (simple display, no bullet needed) ---
ref_risk = pred.get("ref_risk_count")
resp_risk = pred.get("resp_risk_count")
if ref_risk is not None or resp_risk is not None:
match = pred.get("risk_count_match")
color = "#198754" if match else "#dc3545"
icon = "correct" if match else "incorrect"
st.markdown(
f"&nbsp;&nbsp; **Risk factors:** "
f"Actual **{ref_risk}** &nbsp;|&nbsp; "
f"Model predicted **{resp_risk}** "
f"&nbsp; <span style='color:{color}'>({icon})</span>",
unsafe_allow_html=True,
)
has_any = True
if not has_any or y_pos == 0:
return
# Compute x range from all marker data
all_x = []
for trace in fig.data:
all_x.extend(trace.x)
if not all_x:
return
x_min = min(all_x) - 0.3
x_max = max(all_x) + 0.3
fig.update_layout(
height=80 + y_pos * 50,
margin=dict(t=5, b=5, l=10, r=10),
xaxis=dict(range=[max(0, x_min), x_max],
showgrid=True, gridcolor=C_GRID),
yaxis=dict(tickvals=list(range(y_pos)), ticktext=y_labels,
showgrid=False),
plot_bgcolor=C_BG,
legend=dict(orientation="h", yanchor="bottom", y=1.0,
xanchor="left", x=0, font=dict(size=10)),
showlegend=True,
)
st.plotly_chart(fig, width="stretch")
def _render_reinference_controls(ex: dict, bench_data: dict,
sample_idx: int, has_adapted: bool):
"""Re-inference controls for a single sample."""
backend_name, infer_fn = _get_inference_backend()
if infer_fn is None:
st.caption("Live inference unavailable — needs MLX (local) or HF_TOKEN (Spaces).")
return
backend_label = "MLX (local)" if backend_name == "mlx" else "HF API (serverless)"
st.divider()
col_btn, col_model, col_info = st.columns([1, 2, 2])
with col_model:
model_id = bench_data.get("model", "")
# Resolve adapter: for routed bench, pick the adapter that matches
# the current example's category; fall back to general adapter
ex_category = ex.get("category", "")
if _is_routed(bench_data):
routing = bench_data.get("routing", {})
adapter = routing.get(ex_category, "") or bench_data.get("fallback_adapter", "")
else:
adapter = bench_data.get("adapter_path", "")
# Build model choices
choices = [f"{model_id} (base)"]
if adapter:
adapter_short = Path(adapter).name
choices.append(f"{model_id} + {adapter_short}")
model_choice = st.selectbox("Run as", choices, key=f"reinfer_model_{sample_idx}",
index=len(choices) - 1)
with col_info:
st.caption(f"Backend: **{backend_label}**")
if backend_name == "hf_api" and adapter:
st.caption("Note: LoRA adapter not available via HF API — base model only.")
with col_btn:
st.markdown("") # spacing
run = st.button("Re-run inference", key=f"reinfer_btn_{sample_idx}",
type="secondary")
if run:
use_adapter = "adapter" in model_choice or "exp-" in model_choice
effective_adapter = adapter if (use_adapter and backend_name == "mlx") else None
with st.spinner(f"Generating via {backend_label}..."):
try:
response = infer_fn(model_id, effective_adapter,
ex["prompt"], max_tokens=512)
# Compute metrics
metrics = _compute_live_metrics(ex["reference"], response)
# Store in session state
key_prefix = "live_adapted" if use_adapter else "live_base"
live_key = f"{key_prefix}_{bench_data.get('id', '')}_{sample_idx}"
st.session_state[live_key] = response
st.success(
f"Generated {metrics['len']} tokens — "
f"ROUGE-1: {metrics['rouge1_f1']:.4f}, "
f"ROUGE-2: {metrics['rouge2_f1']:.4f}"
)
st.rerun()
except Exception as e:
st.error(f"Inference failed: {e}")
# ============================================================
# Baseline Comparison Table
# ============================================================
def _render_baseline_comparison(bench_data: dict, bench_map: dict,
all_keys: list[str],
agg_override: dict | None = None):
"""Render a styled comparison table: fine-tuned model vs all baselines."""
n_samples = bench_data.get("test_examples")
effective_agg = agg_override if agg_override is not None else bench_data.get("aggregate", {})
adapted_agg = effective_agg.get("adapted", {})
adapted_pred = adapted_agg.get("prediction", {})
# Collect baseline runs (base-only, same test size)
baselines = []
for k in all_keys:
d = bench_map[k][1]
a = d.get("aggregate", {})
if "adapted" in a:
continue
if d.get("test_examples") != n_samples:
continue
b = a.get("base", {})
bp = b.get("prediction", {})
baselines.append({
"model": d.get("model", "?"),
"rouge1": b.get("rouge1_f1"),
"bleu": b.get("bleu"),
"term_recall": b.get("clinical_term_recall"),
"fac_exact": bp.get("fac_exact_match"),
"fac_dir": bp.get("fac_direction_accuracy"),
"speed_err": bp.get("speed_mean_abs_error"),
"risk_acc": bp.get("risk_count_accuracy"),
})
if not baselines:
return
# Build table rows
metrics = [
("ROUGE-1 F1", "rouge1", "{:.2f}", True),
("BLEU-4", "bleu", "{:.2f}", True),
("Term Recall", "term_recall", "{:.0%}", True),
("FAC Exact", "fac_exact", "{:.0%}", True),
("FAC Direction", "fac_dir", "{:.0%}", True),
("Speed Error", "speed_err", "{:.3f}", False),
("Risk Accuracy", "risk_acc", "{:.0%}", True),
]
def _fmt(val, fmt_str):
if val is None:
return "—"
try:
return fmt_str.format(val)
except Exception:
return str(val)
# Adapted model values
adapted_vals = {
"rouge1": adapted_agg.get("rouge1_f1"),
"bleu": adapted_agg.get("bleu"),
"term_recall": adapted_agg.get("clinical_term_recall"),
"fac_exact": adapted_pred.get("fac_exact_match"),
"fac_dir": adapted_pred.get("fac_direction_accuracy"),
"speed_err": adapted_pred.get("speed_mean_abs_error"),
"risk_acc": adapted_pred.get("risk_count_accuracy"),
}
# Shorten model names for display
def _short(model_id: str) -> str:
parts = model_id.split("/")
name = parts[-1] if len(parts) > 1 else model_id
# Remove common suffixes
for suffix in ["-MLX-4bit", "-MLX-8bit"]:
name = name.replace(suffix, "")
return name
# Build HTML table
header_cols = "".join(
f'<th style="padding:6px 10px;text-align:center;font-weight:400;'
f'color:#aaa;font-size:0.82em;">{_short(bl["model"])}</th>'
for bl in baselines
)
adapted_label = "Ours (LoRA)"
html_rows = []
for label, key, fmt, higher_better in metrics:
our_val = adapted_vals.get(key)
our_str = _fmt(our_val, fmt)
cells = ""
for bl in baselines:
bl_val = bl.get(key)
bl_str = _fmt(bl_val, fmt)
# Determine if our model wins on this metric
is_winner = False
if our_val is not None and bl_val is not None:
if higher_better:
is_winner = our_val > bl_val
else:
is_winner = our_val < bl_val
# Style: dim if our model is better
color = "#888" if is_winner else "#e0e0e0"
cells += (
f'<td style="padding:6px 10px;text-align:center;'
f'color:{color};font-size:0.9em;">{bl_str}</td>'
)
# Our column — bold green
our_cell = (
f'<td style="padding:6px 10px;text-align:center;'
f'color:#198754;font-weight:600;font-size:0.9em;">{our_str}</td>'
)
html_rows.append(
f'<tr>'
f'<td style="padding:6px 10px;color:#ccc;font-size:0.85em;">'
f'{label}</td>'
f'{cells}{our_cell}'
f'</tr>'
)
table_html = f"""
<div style="margin:0.8rem 0 0.5rem 0;">
<p style="color:#aaa;font-size:0.82em;margin-bottom:6px;">
Comparison across all evaluated models on {n_samples} held-out samples
</p>
<table style="width:100%;border-collapse:collapse;border:1px solid #333;
border-radius:6px;overflow:hidden;">
<thead>
<tr style="border-bottom:1px solid #333;">
<th style="padding:6px 10px;text-align:left;color:#888;
font-size:0.82em;font-weight:400;">Metric</th>
{header_cols}
<th style="padding:6px 10px;text-align:center;font-weight:600;
color:#198754;font-size:0.82em;
border-left:2px solid #198754;">{adapted_label}</th>
</tr>
</thead>
<tbody>
{"".join(html_rows)}
</tbody>
</table>
</div>
"""
# Fix: add green left-border to our column cells
table_html = table_html.replace(
f'color:#198754;font-weight:600;font-size:0.9em;">',
f'color:#198754;font-weight:600;font-size:0.9em;'
f'border-left:2px solid #198754;">'
)
st.markdown(table_html, unsafe_allow_html=True)
# ============================================================
# Main Entry Point
# ============================================================
def render_eval_tab():
"""Main entry point — called from app.py."""
st.title("Model Evaluation")
st.caption("Quantitative benchmarks on held-out test data · Live re-inference")
# Load bench files
bench_files = _list_bench_files()
if not bench_files:
st.warning(
"No benchmark results found. Run:\n\n"
"```\npython src/models/lab.py bench "
"--dataset rehab_public_v1 --experiment <id>\n```"
)
return
# Load all bench data (including routed runs)
bench_map = {}
for f in bench_files:
data = _load_bench(str(f))
if data:
bench_map[f.stem] = (f, data)
if not bench_map:
st.error("Could not load any benchmark files.")
return
# Bench run selector — only runs with adapted results on the latest
# full test set (533 samples). Base-only runs are accessible via
# the "Baseline source" selector instead.
CURRENT_TEST_SIZE = 533
all_keys = list(bench_map.keys())
bench_keys = [
k for k in all_keys
if "adapted" in bench_map[k][1].get("aggregate", {})
and bench_map[k][1].get("test_examples") == CURRENT_TEST_SIZE
]
if not bench_keys:
# Fallback: show everything if no matching runs exist
bench_keys = all_keys
bench_labels = [_format_bench_label(k, bench_map[k][1]) for k in bench_keys]
sel_idx = st.selectbox(
"Benchmark run",
range(len(bench_keys)),
index=0,
format_func=lambda i: bench_labels[i],
key="eval_bench_selector",
)
selected = bench_keys[sel_idx]
_, bench_data = bench_map[selected]
agg = bench_data.get("aggregate", {})
specialized_agg = _recompute_specialized_aggregate(bench_data)
if specialized_agg is not None:
agg = specialized_agg
examples = bench_data.get("per_example", [])
has_adapted = "adapted" in agg
# Info bar
col1, col2, col3 = st.columns(3)
with col1:
st.caption(f"**Model:** `{bench_data.get('model', '?')}`")
with col2:
adapter = _resolve_adapter(bench_data)
adapter_label = Path(adapter).name if adapter else "none"
if _is_routed(bench_data):
n_routes = len(bench_data.get("routing", {}))
adapter_label = f"routed ({n_routes} specialized + general)"
st.caption(f"**Adapter:** `{adapter_label}`")
with col3:
st.caption(f"**Samples:** `{len(examples)}`")
# --- About This Model (collapsible) ---
_render_model_info(bench_data, agg, has_adapted, len(examples))
# --- Baseline Comparison Table ---
if has_adapted:
_render_baseline_comparison(bench_data, bench_map, all_keys,
agg_override=agg)
# --- Baseline source selector (swaps base metrics for sections below) ---
if has_adapted:
n_samples = len(examples)
own_model = bench_data.get("model", "")
own_key = selected
base_only_keys = [
k for k in all_keys
if k != own_key
and bench_map[k][1].get("test_examples") == n_samples
and (
"adapted" not in bench_map[k][1].get("aggregate", {})
or bench_map[k][1].get("model", "") != own_model
)
]
if base_only_keys:
builtin_label = f"Built-in ({bench_data.get('model', 'Qwen')})"
ext_labels = []
for k in base_only_keys:
d = bench_map[k][1]
model = d.get("model", "?")
ts = d.get("timestamp", "")[:16].replace("T", " ")
has_ext_adapted = "adapted" in d.get("aggregate", {})
tag = " [adapted]" if has_ext_adapted else ""
ext_labels.append(f"{model}{tag} ({ts})")
options = [builtin_label] + ext_labels
default_idx = 1 if len(options) > 1 else 0
bl_sel = st.selectbox(
"Baseline source",
range(len(options)),
index=default_idx,
format_func=lambda i: options[i],
key="eval_baseline_selector",
)
if bl_sel > 0:
ext_key = base_only_keys[bl_sel - 1]
_, ext_data = bench_map[ext_key]
ext_agg = ext_data.get("aggregate", {})
ext_examples = ext_data.get("per_example", [])
agg = dict(agg)
if "adapted" in ext_agg:
agg["base"] = ext_agg["adapted"]
else:
agg["base"] = ext_agg.get("base", agg.get("base", {}))
if len(ext_examples) == len(examples):
examples = [dict(ex) for ex in examples]
for i, ext_ex in enumerate(ext_examples):
examples[i]["metrics_base"] = ext_ex.get(
"metrics_base", examples[i].get("metrics_base", {}))
examples[i]["base_response"] = ext_ex.get(
"base_response", examples[i].get("base_response", ""))
st.divider()
# --- Metric Cards ---
_render_metric_cards(agg, has_adapted)
# --- Predictive Accuracy (hero section for specialized adapters) ---
_render_prediction_accuracy(agg, has_adapted, examples)
# --- Text Quality by Category ---
_render_category_chart(agg, has_adapted)
st.divider()
# --- Sample Browser ---
_render_sample_browser(examples, bench_data, has_adapted)