from torchvision.datasets import Omniglot from torchvision import transforms import matplotlib.pyplot as plt import random, os import json, random import torch from torch.utils.data import Dataset from PIL import Image class SiamesePairDataset(Dataset): def __init__(self, dataset, allowed_classes, transform=None, num_pairs=10000): self.transform = transform self.num_pairs = num_pairs # Group image indices by class self.class_to_indices = {} for idx, (_, label) in enumerate(dataset): if label not in allowed_classes: continue self.class_to_indices.setdefault(label, []).append(idx) self.classes = list(self.class_to_indices.keys()) self.dataset = dataset def __len__(self): return self.num_pairs def __getitem__(self, _): is_positive = random.random() > 0.5 # 50/50 split if is_positive: cls = random.choice(self.classes) i1, i2 = random.sample(self.class_to_indices[cls], 2) else: cls1, cls2 = random.sample(self.classes, 2) i1 = random.choice(self.class_to_indices[cls1]) i2 = random.choice(self.class_to_indices[cls2]) img1, _ = self.dataset[i1] img2, _ = self.dataset[i2] if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) label = torch.tensor(1.0 if is_positive else 0.0) return img1, img2, label def dl_data(): basic = transforms.ToTensor() bg = Omniglot(root=root, background=True, download=True, transform=basic) eval = Omniglot(root=root, background=False, download=True, transform=basic) print(f"Background split : {len(bg)} images") print(f"Evaluation split : {len(eval)} images") # Quick grid of sample images fig, axes = plt.subplots(2, 10, figsize=(16, 4)) for i, ax in enumerate(axes.flat): img, label = bg[i * 20] ax.imshow(img.squeeze(), cmap="gray") ax.axis("off") plt.tight_layout() plt.savefig("../logs/sample_grid.png", dpi=100) plt.show() # Split test, train and eval class_split(bg) def class_split(bg): all_classes = list(set([label for _, label in bg])) random.seed(42) random.shuffle(all_classes) n = len(all_classes) train_classes = all_classes[:int(n * 0.7)] val_classes = all_classes[int(n * 0.7):int(n * 0.85)] test_classes = all_classes[int(n * 0.85):] # NEVER touch until Day 5 split = {"train": train_classes, "val": val_classes, "test": test_classes} with open(os.path.join(root, "class_split.json"), "w") as f: json.dump(split, f, indent=4) print(f"Train: {len(train_classes)} | Val: {len(val_classes)} | Test: {len(test_classes)}") def validate_dataloader(): import json from torch.utils.data import DataLoader bg = Omniglot(root=root, background=True, download=True, transform=None) with open(os.path.join(root, "class_split.json")) as f: split = json.load(f) train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, num_pairs=10000) val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, num_pairs=2000) train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True) # Sanity check img1, img2, labels = next(iter(train_loader)) print(f"img1 shape : {img1.shape}") # [32, 1, 105, 105] print(f"img2 shape : {img2.shape}") # [32, 1, 105, 105] print(f"labels : {labels[:8]}") print(f"Positive % : {labels.mean().item()*100:.1f}%") # should be ~50% assert img1.shape == img2.shape == torch.Size([32, 1, 105, 105]) print("All assertions passed — DataLoader is ready") if __name__ == "__main__": root = "../data" if os.listdir(root) == []: dl_data() MEAN, STD = [0.9220], [0.2256] # Omniglot stats (grayscale) train_transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((105, 105)), transforms.RandomCrop(105, padding=8), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) eval_transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((105, 105)), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) validate_dataloader()