ISDNet-pytorch / train.py
Antoine1091's picture
Upload folder using huggingface_hub
49d2955 verified
#!/usr/bin/env python3
"""
ISDNet Multi-GPU Training Script for FLAIR Dataset
Run with: torchrun --nproc_per_node=4 train.py
Monitor with: tail -f training.log
NOTE: AMP disabled + loop-based batch mm to avoid CUBLAS issues on L40S/CUDA 12.8
"""
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from torch.optim.lr_scheduler import PolynomialLR
from tqdm import tqdm
from datetime import datetime
from isdnet import (
ISDNet,
FLAIRDataset,
DATA_ROOT,
STDC_PRETRAIN_PATH,
BATCH_SIZE_PER_GPU,
NUM_WORKERS,
BASE_LR,
WEIGHT_DECAY,
NUM_EPOCHS,
NUM_CLASSES,
CROP_SIZE,
DOWN_RATIO,
IGNORE_INDEX,
SAVE_INTERVAL,
)
from isdnet.utils import setup_distributed, cleanup_distributed, print_rank0
def train_epoch(model, loader, optimizer, epoch, rank, world_size):
"""Train for one epoch."""
model.train()
total_loss = 0
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
pbar = tqdm(loader, desc=f"Train Epoch {epoch}") if rank == 0 else loader
for batch_idx, (imgs, masks) in enumerate(pbar):
imgs, masks = imgs.cuda(non_blocking=True), masks.cuda(non_blocking=True)
optimizer.zero_grad()
# No AMP - pure FP32 training
o = model(imgs, return_loss=True)
loss = (criterion(o['out'], masks) +
0.4 * (criterion(o['out_deep'], masks) +
criterion(o['out_aux16'], masks) +
criterion(o['out_aux8'], masks) +
criterion(o['aux_out'], masks)) +
o['losses_re']['recon_losses'] +
o['losses_fa']['fa_loss'])
loss.backward()
optimizer.step()
total_loss += loss.item()
if rank == 0:
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
if batch_idx > 0 and batch_idx % 500 == 0:
print(f"[{datetime.now()}] Epoch {epoch}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")
# Reduce loss across all processes
avg_loss = torch.tensor(total_loss / len(loader), device='cuda')
if world_size > 1:
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
return avg_loss.item()
@torch.no_grad()
def validate(model, loader, num_classes, rank, world_size):
"""Validate the model."""
model.eval()
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
total_loss = 0
intersection = torch.zeros(num_classes, device='cuda')
union = torch.zeros(num_classes, device='cuda')
pbar = tqdm(loader, desc="Validation") if rank == 0 else loader
for imgs, masks in pbar:
imgs, masks = imgs.cuda(non_blocking=True), masks.cuda(non_blocking=True)
outputs = model(imgs, return_loss=False)
total_loss += criterion(outputs, masks).item()
preds = outputs.argmax(dim=1)
# Only compute IoU for valid pixels
valid_mask = (masks != IGNORE_INDEX)
for cls in range(num_classes):
cls_pred = (preds == cls) & valid_mask
cls_gt = (masks == cls)
intersection[cls] += (cls_pred & cls_gt).sum()
union[cls] += (cls_pred | cls_gt).sum()
# Reduce across all processes
if world_size > 1:
dist.all_reduce(intersection, op=dist.ReduceOp.SUM)
dist.all_reduce(union, op=dist.ReduceOp.SUM)
avg_loss = torch.tensor(total_loss / len(loader), device='cuda')
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss = avg_loss.item()
else:
total_loss = total_loss / len(loader)
class_iou = intersection / (union + 1e-6)
miou = class_iou.mean().item()
return total_loss, miou, class_iou
def main():
rank, world_size, local_rank = setup_distributed()
print_rank0(f"[{datetime.now()}] Starting ISDNet FLAIR Multi-GPU Training", rank)
print_rank0(f"PyTorch: {torch.__version__}", rank)
print_rank0(f"World size: {world_size} GPUs", rank)
print_rank0(f"Batch size: {BATCH_SIZE_PER_GPU} per GPU = {BATCH_SIZE_PER_GPU * world_size} total", rank)
print_rank0(f"Crop size: {CROP_SIZE}x{CROP_SIZE}", rank)
print_rank0(f"Classes: 0-14 ({NUM_CLASSES} classes), >=15 mapped to {IGNORE_INDEX} (ignored)", rank)
print_rank0(f"Augmentations: RandomCrop (cat_max_ratio), RandomRotate, RandomFlip, PhotoMetricDistortion", rank)
lr = BASE_LR
print_rank0(f"Learning rate: {lr}", rank)
# Datasets
print_rank0(f"\n[{datetime.now()}] Setting up datasets...", rank)
train_ds = FLAIRDataset(DATA_ROOT, 'train', CROP_SIZE)
val_ds = FLAIRDataset(DATA_ROOT, 'valid', CROP_SIZE, augment=False)
train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) if world_size > 1 else None
val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False) if world_size > 1 else None
train_loader = DataLoader(train_ds, BATCH_SIZE_PER_GPU, shuffle=(train_sampler is None),
sampler=train_sampler, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, BATCH_SIZE_PER_GPU, shuffle=False,
sampler=val_sampler, num_workers=NUM_WORKERS, pin_memory=True)
print_rank0(f"Train: {len(train_ds)} samples, {len(train_loader)} batches/GPU", rank)
print_rank0(f"Val: {len(val_ds)} samples, {len(val_loader)} batches/GPU", rank)
# Model
print_rank0(f"\n[{datetime.now()}] Building model...", rank)
model = ISDNet(NUM_CLASSES, 'resnet18', 128, DOWN_RATIO, stdc_pretrain=STDC_PRETRAIN_PATH).cuda()
if world_size > 1:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[local_rank])
print_rank0(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}", rank)
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=WEIGHT_DECAY)
scheduler = PolynomialLR(optimizer, NUM_EPOCHS * len(train_loader), power=0.9)
print_rank0(f"\n[{datetime.now()}] Starting training for {NUM_EPOCHS} epochs (FP32 mode)...", rank)
print_rank0("=" * 60, rank)
best_miou = 0
for epoch in range(1, NUM_EPOCHS + 1):
if train_sampler is not None:
train_sampler.set_epoch(epoch)
print_rank0(f"\n[{datetime.now()}] Epoch {epoch}/{NUM_EPOCHS}", rank)
train_loss = train_epoch(model, train_loader, optimizer, epoch, rank, world_size)
scheduler.step()
print_rank0(f"[{datetime.now()}] Train Loss: {train_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}", rank)
val_loss, val_miou, class_iou = validate(model, val_loader, NUM_CLASSES, rank, world_size)
print_rank0(f"[{datetime.now()}] Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}", rank)
# Save only from rank 0
if rank == 0:
model_to_save = model.module if hasattr(model, 'module') else model
if val_miou > best_miou:
best_miou = val_miou
torch.save({
'epoch': epoch,
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'miou': val_miou,
'class_iou': class_iou.cpu()
}, 'isdnet_flair_best.pth')
print(f"[{datetime.now()}] Saved best model with mIoU: {val_miou:.4f}")
if epoch % SAVE_INTERVAL == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'miou': val_miou
}, f'isdnet_flair_epoch{epoch}.pth')
if world_size > 1:
dist.barrier()
print_rank0(f"\n[{datetime.now()}] Training completed! Best mIoU: {best_miou:.4f}", rank)
cleanup_distributed()
if __name__ == "__main__":
main()