Spaces:
Running
Running
| """ | |
| Unified trainer for the clean data-scaling study. | |
| Usage: | |
| python train.py --model {segnet,unet,segformer_b0,segformer_b5} --share {25,50,100} | |
| Example: | |
| python train.py --model unet --share 25 | |
| python train.py --model segformer_b5 --share 100 | |
| Each run: | |
| - reads subset_{share}.txt for training filenames (cleaned dataset) | |
| - validates on the full cleaned val set every epoch | |
| - logs per-epoch metrics + timing to logs/{model}_{share}.json | |
| - saves two checkpoints: | |
| checkpoints/{model}_{share}_best.pth (highest val Dice) | |
| checkpoints/{model}_{share}_final.pth (last epoch) | |
| Hyperparameters mirror each model's existing trainer in pv_panel_models/, so | |
| the only differences vs. the original baselines are: | |
| (a) the deduplicated training set (no train↔val image leakage) | |
| (b) global confusion-matrix metrics (mIoU, IoU, Dice, PixelAcc) | |
| (c) reproducible seed | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| from pathlib import Path | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from dataset import SubsetSolarPanelDataset | |
| from metrics import SegMetrics | |
| from models import MODEL_REGISTRY | |
| THIS_DIR = Path(__file__).resolve().parent | |
| REPO_ROOT = THIS_DIR.parents[1] | |
| CLEAN = REPO_ROOT / "final_data_clean" | |
| TRAIN_IMG = CLEAN / "train" / "images" | |
| TRAIN_MSK = CLEAN / "train" / "masks" | |
| VAL_IMG = CLEAN / "val" / "images" | |
| VAL_MSK = CLEAN / "val" / "masks" | |
| SUBSETS_DIR = THIS_DIR / "subsets" | |
| LOG_DIR = THIS_DIR / "logs" | |
| CKPT_DIR = THIS_DIR / "checkpoints" | |
| def _fmt(seconds: float) -> str: | |
| seconds = int(round(seconds)) | |
| h, rem = divmod(seconds, 3600) | |
| m, s = divmod(rem, 60) | |
| return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" | |
| def run_epoch(model, loader, criterion, optimizer, device, train: bool, output_is_prob: bool): | |
| model.train(mode=train) | |
| metrics = SegMetrics() | |
| total_loss = 0.0 | |
| n_batches = 0 | |
| desc = "Train" if train else "Val" | |
| ctx = torch.enable_grad() if train else torch.no_grad() | |
| with ctx: | |
| for images, masks in tqdm(loader, desc=desc, leave=False): | |
| images = images.to(device, non_blocking=True) | |
| masks = masks.to(device, non_blocking=True) | |
| if train: | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, masks) | |
| if train: | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| metrics.update(outputs.detach(), masks, output_is_prob=output_is_prob) | |
| avg_loss = total_loss / max(n_batches, 1) | |
| return avg_loss, metrics.compute() | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--model", required=True, choices=list(MODEL_REGISTRY.keys())) | |
| p.add_argument("--share", required=True, type=int, choices=[25, 50, 100]) | |
| p.add_argument("--epochs", type=int, default=50) | |
| p.add_argument("--batch-size", type=int, default=16) | |
| p.add_argument("--image-size", type=int, default=128) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--num-workers", type=int, default=4) | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| torch.manual_seed(args.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[run] model={args.model} share={args.share}% device={device}") | |
| if not CLEAN.is_dir(): | |
| raise FileNotFoundError( | |
| f"Cleaned dataset not found at {CLEAN}\n" | |
| f"Run dedupe_dataset.py first." | |
| ) | |
| LOG_DIR.mkdir(parents=True, exist_ok=True) | |
| CKPT_DIR.mkdir(parents=True, exist_ok=True) | |
| subset_file = SUBSETS_DIR / f"subset_{args.share}.txt" | |
| if not subset_file.is_file(): | |
| raise FileNotFoundError( | |
| f"{subset_file} not found. Run subsets/make_subsets.py first." | |
| ) | |
| train_set = SubsetSolarPanelDataset( | |
| TRAIN_IMG, TRAIN_MSK, | |
| file_list=subset_file, | |
| image_size=args.image_size, | |
| augment=True, | |
| ) | |
| val_set = SubsetSolarPanelDataset( | |
| VAL_IMG, VAL_MSK, | |
| file_list=None, | |
| image_size=args.image_size, | |
| augment=False, | |
| ) | |
| print(f"[data] train={len(train_set)} val={len(val_set)}") | |
| train_loader = DataLoader( | |
| train_set, batch_size=args.batch_size, shuffle=True, | |
| num_workers=args.num_workers, pin_memory=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_set, batch_size=args.batch_size, shuffle=False, | |
| num_workers=args.num_workers, pin_memory=True, | |
| ) | |
| builder = MODEL_REGISTRY[args.model] | |
| model, criterion, output_is_prob = builder() | |
| model = model.to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f"[model] {args.model} params={n_params:,} " | |
| f"output={'prob' if output_is_prob else 'logits'}") | |
| optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="max", patience=5, factor=0.5 | |
| ) | |
| history = { | |
| "model": args.model, | |
| "share": args.share, | |
| "n_train": len(train_set), | |
| "n_val": len(val_set), | |
| "n_params": n_params, | |
| "epochs": [], | |
| } | |
| best_dice = -1.0 | |
| best_epoch = -1 | |
| best_path = CKPT_DIR / f"{args.model}_{args.share}_best.pth" | |
| final_path = CKPT_DIR / f"{args.model}_{args.share}_final.pth" | |
| log_path = LOG_DIR / f"{args.model}_{args.share}.json" | |
| t0 = time.time() | |
| history["start_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(t0)) | |
| val_loss, val_m = 0.0, {"dice": 0.0, "iou": 0.0, "miou": 0.0, "pixel_acc": 0.0} | |
| for epoch in range(args.epochs): | |
| print(f"\nEpoch {epoch + 1}/{args.epochs}") | |
| epoch_t0 = time.time() | |
| train_t0 = time.time() | |
| train_loss, train_m = run_epoch(model, train_loader, criterion, optimizer, device, | |
| train=True, output_is_prob=output_is_prob) | |
| train_seconds = time.time() - train_t0 | |
| val_t0 = time.time() | |
| val_loss, val_m = run_epoch(model, val_loader, criterion, optimizer, device, | |
| train=False, output_is_prob=output_is_prob) | |
| val_seconds = time.time() - val_t0 | |
| scheduler.step(val_m["dice"]) | |
| epoch_seconds = time.time() - epoch_t0 | |
| elapsed = time.time() - t0 | |
| avg_per_epoch = elapsed / (epoch + 1) | |
| eta = avg_per_epoch * (args.epochs - epoch - 1) | |
| epoch_record = { | |
| "epoch": epoch + 1, | |
| "lr": optimizer.param_groups[0]["lr"], | |
| "train_loss": train_loss, | |
| "val_loss": val_loss, | |
| **{f"train_{k}": v for k, v in train_m.items()}, | |
| **{f"val_{k}": v for k, v in val_m.items()}, | |
| "epoch_seconds": epoch_seconds, | |
| "train_seconds": train_seconds, | |
| "val_seconds": val_seconds, | |
| } | |
| history["epochs"].append(epoch_record) | |
| print( | |
| f" train loss={train_loss:.4f} dice={train_m['dice']:.4f} " | |
| f"iou={train_m['iou']:.4f} miou={train_m['miou']:.4f} " | |
| f"pixel_acc={train_m['pixel_acc']:.4f}" | |
| ) | |
| print( | |
| f" val loss={val_loss:.4f} dice={val_m['dice']:.4f} " | |
| f"iou={val_m['iou']:.4f} miou={val_m['miou']:.4f} " | |
| f"pixel_acc={val_m['pixel_acc']:.4f}" | |
| ) | |
| print( | |
| f" time epoch={_fmt(epoch_seconds)} " | |
| f"(train={_fmt(train_seconds)} val={_fmt(val_seconds)}) " | |
| f"elapsed={_fmt(elapsed)} ETA={_fmt(eta)}" | |
| ) | |
| with open(log_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| if val_m["dice"] > best_dice: | |
| best_dice = val_m["dice"] | |
| best_epoch = epoch + 1 | |
| torch.save({ | |
| "epoch": epoch + 1, | |
| "model_state_dict": model.state_dict(), | |
| "val_metrics": val_m, | |
| "model_name": args.model, | |
| "share": args.share, | |
| "output_is_prob": output_is_prob, | |
| }, best_path) | |
| print(f" ↳ new best (dice={best_dice:.4f}) → {best_path.name}") | |
| torch.save({ | |
| "epoch": args.epochs, | |
| "model_state_dict": model.state_dict(), | |
| "val_metrics": val_m, | |
| "model_name": args.model, | |
| "share": args.share, | |
| "output_is_prob": output_is_prob, | |
| }, final_path) | |
| total_seconds = time.time() - t0 | |
| history["best_epoch"] = best_epoch | |
| history["best_val_dice"] = best_dice | |
| history["wall_clock_seconds"] = total_seconds | |
| history["end_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S") | |
| with open(log_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| print(f"\n[done] best epoch {best_epoch} (dice={best_dice:.4f})") | |
| print(f" wall {_fmt(total_seconds)} ({total_seconds:.1f} s)") | |
| print(f" best → {best_path}") | |
| print(f" final → {final_path}") | |
| print(f" log → {log_path}") | |
| if __name__ == "__main__": | |
| main() | |