Spaces:
Running
Running
| """ | |
| Resolution-study trainer. | |
| Trains one (model, image_size) cell at 100% data on final_data_clean/. | |
| Usage: | |
| python train.py --model segformer_b0 --image-size 192 | |
| python train.py --model segformer_b0 --image-size 256 | |
| Each run produces: | |
| checkpoints/{model}_res{image_size}_best.pth (state at highest val Dice) | |
| logs/{model}_res{image_size}.json (per-epoch metrics + timing) | |
| one row appended to results/resolution_results.csv | |
| Hyperparameters held identical to the clean baseline at 128: | |
| Adam, lr=1e-4, ReduceLROnPlateau(mode='max', patience=5, factor=0.5), | |
| 50 epochs, CombinedLoss(0.5*BCE + 0.5*Dice), HFlip+VFlip+Rot15. | |
| Only image_size and (optionally) batch_size vary. | |
| """ | |
| import argparse | |
| import csv | |
| import json | |
| import time | |
| from pathlib import Path | |
| from threading import Lock | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from dataset import SolarPanelDataset | |
| from metrics import SegMetrics | |
| from models import MODEL_REGISTRY | |
| THIS_DIR = Path(__file__).resolve().parent | |
| REPO_ROOT = THIS_DIR.parents[1] | |
| CLEAN = REPO_ROOT / "final_data_clean" | |
| TRAIN_IMG = CLEAN / "train" / "images" | |
| TRAIN_MSK = CLEAN / "train" / "masks" | |
| VAL_IMG = CLEAN / "val" / "images" | |
| VAL_MSK = CLEAN / "val" / "masks" | |
| LOG_DIR = THIS_DIR / "logs" | |
| CKPT_DIR = THIS_DIR / "checkpoints" | |
| RESULTS_DIR = THIS_DIR / "results" | |
| RESULTS_CSV = RESULTS_DIR / "resolution_results.csv" | |
| # Reference val-Dice values at 128 (from clean_data_scaling_study at 100% data). | |
| BASELINE_128 = { | |
| "segnet": 0.9291, | |
| "unet": 0.9370, | |
| "segformer_b0": 0.9280, | |
| "segformer_b5": 0.9371, | |
| } | |
| CSV_FIELDS = [ | |
| "cfg_id", "model", "image_size", "batch_size", | |
| "best_epoch", "best_val_dice", "best_val_miou", "best_val_iou", "best_val_pixel_acc", | |
| "baseline_dice_at_128", "delta_vs_128", | |
| "wall_clock_seconds", | |
| ] | |
| _csv_lock = Lock() | |
| def _fmt(seconds: float) -> str: | |
| seconds = int(round(seconds)) | |
| h, rem = divmod(seconds, 3600) | |
| m, s = divmod(rem, 60) | |
| return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}" | |
| def append_csv_row(row: dict): | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| with _csv_lock: | |
| write_header = not RESULTS_CSV.is_file() | |
| with open(RESULTS_CSV, "a", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=CSV_FIELDS) | |
| if write_header: | |
| writer.writeheader() | |
| writer.writerow({k: row.get(k) for k in CSV_FIELDS}) | |
| def run_epoch(model, loader, criterion, optimizer, device, train: bool, | |
| output_is_prob: bool, throttle: float | None = None): | |
| model.train(mode=train) | |
| metrics = SegMetrics() | |
| total_loss = 0.0 | |
| n_batches = 0 | |
| do_throttle = train and throttle is not None and 0 < throttle < 1.0 | |
| sleep_factor = (1.0 - throttle) / throttle if do_throttle else 0.0 | |
| desc = "Train" if train else "Val" | |
| ctx = torch.enable_grad() if train else torch.no_grad() | |
| with ctx: | |
| for images, masks in tqdm(loader, desc=desc, leave=False): | |
| batch_t0 = time.perf_counter() if do_throttle else None | |
| images = images.to(device, non_blocking=True) | |
| masks = masks.to(device, non_blocking=True) | |
| if train: | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, masks) | |
| if train: | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| metrics.update(outputs.detach(), masks, output_is_prob=output_is_prob) | |
| if do_throttle: | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| step_dt = time.perf_counter() - batch_t0 | |
| time.sleep(step_dt * sleep_factor) | |
| return total_loss / max(n_batches, 1), metrics.compute() | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--model", required=True, choices=list(MODEL_REGISTRY.keys())) | |
| p.add_argument("--image-size", required=True, type=int, | |
| help="square input resolution (e.g. 192, 256)") | |
| p.add_argument("--batch-size", type=int, default=16) | |
| p.add_argument("--epochs", type=int, default=50) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--num-workers", type=int, default=4) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--throttle", type=float, default=None, | |
| help="optional duty-cycle cap, e.g. 0.4 for 40%% average util") | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| torch.manual_seed(args.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| cfg_id = f"{args.model}_res{args.image_size}" | |
| print(f"[run] {cfg_id} image_size={args.image_size} batch_size={args.batch_size} device={device}") | |
| LOG_DIR.mkdir(parents=True, exist_ok=True) | |
| CKPT_DIR.mkdir(parents=True, exist_ok=True) | |
| train_set = SolarPanelDataset(TRAIN_IMG, TRAIN_MSK, image_size=args.image_size, augment=True) | |
| val_set = SolarPanelDataset(VAL_IMG, VAL_MSK, image_size=args.image_size, augment=False) | |
| print(f"[data] train={len(train_set)} val={len(val_set)}") | |
| train_loader = DataLoader( | |
| train_set, batch_size=args.batch_size, shuffle=True, | |
| num_workers=args.num_workers, pin_memory=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_set, batch_size=args.batch_size, shuffle=False, | |
| num_workers=args.num_workers, pin_memory=True, | |
| ) | |
| builder = MODEL_REGISTRY[args.model] | |
| model, criterion, output_is_prob = builder() | |
| model = model.to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f"[model] {args.model} params={n_params:,} " | |
| f"output={'prob' if output_is_prob else 'logits'}") | |
| optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=5, factor=0.5) | |
| baseline = BASELINE_128.get(args.model) | |
| history = { | |
| "cfg_id": cfg_id, | |
| "model": args.model, | |
| "image_size": args.image_size, | |
| "batch_size": args.batch_size, | |
| "throttle": args.throttle, | |
| "n_train": len(train_set), | |
| "n_val": len(val_set), | |
| "n_params": n_params, | |
| "baseline_dice_at_128": baseline, | |
| "epochs": [], | |
| } | |
| best_dice = -1.0 | |
| best_epoch = -1 | |
| best_path = CKPT_DIR / f"{cfg_id}_best.pth" | |
| log_path = LOG_DIR / f"{cfg_id}.json" | |
| val_m = {"dice": 0.0, "iou": 0.0, "miou": 0.0, "pixel_acc": 0.0} | |
| if args.throttle is not None and 0 < args.throttle < 1: | |
| print(f"[throttle] training duty cycle capped at {args.throttle*100:.0f}%") | |
| t0 = time.time() | |
| history["start_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(t0)) | |
| for epoch in range(args.epochs): | |
| epoch_t0 = time.time() | |
| train_loss, train_m = run_epoch( | |
| model, train_loader, criterion, optimizer, device, | |
| train=True, output_is_prob=output_is_prob, throttle=args.throttle, | |
| ) | |
| val_loss, val_m = run_epoch( | |
| model, val_loader, criterion, optimizer, device, | |
| train=False, output_is_prob=output_is_prob, throttle=None, | |
| ) | |
| scheduler.step(val_m["dice"]) | |
| epoch_seconds = time.time() - epoch_t0 | |
| elapsed = time.time() - t0 | |
| eta = (elapsed / (epoch + 1)) * (args.epochs - epoch - 1) | |
| improved = val_m["dice"] > best_dice | |
| if improved: | |
| best_dice = val_m["dice"] | |
| best_epoch = epoch + 1 | |
| torch.save({ | |
| "epoch": epoch + 1, | |
| "model_state_dict": model.state_dict(), | |
| "val_metrics": val_m, | |
| "cfg": { | |
| "cfg_id": cfg_id, | |
| "model": args.model, | |
| "image_size": args.image_size, | |
| }, | |
| "output_is_prob": output_is_prob, | |
| }, best_path) | |
| history["epochs"].append({ | |
| "epoch": epoch + 1, | |
| "lr": optimizer.param_groups[0]["lr"], | |
| "train_loss": train_loss, | |
| "val_loss": val_loss, | |
| **{f"train_{k}": v for k, v in train_m.items()}, | |
| **{f"val_{k}": v for k, v in val_m.items()}, | |
| "epoch_seconds": epoch_seconds, | |
| }) | |
| marker = "★" if improved else " " | |
| print( | |
| f" ep {epoch+1:>2}/{args.epochs} " | |
| f"trL={train_loss:.4f} vL={val_loss:.4f} " | |
| f"vDice={val_m['dice']:.4f} {marker} vIoU={val_m['iou']:.4f} " | |
| f"({_fmt(elapsed)}/ETA {_fmt(eta)})" | |
| ) | |
| with open(log_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| total = time.time() - t0 | |
| history["best_epoch"] = best_epoch | |
| history["best_val_dice"] = best_dice | |
| history["epochs_trained"] = args.epochs | |
| history["wall_clock_seconds"] = total | |
| with open(log_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| if best_path.is_file(): | |
| st = torch.load(best_path, map_location="cpu", weights_only=False) | |
| bvm = st.get("val_metrics", {}) | |
| best_val_iou = bvm.get("iou") | |
| best_val_miou = bvm.get("miou") | |
| best_val_pa = bvm.get("pixel_acc") | |
| else: | |
| best_val_iou = best_val_miou = best_val_pa = None | |
| delta = (best_dice - baseline) if baseline is not None else None | |
| append_csv_row({ | |
| "cfg_id": cfg_id, | |
| "model": args.model, | |
| "image_size": args.image_size, | |
| "batch_size": args.batch_size, | |
| "best_epoch": best_epoch, | |
| "best_val_dice": best_dice, | |
| "best_val_miou": best_val_miou, | |
| "best_val_iou": best_val_iou, | |
| "best_val_pixel_acc": best_val_pa, | |
| "baseline_dice_at_128": baseline, | |
| "delta_vs_128": delta, | |
| "wall_clock_seconds": total, | |
| }) | |
| print(f"\n[done] {cfg_id} best ep {best_epoch} vDice={best_dice:.4f}") | |
| if baseline is not None: | |
| sign = "+" if delta >= 0 else "" | |
| print(f" baseline at 128 = {baseline:.4f} delta = {sign}{delta:+.4f}") | |
| print(f" wall {_fmt(total)}") | |
| if __name__ == "__main__": | |
| main() | |