File size: 4,930 Bytes
0e1a1b4
 
 
762e478
0e1a1b4
cd3dc59
 
0e1a1b4
 
cd3dc59
0e1a1b4
 
 
 
 
 
 
 
cd3dc59
0e1a1b4
762e478
0e1a1b4
 
 
 
 
 
 
 
 
 
cd3dc59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e1a1b4
cd3dc59
0e1a1b4
cd3dc59
0e1a1b4
 
 
 
cd3dc59
0e1a1b4
 
cd3dc59
 
0e1a1b4
 
 
cd3dc59
0e1a1b4
 
cd3dc59
0e1a1b4
 
 
 
 
cd3dc59
762e478
 
 
0e1a1b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3dc59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e1a1b4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Evaluate best (or last) checkpoint on the test set.
Reports genre, style, artist top-1 and artist top-5 accuracy.
Usage: python scripts/eval_cnn.py [--arch cnn|cnnrnn] [--last]
"""
from __future__ import annotations

import sys
from pathlib import Path
from typing import Any, Dict, TypedDict

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))

from config import INDEX_SELECTED, WIKIART_ROOT, checkpoint_dir_for_arch, N_STYLE, N_ARTIST, N_GENRE, BATCH_SIZE
from dataset import WikiArtDataset
from model import ResNet50BiLSTMThreeHeads, ResNet50ThreeHeads

# Reuse train split and transforms
import importlib.util
spec = importlib.util.spec_from_file_location("train_cnn", ROOT / "scripts" / "train_cnn.py")
train_cnn = importlib.util.module_from_spec(spec)
spec.loader.exec_module(train_cnn)
get_transforms = train_cnn.get_transforms
stratified_split = train_cnn.stratified_split


class EvalMetrics(TypedDict):
    arch: str
    checkpoint_name: str
    checkpoint_path: str
    epoch: Any
    test_n: int
    genre_top1: float
    style_top1: float
    artist_top1: float
    artist_top5: float


def compute_test_metrics(*, arch: str, last: bool = False) -> EvalMetrics:
    """
    Run the same evaluation as the CLI and return metrics as floats in [0, 1].
    Used by scripts/export_hf_model_card.py for Hub model card YAML.
    """
    ckpt_name = "last.pt" if last else "best.pt"
    ckpt_path = checkpoint_dir_for_arch(arch) / ckpt_name
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    if not INDEX_SELECTED.exists() or not WIKIART_ROOT.exists():
        raise FileNotFoundError("Index or wikiart root missing.")

    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    n_genre = ckpt["n_genre"]
    n_style = ckpt["n_style"]
    n_artist = ckpt["n_artist"]
    ckpt_arch = ckpt.get("arch", arch)

    import pandas as pd

    df = pd.read_csv(INDEX_SELECTED)
    _, _, idx_test = stratified_split(df)
    ds = WikiArtDataset(INDEX_SELECTED, WIKIART_ROOT, transform=get_transforms(train=False))
    test_loader = DataLoader(Subset(ds, idx_test), batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    if ckpt_arch == "cnnrnn":
        model = ResNet50BiLSTMThreeHeads(n_genre=n_genre, n_style=n_style, n_artist=n_artist).to(device)
    else:
        model = ResNet50ThreeHeads(n_genre=n_genre, n_style=n_style, n_artist=n_artist).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    correct_g = correct_s = correct_a = correct_a5 = total = 0
    with torch.no_grad():
        for images, style_id, artist_id, genre_id in test_loader:
            images = images.to(device)
            style_id = style_id.to(device)
            artist_id = artist_id.to(device)
            genre_id = genre_id.to(device)
            logits_g, logits_s, logits_a = model(images)
            n = images.size(0)
            total += n
            correct_g += (logits_g.argmax(1) == genre_id).sum().item()
            correct_s += (logits_s.argmax(1) == style_id).sum().item()
            correct_a += (logits_a.argmax(1) == artist_id).sum().item()
            _, top5 = logits_a.topk(5, dim=1)
            correct_a5 += (top5 == artist_id.unsqueeze(1)).any(1).sum().item()

    assert total > 0
    return {
        "arch": str(ckpt_arch),
        "checkpoint_name": ckpt_name,
        "checkpoint_path": str(ckpt_path),
        "epoch": ckpt.get("epoch", None),
        "test_n": int(total),
        "genre_top1": correct_g / total,
        "style_top1": correct_s / total,
        "artist_top1": correct_a / total,
        "artist_top5": correct_a5 / total,
    }


def main() -> None:
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--arch", type=str, default="cnn", choices=["cnn", "cnnrnn"], help="Model architecture")
    p.add_argument("--last", action="store_true", help="Evaluate last.pt instead of best.pt")
    args = p.parse_args()
    try:
        m = compute_test_metrics(arch=args.arch, last=args.last)
    except FileNotFoundError as e:
        print(f"ERROR: {e}")
        sys.exit(1)

    print(
        f"Arch: {m['arch']}  Checkpoint: {m['checkpoint_name']}  "
        f"(epoch {m['epoch']!r})  Test n={m['test_n']}"
    )
    print(f"  genre acc (top-1): {m['genre_top1']:.2%}")
    print(f"  style acc (top-1): {m['style_top1']:.2%}")
    print(f"  artist acc (top-1): {m['artist_top1']:.2%}")
    print(f"  artist acc (top-5): {m['artist_top5']:.2%}")


if __name__ == "__main__":
    main()