""" Streamlit dashboard for the clean data-scaling study. Run from the experiments/clean_data_scaling_study/ directory: streamlit run dashboard/app.py Three sections: 1. Learning curves — per-epoch metrics for every (model, share) run 2. Data share vs final — best/final val mIoU/Dice/IoU/PixelAcc as a function of data share 3. Inference — upload an image, see the 4×3 grid of predictions Reads logs from ../logs and checkpoints from ../checkpoints. """ import io import json import sys from pathlib import Path import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import streamlit as st import torch from PIL import Image from torchvision import transforms THIS_DIR = Path(__file__).resolve().parent EXP_DIR = THIS_DIR.parent LOGS_DIR = EXP_DIR / "logs" CKPT_DIR = EXP_DIR / "checkpoints" sys.path.insert(0, str(EXP_DIR)) from models import MODEL_REGISTRY, PRETTY_NAME # noqa: E402 MODELS = ["segnet", "unet", "segformer_b0", "segformer_b5"] SHARES = [25, 50, 100] METRICS = [ ("dice", "Dice"), ("miou", "mIoU"), ("iou", "Foreground IoU"), ("pixel_acc", "Pixel Accuracy"), ("loss", "Loss"), ] try: st.set_page_config(page_title="Clean Data Scaling Study", layout="wide") except Exception: # already configured by a parent multi-page app pass # ── Loaders ──────────────────────────────────────────────────────────────── @st.cache_data(show_spinner=False) def load_log(model: str, share: int): p = LOGS_DIR / f"{model}_{share}.json" if not p.is_file(): return None with open(p) as f: return json.load(f) def log_to_df(log): df = pd.DataFrame(log["epochs"]) df["model"] = log["model"] df["share"] = log["share"] return df @st.cache_data(show_spinner=False) def load_all_logs(): logs = {} for m in MODELS: for s in SHARES: log = load_log(m, s) if log is not None: logs[(m, s)] = log return logs def fmt_hms(seconds): if seconds is None: return "—" seconds = int(round(seconds)) h, rem = divmod(seconds, 3600) m, s = divmod(rem, 60) return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" def scaling_row(log, kind="best"): """One summary row per (model, share). kind ∈ {'best', 'final'}.""" epochs = log["epochs"] row = {"model": PRETTY_NAME[log["model"]], "share": log["share"]} if not epochs: for k in ("val_dice", "val_miou", "val_iou", "val_pixel_acc"): row[k] = None return row if kind == "best": idx = max(range(len(epochs)), key=lambda i: epochs[i].get("val_dice", -1) or -1) else: # final idx = len(epochs) - 1 chosen = epochs[idx] row["epoch"] = chosen["epoch"] row["val_dice"] = chosen.get("val_dice") row["val_miou"] = chosen.get("val_miou") row["val_iou"] = chosen.get("val_iou") row["val_pixel_acc"] = chosen.get("val_pixel_acc") row["wall_clock_seconds"] = log.get("wall_clock_seconds") row["wall_clock"] = fmt_hms(log.get("wall_clock_seconds")) if epochs: per = [e.get("epoch_seconds") for e in epochs if e.get("epoch_seconds") is not None] row["sec_per_epoch"] = (sum(per) / len(per)) if per else None return row # ── Inference helpers ────────────────────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_ckpt(model_name: str, share: int, kind: str, device: str): p = CKPT_DIR / f"{model_name}_{share}_{kind}.pth" if not p.is_file(): return None, False builder = MODEL_REGISTRY[model_name] model, _, output_is_prob = builder() state = torch.load(p, map_location=device, weights_only=False) model.load_state_dict(state["model_state_dict"]) model.to(device).eval() output_is_prob = state.get("output_is_prob", output_is_prob) return model, output_is_prob def preprocess(image: Image.Image, image_size: int = 128): tf = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), ]) return tf(image.convert("RGB")).unsqueeze(0) def run_inference(model, image_tensor, device, output_is_prob: bool, threshold=0.5): with torch.no_grad(): out = model(image_tensor.to(device)) probs = out if output_is_prob else torch.sigmoid(out) probs = probs.squeeze().cpu().numpy() if probs.ndim != 2: probs = probs.reshape(probs.shape[-2], probs.shape[-1]) mask = (probs > threshold).astype(np.float32) return probs, mask def overlay(rgb: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.45): out = rgb.copy() m = mask.astype(bool) out[m] = (alpha * np.array(color) + (1 - alpha) * out[m]).astype(np.uint8) return out def heatmap(probs: np.ndarray) -> np.ndarray: p = np.clip(probs, 0.0, 1.0) rgb = np.zeros((p.shape[0], p.shape[1], 3), dtype=np.uint8) rgb[..., 0] = (p * 255).astype(np.uint8) rgb[..., 1] = (np.maximum(0, 1 - 2 * np.abs(p - 0.5)) * 255).astype(np.uint8) rgb[..., 2] = ((1 - p) * 255).astype(np.uint8) return rgb # ── UI ───────────────────────────────────────────────────────────────────── st.title("📊 Clean Data-Scaling Study — 4 architectures × 3 data shares") st.caption( "Trained from scratch on a deduplicated dataset (final_data_clean/). " "All numbers use the same global confusion-matrix metric code." ) logs = load_all_logs() if not logs: st.warning("No logs found in `../logs/`. Run training first (`./run_all.sh`).") tab_curves, tab_scaling, tab_infer = st.tabs( ["1 · Learning curves", "2 · Data share vs final", "3 · Inference"] ) # ── Tab 1: Learning curves ───────────────────────────────────────────────── with tab_curves: st.subheader("Per-epoch metrics") if not logs: st.info("Waiting for training logs.") else: col_m, col_split = st.columns([2, 2]) with col_m: metric_key, metric_label = st.selectbox( "Metric", METRICS, format_func=lambda x: x[1], ) with col_split: split = st.radio("Split", ["val", "train", "both"], horizontal=True, index=0) for model in MODELS: available = [s for s in SHARES if (model, s) in logs] if not available: continue st.markdown(f"#### {PRETTY_NAME[model]}") fig = go.Figure() for share in available: df = log_to_df(logs[(model, share)]) if split in ("val", "both"): col = f"val_{metric_key}" if col in df.columns and df[col].notna().any(): sub = df.dropna(subset=[col]) fig.add_trace(go.Scatter( x=sub["epoch"], y=sub[col], mode="lines", name=f"{share}% val", )) if split in ("train", "both"): col = f"train_{metric_key}" if col in df.columns and df[col].notna().any(): sub = df.dropna(subset=[col]) fig.add_trace(go.Scatter( x=sub["epoch"], y=sub[col], mode="lines", line=dict(dash="dot"), name=f"{share}% train", )) fig.update_layout( xaxis_title="Epoch", yaxis_title=metric_label, height=360, margin=dict(l=10, r=10, t=10, b=10), legend=dict(orientation="h", y=-0.2), ) st.plotly_chart(fig, use_container_width=True) # ── Tab 2: Data share vs final ───────────────────────────────────────────── with tab_scaling: st.subheader("Val metrics vs data share") if not logs: st.info("Waiting for training logs.") else: kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, index=0) rows = [scaling_row(log, kind=kind) for log in logs.values()] df = pd.DataFrame(rows).sort_values(["model", "share"]).reset_index(drop=True) display_df = df.drop(columns=["wall_clock_seconds", "sec_per_epoch"], errors="ignore") st.dataframe(display_df, use_container_width=True, hide_index=True) trained_seconds = df["wall_clock_seconds"].dropna().sum() if trained_seconds: st.caption( f"⏱ Total training wall-clock across all runs: " f"**{fmt_hms(trained_seconds)}** ({trained_seconds:,.0f} s)" ) col1, col2 = st.columns(2) with col1: fig = px.line( df.dropna(subset=["val_miou"]), x="share", y="val_miou", color="model", markers=True, title=f"Val mIoU ({kind})", labels={"share": "Training data (%)", "val_miou": "Val mIoU"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) with col2: fig = px.line( df.dropna(subset=["val_dice"]), x="share", y="val_dice", color="model", markers=True, title=f"Val Dice ({kind})", labels={"share": "Training data (%)", "val_dice": "Val Dice"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) col3, col4 = st.columns(2) with col3: fig = px.bar( df.dropna(subset=["val_iou"]), x="share", y="val_iou", color="model", barmode="group", title=f"Val foreground IoU ({kind})", labels={"share": "Training data (%)", "val_iou": "Val IoU (foreground)"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) with col4: fig = px.bar( df.dropna(subset=["val_pixel_acc"]), x="share", y="val_pixel_acc", color="model", barmode="group", title=f"Val pixel accuracy ({kind})", labels={"share": "Training data (%)", "val_pixel_acc": "Val pixel acc"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) st.markdown("##### Training time") time_df = df.dropna(subset=["wall_clock_seconds"]).assign( wall_minutes=lambda d: d["wall_clock_seconds"] / 60.0 ) if not time_df.empty: tcol1, tcol2 = st.columns(2) with tcol1: fig = px.bar( time_df, x="share", y="wall_minutes", color="model", barmode="group", title="Total training time (minutes)", labels={"share": "Training data (%)", "wall_minutes": "Wall clock (min)"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) with tcol2: fig = px.bar( time_df.dropna(subset=["sec_per_epoch"]), x="share", y="sec_per_epoch", color="model", barmode="group", title="Average seconds per epoch", labels={"share": "Training data (%)", "sec_per_epoch": "Seconds / epoch"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) # ── Tab 3: Inference ─────────────────────────────────────────────────────── with tab_infer: st.subheader("Upload an image — 4 models × 3 shares = 12 segmentations") st.caption( "Each cell uses one (model, data-share) checkpoint. " "Toggle best vs final to see end-of-training behavior." ) debug = st.checkbox("debug", value=False, key="infer_debug") col_a, col_b, col_c, col_d = st.columns([2, 2, 2, 2]) with col_a: kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, key="infer_kind") with col_b: threshold = st.slider("Threshold", 0.0, 1.0, 0.5, 0.05, key="infer_thr") with col_c: view = st.radio("View", ["mask", "overlay", "heatmap"], horizontal=True, key="infer_view") with col_d: cell_w = st.select_slider( "Cell size (px)", options=[140, 180, 220, 260], value=180, key="infer_cell" ) uploaded = st.file_uploader( "Drop an image (jpg/png)", type=["jpg", "jpeg", "png"], key="infer_upload" ) if uploaded is not None: device = "cuda" if torch.cuda.is_available() else "cpu" try: raw_bytes = uploaded.getvalue() img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") except Exception as e: st.error(f"Could not decode uploaded image: {e}") st.stop() st.caption(f"📁 `{uploaded.name}` — {img.size[0]}×{img.size[1]} px, {len(raw_bytes)/1024:.1f} KB") x = preprocess(img, image_size=128) rgb_small = (x.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8) st.image([img, rgb_small], width=cell_w * 2, caption=["original", "128×128 (model input)"]) if debug: st.write(f"tensor={tuple(x.shape)} device={device} models={MODELS} shares={SHARES}") st.markdown("##### Predictions (rows = model, columns = data share)") def render_cell(probs, mask, rgb): if view == "mask": return (mask * 255).astype(np.uint8) if view == "overlay": return overlay(rgb, mask) return heatmap(probs) for model_name in MODELS: cols = st.columns(len(SHARES)) for col, share in zip(cols, SHARES): with col: st.markdown(f"**{PRETTY_NAME[model_name]} · {share}%**") try: with st.spinner(f"loading {model_name} {share}%…"): m, output_is_prob = load_ckpt(model_name, share, kind, device) if m is None: st.warning(f"missing `{model_name}_{share}_{kind}.pth`") continue probs, mask = run_inference(m, x, device, output_is_prob, threshold) cell_img = render_cell(probs, mask, rgb_small) st.image(cell_img, width=cell_w) st.caption( f"cov={float(mask.mean())*100:.1f}% " f"p[{probs.min():.2f},{probs.max():.2f}]" ) except Exception as e: st.error(f"{model_name} {share}% — {type(e).__name__}") st.exception(e) else: st.info("Upload an image to run inference across all 12 trained models.")