seg-models / app.py
Mohamed-ENNHIRI
Solar Panel Segmentation app for HF Spaces
52efd90
"""
Unified single-page Streamlit dashboard.
Combines the two studies into one app with six top-level tabs:
1. Baseline learning curves (4 architectures Γ— 3 shares)
2. Baseline data-scaling charts
3. Baseline inference grid (uploads β†’ 4 Γ— 3 mask grid)
4. Fine-tune grid heatmap (54 configs)
5. Fine-tune top configs table
6. Fine-tune per-config curves
Run locally:
streamlit run app.py
Deploy on Streamlit Cloud: Main file path = app.py
"""
import io
import json
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
ROOT = Path(__file__).resolve().parent
BASELINE_DIR = ROOT / "experiments" / "clean_data_scaling_study"
GRID_DIR = ROOT / "experiments" / "finetune_grid_search"
BASELINE_LOGS = BASELINE_DIR / "logs"
BASELINE_CKPT = BASELINE_DIR / "checkpoints"
GRID_LOGS = GRID_DIR / "logs"
GRID_CSV = GRID_DIR / "results" / "grid_results.csv"
# HF Model repo holding the .pth files. Set via env var on Spaces; falls back
# to the default below for local runs that haven't touched the env. If the
# repo doesn't exist or the file isn't there, the app shows a "missing" warning.
HF_WEIGHTS_REPO = os.environ.get("HF_WEIGHTS_REPO", "phiniqs/seg-models-weights")
def _resolve_ckpt(local_path: Path, hf_filename: str) -> Path | None:
"""Return a local path to the checkpoint. Try disk first, then HF Hub."""
if local_path.is_file():
return local_path
try:
from huggingface_hub import hf_hub_download
downloaded = hf_hub_download(
repo_id=HF_WEIGHTS_REPO,
filename=hf_filename,
)
return Path(downloaded)
except Exception:
return None
# Make the baseline experiment importable so we can load its model registry.
sys.path.insert(0, str(BASELINE_DIR))
from models import MODEL_REGISTRY, PRETTY_NAME # noqa: E402
BASELINE_MODELS = ["segnet", "unet", "segformer_b0", "segformer_b5"]
SHARES = [25, 50, 100]
GRID_MODELS = ["unet", "segformer_b0"]
GRID_LRS = [1e-5, 5e-5, 1e-4]
GRID_BCES = [0.3, 0.5, 0.7]
GRID_AUGS = ["none", "default", "strong"]
METRICS = [
("dice", "Dice"),
("miou", "mIoU"),
("iou", "Foreground IoU"),
("pixel_acc", "Pixel Accuracy"),
("loss", "Loss"),
]
st.set_page_config(page_title="Solar Panel Segmentation β€” Dashboards", layout="wide")
# ── Loaders ────────────────────────────────────────────────────────────────
def _mtime(path: Path) -> float:
return path.stat().st_mtime if path.is_file() else 0.0
@st.cache_data(show_spinner=False)
def _read_baseline_log(model: str, share: int, _mt: float):
p = BASELINE_LOGS / f"{model}_{share}.json"
if not p.is_file():
return None
with open(p) as f:
return json.load(f)
def load_baseline_log(model: str, share: int):
return _read_baseline_log(model, share, _mtime(BASELINE_LOGS / f"{model}_{share}.json"))
@st.cache_data(show_spinner=False)
def load_all_baseline_logs():
out = {}
for m in BASELINE_MODELS:
for s in SHARES:
log = load_baseline_log(m, s)
if log is not None:
out[(m, s)] = log
return out
@st.cache_data(show_spinner=False)
def _read_grid_csv(_mt: float):
if not GRID_CSV.is_file():
return pd.DataFrame()
return pd.read_csv(GRID_CSV)
def load_grid_results():
return _read_grid_csv(_mtime(GRID_CSV))
@st.cache_data(show_spinner=False)
def _read_grid_log(cfg_id: str, _mt: float):
p = GRID_LOGS / f"{cfg_id}.json"
if not p.is_file():
return None
with open(p) as f:
return json.load(f)
def load_grid_log(cfg_id: str):
return _read_grid_log(cfg_id, _mtime(GRID_LOGS / f"{cfg_id}.json"))
# ── Shared helpers ─────────────────────────────────────────────────────────
def fmt_hms(seconds):
if seconds is None:
return "β€”"
seconds = int(round(seconds))
h, rem = divmod(seconds, 3600)
m, s = divmod(rem, 60)
return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}"
def baseline_log_to_df(log):
df = pd.DataFrame(log["epochs"])
df["model"] = log["model"]
df["share"] = log["share"]
return df
def baseline_scaling_row(log, kind="best"):
epochs = log["epochs"]
row = {"model": PRETTY_NAME[log["model"]], "share": log["share"]}
if not epochs:
for k in ("val_dice", "val_miou", "val_iou", "val_pixel_acc"):
row[k] = None
return row
if kind == "best":
idx = max(range(len(epochs)), key=lambda i: epochs[i].get("val_dice", -1) or -1)
else:
idx = len(epochs) - 1
chosen = epochs[idx]
row["epoch"] = chosen["epoch"]
row["val_dice"] = chosen.get("val_dice")
row["val_miou"] = chosen.get("val_miou")
row["val_iou"] = chosen.get("val_iou")
row["val_pixel_acc"] = chosen.get("val_pixel_acc")
row["wall_clock_seconds"] = log.get("wall_clock_seconds")
row["wall_clock"] = fmt_hms(log.get("wall_clock_seconds"))
if epochs:
per = [e.get("epoch_seconds") for e in epochs if e.get("epoch_seconds") is not None]
row["sec_per_epoch"] = (sum(per) / len(per)) if per else None
return row
# ── Inference helpers (used by tab 3) ──────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_baseline_ckpt(model_name: str, share: int, kind: str, device: str):
fname = f"{model_name}_{share}_{kind}.pth"
p = _resolve_ckpt(BASELINE_CKPT / fname, f"baseline/{fname}")
if p is None:
return None, False
builder = MODEL_REGISTRY[model_name]
model, _, output_is_prob = builder()
state = torch.load(p, map_location=device, weights_only=False)
model.load_state_dict(state["model_state_dict"])
model.to(device).eval()
output_is_prob = state.get("output_is_prob", output_is_prob)
return model, output_is_prob
def preprocess(image: Image.Image, image_size: int = 128):
tf = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
return tf(image.convert("RGB")).unsqueeze(0)
def run_inference(model, image_tensor, device, output_is_prob: bool, threshold=0.5):
with torch.no_grad():
out = model(image_tensor.to(device))
probs = out if output_is_prob else torch.sigmoid(out)
probs = probs.squeeze().cpu().numpy()
if probs.ndim != 2:
probs = probs.reshape(probs.shape[-2], probs.shape[-1])
mask = (probs > threshold).astype(np.float32)
return probs, mask
def overlay(rgb: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.45):
out = rgb.copy()
m = mask.astype(bool)
out[m] = (alpha * np.array(color) + (1 - alpha) * out[m]).astype(np.uint8)
return out
def heatmap_rgb(probs: np.ndarray) -> np.ndarray:
p = np.clip(probs, 0.0, 1.0)
rgb = np.zeros((p.shape[0], p.shape[1], 3), dtype=np.uint8)
rgb[..., 0] = (p * 255).astype(np.uint8)
rgb[..., 1] = (np.maximum(0, 1 - 2 * np.abs(p - 0.5)) * 255).astype(np.uint8)
rgb[..., 2] = ((1 - p) * 255).astype(np.uint8)
return rgb
# ── UI ─────────────────────────────────────────────────────────────────────
st.title("Solar Panel Segmentation β€” Unified Dashboard")
st.caption(
"Two studies in one app. Tabs 1–3 are the 4-model Γ— 3-share baseline; "
"tabs 4–6 are the U-Net & SegFormer-B0 fine-tune grid search."
)
if st.button("πŸ”„ Reload from disk"):
st.cache_data.clear()
st.rerun()
baseline_logs = load_all_baseline_logs()
grid_df = load_grid_results()
t1, t2, t3, t4, t5, t6 = st.tabs([
"1 Β· Baseline curves",
"2 Β· Baseline scaling",
"3 Β· Baseline inference",
"4 Β· Grid heatmap",
"5 Β· Top fine-tune configs",
"6 Β· Per-config curves",
])
# ─── Tab 1: Baseline learning curves ──────────────────────────────────────
with t1:
st.subheader("Per-epoch metrics for the 4 baseline architectures")
if not baseline_logs:
st.info("No baseline logs found at experiments/clean_data_scaling_study/logs/.")
else:
c1, c2 = st.columns(2)
with c1:
metric_key, metric_label = st.selectbox(
"Metric", METRICS, format_func=lambda x: x[1], key="t1_metric",
)
with c2:
split = st.radio("Split", ["val", "train", "both"], horizontal=True, index=0, key="t1_split")
for model in BASELINE_MODELS:
available = [s for s in SHARES if (model, s) in baseline_logs]
if not available:
continue
st.markdown(f"#### {PRETTY_NAME[model]}")
fig = go.Figure()
for share in available:
df = baseline_log_to_df(baseline_logs[(model, share)])
if split in ("val", "both"):
col = f"val_{metric_key}"
if col in df.columns and df[col].notna().any():
sub = df.dropna(subset=[col])
fig.add_trace(go.Scatter(
x=sub["epoch"], y=sub[col], mode="lines",
name=f"{share}% val",
))
if split in ("train", "both"):
col = f"train_{metric_key}"
if col in df.columns and df[col].notna().any():
sub = df.dropna(subset=[col])
fig.add_trace(go.Scatter(
x=sub["epoch"], y=sub[col], mode="lines",
line=dict(dash="dot"), name=f"{share}% train",
))
fig.update_layout(
xaxis_title="Epoch", yaxis_title=metric_label, height=340,
margin=dict(l=10, r=10, t=10, b=10),
legend=dict(orientation="h", y=-0.2),
)
st.plotly_chart(fig, use_container_width=True)
# ─── Tab 2: Baseline data-scaling charts ──────────────────────────────────
with t2:
st.subheader("Val metrics vs training-data share")
if not baseline_logs:
st.info("No baseline logs found.")
else:
kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, index=0, key="t2_kind")
rows = [baseline_scaling_row(log, kind=kind) for log in baseline_logs.values()]
df = pd.DataFrame(rows).sort_values(["model", "share"]).reset_index(drop=True)
display_df = df.drop(columns=["wall_clock_seconds", "sec_per_epoch"], errors="ignore")
st.dataframe(display_df, use_container_width=True, hide_index=True)
total_seconds = df["wall_clock_seconds"].dropna().sum()
if total_seconds:
st.caption(
f"⏱ Total baseline wall-clock: **{fmt_hms(total_seconds)}** "
f"({total_seconds:,.0f} s)"
)
c1, c2 = st.columns(2)
with c1:
fig = px.line(
df.dropna(subset=["val_miou"]),
x="share", y="val_miou", color="model", markers=True,
title=f"Val mIoU ({kind})",
labels={"share": "Training data (%)", "val_miou": "Val mIoU"},
)
fig.update_xaxes(tickvals=SHARES)
st.plotly_chart(fig, use_container_width=True)
with c2:
fig = px.line(
df.dropna(subset=["val_dice"]),
x="share", y="val_dice", color="model", markers=True,
title=f"Val Dice ({kind})",
labels={"share": "Training data (%)", "val_dice": "Val Dice"},
)
fig.update_xaxes(tickvals=SHARES)
st.plotly_chart(fig, use_container_width=True)
c3, c4 = st.columns(2)
with c3:
fig = px.bar(
df.dropna(subset=["val_iou"]),
x="share", y="val_iou", color="model", barmode="group",
title=f"Val foreground IoU ({kind})",
labels={"share": "Training data (%)", "val_iou": "Val IoU (foreground)"},
)
fig.update_xaxes(tickvals=SHARES)
st.plotly_chart(fig, use_container_width=True)
with c4:
fig = px.bar(
df.dropna(subset=["val_pixel_acc"]),
x="share", y="val_pixel_acc", color="model", barmode="group",
title=f"Val pixel accuracy ({kind})",
labels={"share": "Training data (%)", "val_pixel_acc": "Val pixel acc"},
)
fig.update_xaxes(tickvals=SHARES)
st.plotly_chart(fig, use_container_width=True)
st.markdown("##### Training time")
time_df = df.dropna(subset=["wall_clock_seconds"]).assign(
wall_minutes=lambda d: d["wall_clock_seconds"] / 60.0
)
if not time_df.empty:
tcol1, tcol2 = st.columns(2)
with tcol1:
fig = px.bar(
time_df, x="share", y="wall_minutes",
color="model", barmode="group",
title="Total training time (minutes)",
labels={"share": "Training data (%)",
"wall_minutes": "Wall clock (min)"},
)
fig.update_xaxes(tickvals=SHARES)
st.plotly_chart(fig, use_container_width=True)
with tcol2:
fig = px.bar(
time_df.dropna(subset=["sec_per_epoch"]),
x="share", y="sec_per_epoch",
color="model", barmode="group",
title="Average seconds per epoch",
labels={"share": "Training data (%)",
"sec_per_epoch": "Seconds / epoch"},
)
fig.update_xaxes(tickvals=SHARES)
st.plotly_chart(fig, use_container_width=True)
else:
st.caption("No timing data available yet.")
# ─── Tab 3: Baseline inference grid ───────────────────────────────────────
with t3:
st.subheader("Upload an image β€” 4 models Γ— 3 shares = 12 segmentations")
st.caption(
"Each cell uses one (model, share) baseline checkpoint. "
"Cells with no checkpoint locally show a 'missing' warning."
)
a, b, c, d = st.columns([2, 2, 2, 2])
with a:
kind = st.radio("Checkpoint", ["best", "final"], horizontal=True, key="t3_kind")
with b:
threshold = st.slider("Threshold", 0.0, 1.0, 0.5, 0.05, key="t3_thr")
with c:
view = st.radio("View", ["mask", "overlay", "heatmap"], horizontal=True, key="t3_view")
with d:
cell_w = st.select_slider(
"Cell size (px)", options=[140, 180, 220, 260], value=180, key="t3_cell"
)
uploaded = st.file_uploader(
"Drop an image (jpg/png)", type=["jpg", "jpeg", "png"], key="t3_upload"
)
if uploaded is not None:
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
raw_bytes = uploaded.getvalue()
img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
except Exception as e:
st.error(f"Could not decode uploaded image: {e}")
st.stop()
st.caption(f"πŸ“ `{uploaded.name}` β€” {img.size[0]}Γ—{img.size[1]} px, {len(raw_bytes)/1024:.1f} KB")
x = preprocess(img, image_size=128)
rgb_small = (x.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
st.image([img, rgb_small], width=cell_w * 2, caption=["original", "128Γ—128 (model input)"])
st.markdown("##### Predictions (rows = model, columns = data share)")
def render_cell(probs, mask, rgb):
if view == "mask":
return (mask * 255).astype(np.uint8)
if view == "overlay":
return overlay(rgb, mask)
return heatmap_rgb(probs)
for model_name in BASELINE_MODELS:
cols = st.columns(len(SHARES))
for col, share in zip(cols, SHARES):
with col:
st.markdown(f"**{PRETTY_NAME[model_name]} Β· {share}%**")
try:
with st.spinner(f"loading {model_name} {share}%…"):
m, output_is_prob = load_baseline_ckpt(model_name, share, kind, device)
if m is None:
st.warning(f"missing `{model_name}_{share}_{kind}.pth`")
continue
probs, mask = run_inference(m, x, device, output_is_prob, threshold)
cell_img = render_cell(probs, mask, rgb_small)
st.image(cell_img, width=cell_w)
st.caption(
f"cov={float(mask.mean())*100:.1f}% "
f"p[{probs.min():.2f},{probs.max():.2f}]"
)
except Exception as e:
st.error(f"{model_name} {share}% β€” {type(e).__name__}")
st.exception(e)
else:
st.info("Upload an image to run inference across all 12 baseline checkpoints.")
# ─── Tab 4: Fine-tune grid heatmap ────────────────────────────────────────
with t4:
st.subheader("Ξ” Dice across the 54 fine-tune configurations")
if grid_df.empty:
st.info("No grid results found at experiments/finetune_grid_search/results/grid_results.csv.")
else:
df = grid_df.copy()
df["model_pretty"] = df["model"].map(PRETTY_NAME)
df["lr_label"] = df["lr"].apply(lambda x: f"{x:.0e}")
n_runs = len(df)
n_unet = (df["model"] == "unet").sum()
n_seg = (df["model"] == "segformer_b0").sum()
total_seconds = float(df["wall_clock_seconds"].sum())
st.markdown(
f"**{n_runs} runs** "
f"(U-Net: {n_unet}, SegFormer-B0: {n_seg}); "
f"total compute: **{fmt_hms(total_seconds)}**"
)
metric = st.radio(
"Metric",
[("delta_dice", "Ξ” Dice (vs. baseline)"),
("best_val_dice", "Best val Dice (absolute)"),
("best_val_miou", "Best val mIoU (absolute)"),
("best_val_iou", "Best val IoU (absolute)")],
format_func=lambda x: x[1], horizontal=True, key="t4_metric",
)[0]
color_scale = "RdYlGn" if metric == "delta_dice" else "Viridis"
zmid = 0 if metric == "delta_dice" else None
for model in GRID_MODELS:
sub = df[df["model"] == model].copy()
if sub.empty:
continue
st.markdown(f"#### {PRETTY_NAME[model]}")
sub["row"] = sub.apply(
lambda r: f"lr={r['lr_label']} bce={r['bce_weight']:.1f}", axis=1
)
pivot = sub.pivot_table(
index="row", columns="augment", values=metric, aggfunc="first"
).reindex(columns=GRID_AUGS)
ordered_rows = [
f"lr={lr:.0e} bce={bw:.1f}" for lr in GRID_LRS for bw in GRID_BCES
]
pivot = pivot.reindex(ordered_rows)
z = pivot.values
text = np.where(np.isnan(z), "",
np.vectorize(lambda v: f"{v:.4f}")(np.nan_to_num(z, nan=0.0)))
fig = go.Figure(data=go.Heatmap(
z=z, x=GRID_AUGS, y=ordered_rows,
colorscale=color_scale, zmid=zmid,
text=text, texttemplate="%{text}", textfont=dict(size=12),
colorbar=dict(title=metric.replace("_", " ")),
))
fig.update_layout(
height=360, margin=dict(l=10, r=10, t=10, b=10),
xaxis_title="Augmentation", yaxis_title="(learning rate, BCE weight)",
)
st.plotly_chart(fig, use_container_width=True)
i = sub[metric].idxmax()
best = sub.loc[i]
st.success(
f"Best for {PRETTY_NAME[model]} on **{metric}**: "
f"`{best['cfg_id']}` (lr={best['lr_label']}, "
f"bce={best['bce_weight']:.1f}, aug={best['augment']}) "
f"β†’ {metric}={best[metric]:.4f} "
f"(baseline Dice={best['baseline_val_dice']:.4f})"
)
# ─── Tab 5: Top fine-tune configs ─────────────────────────────────────────
with t5:
st.subheader("All 54 fine-tune configurations")
if grid_df.empty:
st.info("No grid results yet.")
else:
df = grid_df.copy()
df["model_pretty"] = df["model"].map(PRETTY_NAME)
df["lr_label"] = df["lr"].apply(lambda x: f"{x:.0e}")
sort_by = st.selectbox(
"Sort by",
["delta_dice", "best_val_dice", "best_val_miou", "best_val_iou", "wall_clock_seconds"],
index=0, key="t5_sort",
)
ascending = st.toggle("ascending", value=False, key="t5_asc")
only_improved = st.checkbox(
"only configs that improved over baseline", value=False, key="t5_imp"
)
show = df.copy()
if only_improved:
show = show[show["delta_dice"] > 0]
show = show.sort_values(sort_by, ascending=ascending)
cols_to_show = [
"cfg_id", "model_pretty", "lr_label", "bce_weight", "augment",
"best_epoch", "epochs_trained", "early_stopped",
"best_val_dice", "best_val_miou", "best_val_iou", "best_val_pixel_acc",
"baseline_val_dice", "delta_dice", "wall_clock_seconds",
]
cols_to_show = [c for c in cols_to_show if c in show.columns]
st.dataframe(show[cols_to_show], use_container_width=True, hide_index=True)
# ─── Tab 6: Per-config fine-tune curves ───────────────────────────────────
with t6:
st.subheader("Per-config learning curves")
if grid_df.empty:
st.info("No grid results yet.")
else:
cfg_options = sorted(grid_df["cfg_id"].unique().tolist())
chosen = st.multiselect(
"Pick configs to overlay",
cfg_options,
default=cfg_options[:3] if cfg_options else [],
key="t6_pick",
)
metric_key, metric_label = st.selectbox(
"Metric",
[("dice", "Dice"), ("miou", "mIoU"), ("iou", "Foreground IoU"),
("pixel_acc", "Pixel accuracy"), ("loss", "Loss")],
format_func=lambda x: x[1], key="t6_metric",
)
if not chosen:
st.info("Select at least one config.")
else:
fig = go.Figure()
for cfg_id in chosen:
log = load_grid_log(cfg_id)
if log is None:
continue
xs = [e["epoch"] for e in log["epochs"]]
ys_val = [e[f"val_{metric_key}"] for e in log["epochs"]]
fig.add_trace(go.Scatter(
x=xs, y=ys_val, mode="lines+markers", name=f"{cfg_id} val",
))
ys_train = [e[f"train_{metric_key}"] for e in log["epochs"]]
fig.add_trace(go.Scatter(
x=xs, y=ys_train, mode="lines", line=dict(dash="dot"),
name=f"{cfg_id} train",
))
for cfg_id in chosen:
log = load_grid_log(cfg_id)
if log and "baseline_val_dice" in log and metric_key == "dice":
fig.add_hline(
y=log["baseline_val_dice"], line_dash="dash",
annotation_text=f"{cfg_id} baseline ({log['baseline_val_dice']:.4f})",
annotation_position="bottom right",
)
fig.update_layout(
xaxis_title="Epoch", yaxis_title=metric_label, height=480,
margin=dict(l=10, r=10, t=10, b=10),
legend=dict(orientation="h", y=-0.2),
)
st.plotly_chart(fig, use_container_width=True)