import argparse from pathlib import Path from tqdm import tqdm import torch from torch.utils.data import DataLoader from augmentations import get_train_transforms, get_val_transforms from datasets.FIVES import FIVESDataset from models import build_model from losses import BCEDiceLoss, compute_dice_score def train_one_epoch(model, loader, optimizer, scaler, criterion, device, use_amp=True): model.train() running_loss = 0.0 running_dice = 0.0 pbar = tqdm(loader, desc="Train", leave=False) for batch in pbar: images = batch["image"].to(device) labels = batch["label"].to(device) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): logits = model(images) loss = criterion(logits, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() dice = compute_dice_score(logits.detach(), labels) running_loss += loss.item() running_dice += dice avg_loss = running_loss / (pbar.n + 1) avg_dice = running_dice / (pbar.n + 1) pbar.set_postfix( loss=f"{avg_loss:.4f}", dice=f"{avg_dice:.4f}", ) return running_loss / len(loader), running_dice / len(loader) @torch.no_grad() def validate(model, loader, criterion, device, use_amp=True): model.eval() running_loss = 0.0 running_dice = 0.0 pbar = tqdm(loader, desc="Val", leave=False) for batch in pbar: images = batch["image"].to(device) labels = batch["label"].to(device) with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): logits = model(images) loss = criterion(logits, labels) dice = compute_dice_score(logits, labels) running_loss += loss.item() running_dice += dice avg_loss = running_loss / (pbar.n + 1) avg_dice = running_dice / (pbar.n + 1) pbar.set_postfix( loss=f"{avg_loss:.4f}", dice=f"{avg_dice:.4f}", ) return running_loss / len(loader), running_dice / len(loader) def save_checkpoint(path, model, optimizer, epoch, best_dice, args): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_dice": best_dice, "args": vars(args), }, path, ) def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_dataset = FIVESDataset( root=args.data_root, split="train", transform=get_train_transforms(image_size=args.image_size), ) val_dataset = FIVESDataset( root=args.data_root, split="test", transform=get_val_transforms(image_size=args.image_size), ) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) model = build_model( model_name=args.model, num_classes=1, in_channels=3, image_size=args.image_size, backbone=args.backbone, pretrained=not args.no_pretrained, base_channels=args.base_channels, dropout=args.dropout, ).to(device) criterion = BCEDiceLoss( bce_weight=args.bce_weight, dice_weight=args.dice_weight, ) optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, ) scaler = torch.amp.GradScaler(enabled=args.amp and device.type == "cuda") best_dice = -1.0 print(f"Device: {device}") print(f"Train samples: {len(train_dataset)}") print(f"Val samples: {len(val_dataset)}") print(f"Image size: {args.image_size}") print(f"Batch size: {args.batch_size}") print(f"Pretrained: {not args.no_pretrained}") for epoch in range(1, args.epochs + 1): print(f"\nEpoch [{epoch:03d}/{args.epochs}]") train_loss, train_dice = train_one_epoch( model=model, loader=train_loader, optimizer=optimizer, scaler=scaler, criterion=criterion, device=device, use_amp=args.amp, ) val_loss, val_dice = validate( model=model, loader=val_loader, criterion=criterion, device=device, use_amp=args.amp, ) print( f"train_loss={train_loss:.4f} " f"train_dice={train_dice:.4f} " f"val_loss={val_loss:.4f} " f"val_dice={val_dice:.4f}" ) if val_dice > best_dice: best_dice = val_dice save_checkpoint( Path(args.output_dir) / "best.pt", model, optimizer, epoch, best_dice, args, ) print(f"Saved best checkpoint: val_dice={best_dice:.4f}") if epoch % args.save_every == 0: save_checkpoint( Path(args.output_dir) / f"epoch_{epoch:03d}.pt", model, optimizer, epoch, best_dice, args, ) save_checkpoint( Path(args.output_dir) / "last.pt", model, optimizer, args.epochs, best_dice, args, ) print("Training complete.") print(f"Best val Dice: {best_dice:.4f}") def parse_args(): parser = argparse.ArgumentParser(description="Train retinal vessel segmentation model on FIVES.") parser.add_argument("--data-root", type=str, required=True) parser.add_argument("--output-dir", type=str, default="checkpoints/fives") parser.add_argument("--image-size", type=int, default=512) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"]) parser.add_argument("--backbone", type=str, default="resnet50") parser.add_argument("--base-channels", type=int, default=32) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--no-pretrained", action="store_true") parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=1e-4) parser.add_argument("--bce-weight", type=float, default=1.0) parser.add_argument("--dice-weight", type=float, default=1.0) parser.add_argument("--save-every", type=int, default=25) parser.add_argument("--amp", action="store_true") return parser.parse_args() if __name__ == "__main__": args = parse_args() main(args)