LETTER / src /eval.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
import torch
import torch.nn.functional as F
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import json, os, random
import numpy as np
from tqdm import tqdm
from model import SiameseNet
# ── Episode Dataset ───────────────────────────────────────────
class EpisodeDataset(Dataset):
"""
Each item is one N-way K-shot episode:
- N classes, K support images each β†’ support set
- N classes, 1 query image each β†’ query set
Returns support embeddings + query image + correct label
"""
def __init__(self, dataset, allowed_classes, transform, n_way=5, k_shot=1, n_episodes=600):
self.transform = transform
self.n_way = n_way
self.k_shot = k_shot
self.n_episodes = n_episodes
self.dataset = dataset
self.class_to_indices = {}
for idx, (_, label) in enumerate(dataset):
if label not in allowed_classes:
continue
self.class_to_indices.setdefault(label, []).append(idx)
# Only keep classes with enough samples for K support + 1 query
self.classes = [c for c, idxs in self.class_to_indices.items()
if len(idxs) >= k_shot + 1]
def __len__(self):
return self.n_episodes
def __getitem__(self, _):
# Sample N classes for this episode
episode_classes = random.sample(self.classes, self.n_way)
support_imgs, query_imgs, query_labels = [], [], []
for label_idx, cls in enumerate(episode_classes):
indices = random.sample(self.class_to_indices[cls], self.k_shot + 1)
support_indices = indices[:self.k_shot]
query_index = indices[self.k_shot]
for i in support_indices:
img, _ = self.dataset[i]
support_imgs.append(self.transform(img))
img, _ = self.dataset[query_index]
query_imgs.append(self.transform(img))
query_labels.append(label_idx)
# support: [N*K, C, H, W] | query: [N, C, H, W]
support = torch.stack(support_imgs)
query = torch.stack(query_imgs)
labels = torch.tensor(query_labels)
return support, query, labels
# ── Evaluation function ───────────────────────────────────────
@torch.no_grad()
def evaluate_episodes(model, episode_ds, device, n_way, k_shot):
model.eval()
correct, total = 0, 0
loader = DataLoader(episode_ds, batch_size=1, shuffle=False, num_workers=2)
for support, query, labels in tqdm(loader, desc=f"{n_way}-way {k_shot}-shot"):
# Remove batch dim (batch_size=1)
support = support.squeeze(0).to(device) # [N*K, C, H, W]
query = query.squeeze(0).to(device) # [N, C, H, W]
labels = labels.squeeze(0).to(device) # [N]
# Get embeddings
support_emb = model.get_embedding(support) # [N*K, 128]
query_emb = model.get_embedding(query) # [N, 128]
# Compute class prototypes (mean of K support embeddings per class)
support_emb = support_emb.view(n_way, k_shot, -1).mean(dim=1) # [N, 128]
# Cosine similarity: each query vs each class prototype
sim = F.cosine_similarity(
query_emb.unsqueeze(1), # [N, 1, 128]
support_emb.unsqueeze(0), # [1, N, 128]
dim=2 # β†’ [N, N]
)
preds = sim.argmax(dim=1) # [N]
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
return accuracy
# ── Run all eval configurations ───────────────────────────────
def run_eval(checkpoint_path, data_root, split_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = SiameseNet(embedding_dim=128).to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
print(f"Loaded checkpoint from epoch {ckpt['epoch']}")
eval_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize([0.9220], [0.2256]),
])
bg = Omniglot(root=data_root, background=True, download=False, transform=None)
with open(split_path) as f:
test_classes = json.load(f)["test"]
print(f"Evaluating on {len(test_classes)} unseen test classes\n")
results = {}
for n_way in [5, 10]:
for k_shot in [1, 5]:
ep_ds = EpisodeDataset(
bg, test_classes, eval_transform,
n_way=n_way, k_shot=k_shot, n_episodes=600
)
acc = evaluate_episodes(model, ep_ds, device, n_way, k_shot)
key = f"{n_way}-way {k_shot}-shot"
results[key] = acc
print(f" {key:18s} β†’ {acc*100:.2f}%")
return results
if __name__ == "__main__":
results = run_eval(
checkpoint_path = "../checkpoints/best.pt",
data_root = "../data",
split_path = "../data/class_split.json",
)