Spaces:
Running
Running
| import argparse | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from augmentations import get_train_transforms, get_val_transforms | |
| from datasets.FIVES import FIVESDataset | |
| from models import build_model | |
| from losses import BCEDiceLoss, compute_dice_score | |
| def train_one_epoch(model, loader, optimizer, scaler, criterion, device, use_amp=True): | |
| model.train() | |
| running_loss = 0.0 | |
| running_dice = 0.0 | |
| pbar = tqdm(loader, desc="Train", leave=False) | |
| for batch in pbar: | |
| images = batch["image"].to(device) | |
| labels = batch["label"].to(device) | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| dice = compute_dice_score(logits.detach(), labels) | |
| running_loss += loss.item() | |
| running_dice += dice | |
| avg_loss = running_loss / (pbar.n + 1) | |
| avg_dice = running_dice / (pbar.n + 1) | |
| pbar.set_postfix( | |
| loss=f"{avg_loss:.4f}", | |
| dice=f"{avg_dice:.4f}", | |
| ) | |
| return running_loss / len(loader), running_dice / len(loader) | |
| def validate(model, loader, criterion, device, use_amp=True): | |
| model.eval() | |
| running_loss = 0.0 | |
| running_dice = 0.0 | |
| pbar = tqdm(loader, desc="Val", leave=False) | |
| for batch in pbar: | |
| images = batch["image"].to(device) | |
| labels = batch["label"].to(device) | |
| with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| dice = compute_dice_score(logits, labels) | |
| running_loss += loss.item() | |
| running_dice += dice | |
| avg_loss = running_loss / (pbar.n + 1) | |
| avg_dice = running_dice / (pbar.n + 1) | |
| pbar.set_postfix( | |
| loss=f"{avg_loss:.4f}", | |
| dice=f"{avg_dice:.4f}", | |
| ) | |
| return running_loss / len(loader), running_dice / len(loader) | |
| def save_checkpoint(path, model, optimizer, epoch, best_dice, args): | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| { | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "best_dice": best_dice, | |
| "args": vars(args), | |
| }, | |
| path, | |
| ) | |
| def main(args): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_dataset = FIVESDataset( | |
| root=args.data_root, | |
| split="train", | |
| transform=get_train_transforms(image_size=args.image_size), | |
| ) | |
| val_dataset = FIVESDataset( | |
| root=args.data_root, | |
| split="test", | |
| transform=get_val_transforms(image_size=args.image_size), | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| ) | |
| model = build_model( | |
| model_name=args.model, | |
| num_classes=1, | |
| in_channels=3, | |
| image_size=args.image_size, | |
| backbone=args.backbone, | |
| pretrained=not args.no_pretrained, | |
| base_channels=args.base_channels, | |
| dropout=args.dropout, | |
| ).to(device) | |
| criterion = BCEDiceLoss( | |
| bce_weight=args.bce_weight, | |
| dice_weight=args.dice_weight, | |
| ) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| scaler = torch.amp.GradScaler(enabled=args.amp and device.type == "cuda") | |
| best_dice = -1.0 | |
| print(f"Device: {device}") | |
| print(f"Train samples: {len(train_dataset)}") | |
| print(f"Val samples: {len(val_dataset)}") | |
| print(f"Image size: {args.image_size}") | |
| print(f"Batch size: {args.batch_size}") | |
| print(f"Pretrained: {not args.no_pretrained}") | |
| for epoch in range(1, args.epochs + 1): | |
| print(f"\nEpoch [{epoch:03d}/{args.epochs}]") | |
| train_loss, train_dice = train_one_epoch( | |
| model=model, | |
| loader=train_loader, | |
| optimizer=optimizer, | |
| scaler=scaler, | |
| criterion=criterion, | |
| device=device, | |
| use_amp=args.amp, | |
| ) | |
| val_loss, val_dice = validate( | |
| model=model, | |
| loader=val_loader, | |
| criterion=criterion, | |
| device=device, | |
| use_amp=args.amp, | |
| ) | |
| print( | |
| f"train_loss={train_loss:.4f} " | |
| f"train_dice={train_dice:.4f} " | |
| f"val_loss={val_loss:.4f} " | |
| f"val_dice={val_dice:.4f}" | |
| ) | |
| if val_dice > best_dice: | |
| best_dice = val_dice | |
| save_checkpoint( | |
| Path(args.output_dir) / "best.pt", | |
| model, | |
| optimizer, | |
| epoch, | |
| best_dice, | |
| args, | |
| ) | |
| print(f"Saved best checkpoint: val_dice={best_dice:.4f}") | |
| if epoch % args.save_every == 0: | |
| save_checkpoint( | |
| Path(args.output_dir) / f"epoch_{epoch:03d}.pt", | |
| model, | |
| optimizer, | |
| epoch, | |
| best_dice, | |
| args, | |
| ) | |
| save_checkpoint( | |
| Path(args.output_dir) / "last.pt", | |
| model, | |
| optimizer, | |
| args.epochs, | |
| best_dice, | |
| args, | |
| ) | |
| print("Training complete.") | |
| print(f"Best val Dice: {best_dice:.4f}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train retinal vessel segmentation model on FIVES.") | |
| parser.add_argument("--data-root", type=str, required=True) | |
| parser.add_argument("--output-dir", type=str, default="checkpoints/fives") | |
| parser.add_argument("--image-size", type=int, default=512) | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--batch-size", type=int, default=4) | |
| parser.add_argument("--num-workers", type=int, default=4) | |
| parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"]) | |
| parser.add_argument("--backbone", type=str, default="resnet50") | |
| parser.add_argument("--base-channels", type=int, default=32) | |
| parser.add_argument("--dropout", type=float, default=0.0) | |
| parser.add_argument("--no-pretrained", action="store_true") | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=1e-4) | |
| parser.add_argument("--bce-weight", type=float, default=1.0) | |
| parser.add_argument("--dice-weight", type=float, default=1.0) | |
| parser.add_argument("--save-every", type=int, default=25) | |
| parser.add_argument("--amp", action="store_true") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) |