File size: 5,421 Bytes
02ac88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.nn.functional as F
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import json, os, random
import numpy as np
from tqdm import tqdm

from model import SiameseNet

# ── Episode Dataset ───────────────────────────────────────────
class EpisodeDataset(Dataset):
    """
    Each item is one N-way K-shot episode:
      - N classes, K support images each  β†’ support set
      - N classes, 1 query image each     β†’ query set
    Returns support embeddings + query image + correct label
    """
    def __init__(self, dataset, allowed_classes, transform, n_way=5, k_shot=1, n_episodes=600):
        self.transform   = transform
        self.n_way       = n_way
        self.k_shot      = k_shot
        self.n_episodes  = n_episodes
        self.dataset     = dataset

        self.class_to_indices = {}
        for idx, (_, label) in enumerate(dataset):
            if label not in allowed_classes:
                continue
            self.class_to_indices.setdefault(label, []).append(idx)

        # Only keep classes with enough samples for K support + 1 query
        self.classes = [c for c, idxs in self.class_to_indices.items()
                        if len(idxs) >= k_shot + 1]

    def __len__(self):
        return self.n_episodes

    def __getitem__(self, _):
        # Sample N classes for this episode
        episode_classes = random.sample(self.classes, self.n_way)

        support_imgs, query_imgs, query_labels = [], [], []

        for label_idx, cls in enumerate(episode_classes):
            indices = random.sample(self.class_to_indices[cls], self.k_shot + 1)
            support_indices = indices[:self.k_shot]
            query_index     = indices[self.k_shot]

            for i in support_indices:
                img, _ = self.dataset[i]
                support_imgs.append(self.transform(img))

            img, _ = self.dataset[query_index]
            query_imgs.append(self.transform(img))
            query_labels.append(label_idx)

        # support: [N*K, C, H, W] | query: [N, C, H, W]
        support = torch.stack(support_imgs)
        query   = torch.stack(query_imgs)
        labels  = torch.tensor(query_labels)
        return support, query, labels


# ── Evaluation function ───────────────────────────────────────
@torch.no_grad()
def evaluate_episodes(model, episode_ds, device, n_way, k_shot):
    model.eval()
    correct, total = 0, 0

    loader = DataLoader(episode_ds, batch_size=1, shuffle=False, num_workers=2)

    for support, query, labels in tqdm(loader, desc=f"{n_way}-way {k_shot}-shot"):
        # Remove batch dim (batch_size=1)
        support = support.squeeze(0).to(device)   # [N*K, C, H, W]
        query   = query.squeeze(0).to(device)     # [N, C, H, W]
        labels  = labels.squeeze(0).to(device)    # [N]

        # Get embeddings
        support_emb = model.get_embedding(support)  # [N*K, 128]
        query_emb   = model.get_embedding(query)    # [N,   128]

        # Compute class prototypes (mean of K support embeddings per class)
        support_emb = support_emb.view(n_way, k_shot, -1).mean(dim=1)  # [N, 128]

        # Cosine similarity: each query vs each class prototype
        sim = F.cosine_similarity(
            query_emb.unsqueeze(1),      # [N, 1, 128]
            support_emb.unsqueeze(0),    # [1, N, 128]
            dim=2                        # β†’ [N, N]
        )

        preds    = sim.argmax(dim=1)     # [N]
        correct += (preds == labels).sum().item()
        total   += labels.size(0)

    accuracy = correct / total
    return accuracy


# ── Run all eval configurations ───────────────────────────────
def run_eval(checkpoint_path, data_root, split_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = SiameseNet(embedding_dim=128).to(device)
    ckpt  = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    print(f"Loaded checkpoint from epoch {ckpt['epoch']}")

    eval_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((105, 105)),
        transforms.ToTensor(),
        transforms.Normalize([0.9220], [0.2256]),
    ])

    bg = Omniglot(root=data_root, background=True, download=False, transform=None)

    with open(split_path) as f:
        test_classes = json.load(f)["test"]

    print(f"Evaluating on {len(test_classes)} unseen test classes\n")

    results = {}
    for n_way in [5, 10]:
        for k_shot in [1, 5]:
            ep_ds = EpisodeDataset(
                bg, test_classes, eval_transform,
                n_way=n_way, k_shot=k_shot, n_episodes=600
            )
            acc = evaluate_episodes(model, ep_ds, device, n_way, k_shot)
            key = f"{n_way}-way {k_shot}-shot"
            results[key] = acc
            print(f"  {key:18s}  β†’  {acc*100:.2f}%")

    return results


if __name__ == "__main__":
    results = run_eval(
        checkpoint_path = "../checkpoints/best.pt",
        data_root       = "../data",
        split_path      = "../data/class_split.json",
    )