| import torch |
| import torch.nn.functional as F |
| from torchvision.datasets import Omniglot |
| from torchvision import transforms |
| from torch.utils.data import Dataset, DataLoader |
| import json, os, random |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from model import SiameseNet |
|
|
| |
| class EpisodeDataset(Dataset): |
| """ |
| Each item is one N-way K-shot episode: |
| - N classes, K support images each β support set |
| - N classes, 1 query image each β query set |
| Returns support embeddings + query image + correct label |
| """ |
| def __init__(self, dataset, allowed_classes, transform, n_way=5, k_shot=1, n_episodes=600): |
| self.transform = transform |
| self.n_way = n_way |
| self.k_shot = k_shot |
| self.n_episodes = n_episodes |
| self.dataset = dataset |
|
|
| self.class_to_indices = {} |
| for idx, (_, label) in enumerate(dataset): |
| if label not in allowed_classes: |
| continue |
| self.class_to_indices.setdefault(label, []).append(idx) |
|
|
| |
| self.classes = [c for c, idxs in self.class_to_indices.items() |
| if len(idxs) >= k_shot + 1] |
|
|
| def __len__(self): |
| return self.n_episodes |
|
|
| def __getitem__(self, _): |
| |
| episode_classes = random.sample(self.classes, self.n_way) |
|
|
| support_imgs, query_imgs, query_labels = [], [], [] |
|
|
| for label_idx, cls in enumerate(episode_classes): |
| indices = random.sample(self.class_to_indices[cls], self.k_shot + 1) |
| support_indices = indices[:self.k_shot] |
| query_index = indices[self.k_shot] |
|
|
| for i in support_indices: |
| img, _ = self.dataset[i] |
| support_imgs.append(self.transform(img)) |
|
|
| img, _ = self.dataset[query_index] |
| query_imgs.append(self.transform(img)) |
| query_labels.append(label_idx) |
|
|
| |
| support = torch.stack(support_imgs) |
| query = torch.stack(query_imgs) |
| labels = torch.tensor(query_labels) |
| return support, query, labels |
|
|
|
|
| |
| @torch.no_grad() |
| def evaluate_episodes(model, episode_ds, device, n_way, k_shot): |
| model.eval() |
| correct, total = 0, 0 |
|
|
| loader = DataLoader(episode_ds, batch_size=1, shuffle=False, num_workers=2) |
|
|
| for support, query, labels in tqdm(loader, desc=f"{n_way}-way {k_shot}-shot"): |
| |
| support = support.squeeze(0).to(device) |
| query = query.squeeze(0).to(device) |
| labels = labels.squeeze(0).to(device) |
|
|
| |
| support_emb = model.get_embedding(support) |
| query_emb = model.get_embedding(query) |
|
|
| |
| support_emb = support_emb.view(n_way, k_shot, -1).mean(dim=1) |
|
|
| |
| sim = F.cosine_similarity( |
| query_emb.unsqueeze(1), |
| support_emb.unsqueeze(0), |
| dim=2 |
| ) |
|
|
| preds = sim.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| accuracy = correct / total |
| return accuracy |
|
|
|
|
| |
| def run_eval(checkpoint_path, data_root, split_path): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = SiameseNet(embedding_dim=128).to(device) |
| ckpt = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(ckpt["model_state"]) |
| print(f"Loaded checkpoint from epoch {ckpt['epoch']}") |
|
|
| eval_transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.Resize((105, 105)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.9220], [0.2256]), |
| ]) |
|
|
| bg = Omniglot(root=data_root, background=True, download=False, transform=None) |
|
|
| with open(split_path) as f: |
| test_classes = json.load(f)["test"] |
|
|
| print(f"Evaluating on {len(test_classes)} unseen test classes\n") |
|
|
| results = {} |
| for n_way in [5, 10]: |
| for k_shot in [1, 5]: |
| ep_ds = EpisodeDataset( |
| bg, test_classes, eval_transform, |
| n_way=n_way, k_shot=k_shot, n_episodes=600 |
| ) |
| acc = evaluate_episodes(model, ep_ds, device, n_way, k_shot) |
| key = f"{n_way}-way {k_shot}-shot" |
| results[key] = acc |
| print(f" {key:18s} β {acc*100:.2f}%") |
|
|
| return results |
|
|
|
|
| if __name__ == "__main__": |
| results = run_eval( |
| checkpoint_path = "../checkpoints/best.pt", |
| data_root = "../data", |
| split_path = "../data/class_split.json", |
| ) |