""" Unified training script for deterministic segmentation baselines. Uses official libraries: smp (U-Net, U-Net++), MONAI (Attention U-Net), smp MAnet (TransUNet-style). """ import argparse import os import sys import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split sys.path.append(os.path.dirname(os.path.abspath(__file__))) from dataset import LIDCFlatDataset from models import get_model class DiceBCELoss(nn.Module): """Combined Dice + BCE loss for binary segmentation.""" def __init__(self, dice_weight=0.5, bce_weight=0.5): super().__init__() self.dice_weight = dice_weight self.bce_weight = bce_weight self.bce = nn.BCEWithLogitsLoss() def forward(self, logits, targets): # BCE loss bce_loss = self.bce(logits, targets) # Dice loss probs = torch.sigmoid(logits) smooth = 1e-5 intersection = (probs * targets).sum(dim=(2, 3)) dice = (2. * intersection + smooth) / (probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + smooth) dice_loss = 1 - dice.mean() return self.dice_weight * dice_loss + self.bce_weight * bce_loss def compute_dice(pred, target): """Compute Dice coefficient for evaluation.""" smooth = 1e-5 pred = (torch.sigmoid(pred) > 0.5).float() intersection = (pred * target).sum() return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth) def train_one_epoch(model, loader, optimizer, criterion, device): model.train() total_loss = 0 total_dice = 0 for images, masks, _ in loader: images = images.to(device) masks = masks.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, masks) loss.backward() optimizer.step() total_loss += loss.item() total_dice += compute_dice(outputs, masks).item() return total_loss / len(loader), total_dice / len(loader) @torch.no_grad() def validate(model, loader, criterion, device): model.eval() total_loss = 0 total_dice = 0 for images, masks, _ in loader: images = images.to(device) masks = masks.to(device) outputs = model(images) loss = criterion(outputs, masks) total_loss += loss.item() total_dice += compute_dice(outputs, masks).item() return total_loss / len(loader), total_dice / len(loader) def main(): parser = argparse.ArgumentParser(description="Train deterministic segmentation baseline") parser.add_argument("--model", type=str, required=True, choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"], help="Model architecture") parser.add_argument("--data_dir", type=str, default="data/flat_train", help="Path to flat training data") parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") parser.add_argument("--val_split", type=float, default=0.1, help="Validation split ratio") parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Checkpoint directory") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers") parser.add_argument("--patience", type=int, default=20, help="Early stopping patience") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") os.makedirs(args.checkpoint_dir, exist_ok=True) # Create dataset full_dataset = LIDCFlatDataset(args.data_dir, augment=True) # Split into train/val val_size = int(len(full_dataset) * args.val_split) train_size = len(full_dataset) - val_size generator = torch.Generator().manual_seed(42) train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator) # Disable augmentation for validation val_dataset_no_aug = LIDCFlatDataset(args.data_dir, augment=False) val_indices = val_dataset.indices val_dataset_final = torch.utils.data.Subset(val_dataset_no_aug, val_indices) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_loader = DataLoader(val_dataset_final, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) print(f"Train: {train_size}, Val: {val_size}") # Create model model = get_model(args.model, in_channels=1, num_classes=1) model = model.to(device) # Count parameters n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model: {args.model}, Parameters: {n_params:,}") # Loss, optimizer, scheduler criterion = DiceBCELoss() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) # Training loop best_val_dice = 0 patience_counter = 0 checkpoint_path = os.path.join(args.checkpoint_dir, f"{args.model}_best.pth") print(f"\nTraining {args.model} for {args.epochs} epochs...") print("-" * 70) start_time = time.time() for epoch in range(1, args.epochs + 1): train_loss, train_dice = train_one_epoch(model, train_loader, optimizer, criterion, device) val_loss, val_dice = validate(model, val_loader, criterion, device) scheduler.step() # Save best model if val_dice > best_val_dice: best_val_dice = val_dice patience_counter = 0 torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "val_dice": val_dice, "model_name": args.model, }, checkpoint_path) else: patience_counter += 1 # Print progress every 5 epochs or at start/end if epoch % 5 == 0 or epoch == 1 or epoch == args.epochs: elapsed = time.time() - start_time lr = optimizer.param_groups[0]['lr'] print(f"Epoch {epoch:3d}/{args.epochs} | " f"Train Loss: {train_loss:.4f} Dice: {train_dice:.4f} | " f"Val Loss: {val_loss:.4f} Dice: {val_dice:.4f} | " f"Best: {best_val_dice:.4f} | " f"LR: {lr:.2e} | " f"Time: {elapsed:.0f}s") # Early stopping if patience_counter >= args.patience: print(f"\nEarly stopping at epoch {epoch} (patience={args.patience})") break total_time = time.time() - start_time print("-" * 70) print(f"Training complete! Best val Dice: {best_val_dice:.4f}") print(f"Total time: {total_time:.0f}s ({total_time/60:.1f}min)") print(f"Checkpoint saved: {checkpoint_path}") if __name__ == "__main__": main()