| """ |
| Main training script for immunogold CenterNet. |
| |
| Usage: |
| python train.py --fold S1 --seed 42 --config config/config.yaml |
| python train.py --fold S1 --seed 42 --config config/config.yaml --dry-run |
| python train.py --fold S1 --seed 42 --config config/config.yaml --device cuda:0 |
| """ |
|
|
| import argparse |
| import os |
| import random |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from src.dataset import ImmunogoldDataset |
| from src.evaluate import match_detections_to_gt |
| from src.heatmap import extract_peaks |
| from src.loss import total_loss |
| from src.model import ImmunogoldCenterNet |
| from src.preprocessing import discover_synapse_data, load_synapse |
| from src.ensemble import sliding_window_inference |
| from src.postprocess import cross_class_nms |
|
|
|
|
| def set_seed(seed: int): |
| """Set all random seeds for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train immunogold CenterNet") |
| parser.add_argument("--fold", type=str, required=True, |
| help="Synapse ID to hold out (e.g., S1)") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--config", type=str, default="config/config.yaml") |
| parser.add_argument("--device", type=str, default="auto", |
| help="Device: auto, cpu, cuda, cuda:0, etc.") |
| parser.add_argument("--dry-run", action="store_true", |
| help="Load data, build model, run 1 batch, exit") |
| parser.add_argument("--resume", type=str, default=None, |
| help="Path to checkpoint to resume from") |
| return parser.parse_args() |
|
|
|
|
| def get_device(device_str: str) -> torch.device: |
| if device_str == "auto": |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return torch.device(device_str) |
|
|
|
|
| def validate_epoch( |
| model, val_data, device, cfg, conf_threshold=0.3, |
| ): |
| """ |
| Run validation: sliding window inference on held-out image. |
| |
| Returns dict with val_loss, val_f1_6nm, val_f1_12nm, val_f1_mean. |
| """ |
| model.eval() |
| has_6nm = val_data["synapse_id"] not in cfg["data"].get("incomplete_6nm", []) |
|
|
| with torch.no_grad(): |
| heatmap_np, offset_np = sliding_window_inference( |
| model, val_data["image"], |
| patch_size=cfg["data"]["patch_size"], |
| device=device, |
| ) |
|
|
| |
| heatmap_t = torch.from_numpy(heatmap_np) |
| offset_t = torch.from_numpy(offset_np) |
|
|
| detections = extract_peaks( |
| heatmap_t, offset_t, |
| stride=cfg["data"]["stride"], |
| conf_threshold=conf_threshold, |
| nms_kernel_sizes=cfg["postprocessing"]["nms_kernel_size"], |
| ) |
| detections = cross_class_nms( |
| detections, |
| cfg["postprocessing"]["cross_class_nms_distance_px"], |
| ) |
|
|
| |
| gt = val_data["annotations"] |
| results = match_detections_to_gt( |
| detections, |
| gt.get("6nm", np.empty((0, 2))), |
| gt.get("12nm", np.empty((0, 2))), |
| match_radii={k: float(v) for k, v in cfg["evaluation"]["match_radii_px"].items()}, |
| ) |
|
|
| return { |
| "val_f1_6nm": results["6nm"]["f1"] if has_6nm else float("nan"), |
| "val_f1_12nm": results["12nm"]["f1"], |
| "val_f1_mean": results["mean_f1"], |
| "detections": detections, |
| "results": results, |
| } |
|
|
|
|
| def train_phase( |
| model, train_loader, optimizer, scheduler, device, cfg, |
| phase_num, n_epochs, writer, global_epoch, val_data, |
| best_f1, checkpoint_dir, snapshot_epochs, |
| ): |
| """Train one phase, return updated global_epoch and best_f1.""" |
| model.train() |
| focal_alpha = cfg["training"]["loss"]["focal_alpha"] |
| focal_beta = cfg["training"]["loss"]["focal_beta"] |
| lambda_offset = cfg["training"]["loss"]["lambda_offset"] |
| patience = cfg["training"]["early_stopping"]["patience"] |
| no_improve = 0 |
|
|
| for epoch in range(n_epochs): |
| global_epoch += 1 |
| epoch_loss = 0.0 |
| epoch_hm_loss = 0.0 |
| epoch_off_loss = 0.0 |
| n_batches = 0 |
|
|
| model.train() |
| for batch in train_loader: |
| images = batch["image"].to(device) |
| hm_gt = batch["heatmap"].to(device) |
| off_gt = batch["offsets"].to(device) |
| off_mask = batch["offset_mask"].to(device) |
| conf_map = batch["conf_map"].to(device) |
|
|
| optimizer.zero_grad() |
| hm_pred, off_pred = model(images) |
|
|
| loss, hm_loss, off_loss = total_loss( |
| hm_pred, hm_gt, off_pred, off_gt, off_mask, |
| lambda_offset=lambda_offset, |
| focal_alpha=focal_alpha, |
| focal_beta=focal_beta, |
| conf_weights=conf_map, |
| ) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) |
| optimizer.step() |
|
|
| epoch_loss += loss.item() |
| epoch_hm_loss += hm_loss |
| epoch_off_loss += off_loss |
| n_batches += 1 |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| avg_loss = epoch_loss / max(n_batches, 1) |
| avg_hm = epoch_hm_loss / max(n_batches, 1) |
| avg_off = epoch_off_loss / max(n_batches, 1) |
|
|
| |
| writer.add_scalar(f"Phase{phase_num}/train_loss", avg_loss, global_epoch) |
| writer.add_scalar(f"Phase{phase_num}/hm_loss", avg_hm, global_epoch) |
| writer.add_scalar(f"Phase{phase_num}/off_loss", avg_off, global_epoch) |
|
|
| |
| val_metrics = None |
| if global_epoch % 5 == 0 or epoch == n_epochs - 1: |
| val_metrics = validate_epoch(model, val_data, device, cfg) |
| writer.add_scalar(f"Phase{phase_num}/val_f1_mean", val_metrics["val_f1_mean"], global_epoch) |
|
|
| if not np.isnan(val_metrics["val_f1_6nm"]): |
| writer.add_scalar(f"Phase{phase_num}/val_f1_6nm", val_metrics["val_f1_6nm"], global_epoch) |
| writer.add_scalar(f"Phase{phase_num}/val_f1_12nm", val_metrics["val_f1_12nm"], global_epoch) |
|
|
| |
| if val_metrics["val_f1_mean"] > best_f1: |
| best_f1 = val_metrics["val_f1_mean"] |
| no_improve = 0 |
| |
| torch.save({ |
| "epoch": global_epoch, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "val_f1_mean": best_f1, |
| "phase": phase_num, |
| }, checkpoint_dir / f"phase{phase_num}_best.pth") |
| else: |
| no_improve += 5 |
|
|
| |
| if global_epoch in snapshot_epochs: |
| torch.save({ |
| "epoch": global_epoch, |
| "model_state_dict": model.state_dict(), |
| "val_f1_mean": best_f1, |
| "phase": phase_num, |
| }, checkpoint_dir / f"phase{phase_num}_{global_epoch}.pth") |
|
|
| |
| f1_str = f", val_f1={val_metrics['val_f1_mean']:.4f}" if val_metrics else "" |
| print( |
| f" Phase {phase_num} | Epoch {global_epoch:3d} | " |
| f"Loss {avg_loss:.4f} (hm={avg_hm:.4f}, off={avg_off:.4f})" |
| f"{f1_str}" |
| ) |
|
|
| if no_improve >= patience: |
| print(f" Early stopping at epoch {global_epoch} (patience={patience})") |
| break |
|
|
| return global_epoch, best_f1 |
|
|
|
|
| def main(): |
| args = parse_args() |
| with open(args.config) as f: |
| cfg = yaml.safe_load(f) |
|
|
| set_seed(args.seed) |
| device = get_device(args.device) |
| print(f"Device: {device}, Fold: {args.fold}, Seed: {args.seed}") |
|
|
| |
| records = discover_synapse_data( |
| cfg["data"]["root"], cfg["data"]["synapse_ids"] |
| ) |
|
|
| |
| val_record = [r for r in records if r.synapse_id == args.fold] |
| if not val_record: |
| raise ValueError(f"Fold {args.fold} not found in synapse IDs") |
| val_data = load_synapse(val_record[0]) |
|
|
| |
| train_dataset = ImmunogoldDataset( |
| records=records, |
| fold_id=args.fold, |
| mode="train", |
| patch_size=cfg["data"]["patch_size"], |
| stride=cfg["data"]["stride"], |
| hard_mining_fraction=cfg["training"]["hard_mining_fraction"], |
| copy_paste_per_class=cfg["training"]["copy_paste_per_class"], |
| sigmas=cfg["heatmap"]["sigmas"], |
| samples_per_epoch=500, |
| seed=args.seed, |
| ) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=cfg["training"]["batch_size"], |
| shuffle=True, |
| num_workers=4, |
| pin_memory=True, |
| drop_last=True, |
| worker_init_fn=ImmunogoldDataset.worker_init_fn, |
| ) |
|
|
| |
| pretrained = cfg["model"]["pretrained_weights"] |
| if pretrained and not Path(pretrained).exists(): |
| print(f"Warning: CEM500K weights not found at {pretrained}, using ImageNet") |
| pretrained = None |
|
|
| model = ImmunogoldCenterNet( |
| pretrained_path=pretrained, |
| bifpn_channels=cfg["model"]["bifpn_channels"], |
| bifpn_rounds=cfg["model"]["bifpn_rounds"], |
| num_classes=cfg["model"]["num_classes"], |
| ).to(device) |
|
|
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"Model parameters: {param_count:,}") |
|
|
| |
| checkpoint_dir = Path("checkpoints") / f"fold_{args.fold}_seed{args.seed}" |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| writer = SummaryWriter(log_dir=f"logs/fold_{args.fold}_seed{args.seed}") |
|
|
| |
| snapshot_epochs = set(cfg["training"]["n_snapshot_epochs"]) |
|
|
| |
| if args.dry_run: |
| print("=== DRY RUN ===") |
| batch = next(iter(train_loader)) |
| images = batch["image"].to(device) |
| print(f"Input shape: {images.shape}") |
| hm, off = model(images) |
| print(f"Heatmap shape: {hm.shape}, Offset shape: {off.shape}") |
|
|
| loss_val, hm_loss, off_loss = total_loss( |
| hm, batch["heatmap"].to(device), |
| off, batch["offsets"].to(device), |
| batch["offset_mask"].to(device), |
| ) |
| print(f"Loss: {loss_val.item():.4f} (hm={hm_loss:.4f}, off={off_loss:.4f})") |
| print("=== DRY RUN PASSED ===") |
| writer.close() |
| return |
|
|
| |
| print("\n=== Phase 1: Frozen encoder ===") |
| phase1_cfg = cfg["training"]["phases"]["phase1"] |
| model.freeze_encoder() |
|
|
| param_groups = model.get_param_groups(1, phase1_cfg) |
| optimizer = torch.optim.AdamW( |
| param_groups, weight_decay=phase1_cfg["weight_decay"] |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| optimizer, T_0=20, T_mult=2 |
| ) |
|
|
| global_epoch = 0 |
| best_f1 = 0.0 |
|
|
| global_epoch, best_f1 = train_phase( |
| model, train_loader, optimizer, scheduler, device, cfg, |
| phase_num=1, n_epochs=phase1_cfg["epochs"], |
| writer=writer, global_epoch=global_epoch, |
| val_data=val_data, best_f1=best_f1, |
| checkpoint_dir=checkpoint_dir, |
| snapshot_epochs=snapshot_epochs, |
| ) |
|
|
| |
| print("\n=== Phase 2: Unfreeze layer3+layer4 ===") |
| phase2_cfg = cfg["training"]["phases"]["phase2"] |
| model.unfreeze_deep_layers() |
|
|
| param_groups = model.get_param_groups(2, phase2_cfg) |
| optimizer = torch.optim.AdamW( |
| param_groups, weight_decay=phase2_cfg["weight_decay"] |
| ) |
| scheduler = None |
|
|
| global_epoch, best_f1 = train_phase( |
| model, train_loader, optimizer, scheduler, device, cfg, |
| phase_num=2, n_epochs=phase2_cfg["epochs"], |
| writer=writer, global_epoch=global_epoch, |
| val_data=val_data, best_f1=best_f1, |
| checkpoint_dir=checkpoint_dir, |
| snapshot_epochs=snapshot_epochs, |
| ) |
|
|
| |
| print("\n=== Phase 3: Full fine-tuning ===") |
| phase3_cfg = cfg["training"]["phases"]["phase3"] |
| model.unfreeze_all() |
|
|
| param_groups = model.get_param_groups(3, phase3_cfg) |
| optimizer = torch.optim.AdamW( |
| param_groups, weight_decay=phase3_cfg["weight_decay"] |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=phase3_cfg["epochs"], |
| eta_min=phase3_cfg["eta_min"], |
| ) |
|
|
| global_epoch, best_f1 = train_phase( |
| model, train_loader, optimizer, scheduler, device, cfg, |
| phase_num=3, n_epochs=phase3_cfg["epochs"], |
| writer=writer, global_epoch=global_epoch, |
| val_data=val_data, best_f1=best_f1, |
| checkpoint_dir=checkpoint_dir, |
| snapshot_epochs=snapshot_epochs, |
| ) |
|
|
| print(f"\nTraining complete. Best val F1: {best_f1:.4f}") |
| print(f"Checkpoints saved to: {checkpoint_dir}") |
| writer.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|