Retina_Training / train.py
Habeeb Okunade
Develop model training
39ec591
from __future__ import annotations
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from pathlib import Path
from tqdm import tqdm
from model_loader import build_classifier
from dataset import get_loaders
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
running = 0.0
correct = 0
total = 0
for imgs, labels in tqdm(loader, desc="train", leave=False):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running += loss.item() * imgs.size(0)
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += imgs.size(0)
return running / total, correct / total
def eval_one_epoch(model, loader, criterion, device):
model.eval()
running = 0.0
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in tqdm(loader, desc="val", leave=False):
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
loss = criterion(logits, labels)
running += loss.item() * imgs.size(0)
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += imgs.size(0)
return running / total, correct / total
def fit(data_root: str,
base_repo: str,
base_filename: str,
epochs: int = 10,
batch_size: int = 16,
lr: float = 5e-4,
weight_decay: float = 0.05,
freeze_backbone: bool = True,
out_dir: str | Path = "checkpoints",
device: str = "cuda" if torch.cuda.is_available() else "cpu"):
train_dl, val_dl, classes = get_loaders(data_root, batch_size=batch_size)
num_classes = len(classes)
model = build_classifier(num_classes=num_classes,
base_repo=base_repo,
base_filename=base_filename,
device=device)
# Optionally freeze backbone (everything except head)
if freeze_backbone:
for name, p in model.named_parameters():
if not name.startswith("head"):
p.requires_grad = False
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
best_acc = 0.0
out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
for ep in range(1, epochs + 1):
tr_loss, tr_acc = train_one_epoch(model, train_dl, criterion, optimizer, device)
va_loss, va_acc = eval_one_epoch(model, val_dl, criterion, device)
scheduler.step()
print(f"Epoch {ep}: train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={va_loss:.4f} acc={va_acc:.3f}")
# Save last
torch.save({
"model": model.state_dict(),
"classes": classes,
"epoch": ep,
"val_acc": va_acc,
}, out_dir / "retfound_classifier_last.pth")
# Save best
if va_acc > best_acc:
best_acc = va_acc
torch.save({
"model": model.state_dict(),
"classes": classes,
"epoch": ep,
"val_acc": va_acc,
}, out_dir / "retfound_classifier_best.pth")
return str(out_dir / "retfound_classifier_best.pth"), classes, best_acc