EyePACS / train.py
Hou
add src
ebbe758
Raw
History Blame Contribute Delete
13.9 kB
import argparse
import json
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from augmentations import get_train_transforms, get_val_transforms
from dataloader import EyePACSDataset
from model import DeepSeeNet
N_CLASSES = 5
class AlbumentationsTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, image):
return self.transform(image=np.asarray(image))["image"]
def parse_args():
parser = argparse.ArgumentParser(description="Train EyePACS DR classifier.")
parser.add_argument("--root", required=True, help="EyePACS root folder.")
parser.add_argument("--output-dir", default="checkpoints/eyepacs_dr")
parser.add_argument("--backbone", default="inception_v3")
parser.add_argument("--image-size", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--fold", type=int, default=0)
parser.add_argument("--n-folds", type=int, default=5)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--no-pretrained", action="store_true")
parser.add_argument("--freeze-backbone", action="store_true")
parser.add_argument("--no-class-weights", action="store_true")
parser.add_argument("--scheduler", choices=["none", "cosine", "step"], default="cosine")
parser.add_argument("--min-lr", type=float, default=1e-6)
parser.add_argument("--step-size", type=int, default=5)
parser.add_argument("--gamma", type=float, default=0.5)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--grad-clip", type=float, default=0.0)
parser.add_argument("--save-every", type=int, default=0)
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--wandb-project", default="eyepacs-dr")
return parser.parse_args()
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def unwrap_logits(output):
if isinstance(output, (tuple, list)):
return output[0]
return output
def get_class_weights(dataset, device):
labels = torch.tensor([s["label"] for s in dataset.samples], dtype=torch.long)
counts = torch.bincount(labels, minlength=N_CLASSES).clamp_min(1)
weights = counts.sum() / (N_CLASSES * counts)
return weights.to(device)
def build_scheduler(optimizer, args):
if args.scheduler == "cosine":
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=args.epochs,
eta_min=args.min_lr,
)
if args.scheduler == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=args.step_size,
gamma=args.gamma,
)
return None
def make_loader(dataset, batch_size, num_workers, shuffle, device):
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=device.type == "cuda",
drop_last=shuffle,
persistent_workers=num_workers > 0,
)
def train_one_epoch(
model,
loader,
optimizer,
scaler,
criterion,
device,
use_amp=True,
grad_clip=0.0,
):
model.train()
total_loss = 0.0
total_correct = 0
total_samples = 0
pbar = tqdm(loader, desc="Train", leave=False)
for images, labels in pbar:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
logits = unwrap_logits(model(images))
loss = criterion(logits, labels)
if scaler is not None:
scaler.scale(loss).backward()
if grad_clip > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
batch_size = labels.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_samples += batch_size
pbar.set_postfix(
loss=f"{total_loss / total_samples:.4f}",
acc=f"{total_correct / total_samples:.4f}",
)
return {
"loss": total_loss / total_samples,
"acc": total_correct / total_samples,
}
@torch.no_grad()
def evaluate(model, loader, criterion, device, use_amp=True):
model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0
all_labels = []
all_probs = []
all_preds = []
pbar = tqdm(loader, desc="Val", leave=False)
for images, labels in pbar:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
logits = unwrap_logits(model(images))
loss = criterion(logits, labels)
probs = F.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
batch_size = labels.size(0)
total_loss += loss.item() * batch_size
total_correct += (preds == labels).sum().item()
total_samples += batch_size
all_labels.append(labels.detach().cpu())
all_probs.append(probs.detach().cpu())
all_preds.append(preds.detach().cpu())
pbar.set_postfix(
loss=f"{total_loss / total_samples:.4f}",
acc=f"{total_correct / total_samples:.4f}",
)
labels = torch.cat(all_labels).numpy()
probs = torch.cat(all_probs).numpy()
preds = torch.cat(all_preds).numpy()
metrics = {
"loss": total_loss / total_samples,
"acc": total_correct / total_samples,
"referable_acc": float(((labels >= 2) == (preds >= 2)).mean()),
"any_dr_acc": float(((labels >= 1) == (preds >= 1)).mean()),
"severe_or_pdr_acc": float(((labels >= 3) == (preds >= 3)).mean()),
}
try:
from sklearn.metrics import cohen_kappa_score, roc_auc_score
metrics["qwk"] = float(cohen_kappa_score(labels, preds, weights="quadratic"))
metrics["referable_auc"] = float(
roc_auc_score((labels >= 2).astype(int), probs[:, 2:].sum(axis=1))
)
metrics["any_dr_auc"] = float(
roc_auc_score((labels >= 1).astype(int), probs[:, 1:].sum(axis=1))
)
metrics["severe_or_pdr_auc"] = float(
roc_auc_score((labels >= 3).astype(int), probs[:, 3:].sum(axis=1))
)
except Exception:
metrics["qwk"] = float("nan")
metrics["referable_auc"] = float("nan")
metrics["any_dr_auc"] = float("nan")
metrics["severe_or_pdr_auc"] = float("nan")
return metrics
def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args, model_only=False):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
ckpt = {
"epoch": epoch,
"model": model.state_dict(),
"best_metric": best_metric,
"args": vars(args),
"id_to_label": {
0: "no_dr",
1: "mild_npdr",
2: "moderate_npdr",
3: "severe_npdr",
4: "pdr",
},
}
if not model_only:
ckpt["optimizer"] = optimizer.state_dict()
if scheduler is not None:
ckpt["scheduler"] = scheduler.state_dict()
torch.save(ckpt, path)
def main():
args = parse_args()
set_seed(args.seed)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = args.amp and device.type == "cuda"
train_dataset = EyePACSDataset(
root=args.root,
split="all",
all_mode="train",
transform=AlbumentationsTransform(get_train_transforms(args.image_size)),
seed=args.seed,
fold=args.fold,
n_folds=args.n_folds,
)
val_dataset = EyePACSDataset(
root=args.root,
split="all",
all_mode="val",
transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
seed=args.seed,
fold=args.fold,
n_folds=args.n_folds,
)
train_loader = make_loader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
device=device,
)
val_loader = make_loader(
val_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=False,
device=device,
)
model = DeepSeeNet(
n_classes=N_CLASSES,
backbone=args.backbone,
pretrained=not args.no_pretrained,
freeze_backbone=args.freeze_backbone,
).to(device)
class_weights = None
if not args.no_class_weights:
class_weights = get_class_weights(train_dataset, device)
train_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
val_criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
scheduler = build_scheduler(optimizer, args)
scaler = torch.amp.GradScaler("cuda") if use_amp else None
wandb = None
if args.wandb:
import wandb
wandb.init(project=args.wandb_project, config=vars(args))
print("\nEyePACS DR training")
print("-------------------")
print(f"Device: {device}")
print(f"Root: {args.root}")
print(f"Output: {output_dir}")
print(f"Backbone: {args.backbone}")
print(f"Image size: {args.image_size}")
print(f"Fold: {args.fold}/{args.n_folds}")
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"AMP: {use_amp}")
print(f"Pretrained: {not args.no_pretrained}")
if class_weights is not None:
print(f"Class weights: {class_weights.detach().cpu().tolist()}")
best_qwk = -float("inf")
history = []
for epoch in range(1, args.epochs + 1):
print(f"\nEpoch [{epoch:03d}/{args.epochs}]")
train_metrics = train_one_epoch(
model=model,
loader=train_loader,
optimizer=optimizer,
scaler=scaler,
criterion=train_criterion,
device=device,
use_amp=use_amp,
grad_clip=args.grad_clip,
)
val_metrics = evaluate(
model=model,
loader=val_loader,
criterion=val_criterion,
device=device,
use_amp=use_amp,
)
lr = optimizer.param_groups[0]["lr"]
row = {
"epoch": epoch,
"lr": lr,
**{f"train_{k}": v for k, v in train_metrics.items()},
**{f"val_{k}": v for k, v in val_metrics.items()},
}
history.append(row)
print(
f"lr={lr:.2e} "
f"train_loss={train_metrics['loss']:.4f} "
f"train_acc={train_metrics['acc']:.4f} "
f"val_loss={val_metrics['loss']:.4f} "
f"val_acc={val_metrics['acc']:.4f} "
f"val_qwk={val_metrics['qwk']:.4f} "
f"val_ref_auc={val_metrics['referable_auc']:.4f}"
)
if wandb is not None:
wandb.log(row)
with (output_dir / "history.json").open("w") as f:
json.dump(history, f, indent=2)
monitor = val_metrics["qwk"]
if np.isnan(monitor):
monitor = -val_metrics["loss"]
if monitor > best_qwk:
best_qwk = monitor
save_checkpoint(
output_dir / "best.pt",
model=model,
optimizer=optimizer,
scheduler=scheduler,
epoch=epoch,
best_metric=best_qwk,
args=args,
model_only=False,
)
save_checkpoint(
output_dir / "best_model_only.pt",
model=model,
optimizer=optimizer,
scheduler=scheduler,
epoch=epoch,
best_metric=best_qwk,
args=args,
model_only=True,
)
print(f"Saved best checkpoint: monitor={best_qwk:.4f}")
if args.save_every > 0 and epoch % args.save_every == 0:
save_checkpoint(
output_dir / f"epoch_{epoch:03d}.pt",
model=model,
optimizer=optimizer,
scheduler=scheduler,
epoch=epoch,
best_metric=best_qwk,
args=args,
model_only=False,
)
if scheduler is not None:
scheduler.step()
save_checkpoint(
output_dir / "last.pt",
model=model,
optimizer=optimizer,
scheduler=scheduler,
epoch=args.epochs,
best_metric=best_qwk,
args=args,
model_only=False,
)
print("\nTraining complete.")
print(f"Best monitor/QWK: {best_qwk:.4f}")
print(f"Saved to: {output_dir}")
if __name__ == "__main__":
main()