""" Resolution-study trainer. Trains one (model, image_size) cell at 100% data on final_data_clean/. Usage: python train.py --model segformer_b0 --image-size 192 python train.py --model segformer_b0 --image-size 256 Each run produces: checkpoints/{model}_res{image_size}_best.pth (state at highest val Dice) logs/{model}_res{image_size}.json (per-epoch metrics + timing) one row appended to results/resolution_results.csv Hyperparameters held identical to the clean baseline at 128: Adam, lr=1e-4, ReduceLROnPlateau(mode='max', patience=5, factor=0.5), 50 epochs, CombinedLoss(0.5*BCE + 0.5*Dice), HFlip+VFlip+Rot15. Only image_size and (optionally) batch_size vary. """ import argparse import csv import json import time from pathlib import Path from threading import Lock import torch import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm from dataset import SolarPanelDataset from metrics import SegMetrics from models import MODEL_REGISTRY THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] CLEAN = REPO_ROOT / "final_data_clean" TRAIN_IMG = CLEAN / "train" / "images" TRAIN_MSK = CLEAN / "train" / "masks" VAL_IMG = CLEAN / "val" / "images" VAL_MSK = CLEAN / "val" / "masks" LOG_DIR = THIS_DIR / "logs" CKPT_DIR = THIS_DIR / "checkpoints" RESULTS_DIR = THIS_DIR / "results" RESULTS_CSV = RESULTS_DIR / "resolution_results.csv" # Reference val-Dice values at 128 (from clean_data_scaling_study at 100% data). BASELINE_128 = { "segnet": 0.9291, "unet": 0.9370, "segformer_b0": 0.9280, "segformer_b5": 0.9371, } CSV_FIELDS = [ "cfg_id", "model", "image_size", "batch_size", "best_epoch", "best_val_dice", "best_val_miou", "best_val_iou", "best_val_pixel_acc", "baseline_dice_at_128", "delta_vs_128", "wall_clock_seconds", ] _csv_lock = Lock() def _fmt(seconds: float) -> str: seconds = int(round(seconds)) h, rem = divmod(seconds, 3600) m, s = divmod(rem, 60) return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" def append_csv_row(row: dict): RESULTS_DIR.mkdir(parents=True, exist_ok=True) with _csv_lock: write_header = not RESULTS_CSV.is_file() with open(RESULTS_CSV, "a", newline="") as f: writer = csv.DictWriter(f, fieldnames=CSV_FIELDS) if write_header: writer.writeheader() writer.writerow({k: row.get(k) for k in CSV_FIELDS}) def run_epoch(model, loader, criterion, optimizer, device, train: bool, output_is_prob: bool, throttle: float | None = None): model.train(mode=train) metrics = SegMetrics() total_loss = 0.0 n_batches = 0 do_throttle = train and throttle is not None and 0 < throttle < 1.0 sleep_factor = (1.0 - throttle) / throttle if do_throttle else 0.0 desc = "Train" if train else "Val" ctx = torch.enable_grad() if train else torch.no_grad() with ctx: for images, masks in tqdm(loader, desc=desc, leave=False): batch_t0 = time.perf_counter() if do_throttle else None images = images.to(device, non_blocking=True) masks = masks.to(device, non_blocking=True) if train: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, masks) if train: loss.backward() optimizer.step() total_loss += loss.item() n_batches += 1 metrics.update(outputs.detach(), masks, output_is_prob=output_is_prob) if do_throttle: if torch.cuda.is_available(): torch.cuda.synchronize() step_dt = time.perf_counter() - batch_t0 time.sleep(step_dt * sleep_factor) return total_loss / max(n_batches, 1), metrics.compute() def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", required=True, choices=list(MODEL_REGISTRY.keys())) p.add_argument("--image-size", required=True, type=int, help="square input resolution (e.g. 192, 256)") p.add_argument("--batch-size", type=int, default=16) p.add_argument("--epochs", type=int, default=50) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--num-workers", type=int, default=4) p.add_argument("--seed", type=int, default=42) p.add_argument("--throttle", type=float, default=None, help="optional duty-cycle cap, e.g. 0.4 for 40%% average util") return p.parse_args() def main(): args = parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cfg_id = f"{args.model}_res{args.image_size}" print(f"[run] {cfg_id} image_size={args.image_size} batch_size={args.batch_size} device={device}") LOG_DIR.mkdir(parents=True, exist_ok=True) CKPT_DIR.mkdir(parents=True, exist_ok=True) train_set = SolarPanelDataset(TRAIN_IMG, TRAIN_MSK, image_size=args.image_size, augment=True) val_set = SolarPanelDataset(VAL_IMG, VAL_MSK, image_size=args.image_size, augment=False) print(f"[data] train={len(train_set)} val={len(val_set)}") train_loader = DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, ) val_loader = DataLoader( val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) builder = MODEL_REGISTRY[args.model] model, criterion, output_is_prob = builder() model = model.to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[model] {args.model} params={n_params:,} " f"output={'prob' if output_is_prob else 'logits'}") optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=5, factor=0.5) baseline = BASELINE_128.get(args.model) history = { "cfg_id": cfg_id, "model": args.model, "image_size": args.image_size, "batch_size": args.batch_size, "throttle": args.throttle, "n_train": len(train_set), "n_val": len(val_set), "n_params": n_params, "baseline_dice_at_128": baseline, "epochs": [], } best_dice = -1.0 best_epoch = -1 best_path = CKPT_DIR / f"{cfg_id}_best.pth" log_path = LOG_DIR / f"{cfg_id}.json" val_m = {"dice": 0.0, "iou": 0.0, "miou": 0.0, "pixel_acc": 0.0} if args.throttle is not None and 0 < args.throttle < 1: print(f"[throttle] training duty cycle capped at {args.throttle*100:.0f}%") t0 = time.time() history["start_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(t0)) for epoch in range(args.epochs): epoch_t0 = time.time() train_loss, train_m = run_epoch( model, train_loader, criterion, optimizer, device, train=True, output_is_prob=output_is_prob, throttle=args.throttle, ) val_loss, val_m = run_epoch( model, val_loader, criterion, optimizer, device, train=False, output_is_prob=output_is_prob, throttle=None, ) scheduler.step(val_m["dice"]) epoch_seconds = time.time() - epoch_t0 elapsed = time.time() - t0 eta = (elapsed / (epoch + 1)) * (args.epochs - epoch - 1) improved = val_m["dice"] > best_dice if improved: best_dice = val_m["dice"] best_epoch = epoch + 1 torch.save({ "epoch": epoch + 1, "model_state_dict": model.state_dict(), "val_metrics": val_m, "cfg": { "cfg_id": cfg_id, "model": args.model, "image_size": args.image_size, }, "output_is_prob": output_is_prob, }, best_path) history["epochs"].append({ "epoch": epoch + 1, "lr": optimizer.param_groups[0]["lr"], "train_loss": train_loss, "val_loss": val_loss, **{f"train_{k}": v for k, v in train_m.items()}, **{f"val_{k}": v for k, v in val_m.items()}, "epoch_seconds": epoch_seconds, }) marker = "★" if improved else " " print( f" ep {epoch+1:>2}/{args.epochs} " f"trL={train_loss:.4f} vL={val_loss:.4f} " f"vDice={val_m['dice']:.4f} {marker} vIoU={val_m['iou']:.4f} " f"({_fmt(elapsed)}/ETA {_fmt(eta)})" ) with open(log_path, "w") as f: json.dump(history, f, indent=2) total = time.time() - t0 history["best_epoch"] = best_epoch history["best_val_dice"] = best_dice history["epochs_trained"] = args.epochs history["wall_clock_seconds"] = total with open(log_path, "w") as f: json.dump(history, f, indent=2) if best_path.is_file(): st = torch.load(best_path, map_location="cpu", weights_only=False) bvm = st.get("val_metrics", {}) best_val_iou = bvm.get("iou") best_val_miou = bvm.get("miou") best_val_pa = bvm.get("pixel_acc") else: best_val_iou = best_val_miou = best_val_pa = None delta = (best_dice - baseline) if baseline is not None else None append_csv_row({ "cfg_id": cfg_id, "model": args.model, "image_size": args.image_size, "batch_size": args.batch_size, "best_epoch": best_epoch, "best_val_dice": best_dice, "best_val_miou": best_val_miou, "best_val_iou": best_val_iou, "best_val_pixel_acc": best_val_pa, "baseline_dice_at_128": baseline, "delta_vs_128": delta, "wall_clock_seconds": total, }) print(f"\n[done] {cfg_id} best ep {best_epoch} vDice={best_dice:.4f}") if baseline is not None: sign = "+" if delta >= 0 else "" print(f" baseline at 128 = {baseline:.4f} delta = {sign}{delta:+.4f}") print(f" wall {_fmt(total)}") if __name__ == "__main__": main()