File size: 5,421 Bytes
02ac88d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import torch
import torch.nn.functional as F
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import json, os, random
import numpy as np
from tqdm import tqdm
from model import SiameseNet
# ββ Episode Dataset βββββββββββββββββββββββββββββββββββββββββββ
class EpisodeDataset(Dataset):
"""
Each item is one N-way K-shot episode:
- N classes, K support images each β support set
- N classes, 1 query image each β query set
Returns support embeddings + query image + correct label
"""
def __init__(self, dataset, allowed_classes, transform, n_way=5, k_shot=1, n_episodes=600):
self.transform = transform
self.n_way = n_way
self.k_shot = k_shot
self.n_episodes = n_episodes
self.dataset = dataset
self.class_to_indices = {}
for idx, (_, label) in enumerate(dataset):
if label not in allowed_classes:
continue
self.class_to_indices.setdefault(label, []).append(idx)
# Only keep classes with enough samples for K support + 1 query
self.classes = [c for c, idxs in self.class_to_indices.items()
if len(idxs) >= k_shot + 1]
def __len__(self):
return self.n_episodes
def __getitem__(self, _):
# Sample N classes for this episode
episode_classes = random.sample(self.classes, self.n_way)
support_imgs, query_imgs, query_labels = [], [], []
for label_idx, cls in enumerate(episode_classes):
indices = random.sample(self.class_to_indices[cls], self.k_shot + 1)
support_indices = indices[:self.k_shot]
query_index = indices[self.k_shot]
for i in support_indices:
img, _ = self.dataset[i]
support_imgs.append(self.transform(img))
img, _ = self.dataset[query_index]
query_imgs.append(self.transform(img))
query_labels.append(label_idx)
# support: [N*K, C, H, W] | query: [N, C, H, W]
support = torch.stack(support_imgs)
query = torch.stack(query_imgs)
labels = torch.tensor(query_labels)
return support, query, labels
# ββ Evaluation function βββββββββββββββββββββββββββββββββββββββ
@torch.no_grad()
def evaluate_episodes(model, episode_ds, device, n_way, k_shot):
model.eval()
correct, total = 0, 0
loader = DataLoader(episode_ds, batch_size=1, shuffle=False, num_workers=2)
for support, query, labels in tqdm(loader, desc=f"{n_way}-way {k_shot}-shot"):
# Remove batch dim (batch_size=1)
support = support.squeeze(0).to(device) # [N*K, C, H, W]
query = query.squeeze(0).to(device) # [N, C, H, W]
labels = labels.squeeze(0).to(device) # [N]
# Get embeddings
support_emb = model.get_embedding(support) # [N*K, 128]
query_emb = model.get_embedding(query) # [N, 128]
# Compute class prototypes (mean of K support embeddings per class)
support_emb = support_emb.view(n_way, k_shot, -1).mean(dim=1) # [N, 128]
# Cosine similarity: each query vs each class prototype
sim = F.cosine_similarity(
query_emb.unsqueeze(1), # [N, 1, 128]
support_emb.unsqueeze(0), # [1, N, 128]
dim=2 # β [N, N]
)
preds = sim.argmax(dim=1) # [N]
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
return accuracy
# ββ Run all eval configurations βββββββββββββββββββββββββββββββ
def run_eval(checkpoint_path, data_root, split_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = SiameseNet(embedding_dim=128).to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
print(f"Loaded checkpoint from epoch {ckpt['epoch']}")
eval_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize([0.9220], [0.2256]),
])
bg = Omniglot(root=data_root, background=True, download=False, transform=None)
with open(split_path) as f:
test_classes = json.load(f)["test"]
print(f"Evaluating on {len(test_classes)} unseen test classes\n")
results = {}
for n_way in [5, 10]:
for k_shot in [1, 5]:
ep_ds = EpisodeDataset(
bg, test_classes, eval_transform,
n_way=n_way, k_shot=k_shot, n_episodes=600
)
acc = evaluate_episodes(model, ep_ds, device, n_way, k_shot)
key = f"{n_way}-way {k_shot}-shot"
results[key] = acc
print(f" {key:18s} β {acc*100:.2f}%")
return results
if __name__ == "__main__":
results = run_eval(
checkpoint_path = "../checkpoints/best.pt",
data_root = "../data",
split_path = "../data/class_split.json",
) |