artydemo / scripts /eval_cnn.py
Pablo Dejuan
add model export
cd3dc59
"""
Evaluate best (or last) checkpoint on the test set.
Reports genre, style, artist top-1 and artist top-5 accuracy.
Usage: python scripts/eval_cnn.py [--arch cnn|cnnrnn] [--last]
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Dict, TypedDict
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
from config import INDEX_SELECTED, WIKIART_ROOT, checkpoint_dir_for_arch, N_STYLE, N_ARTIST, N_GENRE, BATCH_SIZE
from dataset import WikiArtDataset
from model import ResNet50BiLSTMThreeHeads, ResNet50ThreeHeads
# Reuse train split and transforms
import importlib.util
spec = importlib.util.spec_from_file_location("train_cnn", ROOT / "scripts" / "train_cnn.py")
train_cnn = importlib.util.module_from_spec(spec)
spec.loader.exec_module(train_cnn)
get_transforms = train_cnn.get_transforms
stratified_split = train_cnn.stratified_split
class EvalMetrics(TypedDict):
arch: str
checkpoint_name: str
checkpoint_path: str
epoch: Any
test_n: int
genre_top1: float
style_top1: float
artist_top1: float
artist_top5: float
def compute_test_metrics(*, arch: str, last: bool = False) -> EvalMetrics:
"""
Run the same evaluation as the CLI and return metrics as floats in [0, 1].
Used by scripts/export_hf_model_card.py for Hub model card YAML.
"""
ckpt_name = "last.pt" if last else "best.pt"
ckpt_path = checkpoint_dir_for_arch(arch) / ckpt_name
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
if not INDEX_SELECTED.exists() or not WIKIART_ROOT.exists():
raise FileNotFoundError("Index or wikiart root missing.")
if torch.cuda.is_available():
device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
n_genre = ckpt["n_genre"]
n_style = ckpt["n_style"]
n_artist = ckpt["n_artist"]
ckpt_arch = ckpt.get("arch", arch)
import pandas as pd
df = pd.read_csv(INDEX_SELECTED)
_, _, idx_test = stratified_split(df)
ds = WikiArtDataset(INDEX_SELECTED, WIKIART_ROOT, transform=get_transforms(train=False))
test_loader = DataLoader(Subset(ds, idx_test), batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
if ckpt_arch == "cnnrnn":
model = ResNet50BiLSTMThreeHeads(n_genre=n_genre, n_style=n_style, n_artist=n_artist).to(device)
else:
model = ResNet50ThreeHeads(n_genre=n_genre, n_style=n_style, n_artist=n_artist).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
correct_g = correct_s = correct_a = correct_a5 = total = 0
with torch.no_grad():
for images, style_id, artist_id, genre_id in test_loader:
images = images.to(device)
style_id = style_id.to(device)
artist_id = artist_id.to(device)
genre_id = genre_id.to(device)
logits_g, logits_s, logits_a = model(images)
n = images.size(0)
total += n
correct_g += (logits_g.argmax(1) == genre_id).sum().item()
correct_s += (logits_s.argmax(1) == style_id).sum().item()
correct_a += (logits_a.argmax(1) == artist_id).sum().item()
_, top5 = logits_a.topk(5, dim=1)
correct_a5 += (top5 == artist_id.unsqueeze(1)).any(1).sum().item()
assert total > 0
return {
"arch": str(ckpt_arch),
"checkpoint_name": ckpt_name,
"checkpoint_path": str(ckpt_path),
"epoch": ckpt.get("epoch", None),
"test_n": int(total),
"genre_top1": correct_g / total,
"style_top1": correct_s / total,
"artist_top1": correct_a / total,
"artist_top5": correct_a5 / total,
}
def main() -> None:
import argparse
p = argparse.ArgumentParser()
p.add_argument("--arch", type=str, default="cnn", choices=["cnn", "cnnrnn"], help="Model architecture")
p.add_argument("--last", action="store_true", help="Evaluate last.pt instead of best.pt")
args = p.parse_args()
try:
m = compute_test_metrics(arch=args.arch, last=args.last)
except FileNotFoundError as e:
print(f"ERROR: {e}")
sys.exit(1)
print(
f"Arch: {m['arch']} Checkpoint: {m['checkpoint_name']} "
f"(epoch {m['epoch']!r}) Test n={m['test_n']}"
)
print(f" genre acc (top-1): {m['genre_top1']:.2%}")
print(f" style acc (top-1): {m['style_top1']:.2%}")
print(f" artist acc (top-1): {m['artist_top1']:.2%}")
print(f" artist acc (top-5): {m['artist_top5']:.2%}")
if __name__ == "__main__":
main()