""" Unified trainer for the clean data-scaling study. Usage: python train.py --model {segnet,unet,segformer_b0,segformer_b5} --share {25,50,100} Example: python train.py --model unet --share 25 python train.py --model segformer_b5 --share 100 Each run: - reads subset_{share}.txt for training filenames (cleaned dataset) - validates on the full cleaned val set every epoch - logs per-epoch metrics + timing to logs/{model}_{share}.json - saves two checkpoints: checkpoints/{model}_{share}_best.pth (highest val Dice) checkpoints/{model}_{share}_final.pth (last epoch) Hyperparameters mirror each model's existing trainer in pv_panel_models/, so the only differences vs. the original baselines are: (a) the deduplicated training set (no train↔val image leakage) (b) global confusion-matrix metrics (mIoU, IoU, Dice, PixelAcc) (c) reproducible seed """ import argparse import json import os import time from pathlib import Path import torch import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm from dataset import SubsetSolarPanelDataset from metrics import SegMetrics from models import MODEL_REGISTRY THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] CLEAN = REPO_ROOT / "final_data_clean" TRAIN_IMG = CLEAN / "train" / "images" TRAIN_MSK = CLEAN / "train" / "masks" VAL_IMG = CLEAN / "val" / "images" VAL_MSK = CLEAN / "val" / "masks" SUBSETS_DIR = THIS_DIR / "subsets" LOG_DIR = THIS_DIR / "logs" CKPT_DIR = THIS_DIR / "checkpoints" def _fmt(seconds: float) -> str: seconds = int(round(seconds)) h, rem = divmod(seconds, 3600) m, s = divmod(rem, 60) return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" def run_epoch(model, loader, criterion, optimizer, device, train: bool, output_is_prob: bool): model.train(mode=train) metrics = SegMetrics() total_loss = 0.0 n_batches = 0 desc = "Train" if train else "Val" ctx = torch.enable_grad() if train else torch.no_grad() with ctx: for images, masks in tqdm(loader, desc=desc, leave=False): images = images.to(device, non_blocking=True) masks = masks.to(device, non_blocking=True) if train: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, masks) if train: loss.backward() optimizer.step() total_loss += loss.item() n_batches += 1 metrics.update(outputs.detach(), masks, output_is_prob=output_is_prob) avg_loss = total_loss / max(n_batches, 1) return avg_loss, metrics.compute() def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", required=True, choices=list(MODEL_REGISTRY.keys())) p.add_argument("--share", required=True, type=int, choices=[25, 50, 100]) p.add_argument("--epochs", type=int, default=50) p.add_argument("--batch-size", type=int, default=16) p.add_argument("--image-size", type=int, default=128) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--num-workers", type=int, default=4) p.add_argument("--seed", type=int, default=42) return p.parse_args() def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[run] model={args.model} share={args.share}% device={device}") if not CLEAN.is_dir(): raise FileNotFoundError( f"Cleaned dataset not found at {CLEAN}\n" f"Run dedupe_dataset.py first." ) LOG_DIR.mkdir(parents=True, exist_ok=True) CKPT_DIR.mkdir(parents=True, exist_ok=True) subset_file = SUBSETS_DIR / f"subset_{args.share}.txt" if not subset_file.is_file(): raise FileNotFoundError( f"{subset_file} not found. Run subsets/make_subsets.py first." ) train_set = SubsetSolarPanelDataset( TRAIN_IMG, TRAIN_MSK, file_list=subset_file, image_size=args.image_size, augment=True, ) val_set = SubsetSolarPanelDataset( VAL_IMG, VAL_MSK, file_list=None, image_size=args.image_size, augment=False, ) print(f"[data] train={len(train_set)} val={len(val_set)}") train_loader = DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, ) val_loader = DataLoader( val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) builder = MODEL_REGISTRY[args.model] model, criterion, output_is_prob = builder() model = model.to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[model] {args.model} params={n_params:,} " f"output={'prob' if output_is_prob else 'logits'}") optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", patience=5, factor=0.5 ) history = { "model": args.model, "share": args.share, "n_train": len(train_set), "n_val": len(val_set), "n_params": n_params, "epochs": [], } best_dice = -1.0 best_epoch = -1 best_path = CKPT_DIR / f"{args.model}_{args.share}_best.pth" final_path = CKPT_DIR / f"{args.model}_{args.share}_final.pth" log_path = LOG_DIR / f"{args.model}_{args.share}.json" t0 = time.time() history["start_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(t0)) val_loss, val_m = 0.0, {"dice": 0.0, "iou": 0.0, "miou": 0.0, "pixel_acc": 0.0} for epoch in range(args.epochs): print(f"\nEpoch {epoch + 1}/{args.epochs}") epoch_t0 = time.time() train_t0 = time.time() train_loss, train_m = run_epoch(model, train_loader, criterion, optimizer, device, train=True, output_is_prob=output_is_prob) train_seconds = time.time() - train_t0 val_t0 = time.time() val_loss, val_m = run_epoch(model, val_loader, criterion, optimizer, device, train=False, output_is_prob=output_is_prob) val_seconds = time.time() - val_t0 scheduler.step(val_m["dice"]) epoch_seconds = time.time() - epoch_t0 elapsed = time.time() - t0 avg_per_epoch = elapsed / (epoch + 1) eta = avg_per_epoch * (args.epochs - epoch - 1) epoch_record = { "epoch": epoch + 1, "lr": optimizer.param_groups[0]["lr"], "train_loss": train_loss, "val_loss": val_loss, **{f"train_{k}": v for k, v in train_m.items()}, **{f"val_{k}": v for k, v in val_m.items()}, "epoch_seconds": epoch_seconds, "train_seconds": train_seconds, "val_seconds": val_seconds, } history["epochs"].append(epoch_record) print( f" train loss={train_loss:.4f} dice={train_m['dice']:.4f} " f"iou={train_m['iou']:.4f} miou={train_m['miou']:.4f} " f"pixel_acc={train_m['pixel_acc']:.4f}" ) print( f" val loss={val_loss:.4f} dice={val_m['dice']:.4f} " f"iou={val_m['iou']:.4f} miou={val_m['miou']:.4f} " f"pixel_acc={val_m['pixel_acc']:.4f}" ) print( f" time epoch={_fmt(epoch_seconds)} " f"(train={_fmt(train_seconds)} val={_fmt(val_seconds)}) " f"elapsed={_fmt(elapsed)} ETA={_fmt(eta)}" ) with open(log_path, "w") as f: json.dump(history, f, indent=2) if val_m["dice"] > best_dice: best_dice = val_m["dice"] best_epoch = epoch + 1 torch.save({ "epoch": epoch + 1, "model_state_dict": model.state_dict(), "val_metrics": val_m, "model_name": args.model, "share": args.share, "output_is_prob": output_is_prob, }, best_path) print(f" ↳ new best (dice={best_dice:.4f}) → {best_path.name}") torch.save({ "epoch": args.epochs, "model_state_dict": model.state_dict(), "val_metrics": val_m, "model_name": args.model, "share": args.share, "output_is_prob": output_is_prob, }, final_path) total_seconds = time.time() - t0 history["best_epoch"] = best_epoch history["best_val_dice"] = best_dice history["wall_clock_seconds"] = total_seconds history["end_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S") with open(log_path, "w") as f: json.dump(history, f, indent=2) print(f"\n[done] best epoch {best_epoch} (dice={best_dice:.4f})") print(f" wall {_fmt(total_seconds)} ({total_seconds:.1f} s)") print(f" best → {best_path}") print(f" final → {final_path}") print(f" log → {log_path}") if __name__ == "__main__": main()