Spaces:
Running
Running
| """ | |
| Unified single-page Streamlit dashboard. | |
| Combines the two studies into one app with six top-level tabs: | |
| 1. Baseline learning curves (4 architectures Γ 3 shares) | |
| 2. Baseline data-scaling charts | |
| 3. Baseline inference grid (uploads β 4 Γ 3 mask grid) | |
| 4. Fine-tune grid heatmap (54 configs) | |
| 5. Fine-tune top configs table | |
| 6. Fine-tune per-config curves | |
| Run locally: | |
| streamlit run app.py | |
| Deploy on Streamlit Cloud: Main file path = app.py | |
| """ | |
| import io | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| ROOT = Path(__file__).resolve().parent | |
| BASELINE_DIR = ROOT / "experiments" / "clean_data_scaling_study" | |
| GRID_DIR = ROOT / "experiments" / "finetune_grid_search" | |
| BASELINE_LOGS = BASELINE_DIR / "logs" | |
| BASELINE_CKPT = BASELINE_DIR / "checkpoints" | |
| GRID_LOGS = GRID_DIR / "logs" | |
| GRID_CSV = GRID_DIR / "results" / "grid_results.csv" | |
| # HF Model repo holding the .pth files. Set via env var on Spaces; falls back | |
| # to the default below for local runs that haven't touched the env. If the | |
| # repo doesn't exist or the file isn't there, the app shows a "missing" warning. | |
| HF_WEIGHTS_REPO = os.environ.get("HF_WEIGHTS_REPO", "phiniqs/seg-models-weights") | |
| def _resolve_ckpt(local_path: Path, hf_filename: str) -> Path | None: | |
| """Return a local path to the checkpoint. Try disk first, then HF Hub.""" | |
| if local_path.is_file(): | |
| return local_path | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| downloaded = hf_hub_download( | |
| repo_id=HF_WEIGHTS_REPO, | |
| filename=hf_filename, | |
| ) | |
| return Path(downloaded) | |
| except Exception: | |
| return None | |
| # Make the baseline experiment importable so we can load its model registry. | |
| sys.path.insert(0, str(BASELINE_DIR)) | |
| from models import MODEL_REGISTRY, PRETTY_NAME # noqa: E402 | |
| BASELINE_MODELS = ["segnet", "unet", "segformer_b0", "segformer_b5"] | |
| SHARES = [25, 50, 100] | |
| GRID_MODELS = ["unet", "segformer_b0"] | |
| GRID_LRS = [1e-5, 5e-5, 1e-4] | |
| GRID_BCES = [0.3, 0.5, 0.7] | |
| GRID_AUGS = ["none", "default", "strong"] | |
| METRICS = [ | |
| ("dice", "Dice"), | |
| ("miou", "mIoU"), | |
| ("iou", "Foreground IoU"), | |
| ("pixel_acc", "Pixel Accuracy"), | |
| ("loss", "Loss"), | |
| ] | |
| st.set_page_config(page_title="Solar Panel Segmentation β Dashboards", layout="wide") | |
| # ββ Loaders ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _mtime(path: Path) -> float: | |
| return path.stat().st_mtime if path.is_file() else 0.0 | |
| def _read_baseline_log(model: str, share: int, _mt: float): | |
| p = BASELINE_LOGS / f"{model}_{share}.json" | |
| if not p.is_file(): | |
| return None | |
| with open(p) as f: | |
| return json.load(f) | |
| def load_baseline_log(model: str, share: int): | |
| return _read_baseline_log(model, share, _mtime(BASELINE_LOGS / f"{model}_{share}.json")) | |
| def load_all_baseline_logs(): | |
| out = {} | |
| for m in BASELINE_MODELS: | |
| for s in SHARES: | |
| log = load_baseline_log(m, s) | |
| if log is not None: | |
| out[(m, s)] = log | |
| return out | |
| def _read_grid_csv(_mt: float): | |
| if not GRID_CSV.is_file(): | |
| return pd.DataFrame() | |
| return pd.read_csv(GRID_CSV) | |
| def load_grid_results(): | |
| return _read_grid_csv(_mtime(GRID_CSV)) | |
| def _read_grid_log(cfg_id: str, _mt: float): | |
| p = GRID_LOGS / f"{cfg_id}.json" | |
| if not p.is_file(): | |
| return None | |
| with open(p) as f: | |
| return json.load(f) | |
| def load_grid_log(cfg_id: str): | |
| return _read_grid_log(cfg_id, _mtime(GRID_LOGS / f"{cfg_id}.json")) | |
| # ββ Shared helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fmt_hms(seconds): | |
| if seconds is None: | |
| return "β" | |
| seconds = int(round(seconds)) | |
| h, rem = divmod(seconds, 3600) | |
| m, s = divmod(rem, 60) | |
| return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" | |
| def baseline_log_to_df(log): | |
| df = pd.DataFrame(log["epochs"]) | |
| df["model"] = log["model"] | |
| df["share"] = log["share"] | |
| return df | |
| def baseline_scaling_row(log, kind="best"): | |
| epochs = log["epochs"] | |
| row = {"model": PRETTY_NAME[log["model"]], "share": log["share"]} | |
| if not epochs: | |
| for k in ("val_dice", "val_miou", "val_iou", "val_pixel_acc"): | |
| row[k] = None | |
| return row | |
| if kind == "best": | |
| idx = max(range(len(epochs)), key=lambda i: epochs[i].get("val_dice", -1) or -1) | |
| else: | |
| idx = len(epochs) - 1 | |
| chosen = epochs[idx] | |
| row["epoch"] = chosen["epoch"] | |
| row["val_dice"] = chosen.get("val_dice") | |
| row["val_miou"] = chosen.get("val_miou") | |
| row["val_iou"] = chosen.get("val_iou") | |
| row["val_pixel_acc"] = chosen.get("val_pixel_acc") | |
| row["wall_clock_seconds"] = log.get("wall_clock_seconds") | |
| row["wall_clock"] = fmt_hms(log.get("wall_clock_seconds")) | |
| if epochs: | |
| per = [e.get("epoch_seconds") for e in epochs if e.get("epoch_seconds") is not None] | |
| row["sec_per_epoch"] = (sum(per) / len(per)) if per else None | |
| return row | |
| # ββ Inference helpers (used by tab 3) ββββββββββββββββββββββββββββββββββββββ | |
| def load_baseline_ckpt(model_name: str, share: int, kind: str, device: str): | |
| fname = f"{model_name}_{share}_{kind}.pth" | |
| p = _resolve_ckpt(BASELINE_CKPT / fname, f"baseline/{fname}") | |
| if p is None: | |
| return None, False | |
| builder = MODEL_REGISTRY[model_name] | |
| model, _, output_is_prob = builder() | |
| state = torch.load(p, map_location=device, weights_only=False) | |
| model.load_state_dict(state["model_state_dict"]) | |
| model.to(device).eval() | |
| output_is_prob = state.get("output_is_prob", output_is_prob) | |
| return model, output_is_prob | |
| def preprocess(image: Image.Image, image_size: int = 128): | |
| tf = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| return tf(image.convert("RGB")).unsqueeze(0) | |
| def run_inference(model, image_tensor, device, output_is_prob: bool, threshold=0.5): | |
| with torch.no_grad(): | |
| out = model(image_tensor.to(device)) | |
| probs = out if output_is_prob else torch.sigmoid(out) | |
| probs = probs.squeeze().cpu().numpy() | |
| if probs.ndim != 2: | |
| probs = probs.reshape(probs.shape[-2], probs.shape[-1]) | |
| mask = (probs > threshold).astype(np.float32) | |
| return probs, mask | |
| def overlay(rgb: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.45): | |
| out = rgb.copy() | |
| m = mask.astype(bool) | |
| out[m] = (alpha * np.array(color) + (1 - alpha) * out[m]).astype(np.uint8) | |
| return out | |
| def heatmap_rgb(probs: np.ndarray) -> np.ndarray: | |
| p = np.clip(probs, 0.0, 1.0) | |
| rgb = np.zeros((p.shape[0], p.shape[1], 3), dtype=np.uint8) | |
| rgb[..., 0] = (p * 255).astype(np.uint8) | |
| rgb[..., 1] = (np.maximum(0, 1 - 2 * np.abs(p - 0.5)) * 255).astype(np.uint8) | |
| rgb[..., 2] = ((1 - p) * 255).astype(np.uint8) | |
| return rgb | |
| # ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title("Solar Panel Segmentation β Unified Dashboard") | |
| st.caption( | |
| "Two studies in one app. Tabs 1β3 are the 4-model Γ 3-share baseline; " | |
| "tabs 4β6 are the U-Net & SegFormer-B0 fine-tune grid search." | |
| ) | |
| if st.button("π Reload from disk"): | |
| st.cache_data.clear() | |
| st.rerun() | |
| baseline_logs = load_all_baseline_logs() | |
| grid_df = load_grid_results() | |
| t1, t2, t3, t4, t5, t6 = st.tabs([ | |
| "1 Β· Baseline curves", | |
| "2 Β· Baseline scaling", | |
| "3 Β· Baseline inference", | |
| "4 Β· Grid heatmap", | |
| "5 Β· Top fine-tune configs", | |
| "6 Β· Per-config curves", | |
| ]) | |
| # βββ Tab 1: Baseline learning curves ββββββββββββββββββββββββββββββββββββββ | |
| with t1: | |
| st.subheader("Per-epoch metrics for the 4 baseline architectures") | |
| if not baseline_logs: | |
| st.info("No baseline logs found at experiments/clean_data_scaling_study/logs/.") | |
| else: | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| metric_key, metric_label = st.selectbox( | |
| "Metric", METRICS, format_func=lambda x: x[1], key="t1_metric", | |
| ) | |
| with c2: | |
| split = st.radio("Split", ["val", "train", "both"], horizontal=True, index=0, key="t1_split") | |
| for model in BASELINE_MODELS: | |
| available = [s for s in SHARES if (model, s) in baseline_logs] | |
| if not available: | |
| continue | |
| st.markdown(f"#### {PRETTY_NAME[model]}") | |
| fig = go.Figure() | |
| for share in available: | |
| df = baseline_log_to_df(baseline_logs[(model, share)]) | |
| if split in ("val", "both"): | |
| col = f"val_{metric_key}" | |
| if col in df.columns and df[col].notna().any(): | |
| sub = df.dropna(subset=[col]) | |
| fig.add_trace(go.Scatter( | |
| x=sub["epoch"], y=sub[col], mode="lines", | |
| name=f"{share}% val", | |
| )) | |
| if split in ("train", "both"): | |
| col = f"train_{metric_key}" | |
| if col in df.columns and df[col].notna().any(): | |
| sub = df.dropna(subset=[col]) | |
| fig.add_trace(go.Scatter( | |
| x=sub["epoch"], y=sub[col], mode="lines", | |
| line=dict(dash="dot"), name=f"{share}% train", | |
| )) | |
| fig.update_layout( | |
| xaxis_title="Epoch", yaxis_title=metric_label, height=340, | |
| margin=dict(l=10, r=10, t=10, b=10), | |
| legend=dict(orientation="h", y=-0.2), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # βββ Tab 2: Baseline data-scaling charts ββββββββββββββββββββββββββββββββββ | |
| with t2: | |
| st.subheader("Val metrics vs training-data share") | |
| if not baseline_logs: | |
| st.info("No baseline logs found.") | |
| else: | |
| kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, index=0, key="t2_kind") | |
| rows = [baseline_scaling_row(log, kind=kind) for log in baseline_logs.values()] | |
| df = pd.DataFrame(rows).sort_values(["model", "share"]).reset_index(drop=True) | |
| display_df = df.drop(columns=["wall_clock_seconds", "sec_per_epoch"], errors="ignore") | |
| st.dataframe(display_df, use_container_width=True, hide_index=True) | |
| total_seconds = df["wall_clock_seconds"].dropna().sum() | |
| if total_seconds: | |
| st.caption( | |
| f"β± Total baseline wall-clock: **{fmt_hms(total_seconds)}** " | |
| f"({total_seconds:,.0f} s)" | |
| ) | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| fig = px.line( | |
| df.dropna(subset=["val_miou"]), | |
| x="share", y="val_miou", color="model", markers=True, | |
| title=f"Val mIoU ({kind})", | |
| labels={"share": "Training data (%)", "val_miou": "Val mIoU"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with c2: | |
| fig = px.line( | |
| df.dropna(subset=["val_dice"]), | |
| x="share", y="val_dice", color="model", markers=True, | |
| title=f"Val Dice ({kind})", | |
| labels={"share": "Training data (%)", "val_dice": "Val Dice"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| c3, c4 = st.columns(2) | |
| with c3: | |
| fig = px.bar( | |
| df.dropna(subset=["val_iou"]), | |
| x="share", y="val_iou", color="model", barmode="group", | |
| title=f"Val foreground IoU ({kind})", | |
| labels={"share": "Training data (%)", "val_iou": "Val IoU (foreground)"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with c4: | |
| fig = px.bar( | |
| df.dropna(subset=["val_pixel_acc"]), | |
| x="share", y="val_pixel_acc", color="model", barmode="group", | |
| title=f"Val pixel accuracy ({kind})", | |
| labels={"share": "Training data (%)", "val_pixel_acc": "Val pixel acc"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.markdown("##### Training time") | |
| time_df = df.dropna(subset=["wall_clock_seconds"]).assign( | |
| wall_minutes=lambda d: d["wall_clock_seconds"] / 60.0 | |
| ) | |
| if not time_df.empty: | |
| tcol1, tcol2 = st.columns(2) | |
| with tcol1: | |
| fig = px.bar( | |
| time_df, x="share", y="wall_minutes", | |
| color="model", barmode="group", | |
| title="Total training time (minutes)", | |
| labels={"share": "Training data (%)", | |
| "wall_minutes": "Wall clock (min)"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with tcol2: | |
| fig = px.bar( | |
| time_df.dropna(subset=["sec_per_epoch"]), | |
| x="share", y="sec_per_epoch", | |
| color="model", barmode="group", | |
| title="Average seconds per epoch", | |
| labels={"share": "Training data (%)", | |
| "sec_per_epoch": "Seconds / epoch"}, | |
| ) | |
| fig.update_xaxes(tickvals=SHARES) | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.caption("No timing data available yet.") | |
| # βββ Tab 3: Baseline inference grid βββββββββββββββββββββββββββββββββββββββ | |
| with t3: | |
| st.subheader("Upload an image β 4 models Γ 3 shares = 12 segmentations") | |
| st.caption( | |
| "Each cell uses one (model, share) baseline checkpoint. " | |
| "Cells with no checkpoint locally show a 'missing' warning." | |
| ) | |
| a, b, c, d = st.columns([2, 2, 2, 2]) | |
| with a: | |
| kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, key="t3_kind") | |
| with b: | |
| threshold = st.slider("Threshold", 0.0, 1.0, 0.5, 0.05, key="t3_thr") | |
| with c: | |
| view = st.radio("View", ["mask", "overlay", "heatmap"], horizontal=True, key="t3_view") | |
| with d: | |
| cell_w = st.select_slider( | |
| "Cell size (px)", options=[140, 180, 220, 260], value=180, key="t3_cell" | |
| ) | |
| uploaded = st.file_uploader( | |
| "Drop an image (jpg/png)", type=["jpg", "jpeg", "png"], key="t3_upload" | |
| ) | |
| if uploaded is not None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| raw_bytes = uploaded.getvalue() | |
| img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") | |
| except Exception as e: | |
| st.error(f"Could not decode uploaded image: {e}") | |
| st.stop() | |
| st.caption(f"π `{uploaded.name}` β {img.size[0]}Γ{img.size[1]} px, {len(raw_bytes)/1024:.1f} KB") | |
| x = preprocess(img, image_size=128) | |
| rgb_small = (x.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| st.image([img, rgb_small], width=cell_w * 2, caption=["original", "128Γ128 (model input)"]) | |
| st.markdown("##### Predictions (rows = model, columns = data share)") | |
| def render_cell(probs, mask, rgb): | |
| if view == "mask": | |
| return (mask * 255).astype(np.uint8) | |
| if view == "overlay": | |
| return overlay(rgb, mask) | |
| return heatmap_rgb(probs) | |
| for model_name in BASELINE_MODELS: | |
| cols = st.columns(len(SHARES)) | |
| for col, share in zip(cols, SHARES): | |
| with col: | |
| st.markdown(f"**{PRETTY_NAME[model_name]} Β· {share}%**") | |
| try: | |
| with st.spinner(f"loading {model_name} {share}%β¦"): | |
| m, output_is_prob = load_baseline_ckpt(model_name, share, kind, device) | |
| if m is None: | |
| st.warning(f"missing `{model_name}_{share}_{kind}.pth`") | |
| continue | |
| probs, mask = run_inference(m, x, device, output_is_prob, threshold) | |
| cell_img = render_cell(probs, mask, rgb_small) | |
| st.image(cell_img, width=cell_w) | |
| st.caption( | |
| f"cov={float(mask.mean())*100:.1f}% " | |
| f"p[{probs.min():.2f},{probs.max():.2f}]" | |
| ) | |
| except Exception as e: | |
| st.error(f"{model_name} {share}% β {type(e).__name__}") | |
| st.exception(e) | |
| else: | |
| st.info("Upload an image to run inference across all 12 baseline checkpoints.") | |
| # βββ Tab 4: Fine-tune grid heatmap ββββββββββββββββββββββββββββββββββββββββ | |
| with t4: | |
| st.subheader("Ξ Dice across the 54 fine-tune configurations") | |
| if grid_df.empty: | |
| st.info("No grid results found at experiments/finetune_grid_search/results/grid_results.csv.") | |
| else: | |
| df = grid_df.copy() | |
| df["model_pretty"] = df["model"].map(PRETTY_NAME) | |
| df["lr_label"] = df["lr"].apply(lambda x: f"{x:.0e}") | |
| n_runs = len(df) | |
| n_unet = (df["model"] == "unet").sum() | |
| n_seg = (df["model"] == "segformer_b0").sum() | |
| total_seconds = float(df["wall_clock_seconds"].sum()) | |
| st.markdown( | |
| f"**{n_runs} runs** " | |
| f"(U-Net: {n_unet}, SegFormer-B0: {n_seg}); " | |
| f"total compute: **{fmt_hms(total_seconds)}**" | |
| ) | |
| metric = st.radio( | |
| "Metric", | |
| [("delta_dice", "Ξ Dice (vs. baseline)"), | |
| ("best_val_dice", "Best val Dice (absolute)"), | |
| ("best_val_miou", "Best val mIoU (absolute)"), | |
| ("best_val_iou", "Best val IoU (absolute)")], | |
| format_func=lambda x: x[1], horizontal=True, key="t4_metric", | |
| )[0] | |
| color_scale = "RdYlGn" if metric == "delta_dice" else "Viridis" | |
| zmid = 0 if metric == "delta_dice" else None | |
| for model in GRID_MODELS: | |
| sub = df[df["model"] == model].copy() | |
| if sub.empty: | |
| continue | |
| st.markdown(f"#### {PRETTY_NAME[model]}") | |
| sub["row"] = sub.apply( | |
| lambda r: f"lr={r['lr_label']} bce={r['bce_weight']:.1f}", axis=1 | |
| ) | |
| pivot = sub.pivot_table( | |
| index="row", columns="augment", values=metric, aggfunc="first" | |
| ).reindex(columns=GRID_AUGS) | |
| ordered_rows = [ | |
| f"lr={lr:.0e} bce={bw:.1f}" for lr in GRID_LRS for bw in GRID_BCES | |
| ] | |
| pivot = pivot.reindex(ordered_rows) | |
| z = pivot.values | |
| text = np.where(np.isnan(z), "", | |
| np.vectorize(lambda v: f"{v:.4f}")(np.nan_to_num(z, nan=0.0))) | |
| fig = go.Figure(data=go.Heatmap( | |
| z=z, x=GRID_AUGS, y=ordered_rows, | |
| colorscale=color_scale, zmid=zmid, | |
| text=text, texttemplate="%{text}", textfont=dict(size=12), | |
| colorbar=dict(title=metric.replace("_", " ")), | |
| )) | |
| fig.update_layout( | |
| height=360, margin=dict(l=10, r=10, t=10, b=10), | |
| xaxis_title="Augmentation", yaxis_title="(learning rate, BCE weight)", | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| i = sub[metric].idxmax() | |
| best = sub.loc[i] | |
| st.success( | |
| f"Best for {PRETTY_NAME[model]} on **{metric}**: " | |
| f"`{best['cfg_id']}` (lr={best['lr_label']}, " | |
| f"bce={best['bce_weight']:.1f}, aug={best['augment']}) " | |
| f"β {metric}={best[metric]:.4f} " | |
| f"(baseline Dice={best['baseline_val_dice']:.4f})" | |
| ) | |
| # βββ Tab 5: Top fine-tune configs βββββββββββββββββββββββββββββββββββββββββ | |
| with t5: | |
| st.subheader("All 54 fine-tune configurations") | |
| if grid_df.empty: | |
| st.info("No grid results yet.") | |
| else: | |
| df = grid_df.copy() | |
| df["model_pretty"] = df["model"].map(PRETTY_NAME) | |
| df["lr_label"] = df["lr"].apply(lambda x: f"{x:.0e}") | |
| sort_by = st.selectbox( | |
| "Sort by", | |
| ["delta_dice", "best_val_dice", "best_val_miou", "best_val_iou", "wall_clock_seconds"], | |
| index=0, key="t5_sort", | |
| ) | |
| ascending = st.toggle("ascending", value=False, key="t5_asc") | |
| only_improved = st.checkbox( | |
| "only configs that improved over baseline", value=False, key="t5_imp" | |
| ) | |
| show = df.copy() | |
| if only_improved: | |
| show = show[show["delta_dice"] > 0] | |
| show = show.sort_values(sort_by, ascending=ascending) | |
| cols_to_show = [ | |
| "cfg_id", "model_pretty", "lr_label", "bce_weight", "augment", | |
| "best_epoch", "epochs_trained", "early_stopped", | |
| "best_val_dice", "best_val_miou", "best_val_iou", "best_val_pixel_acc", | |
| "baseline_val_dice", "delta_dice", "wall_clock_seconds", | |
| ] | |
| cols_to_show = [c for c in cols_to_show if c in show.columns] | |
| st.dataframe(show[cols_to_show], use_container_width=True, hide_index=True) | |
| # βββ Tab 6: Per-config fine-tune curves βββββββββββββββββββββββββββββββββββ | |
| with t6: | |
| st.subheader("Per-config learning curves") | |
| if grid_df.empty: | |
| st.info("No grid results yet.") | |
| else: | |
| cfg_options = sorted(grid_df["cfg_id"].unique().tolist()) | |
| chosen = st.multiselect( | |
| "Pick configs to overlay", | |
| cfg_options, | |
| default=cfg_options[:3] if cfg_options else [], | |
| key="t6_pick", | |
| ) | |
| metric_key, metric_label = st.selectbox( | |
| "Metric", | |
| [("dice", "Dice"), ("miou", "mIoU"), ("iou", "Foreground IoU"), | |
| ("pixel_acc", "Pixel accuracy"), ("loss", "Loss")], | |
| format_func=lambda x: x[1], key="t6_metric", | |
| ) | |
| if not chosen: | |
| st.info("Select at least one config.") | |
| else: | |
| fig = go.Figure() | |
| for cfg_id in chosen: | |
| log = load_grid_log(cfg_id) | |
| if log is None: | |
| continue | |
| xs = [e["epoch"] for e in log["epochs"]] | |
| ys_val = [e[f"val_{metric_key}"] for e in log["epochs"]] | |
| fig.add_trace(go.Scatter( | |
| x=xs, y=ys_val, mode="lines+markers", name=f"{cfg_id} val", | |
| )) | |
| ys_train = [e[f"train_{metric_key}"] for e in log["epochs"]] | |
| fig.add_trace(go.Scatter( | |
| x=xs, y=ys_train, mode="lines", line=dict(dash="dot"), | |
| name=f"{cfg_id} train", | |
| )) | |
| for cfg_id in chosen: | |
| log = load_grid_log(cfg_id) | |
| if log and "baseline_val_dice" in log and metric_key == "dice": | |
| fig.add_hline( | |
| y=log["baseline_val_dice"], line_dash="dash", | |
| annotation_text=f"{cfg_id} baseline ({log['baseline_val_dice']:.4f})", | |
| annotation_position="bottom right", | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Epoch", yaxis_title=metric_label, height=480, | |
| margin=dict(l=10, r=10, t=10, b=10), | |
| legend=dict(orientation="h", y=-0.2), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |