Mohamed-ENNHIRI
Add Tab 7: resolution study (segformer_b0 + U-Net at 192/256/512)
a3200e4
Raw
History Blame Contribute Delete
10.3 kB
"""
Resolution-study trainer.
Trains one (model, image_size) cell at 100% data on final_data_clean/.
Usage:
python train.py --model segformer_b0 --image-size 192
python train.py --model segformer_b0 --image-size 256
Each run produces:
checkpoints/{model}_res{image_size}_best.pth (state at highest val Dice)
logs/{model}_res{image_size}.json (per-epoch metrics + timing)
one row appended to results/resolution_results.csv
Hyperparameters held identical to the clean baseline at 128:
Adam, lr=1e-4, ReduceLROnPlateau(mode='max', patience=5, factor=0.5),
50 epochs, CombinedLoss(0.5*BCE + 0.5*Dice), HFlip+VFlip+Rot15.
Only image_size and (optionally) batch_size vary.
"""
import argparse
import csv
import json
import time
from pathlib import Path
from threading import Lock
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import SolarPanelDataset
from metrics import SegMetrics
from models import MODEL_REGISTRY
THIS_DIR = Path(__file__).resolve().parent
REPO_ROOT = THIS_DIR.parents[1]
CLEAN = REPO_ROOT / "final_data_clean"
TRAIN_IMG = CLEAN / "train" / "images"
TRAIN_MSK = CLEAN / "train" / "masks"
VAL_IMG = CLEAN / "val" / "images"
VAL_MSK = CLEAN / "val" / "masks"
LOG_DIR = THIS_DIR / "logs"
CKPT_DIR = THIS_DIR / "checkpoints"
RESULTS_DIR = THIS_DIR / "results"
RESULTS_CSV = RESULTS_DIR / "resolution_results.csv"
# Reference val-Dice values at 128 (from clean_data_scaling_study at 100% data).
BASELINE_128 = {
"segnet": 0.9291,
"unet": 0.9370,
"segformer_b0": 0.9280,
"segformer_b5": 0.9371,
}
CSV_FIELDS = [
"cfg_id", "model", "image_size", "batch_size",
"best_epoch", "best_val_dice", "best_val_miou", "best_val_iou", "best_val_pixel_acc",
"baseline_dice_at_128", "delta_vs_128",
"wall_clock_seconds",
]
_csv_lock = Lock()
def _fmt(seconds: float) -> str:
seconds = int(round(seconds))
h, rem = divmod(seconds, 3600)
m, s = divmod(rem, 60)
return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}"
def append_csv_row(row: dict):
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
with _csv_lock:
write_header = not RESULTS_CSV.is_file()
with open(RESULTS_CSV, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=CSV_FIELDS)
if write_header:
writer.writeheader()
writer.writerow({k: row.get(k) for k in CSV_FIELDS})
def run_epoch(model, loader, criterion, optimizer, device, train: bool,
output_is_prob: bool, throttle: float | None = None):
model.train(mode=train)
metrics = SegMetrics()
total_loss = 0.0
n_batches = 0
do_throttle = train and throttle is not None and 0 < throttle < 1.0
sleep_factor = (1.0 - throttle) / throttle if do_throttle else 0.0
desc = "Train" if train else "Val"
ctx = torch.enable_grad() if train else torch.no_grad()
with ctx:
for images, masks in tqdm(loader, desc=desc, leave=False):
batch_t0 = time.perf_counter() if do_throttle else None
images = images.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True)
if train:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
if train:
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
metrics.update(outputs.detach(), masks, output_is_prob=output_is_prob)
if do_throttle:
if torch.cuda.is_available():
torch.cuda.synchronize()
step_dt = time.perf_counter() - batch_t0
time.sleep(step_dt * sleep_factor)
return total_loss / max(n_batches, 1), metrics.compute()
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, choices=list(MODEL_REGISTRY.keys()))
p.add_argument("--image-size", required=True, type=int,
help="square input resolution (e.g. 192, 256)")
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--epochs", type=int, default=50)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--num-workers", type=int, default=4)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--throttle", type=float, default=None,
help="optional duty-cycle cap, e.g. 0.4 for 40%% average util")
return p.parse_args()
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg_id = f"{args.model}_res{args.image_size}"
print(f"[run] {cfg_id} image_size={args.image_size} batch_size={args.batch_size} device={device}")
LOG_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)
train_set = SolarPanelDataset(TRAIN_IMG, TRAIN_MSK, image_size=args.image_size, augment=True)
val_set = SolarPanelDataset(VAL_IMG, VAL_MSK, image_size=args.image_size, augment=False)
print(f"[data] train={len(train_set)} val={len(val_set)}")
train_loader = DataLoader(
train_set, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True,
)
val_loader = DataLoader(
val_set, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True,
)
builder = MODEL_REGISTRY[args.model]
model, criterion, output_is_prob = builder()
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"[model] {args.model} params={n_params:,} "
f"output={'prob' if output_is_prob else 'logits'}")
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=5, factor=0.5)
baseline = BASELINE_128.get(args.model)
history = {
"cfg_id": cfg_id,
"model": args.model,
"image_size": args.image_size,
"batch_size": args.batch_size,
"throttle": args.throttle,
"n_train": len(train_set),
"n_val": len(val_set),
"n_params": n_params,
"baseline_dice_at_128": baseline,
"epochs": [],
}
best_dice = -1.0
best_epoch = -1
best_path = CKPT_DIR / f"{cfg_id}_best.pth"
log_path = LOG_DIR / f"{cfg_id}.json"
val_m = {"dice": 0.0, "iou": 0.0, "miou": 0.0, "pixel_acc": 0.0}
if args.throttle is not None and 0 < args.throttle < 1:
print(f"[throttle] training duty cycle capped at {args.throttle*100:.0f}%")
t0 = time.time()
history["start_time_iso"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(t0))
for epoch in range(args.epochs):
epoch_t0 = time.time()
train_loss, train_m = run_epoch(
model, train_loader, criterion, optimizer, device,
train=True, output_is_prob=output_is_prob, throttle=args.throttle,
)
val_loss, val_m = run_epoch(
model, val_loader, criterion, optimizer, device,
train=False, output_is_prob=output_is_prob, throttle=None,
)
scheduler.step(val_m["dice"])
epoch_seconds = time.time() - epoch_t0
elapsed = time.time() - t0
eta = (elapsed / (epoch + 1)) * (args.epochs - epoch - 1)
improved = val_m["dice"] > best_dice
if improved:
best_dice = val_m["dice"]
best_epoch = epoch + 1
torch.save({
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"val_metrics": val_m,
"cfg": {
"cfg_id": cfg_id,
"model": args.model,
"image_size": args.image_size,
},
"output_is_prob": output_is_prob,
}, best_path)
history["epochs"].append({
"epoch": epoch + 1,
"lr": optimizer.param_groups[0]["lr"],
"train_loss": train_loss,
"val_loss": val_loss,
**{f"train_{k}": v for k, v in train_m.items()},
**{f"val_{k}": v for k, v in val_m.items()},
"epoch_seconds": epoch_seconds,
})
marker = "★" if improved else " "
print(
f" ep {epoch+1:>2}/{args.epochs} "
f"trL={train_loss:.4f} vL={val_loss:.4f} "
f"vDice={val_m['dice']:.4f} {marker} vIoU={val_m['iou']:.4f} "
f"({_fmt(elapsed)}/ETA {_fmt(eta)})"
)
with open(log_path, "w") as f:
json.dump(history, f, indent=2)
total = time.time() - t0
history["best_epoch"] = best_epoch
history["best_val_dice"] = best_dice
history["epochs_trained"] = args.epochs
history["wall_clock_seconds"] = total
with open(log_path, "w") as f:
json.dump(history, f, indent=2)
if best_path.is_file():
st = torch.load(best_path, map_location="cpu", weights_only=False)
bvm = st.get("val_metrics", {})
best_val_iou = bvm.get("iou")
best_val_miou = bvm.get("miou")
best_val_pa = bvm.get("pixel_acc")
else:
best_val_iou = best_val_miou = best_val_pa = None
delta = (best_dice - baseline) if baseline is not None else None
append_csv_row({
"cfg_id": cfg_id,
"model": args.model,
"image_size": args.image_size,
"batch_size": args.batch_size,
"best_epoch": best_epoch,
"best_val_dice": best_dice,
"best_val_miou": best_val_miou,
"best_val_iou": best_val_iou,
"best_val_pixel_acc": best_val_pa,
"baseline_dice_at_128": baseline,
"delta_vs_128": delta,
"wall_clock_seconds": total,
})
print(f"\n[done] {cfg_id} best ep {best_epoch} vDice={best_dice:.4f}")
if baseline is not None:
sign = "+" if delta >= 0 else ""
print(f" baseline at 128 = {baseline:.4f} delta = {sign}{delta:+.4f}")
print(f" wall {_fmt(total)}")
if __name__ == "__main__":
main()