import torch import torch.nn.functional as F from torchvision.datasets import Omniglot from torchvision import transforms from torch.utils.data import Dataset, DataLoader import json, os, random import numpy as np from tqdm import tqdm from model import SiameseNet # ── Episode Dataset ─────────────────────────────────────────── class EpisodeDataset(Dataset): """ Each item is one N-way K-shot episode: - N classes, K support images each → support set - N classes, 1 query image each → query set Returns support embeddings + query image + correct label """ def __init__(self, dataset, allowed_classes, transform, n_way=5, k_shot=1, n_episodes=600): self.transform = transform self.n_way = n_way self.k_shot = k_shot self.n_episodes = n_episodes self.dataset = dataset self.class_to_indices = {} for idx, (_, label) in enumerate(dataset): if label not in allowed_classes: continue self.class_to_indices.setdefault(label, []).append(idx) # Only keep classes with enough samples for K support + 1 query self.classes = [c for c, idxs in self.class_to_indices.items() if len(idxs) >= k_shot + 1] def __len__(self): return self.n_episodes def __getitem__(self, _): # Sample N classes for this episode episode_classes = random.sample(self.classes, self.n_way) support_imgs, query_imgs, query_labels = [], [], [] for label_idx, cls in enumerate(episode_classes): indices = random.sample(self.class_to_indices[cls], self.k_shot + 1) support_indices = indices[:self.k_shot] query_index = indices[self.k_shot] for i in support_indices: img, _ = self.dataset[i] support_imgs.append(self.transform(img)) img, _ = self.dataset[query_index] query_imgs.append(self.transform(img)) query_labels.append(label_idx) # support: [N*K, C, H, W] | query: [N, C, H, W] support = torch.stack(support_imgs) query = torch.stack(query_imgs) labels = torch.tensor(query_labels) return support, query, labels # ── Evaluation function ─────────────────────────────────────── @torch.no_grad() def evaluate_episodes(model, episode_ds, device, n_way, k_shot): model.eval() correct, total = 0, 0 loader = DataLoader(episode_ds, batch_size=1, shuffle=False, num_workers=2) for support, query, labels in tqdm(loader, desc=f"{n_way}-way {k_shot}-shot"): # Remove batch dim (batch_size=1) support = support.squeeze(0).to(device) # [N*K, C, H, W] query = query.squeeze(0).to(device) # [N, C, H, W] labels = labels.squeeze(0).to(device) # [N] # Get embeddings support_emb = model.get_embedding(support) # [N*K, 128] query_emb = model.get_embedding(query) # [N, 128] # Compute class prototypes (mean of K support embeddings per class) support_emb = support_emb.view(n_way, k_shot, -1).mean(dim=1) # [N, 128] # Cosine similarity: each query vs each class prototype sim = F.cosine_similarity( query_emb.unsqueeze(1), # [N, 1, 128] support_emb.unsqueeze(0), # [1, N, 128] dim=2 # → [N, N] ) preds = sim.argmax(dim=1) # [N] correct += (preds == labels).sum().item() total += labels.size(0) accuracy = correct / total return accuracy # ── Run all eval configurations ─────────────────────────────── def run_eval(checkpoint_path, data_root, split_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = SiameseNet(embedding_dim=128).to(device) ckpt = torch.load(checkpoint_path, map_location=device) model.load_state_dict(ckpt["model_state"]) print(f"Loaded checkpoint from epoch {ckpt['epoch']}") eval_transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((105, 105)), transforms.ToTensor(), transforms.Normalize([0.9220], [0.2256]), ]) bg = Omniglot(root=data_root, background=True, download=False, transform=None) with open(split_path) as f: test_classes = json.load(f)["test"] print(f"Evaluating on {len(test_classes)} unseen test classes\n") results = {} for n_way in [5, 10]: for k_shot in [1, 5]: ep_ds = EpisodeDataset( bg, test_classes, eval_transform, n_way=n_way, k_shot=k_shot, n_episodes=600 ) acc = evaluate_episodes(model, ep_ds, device, n_way, k_shot) key = f"{n_way}-way {k_shot}-shot" results[key] = acc print(f" {key:18s} → {acc*100:.2f}%") return results if __name__ == "__main__": results = run_eval( checkpoint_path = "../checkpoints/best.pt", data_root = "../data", split_path = "../data/class_split.json", )