| | |
| | """ |
| | ISDNet Multi-GPU Training Script for FLAIR Dataset |
| | |
| | Run with: torchrun --nproc_per_node=4 train.py |
| | Monitor with: tail -f training.log |
| | |
| | NOTE: AMP disabled + loop-based batch mm to avoid CUBLAS issues on L40S/CUDA 12.8 |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.distributed as dist |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data.distributed import DistributedSampler |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.optim import SGD |
| | from torch.optim.lr_scheduler import PolynomialLR |
| | from tqdm import tqdm |
| | from datetime import datetime |
| |
|
| | from isdnet import ( |
| | ISDNet, |
| | FLAIRDataset, |
| | DATA_ROOT, |
| | STDC_PRETRAIN_PATH, |
| | BATCH_SIZE_PER_GPU, |
| | NUM_WORKERS, |
| | BASE_LR, |
| | WEIGHT_DECAY, |
| | NUM_EPOCHS, |
| | NUM_CLASSES, |
| | CROP_SIZE, |
| | DOWN_RATIO, |
| | IGNORE_INDEX, |
| | SAVE_INTERVAL, |
| | ) |
| | from isdnet.utils import setup_distributed, cleanup_distributed, print_rank0 |
| |
|
| |
|
| | def train_epoch(model, loader, optimizer, epoch, rank, world_size): |
| | """Train for one epoch.""" |
| | model.train() |
| | total_loss = 0 |
| | criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
| |
|
| | pbar = tqdm(loader, desc=f"Train Epoch {epoch}") if rank == 0 else loader |
| |
|
| | for batch_idx, (imgs, masks) in enumerate(pbar): |
| | imgs, masks = imgs.cuda(non_blocking=True), masks.cuda(non_blocking=True) |
| | optimizer.zero_grad() |
| |
|
| | |
| | o = model(imgs, return_loss=True) |
| | loss = (criterion(o['out'], masks) + |
| | 0.4 * (criterion(o['out_deep'], masks) + |
| | criterion(o['out_aux16'], masks) + |
| | criterion(o['out_aux8'], masks) + |
| | criterion(o['aux_out'], masks)) + |
| | o['losses_re']['recon_losses'] + |
| | o['losses_fa']['fa_loss']) |
| |
|
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() |
| |
|
| | if rank == 0: |
| | pbar.set_postfix({'loss': f'{loss.item():.4f}'}) |
| | if batch_idx > 0 and batch_idx % 500 == 0: |
| | print(f"[{datetime.now()}] Epoch {epoch}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}") |
| |
|
| | |
| | avg_loss = torch.tensor(total_loss / len(loader), device='cuda') |
| | if world_size > 1: |
| | dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) |
| | return avg_loss.item() |
| |
|
| |
|
| | @torch.no_grad() |
| | def validate(model, loader, num_classes, rank, world_size): |
| | """Validate the model.""" |
| | model.eval() |
| | criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
| | total_loss = 0 |
| | intersection = torch.zeros(num_classes, device='cuda') |
| | union = torch.zeros(num_classes, device='cuda') |
| |
|
| | pbar = tqdm(loader, desc="Validation") if rank == 0 else loader |
| |
|
| | for imgs, masks in pbar: |
| | imgs, masks = imgs.cuda(non_blocking=True), masks.cuda(non_blocking=True) |
| | outputs = model(imgs, return_loss=False) |
| | total_loss += criterion(outputs, masks).item() |
| | preds = outputs.argmax(dim=1) |
| |
|
| | |
| | valid_mask = (masks != IGNORE_INDEX) |
| | for cls in range(num_classes): |
| | cls_pred = (preds == cls) & valid_mask |
| | cls_gt = (masks == cls) |
| | intersection[cls] += (cls_pred & cls_gt).sum() |
| | union[cls] += (cls_pred | cls_gt).sum() |
| |
|
| | |
| | if world_size > 1: |
| | dist.all_reduce(intersection, op=dist.ReduceOp.SUM) |
| | dist.all_reduce(union, op=dist.ReduceOp.SUM) |
| | avg_loss = torch.tensor(total_loss / len(loader), device='cuda') |
| | dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) |
| | total_loss = avg_loss.item() |
| | else: |
| | total_loss = total_loss / len(loader) |
| |
|
| | class_iou = intersection / (union + 1e-6) |
| | miou = class_iou.mean().item() |
| | return total_loss, miou, class_iou |
| |
|
| |
|
| | def main(): |
| | rank, world_size, local_rank = setup_distributed() |
| |
|
| | print_rank0(f"[{datetime.now()}] Starting ISDNet FLAIR Multi-GPU Training", rank) |
| | print_rank0(f"PyTorch: {torch.__version__}", rank) |
| | print_rank0(f"World size: {world_size} GPUs", rank) |
| | print_rank0(f"Batch size: {BATCH_SIZE_PER_GPU} per GPU = {BATCH_SIZE_PER_GPU * world_size} total", rank) |
| | print_rank0(f"Crop size: {CROP_SIZE}x{CROP_SIZE}", rank) |
| | print_rank0(f"Classes: 0-14 ({NUM_CLASSES} classes), >=15 mapped to {IGNORE_INDEX} (ignored)", rank) |
| | print_rank0(f"Augmentations: RandomCrop (cat_max_ratio), RandomRotate, RandomFlip, PhotoMetricDistortion", rank) |
| |
|
| | lr = BASE_LR |
| | print_rank0(f"Learning rate: {lr}", rank) |
| |
|
| | |
| | print_rank0(f"\n[{datetime.now()}] Setting up datasets...", rank) |
| | train_ds = FLAIRDataset(DATA_ROOT, 'train', CROP_SIZE) |
| | val_ds = FLAIRDataset(DATA_ROOT, 'valid', CROP_SIZE, augment=False) |
| |
|
| | train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) if world_size > 1 else None |
| | val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False) if world_size > 1 else None |
| |
|
| | train_loader = DataLoader(train_ds, BATCH_SIZE_PER_GPU, shuffle=(train_sampler is None), |
| | sampler=train_sampler, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True) |
| | val_loader = DataLoader(val_ds, BATCH_SIZE_PER_GPU, shuffle=False, |
| | sampler=val_sampler, num_workers=NUM_WORKERS, pin_memory=True) |
| |
|
| | print_rank0(f"Train: {len(train_ds)} samples, {len(train_loader)} batches/GPU", rank) |
| | print_rank0(f"Val: {len(val_ds)} samples, {len(val_loader)} batches/GPU", rank) |
| |
|
| | |
| | print_rank0(f"\n[{datetime.now()}] Building model...", rank) |
| | model = ISDNet(NUM_CLASSES, 'resnet18', 128, DOWN_RATIO, stdc_pretrain=STDC_PRETRAIN_PATH).cuda() |
| |
|
| | if world_size > 1: |
| | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| | model = DDP(model, device_ids=[local_rank]) |
| |
|
| | print_rank0(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}", rank) |
| |
|
| | optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=WEIGHT_DECAY) |
| | scheduler = PolynomialLR(optimizer, NUM_EPOCHS * len(train_loader), power=0.9) |
| |
|
| | print_rank0(f"\n[{datetime.now()}] Starting training for {NUM_EPOCHS} epochs (FP32 mode)...", rank) |
| | print_rank0("=" * 60, rank) |
| |
|
| | best_miou = 0 |
| | for epoch in range(1, NUM_EPOCHS + 1): |
| | if train_sampler is not None: |
| | train_sampler.set_epoch(epoch) |
| |
|
| | print_rank0(f"\n[{datetime.now()}] Epoch {epoch}/{NUM_EPOCHS}", rank) |
| | train_loss = train_epoch(model, train_loader, optimizer, epoch, rank, world_size) |
| | scheduler.step() |
| | print_rank0(f"[{datetime.now()}] Train Loss: {train_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}", rank) |
| |
|
| | val_loss, val_miou, class_iou = validate(model, val_loader, NUM_CLASSES, rank, world_size) |
| | print_rank0(f"[{datetime.now()}] Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}", rank) |
| |
|
| | |
| | if rank == 0: |
| | model_to_save = model.module if hasattr(model, 'module') else model |
| | if val_miou > best_miou: |
| | best_miou = val_miou |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': model_to_save.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'miou': val_miou, |
| | 'class_iou': class_iou.cpu() |
| | }, 'isdnet_flair_best.pth') |
| | print(f"[{datetime.now()}] Saved best model with mIoU: {val_miou:.4f}") |
| |
|
| | if epoch % SAVE_INTERVAL == 0: |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': model_to_save.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'miou': val_miou |
| | }, f'isdnet_flair_epoch{epoch}.pth') |
| |
|
| | if world_size > 1: |
| | dist.barrier() |
| |
|
| | print_rank0(f"\n[{datetime.now()}] Training completed! Best mIoU: {best_miou:.4f}", rank) |
| | cleanup_distributed() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|