Mohamed-ENNHIRI
Solar Panel Segmentation app for HF Spaces
52efd90
"""
Unified trainer for the clean data-scaling study.
Usage:
python train.py --model {segnet,unet,segformer_b0,segformer_b5} --share {25,50,100}
Example:
python train.py --model unet --share 25
python train.py --model segformer_b5 --share 100
Each run:
- reads subset_{share}.txt for training filenames (cleaned dataset)
- validates on the full cleaned val set every epoch
- logs per-epoch metrics + timing to logs/{model}_{share}.json
- saves two checkpoints:
checkpoints/{model}_{share}_best.pth (highest val Dice)
checkpoints/{model}_{share}_final.pth (last epoch)
Hyperparameters mirror each model's existing trainer in pv_panel_models/, so
the only differences vs. the original baselines are:
(a) the deduplicated training set (no train↔val image leakage)
(b) global confusion-matrix metrics (mIoU, IoU, Dice, PixelAcc)
(c) reproducible seed
"""
import argparse
import json
import os
import time
from pathlib import Path
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import SubsetSolarPanelDataset
from metrics import SegMetrics
from models import MODEL_REGISTRY
THIS_DIR = Path(__file__).resolve().parent
REPO_ROOT = THIS_DIR.parents[1]
CLEAN = REPO_ROOT / "final_data_clean"
TRAIN_IMG = CLEAN / "train" / "images"
TRAIN_MSK = CLEAN / "train" / "masks"
VAL_IMG = CLEAN / "val" / "images"
VAL_MSK = CLEAN / "val" / "masks"
SUBSETS_DIR = THIS_DIR / "subsets"
LOG_DIR = THIS_DIR / "logs"
CKPT_DIR = THIS_DIR / "checkpoints"
def _fmt(seconds: float) -> str:
seconds = int(round(seconds))
h, rem = divmod(seconds, 3600)
m, s = divmod(rem, 60)
return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}"
def run_epoch(model, loader, criterion, optimizer, device, train: bool, output_is_prob: bool):
model.train(mode=train)
metrics = SegMetrics()
total_loss = 0.0
n_batches = 0
desc = "Train" if train else "Val"
ctx = torch.enable_grad() if train else torch.no_grad()
with ctx:
for images, masks in tqdm(loader, desc=desc, leave=False):
images = images.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True)
if train:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
if train:
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
metrics.update(outputs.detach(), masks, output_is_prob=output_is_prob)
avg_loss = total_loss / max(n_batches, 1)
return avg_loss, metrics.compute()
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, choices=list(MODEL_REGISTRY.keys()))
p.add_argument("--share", required=True, type=int, choices=[25, 50, 100])
p.add_argument("--epochs", type=int, default=50)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--image-size", type=int, default=128)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--num-workers", type=int, default=4)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run] model={args.model} share={args.share}% device={device}")
if not CLEAN.is_dir():
raise FileNotFoundError(
f"Cleaned dataset not found at {CLEAN}\n"
f"Run dedupe_dataset.py first."
)
LOG_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)
subset_file = SUBSETS_DIR / f"subset_{args.share}.txt"
if not subset_file.is_file():
raise FileNotFoundError(
f"{subset_file} not found. Run subsets/make_subsets.py first."
)
train_set = SubsetSolarPanelDataset(
TRAIN_IMG, TRAIN_MSK,
file_list=subset_file,
image_size=args.image_size,
augment=True,
)
val_set = SubsetSolarPanelDataset(
VAL_IMG, VAL_MSK,
file_list=None,
image_size=args.image_size,
augment=False,
)
print(f"[data] train={len(train_set)} val={len(val_set)}")
train_loader = DataLoader(
train_set, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True,
)
val_loader = DataLoader(
val_set, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True,
)
builder = MODEL_REGISTRY[args.model]
model, criterion, output_is_prob = builder()
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"[model] {args.model} params={n_params:,} "
f"output={'prob' if output_is_prob else 'logits'}")
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", patience=5, factor=0.5
)
history = {
"model": args.model,
"share": args.share,
"n_train": len(train_set),
"n_val": len(val_set),
"n_params": n_params,
"epochs": [],
}
best_dice = -1.0
best_epoch = -1
best_path = CKPT_DIR / f"{args.model}_{args.share}_best.pth"
final_path = CKPT_DIR / f"{args.model}_{args.share}_final.pth"
log_path = LOG_DIR / f"{args.model}_{args.share}.json"
t0 = time.time()
history["start_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(t0))
val_loss, val_m = 0.0, {"dice": 0.0, "iou": 0.0, "miou": 0.0, "pixel_acc": 0.0}
for epoch in range(args.epochs):
print(f"\nEpoch {epoch + 1}/{args.epochs}")
epoch_t0 = time.time()
train_t0 = time.time()
train_loss, train_m = run_epoch(model, train_loader, criterion, optimizer, device,
train=True, output_is_prob=output_is_prob)
train_seconds = time.time() - train_t0
val_t0 = time.time()
val_loss, val_m = run_epoch(model, val_loader, criterion, optimizer, device,
train=False, output_is_prob=output_is_prob)
val_seconds = time.time() - val_t0
scheduler.step(val_m["dice"])
epoch_seconds = time.time() - epoch_t0
elapsed = time.time() - t0
avg_per_epoch = elapsed / (epoch + 1)
eta = avg_per_epoch * (args.epochs - epoch - 1)
epoch_record = {
"epoch": epoch + 1,
"lr": optimizer.param_groups[0]["lr"],
"train_loss": train_loss,
"val_loss": val_loss,
**{f"train_{k}": v for k, v in train_m.items()},
**{f"val_{k}": v for k, v in val_m.items()},
"epoch_seconds": epoch_seconds,
"train_seconds": train_seconds,
"val_seconds": val_seconds,
}
history["epochs"].append(epoch_record)
print(
f" train loss={train_loss:.4f} dice={train_m['dice']:.4f} "
f"iou={train_m['iou']:.4f} miou={train_m['miou']:.4f} "
f"pixel_acc={train_m['pixel_acc']:.4f}"
)
print(
f" val loss={val_loss:.4f} dice={val_m['dice']:.4f} "
f"iou={val_m['iou']:.4f} miou={val_m['miou']:.4f} "
f"pixel_acc={val_m['pixel_acc']:.4f}"
)
print(
f" time epoch={_fmt(epoch_seconds)} "
f"(train={_fmt(train_seconds)} val={_fmt(val_seconds)}) "
f"elapsed={_fmt(elapsed)} ETA={_fmt(eta)}"
)
with open(log_path, "w") as f:
json.dump(history, f, indent=2)
if val_m["dice"] > best_dice:
best_dice = val_m["dice"]
best_epoch = epoch + 1
torch.save({
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"val_metrics": val_m,
"model_name": args.model,
"share": args.share,
"output_is_prob": output_is_prob,
}, best_path)
print(f" ↳ new best (dice={best_dice:.4f}) → {best_path.name}")
torch.save({
"epoch": args.epochs,
"model_state_dict": model.state_dict(),
"val_metrics": val_m,
"model_name": args.model,
"share": args.share,
"output_is_prob": output_is_prob,
}, final_path)
total_seconds = time.time() - t0
history["best_epoch"] = best_epoch
history["best_val_dice"] = best_dice
history["wall_clock_seconds"] = total_seconds
history["end_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S")
with open(log_path, "w") as f:
json.dump(history, f, indent=2)
print(f"\n[done] best epoch {best_epoch} (dice={best_dice:.4f})")
print(f" wall {_fmt(total_seconds)} ({total_seconds:.1f} s)")
print(f" best → {best_path}")
print(f" final → {final_path}")
print(f" log → {log_path}")
if __name__ == "__main__":
main()