from __future__ import annotations import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from pathlib import Path from tqdm import tqdm from model_loader import build_classifier from dataset import get_loaders def train_one_epoch(model, loader, criterion, optimizer, device): model.train() running = 0.0 correct = 0 total = 0 for imgs, labels in tqdm(loader, desc="train", leave=False): imgs, labels = imgs.to(device), labels.to(device) optimizer.zero_grad() logits = model(imgs) loss = criterion(logits, labels) loss.backward() optimizer.step() running += loss.item() * imgs.size(0) preds = torch.argmax(logits, dim=1) correct += (preds == labels).sum().item() total += imgs.size(0) return running / total, correct / total def eval_one_epoch(model, loader, criterion, device): model.eval() running = 0.0 correct = 0 total = 0 with torch.no_grad(): for imgs, labels in tqdm(loader, desc="val", leave=False): imgs, labels = imgs.to(device), labels.to(device) logits = model(imgs) loss = criterion(logits, labels) running += loss.item() * imgs.size(0) preds = torch.argmax(logits, dim=1) correct += (preds == labels).sum().item() total += imgs.size(0) return running / total, correct / total def fit(data_root: str, base_repo: str, base_filename: str, epochs: int = 10, batch_size: int = 16, lr: float = 5e-4, weight_decay: float = 0.05, freeze_backbone: bool = True, out_dir: str | Path = "checkpoints", device: str = "cuda" if torch.cuda.is_available() else "cpu"): train_dl, val_dl, classes = get_loaders(data_root, batch_size=batch_size) num_classes = len(classes) model = build_classifier(num_classes=num_classes, base_repo=base_repo, base_filename=base_filename, device=device) # Optionally freeze backbone (everything except head) if freeze_backbone: for name, p in model.named_parameters(): if not name.startswith("head"): p.requires_grad = False criterion = nn.CrossEntropyLoss() optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay) scheduler = CosineAnnealingLR(optimizer, T_max=epochs) best_acc = 0.0 out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True) for ep in range(1, epochs + 1): tr_loss, tr_acc = train_one_epoch(model, train_dl, criterion, optimizer, device) va_loss, va_acc = eval_one_epoch(model, val_dl, criterion, device) scheduler.step() print(f"Epoch {ep}: train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={va_loss:.4f} acc={va_acc:.3f}") # Save last torch.save({ "model": model.state_dict(), "classes": classes, "epoch": ep, "val_acc": va_acc, }, out_dir / "retfound_classifier_last.pth") # Save best if va_acc > best_acc: best_acc = va_acc torch.save({ "model": model.state_dict(), "classes": classes, "epoch": ep, "val_acc": va_acc, }, out_dir / "retfound_classifier_best.pth") return str(out_dir / "retfound_classifier_best.pth"), classes, best_acc