Spaces:
Running
Running
| """ | |
| Streamlit dashboard for the clean data-scaling study. | |
| Run from the experiments/clean_data_scaling_study/ directory: | |
| streamlit run dashboard/app.py | |
| Three sections: | |
| 1. Learning curves β per-epoch metrics for every (model, share) run | |
| 2. Data share vs final β best/final val mIoU/Dice/IoU/PixelAcc as a function of data share | |
| 3. Inference β upload an image, see the 4Γ3 grid of predictions | |
| Reads logs from ../logs and checkpoints from ../checkpoints. | |
| """ | |
| import io | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| THIS_DIR = Path(__file__).resolve().parent | |
| EXP_DIR = THIS_DIR.parent | |
| LOGS_DIR = EXP_DIR / "logs" | |
| CKPT_DIR = EXP_DIR / "checkpoints" | |
| sys.path.insert(0, str(EXP_DIR)) | |
| from models import MODEL_REGISTRY, PRETTY_NAME # noqa: E402 | |
| MODELS = ["segnet", "unet", "segformer_b0", "segformer_b5"] | |
| SHARES = [25, 50, 100] | |
| METRICS = [ | |
| ("dice", "Dice"), | |
| ("miou", "mIoU"), | |
| ("iou", "Foreground IoU"), | |
| ("pixel_acc", "Pixel Accuracy"), | |
| ("loss", "Loss"), | |
| ] | |
| try: | |
| st.set_page_config(page_title="Clean Data Scaling Study", layout="wide") | |
| except Exception: | |
| # already configured by a parent multi-page app | |
| pass | |
| # ββ Loaders ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_log(model: str, share: int): | |
| p = LOGS_DIR / f"{model}_{share}.json" | |
| if not p.is_file(): | |
| return None | |
| with open(p) as f: | |
| return json.load(f) | |
| def log_to_df(log): | |
| df = pd.DataFrame(log["epochs"]) | |
| df["model"] = log["model"] | |
| df["share"] = log["share"] | |
| return df | |
| def load_all_logs(): | |
| logs = {} | |
| for m in MODELS: | |
| for s in SHARES: | |
| log = load_log(m, s) | |
| if log is not None: | |
| logs[(m, s)] = log | |
| return logs | |
| def fmt_hms(seconds): | |
| if seconds is None: | |
| return "β" | |
| seconds = int(round(seconds)) | |
| h, rem = divmod(seconds, 3600) | |
| m, s = divmod(rem, 60) | |
| return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" | |
| def scaling_row(log, kind="best"): | |
| """One summary row per (model, share). kind β {'best', 'final'}.""" | |
| epochs = log["epochs"] | |
| row = {"model": PRETTY_NAME[log["model"]], "share": log["share"]} | |
| if not epochs: | |
| for k in ("val_dice", "val_miou", "val_iou", "val_pixel_acc"): | |
| row[k] = None | |
| return row | |
| if kind == "best": | |
| idx = max(range(len(epochs)), key=lambda i: epochs[i].get("val_dice", -1) or -1) | |
| else: # final | |
| idx = len(epochs) - 1 | |
| chosen = epochs[idx] | |
| row["epoch"] = chosen["epoch"] | |
| row["val_dice"] = chosen.get("val_dice") | |
| row["val_miou"] = chosen.get("val_miou") | |
| row["val_iou"] = chosen.get("val_iou") | |
| row["val_pixel_acc"] = chosen.get("val_pixel_acc") | |
| row["wall_clock_seconds"] = log.get("wall_clock_seconds") | |
| row["wall_clock"] = fmt_hms(log.get("wall_clock_seconds")) | |
| if epochs: | |
| per = [e.get("epoch_seconds") for e in epochs if e.get("epoch_seconds") is not None] | |
| row["sec_per_epoch"] = (sum(per) / len(per)) if per else None | |
| return row | |
| # ββ Inference helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_ckpt(model_name: str, share: int, kind: str, device: str): | |
| p = CKPT_DIR / f"{model_name}_{share}_{kind}.pth" | |
| if not p.is_file(): | |
| return None, False | |
| builder = MODEL_REGISTRY[model_name] | |
| model, _, output_is_prob = builder() | |
| state = torch.load(p, map_location=device, weights_only=False) | |
| model.load_state_dict(state["model_state_dict"]) | |
| model.to(device).eval() | |
| output_is_prob = state.get("output_is_prob", output_is_prob) | |
| return model, output_is_prob | |
| def preprocess(image: Image.Image, image_size: int = 128): | |
| tf = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| return tf(image.convert("RGB")).unsqueeze(0) | |
| def run_inference(model, image_tensor, device, output_is_prob: bool, threshold=0.5): | |
| with torch.no_grad(): | |
| out = model(image_tensor.to(device)) | |
| probs = out if output_is_prob else torch.sigmoid(out) | |
| probs = probs.squeeze().cpu().numpy() | |
| if probs.ndim != 2: | |
| probs = probs.reshape(probs.shape[-2], probs.shape[-1]) | |
| mask = (probs > threshold).astype(np.float32) | |
| return probs, mask | |
| def overlay(rgb: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.45): | |
| out = rgb.copy() | |
| m = mask.astype(bool) | |
| out[m] = (alpha * np.array(color) + (1 - alpha) * out[m]).astype(np.uint8) | |
| return out | |
| def heatmap(probs: np.ndarray) -> np.ndarray: | |
| p = np.clip(probs, 0.0, 1.0) | |
| rgb = np.zeros((p.shape[0], p.shape[1], 3), dtype=np.uint8) | |
| rgb[..., 0] = (p * 255).astype(np.uint8) | |
| rgb[..., 1] = (np.maximum(0, 1 - 2 * np.abs(p - 0.5)) * 255).astype(np.uint8) | |
| rgb[..., 2] = ((1 - p) * 255).astype(np.uint8) | |
| return rgb | |
| # ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title("π Clean Data-Scaling Study β 4 architectures Γ 3 data shares") | |
| st.caption( | |
| "Trained from scratch on a deduplicated dataset (final_data_clean/). " | |
| "All numbers use the same global confusion-matrix metric code." | |
| ) | |
| logs = load_all_logs() | |
| if not logs: | |
| st.warning("No logs found in `../logs/`. Run training first (`./run_all.sh`).") | |
| tab_curves, tab_scaling, tab_infer = st.tabs( | |
| ["1 Β· Learning curves", "2 Β· Data share vs final", "3 Β· Inference"] | |
| ) | |
| # ββ Tab 1: Learning curves βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_curves: | |
| st.subheader("Per-epoch metrics") | |
| if not logs: | |
| st.info("Waiting for training logs.") | |
| else: | |
| col_m, col_split = st.columns([2, 2]) | |
| with col_m: | |
| metric_key, metric_label = st.selectbox( | |
| "Metric", METRICS, format_func=lambda x: x[1], | |
| ) | |
| with col_split: | |
| split = st.radio("Split", ["val", "train", "both"], horizontal=True, index=0) | |
| for model in MODELS: | |
| available = [s for s in SHARES if (model, s) in logs] | |
| if not available: | |
| continue | |
| st.markdown(f"#### {PRETTY_NAME[model]}") | |
| fig = go.Figure() | |
| for share in available: | |
| df = log_to_df(logs[(model, share)]) | |
| if split in ("val", "both"): | |
| col = f"val_{metric_key}" | |
| if col in df.columns and df[col].notna().any(): | |
| sub = df.dropna(subset=[col]) | |
| fig.add_trace(go.Scatter( | |
| x=sub["epoch"], y=sub[col], mode="lines", | |
| name=f"{share}% val", | |
| )) | |
| if split in ("train", "both"): | |
| col = f"train_{metric_key}" | |
| if col in df.columns and df[col].notna().any(): | |
| sub = df.dropna(subset=[col]) | |
| fig.add_trace(go.Scatter( | |
| x=sub["epoch"], y=sub[col], mode="lines", line=dict(dash="dot"), | |
| name=f"{share}% train", | |
| )) | |
| fig.update_layout( | |
| xaxis_title="Epoch", yaxis_title=metric_label, height=360, | |
| margin=dict(l=10, r=10, t=10, b=10), | |
| legend=dict(orientation="h", y=-0.2), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # ββ Tab 2: Data share vs final βββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_scaling: | |
| st.subheader("Val metrics vs data share") | |
| if not logs: | |
| st.info("Waiting for training logs.") | |
| else: | |
| kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, index=0) | |
| rows = [scaling_row(log, kind=kind) for log in logs.values()] | |
| df = pd.DataFrame(rows).sort_values(["model", "share"]).reset_index(drop=True) | |
| display_df = df.drop(columns=["wall_clock_seconds", "sec_per_epoch"], errors="ignore") | |
| st.dataframe(display_df, use_container_width=True, hide_index=True) | |
| trained_seconds = df["wall_clock_seconds"].dropna().sum() | |
| if trained_seconds: | |
| st.caption( | |
| f"β± Total training wall-clock across all runs: " | |
| f"**{fmt_hms(trained_seconds)}** ({trained_seconds:,.0f} s)" | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| fig = px.line( | |
| df.dropna(subset=["val_miou"]), | |
| x="share", y="val_miou", color="model", | |
| markers=True, title=f"Val mIoU ({kind})", | |
| labels={"share": "Training data (%)", "val_miou": "Val mIoU"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with col2: | |
| fig = px.line( | |
| df.dropna(subset=["val_dice"]), | |
| x="share", y="val_dice", color="model", | |
| markers=True, title=f"Val Dice ({kind})", | |
| labels={"share": "Training data (%)", "val_dice": "Val Dice"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| fig = px.bar( | |
| df.dropna(subset=["val_iou"]), | |
| x="share", y="val_iou", color="model", barmode="group", | |
| title=f"Val foreground IoU ({kind})", | |
| labels={"share": "Training data (%)", "val_iou": "Val IoU (foreground)"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with col4: | |
| fig = px.bar( | |
| df.dropna(subset=["val_pixel_acc"]), | |
| x="share", y="val_pixel_acc", color="model", barmode="group", | |
| title=f"Val pixel accuracy ({kind})", | |
| labels={"share": "Training data (%)", "val_pixel_acc": "Val pixel acc"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.markdown("##### Training time") | |
| time_df = df.dropna(subset=["wall_clock_seconds"]).assign( | |
| wall_minutes=lambda d: d["wall_clock_seconds"] / 60.0 | |
| ) | |
| if not time_df.empty: | |
| tcol1, tcol2 = st.columns(2) | |
| with tcol1: | |
| fig = px.bar( | |
| time_df, x="share", y="wall_minutes", color="model", barmode="group", | |
| title="Total training time (minutes)", | |
| labels={"share": "Training data (%)", "wall_minutes": "Wall clock (min)"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with tcol2: | |
| fig = px.bar( | |
| time_df.dropna(subset=["sec_per_epoch"]), | |
| x="share", y="sec_per_epoch", color="model", barmode="group", | |
| title="Average seconds per epoch", | |
| labels={"share": "Training data (%)", "sec_per_epoch": "Seconds / epoch"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # ββ Tab 3: Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_infer: | |
| st.subheader("Upload an image β 4 models Γ 3 shares = 12 segmentations") | |
| st.caption( | |
| "Each cell uses one (model, data-share) checkpoint. " | |
| "Toggle best vs final to see end-of-training behavior." | |
| ) | |
| debug = st.checkbox("debug", value=False, key="infer_debug") | |
| col_a, col_b, col_c, col_d = st.columns([2, 2, 2, 2]) | |
| with col_a: | |
| kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, key="infer_kind") | |
| with col_b: | |
| threshold = st.slider("Threshold", 0.0, 1.0, 0.5, 0.05, key="infer_thr") | |
| with col_c: | |
| view = st.radio("View", ["mask", "overlay", "heatmap"], horizontal=True, key="infer_view") | |
| with col_d: | |
| cell_w = st.select_slider( | |
| "Cell size (px)", options=[140, 180, 220, 260], value=180, key="infer_cell" | |
| ) | |
| uploaded = st.file_uploader( | |
| "Drop an image (jpg/png)", type=["jpg", "jpeg", "png"], key="infer_upload" | |
| ) | |
| if uploaded is not None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| raw_bytes = uploaded.getvalue() | |
| img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") | |
| except Exception as e: | |
| st.error(f"Could not decode uploaded image: {e}") | |
| st.stop() | |
| st.caption(f"π `{uploaded.name}` β {img.size[0]}Γ{img.size[1]} px, {len(raw_bytes)/1024:.1f} KB") | |
| x = preprocess(img, image_size=128) | |
| rgb_small = (x.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| st.image([img, rgb_small], width=cell_w * 2, caption=["original", "128Γ128 (model input)"]) | |
| if debug: | |
| st.write(f"tensor={tuple(x.shape)} device={device} models={MODELS} shares={SHARES}") | |
| st.markdown("##### Predictions (rows = model, columns = data share)") | |
| def render_cell(probs, mask, rgb): | |
| if view == "mask": | |
| return (mask * 255).astype(np.uint8) | |
| if view == "overlay": | |
| return overlay(rgb, mask) | |
| return heatmap(probs) | |
| for model_name in MODELS: | |
| cols = st.columns(len(SHARES)) | |
| for col, share in zip(cols, SHARES): | |
| with col: | |
| st.markdown(f"**{PRETTY_NAME[model_name]} Β· {share}%**") | |
| try: | |
| with st.spinner(f"loading {model_name} {share}%β¦"): | |
| m, output_is_prob = load_ckpt(model_name, share, kind, device) | |
| if m is None: | |
| st.warning(f"missing `{model_name}_{share}_{kind}.pth`") | |
| continue | |
| probs, mask = run_inference(m, x, device, output_is_prob, threshold) | |
| cell_img = render_cell(probs, mask, rgb_small) | |
| st.image(cell_img, width=cell_w) | |
| st.caption( | |
| f"cov={float(mask.mean())*100:.1f}% " | |
| f"p[{probs.min():.2f},{probs.max():.2f}]" | |
| ) | |
| except Exception as e: | |
| st.error(f"{model_name} {share}% β {type(e).__name__}") | |
| st.exception(e) | |
| else: | |
| st.info("Upload an image to run inference across all 12 trained models.") | |