Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from augmentations import get_train_transforms, get_val_transforms | |
| from dataloader import EyePACSDataset | |
| from model import DeepSeeNet | |
| N_CLASSES = 5 | |
| class AlbumentationsTransform: | |
| def __init__(self, transform): | |
| self.transform = transform | |
| def __call__(self, image): | |
| return self.transform(image=np.asarray(image))["image"] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train EyePACS DR classifier.") | |
| parser.add_argument("--root", required=True, help="EyePACS root folder.") | |
| parser.add_argument("--output-dir", default="checkpoints/eyepacs_dr") | |
| parser.add_argument("--backbone", default="inception_v3") | |
| parser.add_argument("--image-size", type=int, default=1024) | |
| parser.add_argument("--epochs", type=int, default=20) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--num-workers", type=int, default=8) | |
| parser.add_argument("--fold", type=int, default=0) | |
| parser.add_argument("--n-folds", type=int, default=5) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=1e-4) | |
| parser.add_argument("--no-pretrained", action="store_true") | |
| parser.add_argument("--freeze-backbone", action="store_true") | |
| parser.add_argument("--no-class-weights", action="store_true") | |
| parser.add_argument("--scheduler", choices=["none", "cosine", "step"], default="cosine") | |
| parser.add_argument("--min-lr", type=float, default=1e-6) | |
| parser.add_argument("--step-size", type=int, default=5) | |
| parser.add_argument("--gamma", type=float, default=0.5) | |
| parser.add_argument("--amp", action="store_true") | |
| parser.add_argument("--grad-clip", type=float, default=0.0) | |
| parser.add_argument("--save-every", type=int, default=0) | |
| parser.add_argument("--wandb", action="store_true") | |
| parser.add_argument("--wandb-project", default="eyepacs-dr") | |
| return parser.parse_args() | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def unwrap_logits(output): | |
| if isinstance(output, (tuple, list)): | |
| return output[0] | |
| return output | |
| def get_class_weights(dataset, device): | |
| labels = torch.tensor([s["label"] for s in dataset.samples], dtype=torch.long) | |
| counts = torch.bincount(labels, minlength=N_CLASSES).clamp_min(1) | |
| weights = counts.sum() / (N_CLASSES * counts) | |
| return weights.to(device) | |
| def build_scheduler(optimizer, args): | |
| if args.scheduler == "cosine": | |
| return torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=args.epochs, | |
| eta_min=args.min_lr, | |
| ) | |
| if args.scheduler == "step": | |
| return torch.optim.lr_scheduler.StepLR( | |
| optimizer, | |
| step_size=args.step_size, | |
| gamma=args.gamma, | |
| ) | |
| return None | |
| def make_loader(dataset, batch_size, num_workers, shuffle, device): | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=device.type == "cuda", | |
| drop_last=shuffle, | |
| persistent_workers=num_workers > 0, | |
| ) | |
| def train_one_epoch( | |
| model, | |
| loader, | |
| optimizer, | |
| scaler, | |
| criterion, | |
| device, | |
| use_amp=True, | |
| grad_clip=0.0, | |
| ): | |
| model.train() | |
| total_loss = 0.0 | |
| total_correct = 0 | |
| total_samples = 0 | |
| pbar = tqdm(loader, desc="Train", leave=False) | |
| for images, labels in pbar: | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): | |
| logits = unwrap_logits(model(images)) | |
| loss = criterion(logits, labels) | |
| if scaler is not None: | |
| scaler.scale(loss).backward() | |
| if grad_clip > 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| if grad_clip > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| batch_size = labels.size(0) | |
| total_loss += loss.item() * batch_size | |
| total_correct += (logits.argmax(dim=1) == labels).sum().item() | |
| total_samples += batch_size | |
| pbar.set_postfix( | |
| loss=f"{total_loss / total_samples:.4f}", | |
| acc=f"{total_correct / total_samples:.4f}", | |
| ) | |
| return { | |
| "loss": total_loss / total_samples, | |
| "acc": total_correct / total_samples, | |
| } | |
| def evaluate(model, loader, criterion, device, use_amp=True): | |
| model.eval() | |
| total_loss = 0.0 | |
| total_correct = 0 | |
| total_samples = 0 | |
| all_labels = [] | |
| all_probs = [] | |
| all_preds = [] | |
| pbar = tqdm(loader, desc="Val", leave=False) | |
| for images, labels in pbar: | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): | |
| logits = unwrap_logits(model(images)) | |
| loss = criterion(logits, labels) | |
| probs = F.softmax(logits, dim=1) | |
| preds = probs.argmax(dim=1) | |
| batch_size = labels.size(0) | |
| total_loss += loss.item() * batch_size | |
| total_correct += (preds == labels).sum().item() | |
| total_samples += batch_size | |
| all_labels.append(labels.detach().cpu()) | |
| all_probs.append(probs.detach().cpu()) | |
| all_preds.append(preds.detach().cpu()) | |
| pbar.set_postfix( | |
| loss=f"{total_loss / total_samples:.4f}", | |
| acc=f"{total_correct / total_samples:.4f}", | |
| ) | |
| labels = torch.cat(all_labels).numpy() | |
| probs = torch.cat(all_probs).numpy() | |
| preds = torch.cat(all_preds).numpy() | |
| metrics = { | |
| "loss": total_loss / total_samples, | |
| "acc": total_correct / total_samples, | |
| "referable_acc": float(((labels >= 2) == (preds >= 2)).mean()), | |
| "any_dr_acc": float(((labels >= 1) == (preds >= 1)).mean()), | |
| "severe_or_pdr_acc": float(((labels >= 3) == (preds >= 3)).mean()), | |
| } | |
| try: | |
| from sklearn.metrics import cohen_kappa_score, roc_auc_score | |
| metrics["qwk"] = float(cohen_kappa_score(labels, preds, weights="quadratic")) | |
| metrics["referable_auc"] = float( | |
| roc_auc_score((labels >= 2).astype(int), probs[:, 2:].sum(axis=1)) | |
| ) | |
| metrics["any_dr_auc"] = float( | |
| roc_auc_score((labels >= 1).astype(int), probs[:, 1:].sum(axis=1)) | |
| ) | |
| metrics["severe_or_pdr_auc"] = float( | |
| roc_auc_score((labels >= 3).astype(int), probs[:, 3:].sum(axis=1)) | |
| ) | |
| except Exception: | |
| metrics["qwk"] = float("nan") | |
| metrics["referable_auc"] = float("nan") | |
| metrics["any_dr_auc"] = float("nan") | |
| metrics["severe_or_pdr_auc"] = float("nan") | |
| return metrics | |
| def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args, model_only=False): | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| ckpt = { | |
| "epoch": epoch, | |
| "model": model.state_dict(), | |
| "best_metric": best_metric, | |
| "args": vars(args), | |
| "id_to_label": { | |
| 0: "no_dr", | |
| 1: "mild_npdr", | |
| 2: "moderate_npdr", | |
| 3: "severe_npdr", | |
| 4: "pdr", | |
| }, | |
| } | |
| if not model_only: | |
| ckpt["optimizer"] = optimizer.state_dict() | |
| if scheduler is not None: | |
| ckpt["scheduler"] = scheduler.state_dict() | |
| torch.save(ckpt, path) | |
| def main(): | |
| args = parse_args() | |
| set_seed(args.seed) | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| use_amp = args.amp and device.type == "cuda" | |
| train_dataset = EyePACSDataset( | |
| root=args.root, | |
| split="all", | |
| all_mode="train", | |
| transform=AlbumentationsTransform(get_train_transforms(args.image_size)), | |
| seed=args.seed, | |
| fold=args.fold, | |
| n_folds=args.n_folds, | |
| ) | |
| val_dataset = EyePACSDataset( | |
| root=args.root, | |
| split="all", | |
| all_mode="val", | |
| transform=AlbumentationsTransform(get_val_transforms(args.image_size)), | |
| seed=args.seed, | |
| fold=args.fold, | |
| n_folds=args.n_folds, | |
| ) | |
| train_loader = make_loader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| shuffle=True, | |
| device=device, | |
| ) | |
| val_loader = make_loader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| shuffle=False, | |
| device=device, | |
| ) | |
| model = DeepSeeNet( | |
| n_classes=N_CLASSES, | |
| backbone=args.backbone, | |
| pretrained=not args.no_pretrained, | |
| freeze_backbone=args.freeze_backbone, | |
| ).to(device) | |
| class_weights = None | |
| if not args.no_class_weights: | |
| class_weights = get_class_weights(train_dataset, device) | |
| train_criterion = torch.nn.CrossEntropyLoss(weight=class_weights) | |
| val_criterion = torch.nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| scheduler = build_scheduler(optimizer, args) | |
| scaler = torch.amp.GradScaler("cuda") if use_amp else None | |
| wandb = None | |
| if args.wandb: | |
| import wandb | |
| wandb.init(project=args.wandb_project, config=vars(args)) | |
| print("\nEyePACS DR training") | |
| print("-------------------") | |
| print(f"Device: {device}") | |
| print(f"Root: {args.root}") | |
| print(f"Output: {output_dir}") | |
| print(f"Backbone: {args.backbone}") | |
| print(f"Image size: {args.image_size}") | |
| print(f"Fold: {args.fold}/{args.n_folds}") | |
| print(f"Train samples: {len(train_dataset)}") | |
| print(f"Val samples: {len(val_dataset)}") | |
| print(f"AMP: {use_amp}") | |
| print(f"Pretrained: {not args.no_pretrained}") | |
| if class_weights is not None: | |
| print(f"Class weights: {class_weights.detach().cpu().tolist()}") | |
| best_qwk = -float("inf") | |
| history = [] | |
| for epoch in range(1, args.epochs + 1): | |
| print(f"\nEpoch [{epoch:03d}/{args.epochs}]") | |
| train_metrics = train_one_epoch( | |
| model=model, | |
| loader=train_loader, | |
| optimizer=optimizer, | |
| scaler=scaler, | |
| criterion=train_criterion, | |
| device=device, | |
| use_amp=use_amp, | |
| grad_clip=args.grad_clip, | |
| ) | |
| val_metrics = evaluate( | |
| model=model, | |
| loader=val_loader, | |
| criterion=val_criterion, | |
| device=device, | |
| use_amp=use_amp, | |
| ) | |
| lr = optimizer.param_groups[0]["lr"] | |
| row = { | |
| "epoch": epoch, | |
| "lr": lr, | |
| **{f"train_{k}": v for k, v in train_metrics.items()}, | |
| **{f"val_{k}": v for k, v in val_metrics.items()}, | |
| } | |
| history.append(row) | |
| print( | |
| f"lr={lr:.2e} " | |
| f"train_loss={train_metrics['loss']:.4f} " | |
| f"train_acc={train_metrics['acc']:.4f} " | |
| f"val_loss={val_metrics['loss']:.4f} " | |
| f"val_acc={val_metrics['acc']:.4f} " | |
| f"val_qwk={val_metrics['qwk']:.4f} " | |
| f"val_ref_auc={val_metrics['referable_auc']:.4f}" | |
| ) | |
| if wandb is not None: | |
| wandb.log(row) | |
| with (output_dir / "history.json").open("w") as f: | |
| json.dump(history, f, indent=2) | |
| monitor = val_metrics["qwk"] | |
| if np.isnan(monitor): | |
| monitor = -val_metrics["loss"] | |
| if monitor > best_qwk: | |
| best_qwk = monitor | |
| save_checkpoint( | |
| output_dir / "best.pt", | |
| model=model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| epoch=epoch, | |
| best_metric=best_qwk, | |
| args=args, | |
| model_only=False, | |
| ) | |
| save_checkpoint( | |
| output_dir / "best_model_only.pt", | |
| model=model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| epoch=epoch, | |
| best_metric=best_qwk, | |
| args=args, | |
| model_only=True, | |
| ) | |
| print(f"Saved best checkpoint: monitor={best_qwk:.4f}") | |
| if args.save_every > 0 and epoch % args.save_every == 0: | |
| save_checkpoint( | |
| output_dir / f"epoch_{epoch:03d}.pt", | |
| model=model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| epoch=epoch, | |
| best_metric=best_qwk, | |
| args=args, | |
| model_only=False, | |
| ) | |
| if scheduler is not None: | |
| scheduler.step() | |
| save_checkpoint( | |
| output_dir / "last.pt", | |
| model=model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| epoch=args.epochs, | |
| best_metric=best_qwk, | |
| args=args, | |
| model_only=False, | |
| ) | |
| print("\nTraining complete.") | |
| print(f"Best monitor/QWK: {best_qwk:.4f}") | |
| print(f"Saved to: {output_dir}") | |
| if __name__ == "__main__": | |
| main() |