File size: 3,613 Bytes
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from pathlib import Path
from tqdm import tqdm

from model_loader import build_classifier
from dataset import get_loaders


def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running = 0.0
    correct = 0
    total = 0
    for imgs, labels in tqdm(loader, desc="train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running += loss.item() * imgs.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return running / total, correct / total


def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    running = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="val", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            running += loss.item() * imgs.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return running / total, correct / total


def fit(data_root: str,
        base_repo: str,
        base_filename: str,
        epochs: int = 10,
        batch_size: int = 16,
        lr: float = 5e-4,
        weight_decay: float = 0.05,
        freeze_backbone: bool = True,
        out_dir: str | Path = "checkpoints",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"):

    train_dl, val_dl, classes = get_loaders(data_root, batch_size=batch_size)
    num_classes = len(classes)

    model = build_classifier(num_classes=num_classes,
                             base_repo=base_repo,
                             base_filename=base_filename,
                             device=device)

    # Optionally freeze backbone (everything except head)
    if freeze_backbone:
        for name, p in model.named_parameters():
            if not name.startswith("head"):
                p.requires_grad = False

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    best_acc = 0.0
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    for ep in range(1, epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_dl, criterion, optimizer, device)
        va_loss, va_acc = eval_one_epoch(model, val_dl, criterion, device)
        scheduler.step()
        print(f"Epoch {ep}: train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={va_loss:.4f} acc={va_acc:.3f}")

        # Save last
        torch.save({
            "model": model.state_dict(),
            "classes": classes,
            "epoch": ep,
            "val_acc": va_acc,
        }, out_dir / "retfound_classifier_last.pth")

        # Save best
        if va_acc > best_acc:
            best_acc = va_acc
            torch.save({
                "model": model.state_dict(),
                "classes": classes,
                "epoch": ep,
                "val_acc": va_acc,
            }, out_dir / "retfound_classifier_best.pth")

    return str(out_dir / "retfound_classifier_best.pth"), classes, best_acc