Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from model_loader import build_classifier | |
| from dataset import get_loaders | |
| def train_one_epoch(model, loader, criterion, optimizer, device): | |
| model.train() | |
| running = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for imgs, labels in tqdm(loader, desc="train", leave=False): | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| logits = model(imgs) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running += loss.item() * imgs.size(0) | |
| preds = torch.argmax(logits, dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += imgs.size(0) | |
| return running / total, correct / total | |
| def eval_one_epoch(model, loader, criterion, device): | |
| model.eval() | |
| running = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for imgs, labels in tqdm(loader, desc="val", leave=False): | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| logits = model(imgs) | |
| loss = criterion(logits, labels) | |
| running += loss.item() * imgs.size(0) | |
| preds = torch.argmax(logits, dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += imgs.size(0) | |
| return running / total, correct / total | |
| def fit(data_root: str, | |
| base_repo: str, | |
| base_filename: str, | |
| epochs: int = 10, | |
| batch_size: int = 16, | |
| lr: float = 5e-4, | |
| weight_decay: float = 0.05, | |
| freeze_backbone: bool = True, | |
| out_dir: str | Path = "checkpoints", | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu"): | |
| train_dl, val_dl, classes = get_loaders(data_root, batch_size=batch_size) | |
| num_classes = len(classes) | |
| model = build_classifier(num_classes=num_classes, | |
| base_repo=base_repo, | |
| base_filename=base_filename, | |
| device=device) | |
| # Optionally freeze backbone (everything except head) | |
| if freeze_backbone: | |
| for name, p in model.named_parameters(): | |
| if not name.startswith("head"): | |
| p.requires_grad = False | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay) | |
| scheduler = CosineAnnealingLR(optimizer, T_max=epochs) | |
| best_acc = 0.0 | |
| out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True) | |
| for ep in range(1, epochs + 1): | |
| tr_loss, tr_acc = train_one_epoch(model, train_dl, criterion, optimizer, device) | |
| va_loss, va_acc = eval_one_epoch(model, val_dl, criterion, device) | |
| scheduler.step() | |
| print(f"Epoch {ep}: train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={va_loss:.4f} acc={va_acc:.3f}") | |
| # Save last | |
| torch.save({ | |
| "model": model.state_dict(), | |
| "classes": classes, | |
| "epoch": ep, | |
| "val_acc": va_acc, | |
| }, out_dir / "retfound_classifier_last.pth") | |
| # Save best | |
| if va_acc > best_acc: | |
| best_acc = va_acc | |
| torch.save({ | |
| "model": model.state_dict(), | |
| "classes": classes, | |
| "epoch": ep, | |
| "val_acc": va_acc, | |
| }, out_dir / "retfound_classifier_best.pth") | |
| return str(out_dir / "retfound_classifier_best.pth"), classes, best_acc |