""" Unified single-page Streamlit dashboard. Combines the two studies into one app with six top-level tabs: 1. Baseline learning curves (4 architectures × 3 shares) 2. Baseline data-scaling charts 3. Baseline inference grid (uploads → 4 × 3 mask grid) 4. Fine-tune grid heatmap (54 configs) 5. Fine-tune top configs table 6. Fine-tune per-config curves Run locally: streamlit run app.py Deploy on Streamlit Cloud: Main file path = app.py """ import io import json import os import sys from pathlib import Path import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import streamlit as st import torch from PIL import Image from torchvision import transforms ROOT = Path(__file__).resolve().parent BASELINE_DIR = ROOT / "experiments" / "clean_data_scaling_study" GRID_DIR = ROOT / "experiments" / "finetune_grid_search" BASELINE_LOGS = BASELINE_DIR / "logs" BASELINE_CKPT = BASELINE_DIR / "checkpoints" GRID_LOGS = GRID_DIR / "logs" GRID_CSV = GRID_DIR / "results" / "grid_results.csv" # HF Model repo holding the .pth files. Set via env var on Spaces; falls back # to the default below for local runs that haven't touched the env. If the # repo doesn't exist or the file isn't there, the app shows a "missing" warning. HF_WEIGHTS_REPO = os.environ.get("HF_WEIGHTS_REPO", "phiniqs/seg-models-weights") def _resolve_ckpt(local_path: Path, hf_filename: str) -> Path | None: """Return a local path to the checkpoint. Try disk first, then HF Hub.""" if local_path.is_file(): return local_path try: from huggingface_hub import hf_hub_download downloaded = hf_hub_download( repo_id=HF_WEIGHTS_REPO, filename=hf_filename, ) return Path(downloaded) except Exception: return None # Make the baseline experiment importable so we can load its model registry. sys.path.insert(0, str(BASELINE_DIR)) from models import MODEL_REGISTRY, PRETTY_NAME # noqa: E402 BASELINE_MODELS = ["segnet", "unet", "segformer_b0", "segformer_b5"] SHARES = [25, 50, 100] GRID_MODELS = ["unet", "segformer_b0"] GRID_LRS = [1e-5, 5e-5, 1e-4] GRID_BCES = [0.3, 0.5, 0.7] GRID_AUGS = ["none", "default", "strong"] METRICS = [ ("dice", "Dice"), ("miou", "mIoU"), ("iou", "Foreground IoU"), ("pixel_acc", "Pixel Accuracy"), ("loss", "Loss"), ] st.set_page_config(page_title="Solar Panel Segmentation — Dashboards", layout="wide") # ── Loaders ──────────────────────────────────────────────────────────────── def _mtime(path: Path) -> float: return path.stat().st_mtime if path.is_file() else 0.0 @st.cache_data(show_spinner=False) def _read_baseline_log(model: str, share: int, _mt: float): p = BASELINE_LOGS / f"{model}_{share}.json" if not p.is_file(): return None with open(p) as f: return json.load(f) def load_baseline_log(model: str, share: int): return _read_baseline_log(model, share, _mtime(BASELINE_LOGS / f"{model}_{share}.json")) @st.cache_data(show_spinner=False) def load_all_baseline_logs(): out = {} for m in BASELINE_MODELS: for s in SHARES: log = load_baseline_log(m, s) if log is not None: out[(m, s)] = log return out @st.cache_data(show_spinner=False) def _read_grid_csv(_mt: float): if not GRID_CSV.is_file(): return pd.DataFrame() return pd.read_csv(GRID_CSV) def load_grid_results(): return _read_grid_csv(_mtime(GRID_CSV)) @st.cache_data(show_spinner=False) def _read_grid_log(cfg_id: str, _mt: float): p = GRID_LOGS / f"{cfg_id}.json" if not p.is_file(): return None with open(p) as f: return json.load(f) def load_grid_log(cfg_id: str): return _read_grid_log(cfg_id, _mtime(GRID_LOGS / f"{cfg_id}.json")) # ── Shared helpers ───────────────────────────────────────────────────────── def fmt_hms(seconds): if seconds is None: return "—" seconds = int(round(seconds)) h, rem = divmod(seconds, 3600) m, s = divmod(rem, 60) return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" def baseline_log_to_df(log): df = pd.DataFrame(log["epochs"]) df["model"] = log["model"] df["share"] = log["share"] return df def baseline_scaling_row(log, kind="best"): epochs = log["epochs"] row = {"model": PRETTY_NAME[log["model"]], "share": log["share"]} if not epochs: for k in ("val_dice", "val_miou", "val_iou", "val_pixel_acc"): row[k] = None return row if kind == "best": idx = max(range(len(epochs)), key=lambda i: epochs[i].get("val_dice", -1) or -1) else: idx = len(epochs) - 1 chosen = epochs[idx] row["epoch"] = chosen["epoch"] row["val_dice"] = chosen.get("val_dice") row["val_miou"] = chosen.get("val_miou") row["val_iou"] = chosen.get("val_iou") row["val_pixel_acc"] = chosen.get("val_pixel_acc") row["wall_clock_seconds"] = log.get("wall_clock_seconds") row["wall_clock"] = fmt_hms(log.get("wall_clock_seconds")) if epochs: per = [e.get("epoch_seconds") for e in epochs if e.get("epoch_seconds") is not None] row["sec_per_epoch"] = (sum(per) / len(per)) if per else None return row # ── Inference helpers (used by tab 3) ────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_baseline_ckpt(model_name: str, share: int, kind: str, device: str): fname = f"{model_name}_{share}_{kind}.pth" p = _resolve_ckpt(BASELINE_CKPT / fname, f"baseline/{fname}") if p is None: return None, False builder = MODEL_REGISTRY[model_name] model, _, output_is_prob = builder() state = torch.load(p, map_location=device, weights_only=False) model.load_state_dict(state["model_state_dict"]) model.to(device).eval() output_is_prob = state.get("output_is_prob", output_is_prob) return model, output_is_prob def preprocess(image: Image.Image, image_size: int = 128): tf = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), ]) return tf(image.convert("RGB")).unsqueeze(0) def run_inference(model, image_tensor, device, output_is_prob: bool, threshold=0.5): with torch.no_grad(): out = model(image_tensor.to(device)) probs = out if output_is_prob else torch.sigmoid(out) probs = probs.squeeze().cpu().numpy() if probs.ndim != 2: probs = probs.reshape(probs.shape[-2], probs.shape[-1]) mask = (probs > threshold).astype(np.float32) return probs, mask def overlay(rgb: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.45): out = rgb.copy() m = mask.astype(bool) out[m] = (alpha * np.array(color) + (1 - alpha) * out[m]).astype(np.uint8) return out def heatmap_rgb(probs: np.ndarray) -> np.ndarray: p = np.clip(probs, 0.0, 1.0) rgb = np.zeros((p.shape[0], p.shape[1], 3), dtype=np.uint8) rgb[..., 0] = (p * 255).astype(np.uint8) rgb[..., 1] = (np.maximum(0, 1 - 2 * np.abs(p - 0.5)) * 255).astype(np.uint8) rgb[..., 2] = ((1 - p) * 255).astype(np.uint8) return rgb # ── UI ───────────────────────────────────────────────────────────────────── st.title("Solar Panel Segmentation — Unified Dashboard") st.caption( "Two studies in one app. Tabs 1–3 are the 4-model × 3-share baseline; " "tabs 4–6 are the U-Net & SegFormer-B0 fine-tune grid search." ) if st.button("🔄 Reload from disk"): st.cache_data.clear() st.rerun() baseline_logs = load_all_baseline_logs() grid_df = load_grid_results() t1, t2, t3, t4, t5, t6 = st.tabs([ "1 · Baseline curves", "2 · Baseline scaling", "3 · Baseline inference", "4 · Grid heatmap", "5 · Top fine-tune configs", "6 · Per-config curves", ]) # ─── Tab 1: Baseline learning curves ────────────────────────────────────── with t1: st.subheader("Per-epoch metrics for the 4 baseline architectures") if not baseline_logs: st.info("No baseline logs found at experiments/clean_data_scaling_study/logs/.") else: c1, c2 = st.columns(2) with c1: metric_key, metric_label = st.selectbox( "Metric", METRICS, format_func=lambda x: x[1], key="t1_metric", ) with c2: split = st.radio("Split", ["val", "train", "both"], horizontal=True, index=0, key="t1_split") for model in BASELINE_MODELS: available = [s for s in SHARES if (model, s) in baseline_logs] if not available: continue st.markdown(f"#### {PRETTY_NAME[model]}") fig = go.Figure() for share in available: df = baseline_log_to_df(baseline_logs[(model, share)]) if split in ("val", "both"): col = f"val_{metric_key}" if col in df.columns and df[col].notna().any(): sub = df.dropna(subset=[col]) fig.add_trace(go.Scatter( x=sub["epoch"], y=sub[col], mode="lines", name=f"{share}% val", )) if split in ("train", "both"): col = f"train_{metric_key}" if col in df.columns and df[col].notna().any(): sub = df.dropna(subset=[col]) fig.add_trace(go.Scatter( x=sub["epoch"], y=sub[col], mode="lines", line=dict(dash="dot"), name=f"{share}% train", )) fig.update_layout( xaxis_title="Epoch", yaxis_title=metric_label, height=340, margin=dict(l=10, r=10, t=10, b=10), legend=dict(orientation="h", y=-0.2), ) st.plotly_chart(fig, use_container_width=True) # ─── Tab 2: Baseline data-scaling charts ────────────────────────────────── with t2: st.subheader("Val metrics vs training-data share") if not baseline_logs: st.info("No baseline logs found.") else: kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, index=0, key="t2_kind") rows = [baseline_scaling_row(log, kind=kind) for log in baseline_logs.values()] df = pd.DataFrame(rows).sort_values(["model", "share"]).reset_index(drop=True) display_df = df.drop(columns=["wall_clock_seconds", "sec_per_epoch"], errors="ignore") st.dataframe(display_df, use_container_width=True, hide_index=True) total_seconds = df["wall_clock_seconds"].dropna().sum() if total_seconds: st.caption( f"⏱ Total baseline wall-clock: **{fmt_hms(total_seconds)}** " f"({total_seconds:,.0f} s)" ) c1, c2 = st.columns(2) with c1: fig = px.line( df.dropna(subset=["val_miou"]), x="share", y="val_miou", color="model", markers=True, title=f"Val mIoU ({kind})", labels={"share": "Training data (%)", "val_miou": "Val mIoU"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) with c2: fig = px.line( df.dropna(subset=["val_dice"]), x="share", y="val_dice", color="model", markers=True, title=f"Val Dice ({kind})", labels={"share": "Training data (%)", "val_dice": "Val Dice"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) c3, c4 = st.columns(2) with c3: fig = px.bar( df.dropna(subset=["val_iou"]), x="share", y="val_iou", color="model", barmode="group", title=f"Val foreground IoU ({kind})", labels={"share": "Training data (%)", "val_iou": "Val IoU (foreground)"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) with c4: fig = px.bar( df.dropna(subset=["val_pixel_acc"]), x="share", y="val_pixel_acc", color="model", barmode="group", title=f"Val pixel accuracy ({kind})", labels={"share": "Training data (%)", "val_pixel_acc": "Val pixel acc"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) st.markdown("##### Training time") time_df = df.dropna(subset=["wall_clock_seconds"]).assign( wall_minutes=lambda d: d["wall_clock_seconds"] / 60.0 ) if not time_df.empty: tcol1, tcol2 = st.columns(2) with tcol1: fig = px.bar( time_df, x="share", y="wall_minutes", color="model", barmode="group", title="Total training time (minutes)", labels={"share": "Training data (%)", "wall_minutes": "Wall clock (min)"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) with tcol2: fig = px.bar( time_df.dropna(subset=["sec_per_epoch"]), x="share", y="sec_per_epoch", color="model", barmode="group", title="Average seconds per epoch", labels={"share": "Training data (%)", "sec_per_epoch": "Seconds / epoch"}, ) fig.update_xaxes(tickvals=SHARES) st.plotly_chart(fig, use_container_width=True) else: st.caption("No timing data available yet.") # ─── Tab 3: Baseline inference grid ─────────────────────────────────────── with t3: st.subheader("Upload an image — 4 models × 3 shares = 12 segmentations") st.caption( "Each cell uses one (model, share) baseline checkpoint. " "Cells with no checkpoint locally show a 'missing' warning." ) a, b, c, d = st.columns([2, 2, 2, 2]) with a: kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, key="t3_kind") with b: threshold = st.slider("Threshold", 0.0, 1.0, 0.5, 0.05, key="t3_thr") with c: view = st.radio("View", ["mask", "overlay", "heatmap"], horizontal=True, key="t3_view") with d: cell_w = st.select_slider( "Cell size (px)", options=[140, 180, 220, 260], value=180, key="t3_cell" ) uploaded = st.file_uploader( "Drop an image (jpg/png)", type=["jpg", "jpeg", "png"], key="t3_upload" ) if uploaded is not None: device = "cuda" if torch.cuda.is_available() else "cpu" try: raw_bytes = uploaded.getvalue() img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") except Exception as e: st.error(f"Could not decode uploaded image: {e}") st.stop() st.caption(f"📁 `{uploaded.name}` — {img.size[0]}×{img.size[1]} px, {len(raw_bytes)/1024:.1f} KB") x = preprocess(img, image_size=128) rgb_small = (x.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8) st.image([img, rgb_small], width=cell_w * 2, caption=["original", "128×128 (model input)"]) st.markdown("##### Predictions (rows = model, columns = data share)") def render_cell(probs, mask, rgb): if view == "mask": return (mask * 255).astype(np.uint8) if view == "overlay": return overlay(rgb, mask) return heatmap_rgb(probs) for model_name in BASELINE_MODELS: cols = st.columns(len(SHARES)) for col, share in zip(cols, SHARES): with col: st.markdown(f"**{PRETTY_NAME[model_name]} · {share}%**") try: with st.spinner(f"loading {model_name} {share}%…"): m, output_is_prob = load_baseline_ckpt(model_name, share, kind, device) if m is None: st.warning(f"missing `{model_name}_{share}_{kind}.pth`") continue probs, mask = run_inference(m, x, device, output_is_prob, threshold) cell_img = render_cell(probs, mask, rgb_small) st.image(cell_img, width=cell_w) st.caption( f"cov={float(mask.mean())*100:.1f}% " f"p[{probs.min():.2f},{probs.max():.2f}]" ) except Exception as e: st.error(f"{model_name} {share}% — {type(e).__name__}") st.exception(e) else: st.info("Upload an image to run inference across all 12 baseline checkpoints.") # ─── Tab 4: Fine-tune grid heatmap ──────────────────────────────────────── with t4: st.subheader("Δ Dice across the 54 fine-tune configurations") if grid_df.empty: st.info("No grid results found at experiments/finetune_grid_search/results/grid_results.csv.") else: df = grid_df.copy() df["model_pretty"] = df["model"].map(PRETTY_NAME) df["lr_label"] = df["lr"].apply(lambda x: f"{x:.0e}") n_runs = len(df) n_unet = (df["model"] == "unet").sum() n_seg = (df["model"] == "segformer_b0").sum() total_seconds = float(df["wall_clock_seconds"].sum()) st.markdown( f"**{n_runs} runs** " f"(U-Net: {n_unet}, SegFormer-B0: {n_seg}); " f"total compute: **{fmt_hms(total_seconds)}**" ) metric = st.radio( "Metric", [("delta_dice", "Δ Dice (vs. baseline)"), ("best_val_dice", "Best val Dice (absolute)"), ("best_val_miou", "Best val mIoU (absolute)"), ("best_val_iou", "Best val IoU (absolute)")], format_func=lambda x: x[1], horizontal=True, key="t4_metric", )[0] color_scale = "RdYlGn" if metric == "delta_dice" else "Viridis" zmid = 0 if metric == "delta_dice" else None for model in GRID_MODELS: sub = df[df["model"] == model].copy() if sub.empty: continue st.markdown(f"#### {PRETTY_NAME[model]}") sub["row"] = sub.apply( lambda r: f"lr={r['lr_label']} bce={r['bce_weight']:.1f}", axis=1 ) pivot = sub.pivot_table( index="row", columns="augment", values=metric, aggfunc="first" ).reindex(columns=GRID_AUGS) ordered_rows = [ f"lr={lr:.0e} bce={bw:.1f}" for lr in GRID_LRS for bw in GRID_BCES ] pivot = pivot.reindex(ordered_rows) z = pivot.values text = np.where(np.isnan(z), "", np.vectorize(lambda v: f"{v:.4f}")(np.nan_to_num(z, nan=0.0))) fig = go.Figure(data=go.Heatmap( z=z, x=GRID_AUGS, y=ordered_rows, colorscale=color_scale, zmid=zmid, text=text, texttemplate="%{text}", textfont=dict(size=12), colorbar=dict(title=metric.replace("_", " ")), )) fig.update_layout( height=360, margin=dict(l=10, r=10, t=10, b=10), xaxis_title="Augmentation", yaxis_title="(learning rate, BCE weight)", ) st.plotly_chart(fig, use_container_width=True) i = sub[metric].idxmax() best = sub.loc[i] st.success( f"Best for {PRETTY_NAME[model]} on **{metric}**: " f"`{best['cfg_id']}` (lr={best['lr_label']}, " f"bce={best['bce_weight']:.1f}, aug={best['augment']}) " f"→ {metric}={best[metric]:.4f} " f"(baseline Dice={best['baseline_val_dice']:.4f})" ) # ─── Tab 5: Top fine-tune configs ───────────────────────────────────────── with t5: st.subheader("All 54 fine-tune configurations") if grid_df.empty: st.info("No grid results yet.") else: df = grid_df.copy() df["model_pretty"] = df["model"].map(PRETTY_NAME) df["lr_label"] = df["lr"].apply(lambda x: f"{x:.0e}") sort_by = st.selectbox( "Sort by", ["delta_dice", "best_val_dice", "best_val_miou", "best_val_iou", "wall_clock_seconds"], index=0, key="t5_sort", ) ascending = st.toggle("ascending", value=False, key="t5_asc") only_improved = st.checkbox( "only configs that improved over baseline", value=False, key="t5_imp" ) show = df.copy() if only_improved: show = show[show["delta_dice"] > 0] show = show.sort_values(sort_by, ascending=ascending) cols_to_show = [ "cfg_id", "model_pretty", "lr_label", "bce_weight", "augment", "best_epoch", "epochs_trained", "early_stopped", "best_val_dice", "best_val_miou", "best_val_iou", "best_val_pixel_acc", "baseline_val_dice", "delta_dice", "wall_clock_seconds", ] cols_to_show = [c for c in cols_to_show if c in show.columns] st.dataframe(show[cols_to_show], use_container_width=True, hide_index=True) # ─── Tab 6: Per-config fine-tune curves ─────────────────────────────────── with t6: st.subheader("Per-config learning curves") if grid_df.empty: st.info("No grid results yet.") else: cfg_options = sorted(grid_df["cfg_id"].unique().tolist()) chosen = st.multiselect( "Pick configs to overlay", cfg_options, default=cfg_options[:3] if cfg_options else [], key="t6_pick", ) metric_key, metric_label = st.selectbox( "Metric", [("dice", "Dice"), ("miou", "mIoU"), ("iou", "Foreground IoU"), ("pixel_acc", "Pixel accuracy"), ("loss", "Loss")], format_func=lambda x: x[1], key="t6_metric", ) if not chosen: st.info("Select at least one config.") else: fig = go.Figure() for cfg_id in chosen: log = load_grid_log(cfg_id) if log is None: continue xs = [e["epoch"] for e in log["epochs"]] ys_val = [e[f"val_{metric_key}"] for e in log["epochs"]] fig.add_trace(go.Scatter( x=xs, y=ys_val, mode="lines+markers", name=f"{cfg_id} val", )) ys_train = [e[f"train_{metric_key}"] for e in log["epochs"]] fig.add_trace(go.Scatter( x=xs, y=ys_train, mode="lines", line=dict(dash="dot"), name=f"{cfg_id} train", )) for cfg_id in chosen: log = load_grid_log(cfg_id) if log and "baseline_val_dice" in log and metric_key == "dice": fig.add_hline( y=log["baseline_val_dice"], line_dash="dash", annotation_text=f"{cfg_id} baseline ({log['baseline_val_dice']:.4f})", annotation_position="bottom right", ) fig.update_layout( xaxis_title="Epoch", yaxis_title=metric_label, height=480, margin=dict(l=10, r=10, t=10, b=10), legend=dict(orientation="h", y=-0.2), ) st.plotly_chart(fig, use_container_width=True)