#!/usr/bin/env python3 """ ISDNet Multi-GPU Training Script for FLAIR Dataset Run with: torchrun --nproc_per_node=4 train.py Monitor with: tail -f training.log NOTE: AMP disabled + loop-based batch mm to avoid CUBLAS issues on L40S/CUDA 12.8 """ import torch import torch.nn as nn import torch.distributed as dist from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD from torch.optim.lr_scheduler import PolynomialLR from tqdm import tqdm from datetime import datetime from isdnet import ( ISDNet, FLAIRDataset, DATA_ROOT, STDC_PRETRAIN_PATH, BATCH_SIZE_PER_GPU, NUM_WORKERS, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS, NUM_CLASSES, CROP_SIZE, DOWN_RATIO, IGNORE_INDEX, SAVE_INTERVAL, ) from isdnet.utils import setup_distributed, cleanup_distributed, print_rank0 def train_epoch(model, loader, optimizer, epoch, rank, world_size): """Train for one epoch.""" model.train() total_loss = 0 criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) pbar = tqdm(loader, desc=f"Train Epoch {epoch}") if rank == 0 else loader for batch_idx, (imgs, masks) in enumerate(pbar): imgs, masks = imgs.cuda(non_blocking=True), masks.cuda(non_blocking=True) optimizer.zero_grad() # No AMP - pure FP32 training o = model(imgs, return_loss=True) loss = (criterion(o['out'], masks) + 0.4 * (criterion(o['out_deep'], masks) + criterion(o['out_aux16'], masks) + criterion(o['out_aux8'], masks) + criterion(o['aux_out'], masks)) + o['losses_re']['recon_losses'] + o['losses_fa']['fa_loss']) loss.backward() optimizer.step() total_loss += loss.item() if rank == 0: pbar.set_postfix({'loss': f'{loss.item():.4f}'}) if batch_idx > 0 and batch_idx % 500 == 0: print(f"[{datetime.now()}] Epoch {epoch}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}") # Reduce loss across all processes avg_loss = torch.tensor(total_loss / len(loader), device='cuda') if world_size > 1: dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) return avg_loss.item() @torch.no_grad() def validate(model, loader, num_classes, rank, world_size): """Validate the model.""" model.eval() criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) total_loss = 0 intersection = torch.zeros(num_classes, device='cuda') union = torch.zeros(num_classes, device='cuda') pbar = tqdm(loader, desc="Validation") if rank == 0 else loader for imgs, masks in pbar: imgs, masks = imgs.cuda(non_blocking=True), masks.cuda(non_blocking=True) outputs = model(imgs, return_loss=False) total_loss += criterion(outputs, masks).item() preds = outputs.argmax(dim=1) # Only compute IoU for valid pixels valid_mask = (masks != IGNORE_INDEX) for cls in range(num_classes): cls_pred = (preds == cls) & valid_mask cls_gt = (masks == cls) intersection[cls] += (cls_pred & cls_gt).sum() union[cls] += (cls_pred | cls_gt).sum() # Reduce across all processes if world_size > 1: dist.all_reduce(intersection, op=dist.ReduceOp.SUM) dist.all_reduce(union, op=dist.ReduceOp.SUM) avg_loss = torch.tensor(total_loss / len(loader), device='cuda') dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) total_loss = avg_loss.item() else: total_loss = total_loss / len(loader) class_iou = intersection / (union + 1e-6) miou = class_iou.mean().item() return total_loss, miou, class_iou def main(): rank, world_size, local_rank = setup_distributed() print_rank0(f"[{datetime.now()}] Starting ISDNet FLAIR Multi-GPU Training", rank) print_rank0(f"PyTorch: {torch.__version__}", rank) print_rank0(f"World size: {world_size} GPUs", rank) print_rank0(f"Batch size: {BATCH_SIZE_PER_GPU} per GPU = {BATCH_SIZE_PER_GPU * world_size} total", rank) print_rank0(f"Crop size: {CROP_SIZE}x{CROP_SIZE}", rank) print_rank0(f"Classes: 0-14 ({NUM_CLASSES} classes), >=15 mapped to {IGNORE_INDEX} (ignored)", rank) print_rank0(f"Augmentations: RandomCrop (cat_max_ratio), RandomRotate, RandomFlip, PhotoMetricDistortion", rank) lr = BASE_LR print_rank0(f"Learning rate: {lr}", rank) # Datasets print_rank0(f"\n[{datetime.now()}] Setting up datasets...", rank) train_ds = FLAIRDataset(DATA_ROOT, 'train', CROP_SIZE) val_ds = FLAIRDataset(DATA_ROOT, 'valid', CROP_SIZE, augment=False) train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) if world_size > 1 else None val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False) if world_size > 1 else None train_loader = DataLoader(train_ds, BATCH_SIZE_PER_GPU, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True) val_loader = DataLoader(val_ds, BATCH_SIZE_PER_GPU, shuffle=False, sampler=val_sampler, num_workers=NUM_WORKERS, pin_memory=True) print_rank0(f"Train: {len(train_ds)} samples, {len(train_loader)} batches/GPU", rank) print_rank0(f"Val: {len(val_ds)} samples, {len(val_loader)} batches/GPU", rank) # Model print_rank0(f"\n[{datetime.now()}] Building model...", rank) model = ISDNet(NUM_CLASSES, 'resnet18', 128, DOWN_RATIO, stdc_pretrain=STDC_PRETRAIN_PATH).cuda() if world_size > 1: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[local_rank]) print_rank0(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}", rank) optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=WEIGHT_DECAY) scheduler = PolynomialLR(optimizer, NUM_EPOCHS * len(train_loader), power=0.9) print_rank0(f"\n[{datetime.now()}] Starting training for {NUM_EPOCHS} epochs (FP32 mode)...", rank) print_rank0("=" * 60, rank) best_miou = 0 for epoch in range(1, NUM_EPOCHS + 1): if train_sampler is not None: train_sampler.set_epoch(epoch) print_rank0(f"\n[{datetime.now()}] Epoch {epoch}/{NUM_EPOCHS}", rank) train_loss = train_epoch(model, train_loader, optimizer, epoch, rank, world_size) scheduler.step() print_rank0(f"[{datetime.now()}] Train Loss: {train_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}", rank) val_loss, val_miou, class_iou = validate(model, val_loader, NUM_CLASSES, rank, world_size) print_rank0(f"[{datetime.now()}] Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}", rank) # Save only from rank 0 if rank == 0: model_to_save = model.module if hasattr(model, 'module') else model if val_miou > best_miou: best_miou = val_miou torch.save({ 'epoch': epoch, 'model_state_dict': model_to_save.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'miou': val_miou, 'class_iou': class_iou.cpu() }, 'isdnet_flair_best.pth') print(f"[{datetime.now()}] Saved best model with mIoU: {val_miou:.4f}") if epoch % SAVE_INTERVAL == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model_to_save.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'miou': val_miou }, f'isdnet_flair_epoch{epoch}.pth') if world_size > 1: dist.barrier() print_rank0(f"\n[{datetime.now()}] Training completed! Best mIoU: {best_miou:.4f}", rank) cleanup_distributed() if __name__ == "__main__": main()