Spaces:
Running
Running
| """ | |
| GURMA.ai — Model Evaluation Tab | |
| Displays benchmark results, test sample comparisons, and live re-inference. | |
| Dual backend: MLX (local Apple Silicon) / HF Inference API (HF Spaces). | |
| """ | |
| import copy | |
| import json | |
| import os | |
| import re | |
| from collections import Counter | |
| from pathlib import Path | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| # ============================================================ | |
| # Environment & Paths | |
| # ============================================================ | |
| IS_HF_SPACE = os.getenv("HF_SPACE") or Path("/app/research.py").exists() | |
| if IS_HF_SPACE: | |
| DATA_ROOT = Path("/app/data") | |
| else: | |
| DATA_ROOT = Path(__file__).resolve().parent.parent.parent / "data" | |
| EXPERIMENTS_DIR = DATA_ROOT / "experiments" | |
| TRAINING_DIR = DATA_ROOT / "training" | |
| ADAPTERS_DIR = DATA_ROOT / "adapters" | |
| # MLX model → HF Hub model for Inference API | |
| MODEL_HF_MAP = { | |
| "Qwen/Qwen3-8B-MLX-4bit": "Qwen/Qwen3-8B", | |
| "Qwen/Qwen3-8B-MLX-8bit": "Qwen/Qwen3-8B", | |
| "Qwen/Qwen3-30B-A3B-MLX-4bit": "Qwen/Qwen3-30B-A3B", | |
| } | |
| SYSTEM_PROMPT = ( | |
| "You are a rehabilitation robotics AI assistant for GURMA.ai. " | |
| "You help clinicians and engineers with therapy parameters, " | |
| "patient outcome interpretation, safety protocols, and session reporting. " | |
| "Be precise, evidence-aware, and flag uncertainty explicitly." | |
| ) | |
| # ============================================================ | |
| # Colour palette (GURMA brand) | |
| # ============================================================ | |
| C_BASE = "#6c757d" # grey | |
| C_ADAPTED = "#198754" # green | |
| C_BG = "rgba(0,0,0,0)" | |
| C_GRID = "rgba(200,200,200,0.3)" | |
| # ============================================================ | |
| # Helpers | |
| # ============================================================ | |
| def _resolve_adapter(bench_data: dict) -> str: | |
| """Get the effective adapter path from bench data. | |
| Standard bench stores 'adapter_path'. Routed bench stores | |
| 'fallback_adapter' + per-category 'routing'. Return the | |
| fallback (general) adapter for routed, or adapter_path for standard. | |
| """ | |
| ap = bench_data.get("adapter_path") | |
| if ap: | |
| return ap | |
| return bench_data.get("fallback_adapter") or "" | |
| def _is_routed(bench_data: dict) -> bool: | |
| """Whether this bench used per-category adapter routing.""" | |
| return bool(bench_data.get("routing")) | |
| # ============================================================ | |
| # Data Loading | |
| # ============================================================ | |
| def _load_bench(path: str) -> dict | None: | |
| try: | |
| with open(path) as f: | |
| return json.load(f) | |
| except Exception: | |
| return None | |
| def _list_bench_files() -> list[Path]: | |
| if not EXPERIMENTS_DIR.exists(): | |
| return [] | |
| # Sort by internal timestamp (newest first). File mtime is unreliable | |
| # on HF Spaces where all files get the same mtime at deploy time. | |
| files = list(EXPERIMENTS_DIR.glob("bench-*.json")) | |
| timestamped = [] | |
| for f in files: | |
| try: | |
| with open(f) as fh: | |
| ts = json.load(fh).get("timestamp", "") | |
| except Exception: | |
| ts = "" | |
| timestamped.append((ts, f)) | |
| timestamped.sort(key=lambda x: x[0], reverse=True) | |
| return [f for _, f in timestamped] | |
| def _format_bench_label(stem: str, data: dict) -> str: | |
| ts = data.get("timestamp", "")[:16].replace("T", " ") | |
| n = data.get("test_examples", "?") | |
| model_id = data.get("model", "") | |
| model_short = model_id.split("/")[-1] if model_id else "" | |
| if _is_routed(data): | |
| n_routes = len(data.get("routing", {})) | |
| adapter = f"routed ({n_routes} specialized)" | |
| else: | |
| ap = _resolve_adapter(data) | |
| adapter = Path(ap).name if ap else "base only" | |
| parts = [ts, "—"] | |
| if model_short: | |
| parts.append(model_short) | |
| parts.append("—") | |
| parts.append(adapter) | |
| parts.append(f"({n} samples)") | |
| return " ".join(parts) | |
| # ============================================================ | |
| # Aggregate Recomputation for Routed Benchmarks | |
| # ============================================================ | |
| def _recompute_specialized_aggregate(bench_data: dict) -> dict | None: | |
| """For routed benchmarks with a fallback adapter, recompute the adapted | |
| aggregate using only the specialized categories so the headline metrics | |
| reflect the dedicated adapters rather than being diluted by the general | |
| fallback. Returns a patched copy of the aggregate dict, or None if | |
| no recomputation is needed.""" | |
| routing = bench_data.get("routing", {}) | |
| if not routing or not bench_data.get("fallback_adapter"): | |
| return None | |
| examples = bench_data.get("per_example", []) | |
| specialized_cats = set(routing.keys()) | |
| specialized = [ex for ex in examples if ex.get("category") in specialized_cats] | |
| if not specialized: | |
| return None | |
| orig_agg = bench_data.get("aggregate", {}) | |
| if "adapted" not in orig_agg: | |
| return None | |
| def _mean(lst): | |
| return round(sum(lst) / len(lst), 4) if lst else None | |
| scores = {"rouge1_f1": [], "rouge2_f1": [], "rougeL_f1": [], | |
| "bleu": [], "response_len": [], | |
| "clinical_term_recall": [], "numeric_recall": [], | |
| "structured_pct": [], "safety_pct": []} | |
| pred_scores = {"fac_exact_match": [], "fac_error": [], | |
| "fac_direction_match": [], | |
| "speed_abs_error": [], "speed_direction_match": [], | |
| "risk_count_match": []} | |
| for r in specialized: | |
| m = r.get("metrics_adapted") | |
| if not m: | |
| continue | |
| scores["rouge1_f1"].append(m["rouge1"]["f1"]) | |
| scores["rouge2_f1"].append(m["rouge2"]["f1"]) | |
| scores["rougeL_f1"].append(m["rougeL"]["f1"]) | |
| scores["bleu"].append(m["bleu"]["score"]) | |
| scores["response_len"].append(m["response_len"]) | |
| d = m.get("domain", {}) | |
| if d.get("clinical_term_recall") is not None: | |
| scores["clinical_term_recall"].append(d["clinical_term_recall"]) | |
| if d.get("numeric_recall") is not None: | |
| scores["numeric_recall"].append(d["numeric_recall"]) | |
| scores["structured_pct"].append(1.0 if d.get("structured") else 0.0) | |
| scores["safety_pct"].append(1.0 if d.get("safety_awareness") else 0.0) | |
| pred = d.get("prediction") | |
| if pred: | |
| if pred.get("fac_exact_match") is not None: | |
| pred_scores["fac_exact_match"].append(1.0 if pred["fac_exact_match"] else 0.0) | |
| if pred.get("fac_error") is not None: | |
| pred_scores["fac_error"].append(pred["fac_error"]) | |
| if pred.get("fac_direction_match") is not None: | |
| pred_scores["fac_direction_match"].append(1.0 if pred["fac_direction_match"] else 0.0) | |
| if pred.get("speed_abs_error") is not None: | |
| pred_scores["speed_abs_error"].append(pred["speed_abs_error"]) | |
| if pred.get("speed_direction_match") is not None: | |
| pred_scores["speed_direction_match"].append(1.0 if pred["speed_direction_match"] else 0.0) | |
| if pred.get("risk_count_match") is not None: | |
| pred_scores["risk_count_match"].append(1.0 if pred["risk_count_match"] else 0.0) | |
| new_adapted = { | |
| "rouge1_f1": _mean(scores["rouge1_f1"]), | |
| "rouge2_f1": _mean(scores["rouge2_f1"]), | |
| "rougeL_f1": _mean(scores["rougeL_f1"]), | |
| "bleu": _mean(scores["bleu"]), | |
| "avg_response_len": _mean(scores["response_len"]), | |
| "clinical_term_recall": _mean(scores["clinical_term_recall"]), | |
| "numeric_recall": _mean(scores["numeric_recall"]), | |
| "structured_pct": _mean(scores["structured_pct"]), | |
| "safety_awareness_pct": _mean(scores["safety_pct"]), | |
| "prediction": { | |
| "fac_exact_match": _mean(pred_scores["fac_exact_match"]), | |
| "fac_mean_error": _mean(pred_scores["fac_error"]), | |
| "fac_direction_accuracy": _mean(pred_scores["fac_direction_match"]), | |
| "speed_mean_abs_error": _mean(pred_scores["speed_abs_error"]), | |
| "speed_direction_accuracy": _mean(pred_scores["speed_direction_match"]), | |
| "risk_count_accuracy": _mean(pred_scores["risk_count_match"]), | |
| }, | |
| "by_category": orig_agg["adapted"].get("by_category", {}), | |
| } | |
| patched = copy.deepcopy(orig_agg) | |
| patched["adapted"] = new_adapted | |
| return patched | |
| # ============================================================ | |
| # Inference Backends | |
| # ============================================================ | |
| def _get_inference_backend(): | |
| """Return (backend_name, run_fn) tuple. | |
| run_fn(model_id, adapter_path, prompt, max_tokens) -> str | |
| """ | |
| # Try MLX first (local Apple Silicon) | |
| try: | |
| import mlx.core # noqa: F401 | |
| return "mlx", _infer_mlx | |
| except ImportError: | |
| pass | |
| # Fall back to HF Inference API | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| if hf_token: | |
| return "hf_api", _infer_hf_api | |
| return "none", None | |
| def _infer_mlx(model_id: str, adapter_path: str | None, | |
| prompt: str, max_tokens: int = 512) -> str: | |
| """Run inference via MLX (local).""" | |
| from mlx_lm import load, generate | |
| from mlx_lm.sample_utils import make_sampler | |
| cache_key = f"mlx_{model_id}_{adapter_path}" | |
| if cache_key not in st.session_state: | |
| model, tokenizer = load(model_id, adapter_path=adapter_path) | |
| st.session_state[cache_key] = (model, tokenizer) | |
| model, tokenizer = st.session_state[cache_key] | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| formatted = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| else: | |
| formatted = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n" | |
| return generate(model, tokenizer, prompt=formatted, | |
| max_tokens=max_tokens, | |
| sampler=make_sampler(temp=0.1), verbose=False) | |
| def _infer_hf_api(model_id: str, adapter_path: str | None, | |
| prompt: str, max_tokens: int = 512) -> str: | |
| """Run inference via HF Inference API (serverless).""" | |
| from huggingface_hub import InferenceClient | |
| hf_model = MODEL_HF_MAP.get(model_id, model_id) | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| client = InferenceClient(token=token) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| resp = client.chat_completion( | |
| model=hf_model, messages=messages, | |
| max_tokens=max_tokens, temperature=0.1, | |
| ) | |
| return resp.choices[0].message.content | |
| # ============================================================ | |
| # Inline Metric Computation (for live re-inference) | |
| # ============================================================ | |
| def _tokenize(text: str) -> list[str]: | |
| return re.findall(r"\w+", text.lower()) | |
| def _ngrams(tokens: list[str], n: int) -> Counter: | |
| return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)) | |
| def _rouge_n_f1(ref_tokens: list[str], hyp_tokens: list[str], n: int) -> float: | |
| if not ref_tokens or not hyp_tokens: | |
| return 0.0 | |
| ref_ng = _ngrams(ref_tokens, n) | |
| hyp_ng = _ngrams(hyp_tokens, n) | |
| overlap = sum((ref_ng & hyp_ng).values()) | |
| p = overlap / sum(hyp_ng.values()) if hyp_ng else 0.0 | |
| r = overlap / sum(ref_ng.values()) if ref_ng else 0.0 | |
| return 2 * p * r / (p + r) if (p + r) > 0 else 0.0 | |
| def _compute_live_metrics(reference: str, response: str) -> dict: | |
| ref_tok = _tokenize(reference) | |
| hyp_tok = _tokenize(response) | |
| return { | |
| "rouge1_f1": round(_rouge_n_f1(ref_tok, hyp_tok, 1), 4), | |
| "rouge2_f1": round(_rouge_n_f1(ref_tok, hyp_tok, 2), 4), | |
| "len": len(hyp_tok), | |
| } | |
| # ============================================================ | |
| # Visualisations | |
| # ============================================================ | |
| def _render_model_info(bench_data: dict, agg: dict, has_adapted: bool, | |
| n_test: int): | |
| """Collapsible explainer section for non-ML stakeholders.""" | |
| model_id = bench_data.get("model", "Unknown") | |
| adapter_path = _resolve_adapter(bench_data) | |
| adapter_name = Path(adapter_path).name if adapter_path else None | |
| # Try to load training metadata | |
| meta = None | |
| for name in ["rehab_public_v1"]: | |
| meta_path = TRAINING_DIR / name / "metadata.json" | |
| if meta_path.exists(): | |
| try: | |
| with open(meta_path) as f: | |
| meta = json.load(f) | |
| except Exception: | |
| pass | |
| break | |
| # Check for specialized adapters | |
| spec_manifest_path = TRAINING_DIR / "specialized" / "manifest.json" | |
| has_specialized = spec_manifest_path.exists() | |
| n_adapters = 0 | |
| if has_specialized: | |
| try: | |
| with open(spec_manifest_path) as f: | |
| spec_manifest = json.load(f) | |
| n_adapters = len(spec_manifest.get("subtasks", {})) | |
| except Exception: | |
| has_specialized = False | |
| with st.expander("About this model and evaluation", expanded=False): | |
| st.markdown(f""" | |
| **What is this?** | |
| This page evaluates how well our AI model answers rehabilitation-specific | |
| clinical questions — comparing a general-purpose base model against our | |
| fine-tuned version. | |
| """) | |
| # --- Base Model --- | |
| model_short = model_id.split("/")[-1] | |
| is_moe = "A3B" in model_short or "MoE" in model_short | |
| arch_desc = ("a Mixture-of-Experts model (30B total, 3B active per token)" | |
| if is_moe else "a dense transformer model") | |
| st.markdown(f""" | |
| **Base model** — `{model_short}` | |
| An open-source large language model — {arch_desc} — that has broad | |
| medical knowledge but nothing specific about our domain: rehabilitation | |
| robotics, gait analysis, or patient outcome prediction. | |
| """) | |
| # --- Adapter Architecture --- | |
| if adapter_name or has_specialized: | |
| st.markdown("---") | |
| st.markdown("**Fine-tuning approach — LoRA adapters**") | |
| st.markdown(""" | |
| We teach the base model our domain through **LoRA** (Low-Rank Adaptation) | |
| — a technique that adjusts a small fraction of the model's weights | |
| (~0.1%) rather than retraining the whole thing. Think of it as adding a | |
| specialized lens on top of general medical knowledge. | |
| Training setup: batch size 4, 16 adapted layers, prompt-masked loss | |
| (the model only learns from the answer, not the question). | |
| """) | |
| if has_specialized: | |
| st.markdown(f""" | |
| **Multi-adapter architecture** — We train **{n_adapters} specialized | |
| adapters**, each focused on one prediction sub-task. During evaluation, | |
| each test question is automatically routed to the right adapter: | |
| """) | |
| adapter_info = { | |
| "prediction_trajectory": ("Trajectory", "Forecasts overall recovery path from a single initial visit"), | |
| "prediction_fac": ("FAC", "Predicts the Functional Ambulation Category score change"), | |
| "prediction_speed": ("Speed", "Predicts gait speed trajectory over coming weeks"), | |
| "prediction_risk": ("Risk", "Identifies recovery risk factors from baseline data"), | |
| } | |
| adapter_rows = [] | |
| for task, info in adapter_info.items(): | |
| adapter_rows.append({ | |
| "Adapter": info[0], | |
| "Focus": info[1], | |
| }) | |
| st.dataframe( | |
| pd.DataFrame(adapter_rows), | |
| hide_index=True, | |
| width="stretch", | |
| height=35 * (len(adapter_rows) + 1), | |
| ) | |
| st.markdown(""" | |
| All other tasks (clinical interpretation, session reporting, progress | |
| analysis) use a single **general adapter** trained on the full dataset. | |
| This split lets each prediction adapter focus deeply on its task without | |
| being diluted by unrelated training data. | |
| """) | |
| elif adapter_name: | |
| st.markdown(f"**Current adapter:** `{adapter_name}`") | |
| # --- Training Data --- | |
| if meta: | |
| cats = meta.get("categories", {}) | |
| st.markdown("---") | |
| st.markdown(f""" | |
| **Training data** — `{meta.get("total_pairs", "?")}` question-answer pairs | |
| Built from a public stroke rehabilitation dataset | |
| ([Zenodo 10534055](https://zenodo.org/records/10534055)) — 10 patients | |
| with longitudinal gait measurements across two therapy visits. We | |
| converted the raw sensor data into structured clinical Q&A pairs across | |
| {len(cats)} task types: | |
| """) | |
| task_labels = { | |
| "progress_prediction": ("Progress Analysis", "Retrospective comparison of two visits"), | |
| "prediction_trajectory": ("Trajectory Prediction", "Forecast recovery from initial assessment only"), | |
| "prediction_fac": ("FAC Forecasting", "Predict functional ambulation category change"), | |
| "prediction_speed": ("Speed Prediction", "Predict gait speed trajectory"), | |
| "prediction_risk": ("Risk Assessment", "Identify recovery risks from baseline data"), | |
| "automated_reporting": ("Clinical Reporting", "Generate therapy session reports"), | |
| "clinical_interpretation": ("Parameter Interpretation", "Explain what gait measurements mean"), | |
| } | |
| rows = [] | |
| for cat_key, count in sorted(cats.items(), key=lambda x: -x[1]): | |
| label, desc = task_labels.get(cat_key, (cat_key, "")) | |
| rows.append({"Task": label, "Pairs": count, "Description": desc}) | |
| st.dataframe( | |
| pd.DataFrame(rows), | |
| hide_index=True, | |
| width="stretch", | |
| height=min(35 * (len(rows) + 1), 300), | |
| ) | |
| # --- How to Read Scores --- | |
| st.markdown("---") | |
| st.markdown(""" | |
| **How to read the scores** | |
| *Text overlap metrics:* | |
| - **ROUGE / BLEU** — measure how closely the model's answer matches our | |
| reference answer (1.0 = perfect match). Higher is better. These tell us | |
| if the model produces the right *format and vocabulary*. | |
| *Domain quality:* | |
| - **Clinical Term Recall** — does the model mention the right medical | |
| terms? (e.g., FAC, MCID, gait speed) | |
| - **Numeric Recall** — does it use the correct numbers from the patient | |
| data? | |
| - **Safety Awareness** — does it flag risks, recommend monitoring, or | |
| note limitations? | |
| *Predictive accuracy (prediction tasks only):* | |
| - **FAC Exact Match** — did the model predict the exact correct FAC score? | |
| - **Speed / FAC Direction** — did it get the direction right (improving, | |
| stable, declining)? | |
| - **Speed Error** — how far off is the predicted gait speed from the | |
| actual outcome, in m/s? | |
| The **base model** gives generic textbook answers. The **adapted model** | |
| produces structured, data-grounded responses in our specific clinical | |
| format — the kind of output we need for automated reporting and decision | |
| support in BAMA's rehabilitation workflow. | |
| """) | |
| def _render_metric_cards(agg: dict, has_adapted: bool): | |
| """Top-level KPI metric cards mixing text-quality and prediction metrics.""" | |
| b = agg.get("base", {}) | |
| a = agg.get("adapted", {}) | |
| bp = b.get("prediction", {}) | |
| ap = a.get("prediction", {}) if has_adapted else {} | |
| # (label, key, source, fmt, higher_better, help) | |
| # source: "top" = agg[section][key], "pred" = agg[section]["prediction"][key] | |
| metrics = [ | |
| ("ROUGE-1", "rouge1_f1", "top", ".4f", True, | |
| "Token overlap F1 at unigram level (higher is better)."), | |
| ("Term Recall", "clinical_term_recall", "top", ".0%", True, | |
| "Share of domain clinical terms recovered in output."), | |
| ("FAC Accuracy", "fac_exact_match", "pred", ".0%", True, | |
| "Exact FAC score match rate against reference (higher is better)."), | |
| ("FAC Direction", "fac_direction_accuracy", "pred", ".0%", True, | |
| "FAC trend direction match — improve/stable/decline (higher is better)."), | |
| ("Speed Error", "speed_mean_abs_error", "pred", ".3f", False, | |
| "Mean absolute gait speed error in m/s (lower is better)."), | |
| ("Risk Accuracy", "risk_count_accuracy", "pred", ".0%", True, | |
| "Exact match rate for extracted risk-factor count (higher is better)."), | |
| ] | |
| cols = st.columns(len(metrics)) | |
| for col, (label, key, source, fmt, higher_better, help_text) in zip(cols, metrics): | |
| b_src = bp if source == "pred" else b | |
| a_src = ap if source == "pred" else a | |
| bv = b_src.get(key) | |
| av = a_src.get(key) if has_adapted else None | |
| def _display(val): | |
| if val is None: | |
| return "—" | |
| if fmt.endswith("%"): | |
| return f"{val:{fmt}}" | |
| return f"{val:{fmt}}" | |
| def _delta_str(bv_, av_): | |
| if bv_ is None or av_ is None: | |
| return None | |
| d = av_ - bv_ | |
| if fmt.endswith("%"): | |
| return f"{d * 100:+.0f}pp" | |
| return f"{d:+{fmt}}" | |
| with col: | |
| if has_adapted and av is not None: | |
| d_str = _delta_str(bv, av) | |
| delta_color = "normal" | |
| if d_str is not None and bv is not None: | |
| d_val = av - bv | |
| if (higher_better and d_val < 0) or (not higher_better and d_val > 0): | |
| delta_color = "inverse" | |
| st.metric(label, _display(av), delta=d_str, | |
| delta_color=delta_color, help=help_text) | |
| else: | |
| st.metric(label, _display(bv), help=help_text) | |
| def _render_category_chart(agg: dict, has_adapted: bool): | |
| """Per-category ROUGE-1 chart for text-quality categories only. | |
| Prediction categories are excluded — their quality is shown via | |
| task-specific metrics (FAC accuracy, speed error, etc.) instead. | |
| """ | |
| b_cats = agg.get("base", {}).get("by_category", {}) | |
| a_cats = agg.get("adapted", {}).get("by_category", {}) if has_adapted else {} | |
| all_cats = sorted(set(list(b_cats.keys()) + list(a_cats.keys()))) | |
| categories = [c for c in all_cats if not c.startswith("prediction_")] | |
| if not categories: | |
| return | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| name="Base", | |
| x=categories, | |
| y=[b_cats.get(c, {}).get("rouge1_f1", 0) for c in categories], | |
| marker_color=C_BASE, | |
| text=[f"{b_cats.get(c, {}).get('rouge1_f1', 0):.2f}" for c in categories], | |
| textposition="outside", | |
| )) | |
| if has_adapted: | |
| a_vals = [a_cats.get(c, {}).get("rouge1_f1") for c in categories] | |
| fig.add_trace(go.Bar( | |
| name="Adapted", | |
| x=categories, | |
| y=[v if v is not None else None for v in a_vals], | |
| marker_color=C_ADAPTED, | |
| text=[f"{v:.2f}" if v is not None else "" for v in a_vals], | |
| textposition="outside", | |
| )) | |
| fig.update_layout( | |
| title="Text Quality by Category (ROUGE-1 F1)", | |
| barmode="group", | |
| yaxis_range=[0, 1.05], | |
| yaxis_title="F1 Score", | |
| height=340, | |
| margin=dict(t=40, b=20, l=40, r=20), | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), | |
| plot_bgcolor=C_BG, | |
| ) | |
| fig.update_yaxes(gridcolor=C_GRID) | |
| st.plotly_chart(fig, width="stretch") | |
| def _render_prediction_accuracy(agg: dict, has_adapted: bool, | |
| examples: list[dict]): | |
| """Show predictive accuracy metrics for prediction_* categories. | |
| Only renders if prediction metrics exist in the aggregate data. | |
| """ | |
| # Check if there are any prediction examples | |
| pred_examples = [ex for ex in examples | |
| if ex.get("category", "").startswith("prediction_")] | |
| if not pred_examples: | |
| return | |
| bp = agg.get("base", {}).get("prediction", {}) | |
| ap = agg.get("adapted", {}).get("prediction", {}) if has_adapted else {} | |
| # Only render if we have at least some non-None values (in base OR adapted) | |
| has_any_bp = any(v is not None for v in bp.values()) if bp else False | |
| has_any_ap = any(v is not None for v in ap.values()) if ap else False | |
| if not has_any_bp and not has_any_ap: | |
| # Recompute from per-example data if aggregate is missing | |
| # (e.g. older bench files without prediction aggregate) | |
| bp, ap = _recompute_prediction_agg(examples, has_adapted) | |
| has_any_bp = any(v is not None for v in bp.values()) if bp else False | |
| has_any_ap = any(v is not None for v in ap.values()) if ap else False | |
| if not has_any_bp and not has_any_ap: | |
| return | |
| st.divider() | |
| st.subheader("Predictive Accuracy") | |
| st.caption( | |
| "Structured value extraction from model output — " | |
| "compares predicted FAC, speed, and risk against actual outcomes" | |
| ) | |
| # Build metric rows | |
| metrics_def = [ | |
| ("FAC Exact Match", "fac_exact_match", "%", True, | |
| "Exact FAC score match rate against reference."), | |
| ("FAC Direction", "fac_direction_accuracy", "%", True, | |
| "Whether FAC trend direction matches (improve/stable/decline)."), | |
| ("FAC Mean Error", "fac_mean_error", "levels", False, | |
| "Average absolute FAC level difference (lower is better)."), | |
| ("Speed Direction", "speed_direction_accuracy", "%", True, | |
| "Whether speed trend direction matches reference."), | |
| ("Speed Mean Error", "speed_mean_abs_error", "m/s", False, | |
| "Mean absolute speed error in m/s (lower is better)."), | |
| ("Risk Count Match", "risk_count_accuracy", "%", True, | |
| "Exact match rate for extracted risk-factor count."), | |
| ] | |
| # Render as metric columns (2 rows of 3) | |
| for row_start in range(0, len(metrics_def), 3): | |
| row_items = metrics_def[row_start:row_start + 3] | |
| cols = st.columns(len(row_items)) | |
| for col, (label, key, unit, higher_better, help_text) in zip(cols, row_items): | |
| b_val = bp.get(key) | |
| a_val = ap.get(key) if ap else None | |
| with col: | |
| if has_adapted and a_val is not None: | |
| if unit == "%": | |
| display = f"{a_val * 100:.0f}%" | |
| else: | |
| display = f"{a_val:.3f} {unit}" | |
| if b_val is not None: | |
| delta = a_val - b_val | |
| if unit == "%": | |
| delta_str = f"{delta * 100:+.0f}pp" | |
| else: | |
| delta_str = f"{delta:+.3f}" | |
| delta_color = ("normal" if | |
| (higher_better and delta >= 0) or | |
| (not higher_better and delta <= 0) | |
| else "inverse") | |
| st.metric(label, display, delta_str, | |
| delta_color=delta_color, | |
| help=help_text) | |
| else: | |
| st.metric(label, display, help=help_text) | |
| elif b_val is not None: | |
| if unit == "%": | |
| display = f"{b_val * 100:.0f}%" | |
| else: | |
| display = f"{b_val:.3f} {unit}" | |
| st.metric(label, display, help=help_text) | |
| else: | |
| st.metric(label, "—", help=help_text) | |
| # Per-example detail for prediction categories | |
| with st.expander("Per-sample prediction extraction", expanded=False): | |
| rows = [] | |
| for i, ex in enumerate(examples): | |
| cat = ex.get("category", "") | |
| if not cat.startswith("prediction_"): | |
| continue | |
| short_cat = cat.replace("prediction_", "") | |
| metrics_key = "metrics_adapted" if has_adapted else "metrics_base" | |
| m = ex.get(metrics_key, {}) | |
| pred = m.get("domain", {}).get("prediction", {}) | |
| if not pred: | |
| continue | |
| row = {"#": i + 1, "Task": short_cat} | |
| if "ref_fac" in pred: | |
| rf = pred["ref_fac"] | |
| pf = pred["resp_fac"] | |
| row["Ref FAC"] = (f"{rf['current']}→{rf['predicted']}" | |
| if rf["current"] is not None and | |
| rf["predicted"] is not None else "—") | |
| row["Pred FAC"] = (f"{pf['current']}→{pf['predicted']}" | |
| if pf["current"] is not None and | |
| pf["predicted"] is not None else "—") | |
| if "ref_speed" in pred: | |
| rs = pred["ref_speed"] | |
| ps = pred["resp_speed"] | |
| row["Ref Speed"] = (f"{rs['current']}→{rs['predicted']}" | |
| if rs["current"] is not None and | |
| rs["predicted"] is not None else "—") | |
| row["Pred Speed"] = (f"{ps['current']}→{ps['predicted']}" | |
| if ps["current"] is not None and | |
| ps["predicted"] is not None else "—") | |
| if "ref_risk_count" in pred: | |
| row["Ref Risks"] = pred.get("ref_risk_count", "—") | |
| row["Pred Risks"] = pred.get("resp_risk_count", "—") | |
| # Overall match indicators | |
| checks = [] | |
| if pred.get("fac_exact_match") is True: | |
| checks.append("FAC-exact") | |
| elif pred.get("fac_direction_match") is True: | |
| checks.append("FAC-dir") | |
| if pred.get("speed_direction_match") is True: | |
| checks.append("Speed-dir") | |
| if pred.get("risk_count_match") is True: | |
| checks.append("Risk-count") | |
| row["Matches"] = ", ".join(checks) if checks else "—" | |
| rows.append(row) | |
| if rows: | |
| st.dataframe(pd.DataFrame(rows), hide_index=True, | |
| width="stretch") | |
| else: | |
| st.info("No extractable prediction values found in model outputs.") | |
| def _recompute_prediction_agg(examples: list[dict], | |
| has_adapted: bool) -> tuple[dict, dict]: | |
| """Recompute prediction aggregate from per-example metrics. | |
| Used for older bench files that don't have prediction aggregate. | |
| """ | |
| bp = {"fac_exact_match": [], "fac_error": [], "fac_direction_match": [], | |
| "speed_abs_error": [], "speed_direction_match": [], | |
| "risk_count_match": []} | |
| ap = {"fac_exact_match": [], "fac_error": [], "fac_direction_match": [], | |
| "speed_abs_error": [], "speed_direction_match": [], | |
| "risk_count_match": []} | |
| for ex in examples: | |
| for target, mkey in [(bp, "metrics_base"), (ap, "metrics_adapted")]: | |
| if mkey == "metrics_adapted" and not has_adapted: | |
| continue | |
| m = ex.get(mkey, {}) | |
| pred = m.get("domain", {}).get("prediction") | |
| if not pred: | |
| continue | |
| for k in target: | |
| v = pred.get(k) | |
| if v is not None: | |
| if isinstance(v, bool): | |
| target[k].append(1.0 if v else 0.0) | |
| else: | |
| target[k].append(v) | |
| def _mean(lst): | |
| return round(sum(lst) / len(lst), 4) if lst else None | |
| bp_agg = {k: _mean(v) for k, v in bp.items()} | |
| ap_agg = {k: _mean(v) for k, v in ap.items()} if has_adapted else {} | |
| # Remap to aggregate keys | |
| key_map = {"fac_error": "fac_mean_error", | |
| "speed_abs_error": "speed_mean_abs_error", | |
| "fac_direction_match": "fac_direction_accuracy", | |
| "speed_direction_match": "speed_direction_accuracy", | |
| "risk_count_match": "risk_count_accuracy"} | |
| for old, new in key_map.items(): | |
| if old in bp_agg: | |
| bp_agg[new] = bp_agg.pop(old) | |
| if old in ap_agg: | |
| ap_agg[new] = ap_agg.pop(old) | |
| return bp_agg, ap_agg | |
| def _render_sample_browser(examples: list[dict], bench_data: dict, | |
| has_adapted: bool): | |
| """Individual test sample viewer with side-by-side comparison.""" | |
| st.subheader("Test Samples") | |
| if not examples: | |
| st.info("No test samples in this benchmark.") | |
| return | |
| # Filters | |
| col_filter, col_cat = st.columns([3, 1]) | |
| categories = sorted(set(ex.get("category", "unknown") for ex in examples)) | |
| with col_cat: | |
| cat_filter = st.selectbox( | |
| "Category", ["All"] + categories, key="eval_cat_filter") | |
| filtered = examples if cat_filter == "All" else [ | |
| ex for ex in examples if ex.get("category") == cat_filter] | |
| with col_filter: | |
| sample_labels = [ | |
| f"[{i+1}] {ex.get('category', '?')} — {ex['prompt'][:70]}..." | |
| for i, ex in enumerate(filtered) | |
| ] | |
| if not sample_labels: | |
| st.info("No samples in this category.") | |
| return | |
| # Default to sample with most timepoints (richest forecast chart) | |
| default_idx = 0 | |
| max_tps = 0 | |
| for i, ex in enumerate(filtered): | |
| n_tps = len(re.findall(r"^\s+T\d+:", ex.get("prompt", ""), re.MULTILINE)) | |
| if n_tps > max_tps: | |
| max_tps = n_tps | |
| default_idx = i | |
| idx = st.selectbox("Sample", range(len(sample_labels)), | |
| index=default_idx, | |
| format_func=lambda i: sample_labels[i], | |
| key="eval_sample_select") | |
| ex = filtered[idx] | |
| # Prompt | |
| st.markdown("**Prompt:**") | |
| st.text_area("prompt_display", ex["prompt"], height=100, | |
| disabled=True, label_visibility="collapsed") | |
| # Side-by-side responses | |
| n_cols = 3 if has_adapted else 2 | |
| cols = st.columns(n_cols) | |
| with cols[0]: | |
| st.markdown(f"**Reference** `{len(ex['reference'].split())} words`") | |
| st.text_area("ref_display", ex["reference"], height=300, | |
| disabled=True, label_visibility="collapsed", key="ref_ta") | |
| with cols[1]: | |
| m_base = ex.get("metrics_base", {}) | |
| r1_b = m_base.get("rouge1", {}).get("f1", "?") | |
| bl_b = m_base.get("bleu", {}).get("score", "?") | |
| st.markdown(f"**Base Model** `R1:{r1_b}` `B4:{bl_b}`") | |
| # Check for live re-inference result | |
| live_key = f"live_base_{bench_data.get('id', '')}_{idx}" | |
| display_text = st.session_state.get(live_key, ex.get("base_response", "")) | |
| st.text_area("base_display", display_text, height=300, | |
| disabled=True, label_visibility="collapsed", key="base_ta") | |
| if has_adapted and n_cols == 3: | |
| with cols[2]: | |
| m_adpt = ex.get("metrics_adapted", {}) | |
| r1_a = m_adpt.get("rouge1", {}).get("f1", "?") | |
| bl_a = m_adpt.get("bleu", {}).get("score", "?") | |
| st.markdown(f"**Adapted** `R1:{r1_a}` `B4:{bl_a}`") | |
| live_key_a = f"live_adapted_{bench_data.get('id', '')}_{idx}" | |
| display_text_a = st.session_state.get( | |
| live_key_a, ex.get("adapted_response", "")) | |
| st.text_area("adapted_display", display_text_a, height=300, | |
| disabled=True, label_visibility="collapsed", key="adpt_ta") | |
| # --- Prediction visualizations (prediction_* categories only) --- | |
| if ex.get("category", "").startswith("prediction_"): | |
| if _is_timeseries_prompt(ex.get("prompt", "")): | |
| # Longitudinal: forecast chart supersedes the bullet chart | |
| _render_ts_forecast_chart(ex, has_adapted) | |
| else: | |
| # Single-visit: bullet chart for point predictions | |
| _render_prediction_bullet(ex, has_adapted) | |
| # --- Live Re-Inference --- | |
| _render_reinference_controls(ex, bench_data, idx, has_adapted) | |
| # ============================================================ | |
| # Time-Series Forecast Helpers | |
| # ============================================================ | |
| def _is_timeseries_prompt(prompt: str) -> bool: | |
| """Check if prompt contains longitudinal multi-timepoint data.""" | |
| tp_matches = re.findall(r"^\s+T\d+:", prompt, re.MULTILINE) | |
| return len(tp_matches) >= 2 | |
| def _parse_ts_history(prompt: str) -> list[dict]: | |
| """Parse timepoint history from a longitudinal prompt. | |
| Returns list of dicts sorted by timepoint: | |
| [{"tp": "T0", "fac": 2, "cadence": 50.4, "speed": 0.356, | |
| "stride_time": 1.191, "regularity": 0.816}, ...] | |
| """ | |
| results = [] | |
| for m in re.finditer( | |
| r"^\s+(T\d+):\s*FAC\s+(\d+)" | |
| r"(?:,\s*cadence\s+([\d.]+))?" | |
| r"(?:.*?stride time\s+([\d.]+)\s*s)?" | |
| r"(?:.*?speed\s*\(est\.\)\s*([\d.]+))?" | |
| r"(?:.*?step regularity\s+([\d.]+))?", | |
| prompt, re.MULTILINE, | |
| ): | |
| tp = { | |
| "tp": m.group(1), | |
| "fac": int(m.group(2)), | |
| } | |
| if m.group(3): | |
| tp["cadence"] = float(m.group(3)) | |
| if m.group(4): | |
| tp["stride_time"] = float(m.group(4)) | |
| if m.group(5): | |
| tp["speed"] = float(m.group(5)) | |
| if m.group(6): | |
| tp["regularity"] = float(m.group(6)) | |
| results.append(tp) | |
| results.sort(key=lambda x: int(x["tp"][1:])) | |
| return results | |
| def _parse_ts_prediction(text: str) -> dict: | |
| """Extract predicted values from reference or model response text. | |
| Strips markdown bold markers before matching. | |
| """ | |
| clean = text.replace("*", "") | |
| pred: dict = {} | |
| m_spd = re.search(r"Predicted Gait Speed[:\s]*([\d.]+)", clean) | |
| if m_spd: | |
| pred["speed"] = float(m_spd.group(1)) | |
| m_cad = re.search(r"Predicted Cadence[:\s]*([\d.]+)", clean) | |
| if m_cad: | |
| pred["cadence"] = float(m_cad.group(1)) | |
| m_fac = re.search(r"Predicted FAC(?: Score)?[:\s]*(\d+)", clean) | |
| if m_fac: | |
| pred["fac"] = int(m_fac.group(1)) | |
| # Also try speed patterns in speed-only responses | |
| if "speed" not in pred: | |
| m_spd2 = re.search(r"Gait Speed[:\s]*([\d.]+)\s*m/s", clean) | |
| if m_spd2: | |
| pred["speed"] = float(m_spd2.group(1)) | |
| return pred | |
| # Colours for forecast chart | |
| _C_HISTORY = "#6c757d" # grey (observed) | |
| _C_ACTUAL = "#0d6efd" # blue (reference/ground truth) | |
| _C_MODEL = "#198754" # green (model prediction) | |
| def _render_ts_forecast_chart(ex: dict, has_adapted: bool): | |
| """Render a line chart showing longitudinal trajectory + forecast. | |
| X-axis: timepoints (T0, T1, ..., T_predicted) | |
| Traces: observed history, reference prediction, model prediction. | |
| """ | |
| prompt = ex.get("prompt", "") | |
| history = _parse_ts_history(prompt) | |
| if len(history) < 2: | |
| return | |
| ref_pred = _parse_ts_prediction(ex.get("reference", "")) | |
| model_text = ex.get("adapted_response" if has_adapted else "base_response", "") | |
| model_pred = _parse_ts_prediction(model_text) if model_text else {} | |
| # Determine the forecast timepoint label | |
| last_tp_num = int(history[-1]["tp"][1:]) | |
| forecast_tp = f"T{last_tp_num + 1}" | |
| # Try to extract from reference or prompt | |
| for src in [ex.get("reference", ""), ex.get("prompt", "")]: | |
| m_tp = re.search(r"at (?:the next assessment \()?(T\d+)\)?", src) | |
| if m_tp: | |
| forecast_tp = m_tp.group(1) | |
| break | |
| # Available parameters: only those that have a forecast value | |
| # (from reference or model) AND at least one history point | |
| param_options = [] | |
| param_labels = { | |
| "speed": "Gait Speed (m/s)", | |
| "cadence": "Cadence (steps/min)", | |
| "stride_time": "Stride Time (s)", | |
| "fac": "FAC Score", | |
| "regularity": "Step Regularity", | |
| } | |
| for key in ["speed", "cadence", "stride_time", "fac", "regularity"]: | |
| has_history = any(tp.get(key) is not None for tp in history) | |
| has_forecast = ref_pred.get(key) is not None or model_pred.get(key) is not None | |
| if has_history and has_forecast: | |
| param_options.append(key) | |
| if not param_options: | |
| return | |
| st.markdown("---") | |
| st.markdown("**Longitudinal Forecast**") | |
| # Use prompt hash for stable key across re-renders | |
| prompt_hash = hash(ex.get("prompt", "")) & 0xFFFFFFFF | |
| selected_param = st.radio( | |
| "Parameter", | |
| param_options, | |
| format_func=lambda k: param_labels.get(k, k), | |
| horizontal=True, | |
| key=f"ts_param_{prompt_hash}", | |
| ) | |
| # Build data | |
| tp_labels = [tp["tp"] for tp in history] | |
| hist_values = [tp.get(selected_param) for tp in history] | |
| fig = go.Figure() | |
| # Observed history line | |
| valid_x = [tp_labels[i] for i in range(len(hist_values)) if hist_values[i] is not None] | |
| valid_y = [v for v in hist_values if v is not None] | |
| if valid_x: | |
| fig.add_trace(go.Scatter( | |
| x=valid_x, y=valid_y, | |
| mode="lines+markers", | |
| line=dict(color=_C_HISTORY, width=2), | |
| marker=dict(size=8, color=_C_HISTORY), | |
| name="Observed", | |
| )) | |
| # Reference (ground truth) at forecast timepoint | |
| ref_val = ref_pred.get(selected_param) | |
| if ref_val is not None: | |
| # Dashed connector from last observed to reference | |
| if valid_y: | |
| fig.add_trace(go.Scatter( | |
| x=[valid_x[-1], forecast_tp], | |
| y=[valid_y[-1], ref_val], | |
| mode="lines", | |
| line=dict(color=_C_ACTUAL, width=1, dash="dot"), | |
| showlegend=False, | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=[forecast_tp], y=[ref_val], | |
| mode="markers+text", | |
| marker=dict(size=12, color=_C_ACTUAL, symbol="circle", | |
| line=dict(width=1, color="#fff")), | |
| text=[f"{ref_val}"], | |
| textposition="top center", | |
| textfont=dict(size=10, color=_C_ACTUAL), | |
| name="Actual", | |
| )) | |
| # Model prediction at forecast timepoint | |
| model_val = model_pred.get(selected_param) | |
| if model_val is not None: | |
| # Dashed connector from last observed to model prediction | |
| if valid_y: | |
| fig.add_trace(go.Scatter( | |
| x=[valid_x[-1], forecast_tp], | |
| y=[valid_y[-1], model_val], | |
| mode="lines", | |
| line=dict(color=_C_MODEL, width=1, dash="dot"), | |
| showlegend=False, | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=[forecast_tp], y=[model_val], | |
| mode="markers+text", | |
| marker=dict(size=12, color=_C_MODEL, symbol="triangle-up", | |
| line=dict(width=1, color="#fff")), | |
| text=[f"{model_val}"], | |
| textposition="bottom center", | |
| textfont=dict(size=10, color=_C_MODEL), | |
| name="Model", | |
| )) | |
| # Layout | |
| all_tp = tp_labels + [forecast_tp] | |
| fig.update_layout( | |
| height=280, | |
| margin=dict(t=30, b=30, l=50, r=20), | |
| xaxis=dict( | |
| categoryorder="array", | |
| categoryarray=all_tp, | |
| title="Assessment", | |
| showgrid=True, gridcolor=C_GRID, | |
| ), | |
| yaxis=dict( | |
| title=param_labels.get(selected_param, selected_param), | |
| showgrid=True, gridcolor=C_GRID, | |
| ), | |
| plot_bgcolor=C_BG, | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, | |
| xanchor="left", x=0, font=dict(size=10)), | |
| ) | |
| # Add a vertical dashed line separating history from forecast | |
| # Use add_shape instead of add_vline to avoid Plotly categorical axis bug | |
| last_hist_idx = len(tp_labels) - 1 | |
| fig.add_shape( | |
| type="line", | |
| x0=last_hist_idx, x1=last_hist_idx, | |
| y0=0, y1=1, yref="paper", | |
| line=dict(dash="dash", color="#ccc", width=1), | |
| ) | |
| fig.add_annotation( | |
| x=last_hist_idx + 0.5, y=1.0, yref="paper", | |
| text="forecast", showarrow=False, | |
| font=dict(size=9, color="#aaa"), | |
| ) | |
| st.plotly_chart(fig, width="stretch") | |
| def _render_prediction_bullet(ex: dict, has_adapted: bool): | |
| """Compact bullet chart showing predicted vs actual values. | |
| Renders a horizontal number line for FAC and/or speed predictions, | |
| with markers for current value, reference prediction, and model | |
| prediction. Only shown for prediction_* categories. | |
| """ | |
| category = ex.get("category", "") | |
| metrics_key = "metrics_adapted" if has_adapted else "metrics_base" | |
| pred = ex.get(metrics_key, {}).get("domain", {}).get("prediction", {}) | |
| if not pred: | |
| return | |
| fig = go.Figure() | |
| y_pos = 0 # Track vertical position for stacked rows | |
| y_labels = [] | |
| has_any = False | |
| # --- FAC bullet --- | |
| ref_fac = pred.get("ref_fac", {}) | |
| resp_fac = pred.get("resp_fac", {}) | |
| if ref_fac.get("current") is not None or ref_fac.get("predicted") is not None: | |
| y_labels.append("FAC Score") | |
| # Current FAC (diamond, grey) | |
| if ref_fac.get("current") is not None: | |
| fig.add_trace(go.Scatter( | |
| x=[ref_fac["current"]], y=[y_pos], | |
| mode="markers+text", | |
| marker=dict(symbol="diamond", size=14, color=C_BASE, | |
| line=dict(width=1, color="#fff")), | |
| text=["current"], textposition="bottom center", | |
| textfont=dict(size=9, color="#888"), | |
| name="Current", | |
| showlegend=(y_pos == 0), | |
| legendgroup="current", | |
| )) | |
| # Reference predicted FAC (circle, blue) | |
| if ref_fac.get("predicted") is not None: | |
| fig.add_trace(go.Scatter( | |
| x=[ref_fac["predicted"]], y=[y_pos], | |
| mode="markers+text", | |
| marker=dict(symbol="circle", size=14, color="#0d6efd", | |
| line=dict(width=1, color="#fff")), | |
| text=["actual"], textposition="top center", | |
| textfont=dict(size=9, color="#0d6efd"), | |
| name="Actual outcome", | |
| showlegend=(y_pos == 0), | |
| legendgroup="actual", | |
| )) | |
| # Model predicted FAC (triangle, color-coded) | |
| if resp_fac.get("predicted") is not None: | |
| exact = pred.get("fac_exact_match") | |
| dir_match = pred.get("fac_direction_match") | |
| if exact: | |
| color, label = "#198754", "exact match" | |
| elif dir_match: | |
| color, label = "#fd7e14", "direction correct" | |
| else: | |
| color, label = "#dc3545", "missed" | |
| fig.add_trace(go.Scatter( | |
| x=[resp_fac["predicted"]], y=[y_pos], | |
| mode="markers+text", | |
| marker=dict(symbol="triangle-up", size=16, color=color, | |
| line=dict(width=1, color="#fff")), | |
| text=[f"model ({label})"], | |
| textposition="bottom center", | |
| textfont=dict(size=9, color=color), | |
| name="Model prediction", | |
| showlegend=(y_pos == 0), | |
| legendgroup="model", | |
| )) | |
| has_any = True | |
| elif resp_fac.get("direction"): | |
| has_any = True | |
| y_pos += 1 | |
| # --- Speed bullet --- | |
| ref_spd = pred.get("ref_speed", {}) | |
| resp_spd = pred.get("resp_speed", {}) | |
| if ref_spd.get("current") is not None or ref_spd.get("predicted") is not None: | |
| y_labels.append("Gait Speed (m/s)") | |
| # Current speed (diamond, grey) | |
| if ref_spd.get("current") is not None: | |
| fig.add_trace(go.Scatter( | |
| x=[ref_spd["current"]], y=[y_pos], | |
| mode="markers+text", | |
| marker=dict(symbol="diamond", size=14, color=C_BASE, | |
| line=dict(width=1, color="#fff")), | |
| text=["current"], textposition="bottom center", | |
| textfont=dict(size=9, color="#888"), | |
| name="Current", | |
| showlegend=(y_pos == 0), | |
| legendgroup="current", | |
| )) | |
| # Reference predicted speed (circle, blue) | |
| if ref_spd.get("predicted") is not None: | |
| fig.add_trace(go.Scatter( | |
| x=[ref_spd["predicted"]], y=[y_pos], | |
| mode="markers+text", | |
| marker=dict(symbol="circle", size=14, color="#0d6efd", | |
| line=dict(width=1, color="#fff")), | |
| text=["actual"], textposition="top center", | |
| textfont=dict(size=9, color="#0d6efd"), | |
| name="Actual outcome", | |
| showlegend=(y_pos == 0), | |
| legendgroup="actual", | |
| )) | |
| # Model predicted speed (triangle, color-coded) | |
| if resp_spd.get("predicted") is not None: | |
| dir_match = pred.get("speed_direction_match") | |
| abs_err = pred.get("speed_abs_error") | |
| if abs_err is not None and abs_err < 0.05: | |
| color, label = "#198754", f"close ({abs_err:.2f} m/s off)" | |
| elif dir_match: | |
| color, label = "#fd7e14", f"direction ok ({abs_err:.2f} off)" if abs_err else "direction ok" | |
| elif dir_match is False: | |
| color, label = "#dc3545", f"wrong direction ({abs_err:.2f} off)" if abs_err else "wrong direction" | |
| else: | |
| color, label = C_BASE, "extracted" | |
| fig.add_trace(go.Scatter( | |
| x=[resp_spd["predicted"]], y=[y_pos], | |
| mode="markers+text", | |
| marker=dict(symbol="triangle-up", size=16, color=color, | |
| line=dict(width=1, color="#fff")), | |
| text=[f"model ({label})"], | |
| textposition="bottom center", | |
| textfont=dict(size=9, color=color), | |
| name="Model prediction", | |
| showlegend=(y_pos == 0), | |
| legendgroup="model", | |
| )) | |
| has_any = True | |
| y_pos += 1 | |
| # --- Risk count (simple display, no bullet needed) --- | |
| ref_risk = pred.get("ref_risk_count") | |
| resp_risk = pred.get("resp_risk_count") | |
| if ref_risk is not None or resp_risk is not None: | |
| match = pred.get("risk_count_match") | |
| color = "#198754" if match else "#dc3545" | |
| icon = "correct" if match else "incorrect" | |
| st.markdown( | |
| f" **Risk factors:** " | |
| f"Actual **{ref_risk}** | " | |
| f"Model predicted **{resp_risk}** " | |
| f" <span style='color:{color}'>({icon})</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| has_any = True | |
| if not has_any or y_pos == 0: | |
| return | |
| # Compute x range from all marker data | |
| all_x = [] | |
| for trace in fig.data: | |
| all_x.extend(trace.x) | |
| if not all_x: | |
| return | |
| x_min = min(all_x) - 0.3 | |
| x_max = max(all_x) + 0.3 | |
| fig.update_layout( | |
| height=80 + y_pos * 50, | |
| margin=dict(t=5, b=5, l=10, r=10), | |
| xaxis=dict(range=[max(0, x_min), x_max], | |
| showgrid=True, gridcolor=C_GRID), | |
| yaxis=dict(tickvals=list(range(y_pos)), ticktext=y_labels, | |
| showgrid=False), | |
| plot_bgcolor=C_BG, | |
| legend=dict(orientation="h", yanchor="bottom", y=1.0, | |
| xanchor="left", x=0, font=dict(size=10)), | |
| showlegend=True, | |
| ) | |
| st.plotly_chart(fig, width="stretch") | |
| def _render_reinference_controls(ex: dict, bench_data: dict, | |
| sample_idx: int, has_adapted: bool): | |
| """Re-inference controls for a single sample.""" | |
| backend_name, infer_fn = _get_inference_backend() | |
| if infer_fn is None: | |
| st.caption("Live inference unavailable — needs MLX (local) or HF_TOKEN (Spaces).") | |
| return | |
| backend_label = "MLX (local)" if backend_name == "mlx" else "HF API (serverless)" | |
| st.divider() | |
| col_btn, col_model, col_info = st.columns([1, 2, 2]) | |
| with col_model: | |
| model_id = bench_data.get("model", "") | |
| # Resolve adapter: for routed bench, pick the adapter that matches | |
| # the current example's category; fall back to general adapter | |
| ex_category = ex.get("category", "") | |
| if _is_routed(bench_data): | |
| routing = bench_data.get("routing", {}) | |
| adapter = routing.get(ex_category, "") or bench_data.get("fallback_adapter", "") | |
| else: | |
| adapter = bench_data.get("adapter_path", "") | |
| # Build model choices | |
| choices = [f"{model_id} (base)"] | |
| if adapter: | |
| adapter_short = Path(adapter).name | |
| choices.append(f"{model_id} + {adapter_short}") | |
| model_choice = st.selectbox("Run as", choices, key=f"reinfer_model_{sample_idx}", | |
| index=len(choices) - 1) | |
| with col_info: | |
| st.caption(f"Backend: **{backend_label}**") | |
| if backend_name == "hf_api" and adapter: | |
| st.caption("Note: LoRA adapter not available via HF API — base model only.") | |
| with col_btn: | |
| st.markdown("") # spacing | |
| run = st.button("Re-run inference", key=f"reinfer_btn_{sample_idx}", | |
| type="secondary") | |
| if run: | |
| use_adapter = "adapter" in model_choice or "exp-" in model_choice | |
| effective_adapter = adapter if (use_adapter and backend_name == "mlx") else None | |
| with st.spinner(f"Generating via {backend_label}..."): | |
| try: | |
| response = infer_fn(model_id, effective_adapter, | |
| ex["prompt"], max_tokens=512) | |
| # Compute metrics | |
| metrics = _compute_live_metrics(ex["reference"], response) | |
| # Store in session state | |
| key_prefix = "live_adapted" if use_adapter else "live_base" | |
| live_key = f"{key_prefix}_{bench_data.get('id', '')}_{sample_idx}" | |
| st.session_state[live_key] = response | |
| st.success( | |
| f"Generated {metrics['len']} tokens — " | |
| f"ROUGE-1: {metrics['rouge1_f1']:.4f}, " | |
| f"ROUGE-2: {metrics['rouge2_f1']:.4f}" | |
| ) | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Inference failed: {e}") | |
| # ============================================================ | |
| # Baseline Comparison Table | |
| # ============================================================ | |
| def _render_baseline_comparison(bench_data: dict, bench_map: dict, | |
| all_keys: list[str], | |
| agg_override: dict | None = None): | |
| """Render a styled comparison table: fine-tuned model vs all baselines.""" | |
| n_samples = bench_data.get("test_examples") | |
| effective_agg = agg_override if agg_override is not None else bench_data.get("aggregate", {}) | |
| adapted_agg = effective_agg.get("adapted", {}) | |
| adapted_pred = adapted_agg.get("prediction", {}) | |
| # Collect baseline runs (base-only, same test size) | |
| baselines = [] | |
| for k in all_keys: | |
| d = bench_map[k][1] | |
| a = d.get("aggregate", {}) | |
| if "adapted" in a: | |
| continue | |
| if d.get("test_examples") != n_samples: | |
| continue | |
| b = a.get("base", {}) | |
| bp = b.get("prediction", {}) | |
| baselines.append({ | |
| "model": d.get("model", "?"), | |
| "rouge1": b.get("rouge1_f1"), | |
| "bleu": b.get("bleu"), | |
| "term_recall": b.get("clinical_term_recall"), | |
| "fac_exact": bp.get("fac_exact_match"), | |
| "fac_dir": bp.get("fac_direction_accuracy"), | |
| "speed_err": bp.get("speed_mean_abs_error"), | |
| "risk_acc": bp.get("risk_count_accuracy"), | |
| }) | |
| if not baselines: | |
| return | |
| # Build table rows | |
| metrics = [ | |
| ("ROUGE-1 F1", "rouge1", "{:.2f}", True), | |
| ("BLEU-4", "bleu", "{:.2f}", True), | |
| ("Term Recall", "term_recall", "{:.0%}", True), | |
| ("FAC Exact", "fac_exact", "{:.0%}", True), | |
| ("FAC Direction", "fac_dir", "{:.0%}", True), | |
| ("Speed Error", "speed_err", "{:.3f}", False), | |
| ("Risk Accuracy", "risk_acc", "{:.0%}", True), | |
| ] | |
| def _fmt(val, fmt_str): | |
| if val is None: | |
| return "—" | |
| try: | |
| return fmt_str.format(val) | |
| except Exception: | |
| return str(val) | |
| # Adapted model values | |
| adapted_vals = { | |
| "rouge1": adapted_agg.get("rouge1_f1"), | |
| "bleu": adapted_agg.get("bleu"), | |
| "term_recall": adapted_agg.get("clinical_term_recall"), | |
| "fac_exact": adapted_pred.get("fac_exact_match"), | |
| "fac_dir": adapted_pred.get("fac_direction_accuracy"), | |
| "speed_err": adapted_pred.get("speed_mean_abs_error"), | |
| "risk_acc": adapted_pred.get("risk_count_accuracy"), | |
| } | |
| # Shorten model names for display | |
| def _short(model_id: str) -> str: | |
| parts = model_id.split("/") | |
| name = parts[-1] if len(parts) > 1 else model_id | |
| # Remove common suffixes | |
| for suffix in ["-MLX-4bit", "-MLX-8bit"]: | |
| name = name.replace(suffix, "") | |
| return name | |
| # Build HTML table | |
| header_cols = "".join( | |
| f'<th style="padding:6px 10px;text-align:center;font-weight:400;' | |
| f'color:#aaa;font-size:0.82em;">{_short(bl["model"])}</th>' | |
| for bl in baselines | |
| ) | |
| adapted_label = "Ours (LoRA)" | |
| html_rows = [] | |
| for label, key, fmt, higher_better in metrics: | |
| our_val = adapted_vals.get(key) | |
| our_str = _fmt(our_val, fmt) | |
| cells = "" | |
| for bl in baselines: | |
| bl_val = bl.get(key) | |
| bl_str = _fmt(bl_val, fmt) | |
| # Determine if our model wins on this metric | |
| is_winner = False | |
| if our_val is not None and bl_val is not None: | |
| if higher_better: | |
| is_winner = our_val > bl_val | |
| else: | |
| is_winner = our_val < bl_val | |
| # Style: dim if our model is better | |
| color = "#888" if is_winner else "#e0e0e0" | |
| cells += ( | |
| f'<td style="padding:6px 10px;text-align:center;' | |
| f'color:{color};font-size:0.9em;">{bl_str}</td>' | |
| ) | |
| # Our column — bold green | |
| our_cell = ( | |
| f'<td style="padding:6px 10px;text-align:center;' | |
| f'color:#198754;font-weight:600;font-size:0.9em;">{our_str}</td>' | |
| ) | |
| html_rows.append( | |
| f'<tr>' | |
| f'<td style="padding:6px 10px;color:#ccc;font-size:0.85em;">' | |
| f'{label}</td>' | |
| f'{cells}{our_cell}' | |
| f'</tr>' | |
| ) | |
| table_html = f""" | |
| <div style="margin:0.8rem 0 0.5rem 0;"> | |
| <p style="color:#aaa;font-size:0.82em;margin-bottom:6px;"> | |
| Comparison across all evaluated models on {n_samples} held-out samples | |
| </p> | |
| <table style="width:100%;border-collapse:collapse;border:1px solid #333; | |
| border-radius:6px;overflow:hidden;"> | |
| <thead> | |
| <tr style="border-bottom:1px solid #333;"> | |
| <th style="padding:6px 10px;text-align:left;color:#888; | |
| font-size:0.82em;font-weight:400;">Metric</th> | |
| {header_cols} | |
| <th style="padding:6px 10px;text-align:center;font-weight:600; | |
| color:#198754;font-size:0.82em; | |
| border-left:2px solid #198754;">{adapted_label}</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {"".join(html_rows)} | |
| </tbody> | |
| </table> | |
| </div> | |
| """ | |
| # Fix: add green left-border to our column cells | |
| table_html = table_html.replace( | |
| f'color:#198754;font-weight:600;font-size:0.9em;">', | |
| f'color:#198754;font-weight:600;font-size:0.9em;' | |
| f'border-left:2px solid #198754;">' | |
| ) | |
| st.markdown(table_html, unsafe_allow_html=True) | |
| # ============================================================ | |
| # Main Entry Point | |
| # ============================================================ | |
| def render_eval_tab(): | |
| """Main entry point — called from app.py.""" | |
| st.title("Model Evaluation") | |
| st.caption("Quantitative benchmarks on held-out test data · Live re-inference") | |
| # Load bench files | |
| bench_files = _list_bench_files() | |
| if not bench_files: | |
| st.warning( | |
| "No benchmark results found. Run:\n\n" | |
| "```\npython src/models/lab.py bench " | |
| "--dataset rehab_public_v1 --experiment <id>\n```" | |
| ) | |
| return | |
| # Load all bench data (including routed runs) | |
| bench_map = {} | |
| for f in bench_files: | |
| data = _load_bench(str(f)) | |
| if data: | |
| bench_map[f.stem] = (f, data) | |
| if not bench_map: | |
| st.error("Could not load any benchmark files.") | |
| return | |
| # Bench run selector — only runs with adapted results on the latest | |
| # full test set (533 samples). Base-only runs are accessible via | |
| # the "Baseline source" selector instead. | |
| CURRENT_TEST_SIZE = 533 | |
| all_keys = list(bench_map.keys()) | |
| bench_keys = [ | |
| k for k in all_keys | |
| if "adapted" in bench_map[k][1].get("aggregate", {}) | |
| and bench_map[k][1].get("test_examples") == CURRENT_TEST_SIZE | |
| ] | |
| if not bench_keys: | |
| # Fallback: show everything if no matching runs exist | |
| bench_keys = all_keys | |
| bench_labels = [_format_bench_label(k, bench_map[k][1]) for k in bench_keys] | |
| sel_idx = st.selectbox( | |
| "Benchmark run", | |
| range(len(bench_keys)), | |
| index=0, | |
| format_func=lambda i: bench_labels[i], | |
| key="eval_bench_selector", | |
| ) | |
| selected = bench_keys[sel_idx] | |
| _, bench_data = bench_map[selected] | |
| agg = bench_data.get("aggregate", {}) | |
| specialized_agg = _recompute_specialized_aggregate(bench_data) | |
| if specialized_agg is not None: | |
| agg = specialized_agg | |
| examples = bench_data.get("per_example", []) | |
| has_adapted = "adapted" in agg | |
| # Info bar | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.caption(f"**Model:** `{bench_data.get('model', '?')}`") | |
| with col2: | |
| adapter = _resolve_adapter(bench_data) | |
| adapter_label = Path(adapter).name if adapter else "none" | |
| if _is_routed(bench_data): | |
| n_routes = len(bench_data.get("routing", {})) | |
| adapter_label = f"routed ({n_routes} specialized + general)" | |
| st.caption(f"**Adapter:** `{adapter_label}`") | |
| with col3: | |
| st.caption(f"**Samples:** `{len(examples)}`") | |
| # --- About This Model (collapsible) --- | |
| _render_model_info(bench_data, agg, has_adapted, len(examples)) | |
| # --- Baseline Comparison Table --- | |
| if has_adapted: | |
| _render_baseline_comparison(bench_data, bench_map, all_keys, | |
| agg_override=agg) | |
| # --- Baseline source selector (swaps base metrics for sections below) --- | |
| if has_adapted: | |
| n_samples = len(examples) | |
| own_model = bench_data.get("model", "") | |
| own_key = selected | |
| base_only_keys = [ | |
| k for k in all_keys | |
| if k != own_key | |
| and bench_map[k][1].get("test_examples") == n_samples | |
| and ( | |
| "adapted" not in bench_map[k][1].get("aggregate", {}) | |
| or bench_map[k][1].get("model", "") != own_model | |
| ) | |
| ] | |
| if base_only_keys: | |
| builtin_label = f"Built-in ({bench_data.get('model', 'Qwen')})" | |
| ext_labels = [] | |
| for k in base_only_keys: | |
| d = bench_map[k][1] | |
| model = d.get("model", "?") | |
| ts = d.get("timestamp", "")[:16].replace("T", " ") | |
| has_ext_adapted = "adapted" in d.get("aggregate", {}) | |
| tag = " [adapted]" if has_ext_adapted else "" | |
| ext_labels.append(f"{model}{tag} ({ts})") | |
| options = [builtin_label] + ext_labels | |
| default_idx = 1 if len(options) > 1 else 0 | |
| bl_sel = st.selectbox( | |
| "Baseline source", | |
| range(len(options)), | |
| index=default_idx, | |
| format_func=lambda i: options[i], | |
| key="eval_baseline_selector", | |
| ) | |
| if bl_sel > 0: | |
| ext_key = base_only_keys[bl_sel - 1] | |
| _, ext_data = bench_map[ext_key] | |
| ext_agg = ext_data.get("aggregate", {}) | |
| ext_examples = ext_data.get("per_example", []) | |
| agg = dict(agg) | |
| if "adapted" in ext_agg: | |
| agg["base"] = ext_agg["adapted"] | |
| else: | |
| agg["base"] = ext_agg.get("base", agg.get("base", {})) | |
| if len(ext_examples) == len(examples): | |
| examples = [dict(ex) for ex in examples] | |
| for i, ext_ex in enumerate(ext_examples): | |
| examples[i]["metrics_base"] = ext_ex.get( | |
| "metrics_base", examples[i].get("metrics_base", {})) | |
| examples[i]["base_response"] = ext_ex.get( | |
| "base_response", examples[i].get("base_response", "")) | |
| st.divider() | |
| # --- Metric Cards --- | |
| _render_metric_cards(agg, has_adapted) | |
| # --- Predictive Accuracy (hero section for specialized adapters) --- | |
| _render_prediction_accuracy(agg, has_adapted, examples) | |
| # --- Text Quality by Category --- | |
| _render_category_chart(agg, has_adapted) | |
| st.divider() | |
| # --- Sample Browser --- | |
| _render_sample_browser(examples, bench_data, has_adapted) | |