import argparse import json import random from pathlib import Path import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm from augmentations import get_train_transforms, get_val_transforms from dataloader import EyePACSDataset from model import DeepSeeNet N_CLASSES = 5 class AlbumentationsTransform: def __init__(self, transform): self.transform = transform def __call__(self, image): return self.transform(image=np.asarray(image))["image"] def parse_args(): parser = argparse.ArgumentParser(description="Train EyePACS DR classifier.") parser.add_argument("--root", required=True, help="EyePACS root folder.") parser.add_argument("--output-dir", default="checkpoints/eyepacs_dr") parser.add_argument("--backbone", default="inception_v3") parser.add_argument("--image-size", type=int, default=1024) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--num-workers", type=int, default=8) parser.add_argument("--fold", type=int, default=0) parser.add_argument("--n-folds", type=int, default=5) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=1e-4) parser.add_argument("--no-pretrained", action="store_true") parser.add_argument("--freeze-backbone", action="store_true") parser.add_argument("--no-class-weights", action="store_true") parser.add_argument("--scheduler", choices=["none", "cosine", "step"], default="cosine") parser.add_argument("--min-lr", type=float, default=1e-6) parser.add_argument("--step-size", type=int, default=5) parser.add_argument("--gamma", type=float, default=0.5) parser.add_argument("--amp", action="store_true") parser.add_argument("--grad-clip", type=float, default=0.0) parser.add_argument("--save-every", type=int, default=0) parser.add_argument("--wandb", action="store_true") parser.add_argument("--wandb-project", default="eyepacs-dr") return parser.parse_args() def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def unwrap_logits(output): if isinstance(output, (tuple, list)): return output[0] return output def get_class_weights(dataset, device): labels = torch.tensor([s["label"] for s in dataset.samples], dtype=torch.long) counts = torch.bincount(labels, minlength=N_CLASSES).clamp_min(1) weights = counts.sum() / (N_CLASSES * counts) return weights.to(device) def build_scheduler(optimizer, args): if args.scheduler == "cosine": return torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=args.min_lr, ) if args.scheduler == "step": return torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.step_size, gamma=args.gamma, ) return None def make_loader(dataset, batch_size, num_workers, shuffle, device): return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=device.type == "cuda", drop_last=shuffle, persistent_workers=num_workers > 0, ) def train_one_epoch( model, loader, optimizer, scaler, criterion, device, use_amp=True, grad_clip=0.0, ): model.train() total_loss = 0.0 total_correct = 0 total_samples = 0 pbar = tqdm(loader, desc="Train", leave=False) for images, labels in pbar: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): logits = unwrap_logits(model(images)) loss = criterion(logits, labels) if scaler is not None: scaler.scale(loss).backward() if grad_clip > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() if grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() batch_size = labels.size(0) total_loss += loss.item() * batch_size total_correct += (logits.argmax(dim=1) == labels).sum().item() total_samples += batch_size pbar.set_postfix( loss=f"{total_loss / total_samples:.4f}", acc=f"{total_correct / total_samples:.4f}", ) return { "loss": total_loss / total_samples, "acc": total_correct / total_samples, } @torch.no_grad() def evaluate(model, loader, criterion, device, use_amp=True): model.eval() total_loss = 0.0 total_correct = 0 total_samples = 0 all_labels = [] all_probs = [] all_preds = [] pbar = tqdm(loader, desc="Val", leave=False) for images, labels in pbar: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): logits = unwrap_logits(model(images)) loss = criterion(logits, labels) probs = F.softmax(logits, dim=1) preds = probs.argmax(dim=1) batch_size = labels.size(0) total_loss += loss.item() * batch_size total_correct += (preds == labels).sum().item() total_samples += batch_size all_labels.append(labels.detach().cpu()) all_probs.append(probs.detach().cpu()) all_preds.append(preds.detach().cpu()) pbar.set_postfix( loss=f"{total_loss / total_samples:.4f}", acc=f"{total_correct / total_samples:.4f}", ) labels = torch.cat(all_labels).numpy() probs = torch.cat(all_probs).numpy() preds = torch.cat(all_preds).numpy() metrics = { "loss": total_loss / total_samples, "acc": total_correct / total_samples, "referable_acc": float(((labels >= 2) == (preds >= 2)).mean()), "any_dr_acc": float(((labels >= 1) == (preds >= 1)).mean()), "severe_or_pdr_acc": float(((labels >= 3) == (preds >= 3)).mean()), } try: from sklearn.metrics import cohen_kappa_score, roc_auc_score metrics["qwk"] = float(cohen_kappa_score(labels, preds, weights="quadratic")) metrics["referable_auc"] = float( roc_auc_score((labels >= 2).astype(int), probs[:, 2:].sum(axis=1)) ) metrics["any_dr_auc"] = float( roc_auc_score((labels >= 1).astype(int), probs[:, 1:].sum(axis=1)) ) metrics["severe_or_pdr_auc"] = float( roc_auc_score((labels >= 3).astype(int), probs[:, 3:].sum(axis=1)) ) except Exception: metrics["qwk"] = float("nan") metrics["referable_auc"] = float("nan") metrics["any_dr_auc"] = float("nan") metrics["severe_or_pdr_auc"] = float("nan") return metrics def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args, model_only=False): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) ckpt = { "epoch": epoch, "model": model.state_dict(), "best_metric": best_metric, "args": vars(args), "id_to_label": { 0: "no_dr", 1: "mild_npdr", 2: "moderate_npdr", 3: "severe_npdr", 4: "pdr", }, } if not model_only: ckpt["optimizer"] = optimizer.state_dict() if scheduler is not None: ckpt["scheduler"] = scheduler.state_dict() torch.save(ckpt, path) def main(): args = parse_args() set_seed(args.seed) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_amp = args.amp and device.type == "cuda" train_dataset = EyePACSDataset( root=args.root, split="all", all_mode="train", transform=AlbumentationsTransform(get_train_transforms(args.image_size)), seed=args.seed, fold=args.fold, n_folds=args.n_folds, ) val_dataset = EyePACSDataset( root=args.root, split="all", all_mode="val", transform=AlbumentationsTransform(get_val_transforms(args.image_size)), seed=args.seed, fold=args.fold, n_folds=args.n_folds, ) train_loader = make_loader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, device=device, ) val_loader = make_loader( val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, device=device, ) model = DeepSeeNet( n_classes=N_CLASSES, backbone=args.backbone, pretrained=not args.no_pretrained, freeze_backbone=args.freeze_backbone, ).to(device) class_weights = None if not args.no_class_weights: class_weights = get_class_weights(train_dataset, device) train_criterion = torch.nn.CrossEntropyLoss(weight=class_weights) val_criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, ) scheduler = build_scheduler(optimizer, args) scaler = torch.amp.GradScaler("cuda") if use_amp else None wandb = None if args.wandb: import wandb wandb.init(project=args.wandb_project, config=vars(args)) print("\nEyePACS DR training") print("-------------------") print(f"Device: {device}") print(f"Root: {args.root}") print(f"Output: {output_dir}") print(f"Backbone: {args.backbone}") print(f"Image size: {args.image_size}") print(f"Fold: {args.fold}/{args.n_folds}") print(f"Train samples: {len(train_dataset)}") print(f"Val samples: {len(val_dataset)}") print(f"AMP: {use_amp}") print(f"Pretrained: {not args.no_pretrained}") if class_weights is not None: print(f"Class weights: {class_weights.detach().cpu().tolist()}") best_qwk = -float("inf") history = [] for epoch in range(1, args.epochs + 1): print(f"\nEpoch [{epoch:03d}/{args.epochs}]") train_metrics = train_one_epoch( model=model, loader=train_loader, optimizer=optimizer, scaler=scaler, criterion=train_criterion, device=device, use_amp=use_amp, grad_clip=args.grad_clip, ) val_metrics = evaluate( model=model, loader=val_loader, criterion=val_criterion, device=device, use_amp=use_amp, ) lr = optimizer.param_groups[0]["lr"] row = { "epoch": epoch, "lr": lr, **{f"train_{k}": v for k, v in train_metrics.items()}, **{f"val_{k}": v for k, v in val_metrics.items()}, } history.append(row) print( f"lr={lr:.2e} " f"train_loss={train_metrics['loss']:.4f} " f"train_acc={train_metrics['acc']:.4f} " f"val_loss={val_metrics['loss']:.4f} " f"val_acc={val_metrics['acc']:.4f} " f"val_qwk={val_metrics['qwk']:.4f} " f"val_ref_auc={val_metrics['referable_auc']:.4f}" ) if wandb is not None: wandb.log(row) with (output_dir / "history.json").open("w") as f: json.dump(history, f, indent=2) monitor = val_metrics["qwk"] if np.isnan(monitor): monitor = -val_metrics["loss"] if monitor > best_qwk: best_qwk = monitor save_checkpoint( output_dir / "best.pt", model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, best_metric=best_qwk, args=args, model_only=False, ) save_checkpoint( output_dir / "best_model_only.pt", model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, best_metric=best_qwk, args=args, model_only=True, ) print(f"Saved best checkpoint: monitor={best_qwk:.4f}") if args.save_every > 0 and epoch % args.save_every == 0: save_checkpoint( output_dir / f"epoch_{epoch:03d}.pt", model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, best_metric=best_qwk, args=args, model_only=False, ) if scheduler is not None: scheduler.step() save_checkpoint( output_dir / "last.pt", model=model, optimizer=optimizer, scheduler=scheduler, epoch=args.epochs, best_metric=best_qwk, args=args, model_only=False, ) print("\nTraining complete.") print(f"Best monitor/QWK: {best_qwk:.4f}") print(f"Saved to: {output_dir}") if __name__ == "__main__": main()