| from torchvision.datasets import Omniglot |
| from torchvision import transforms |
| import matplotlib.pyplot as plt |
| import random, os |
| import json, random |
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
|
|
| class SiamesePairDataset(Dataset): |
| def __init__(self, dataset, allowed_classes, transform=None, num_pairs=10000): |
| self.transform = transform |
| self.num_pairs = num_pairs |
|
|
| |
| self.class_to_indices = {} |
| for idx, (_, label) in enumerate(dataset): |
| if label not in allowed_classes: |
| continue |
| self.class_to_indices.setdefault(label, []).append(idx) |
|
|
| self.classes = list(self.class_to_indices.keys()) |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return self.num_pairs |
|
|
| def __getitem__(self, _): |
| is_positive = random.random() > 0.5 |
|
|
| if is_positive: |
| cls = random.choice(self.classes) |
| i1, i2 = random.sample(self.class_to_indices[cls], 2) |
| else: |
| cls1, cls2 = random.sample(self.classes, 2) |
| i1 = random.choice(self.class_to_indices[cls1]) |
| i2 = random.choice(self.class_to_indices[cls2]) |
|
|
| img1, _ = self.dataset[i1] |
| img2, _ = self.dataset[i2] |
|
|
| if self.transform: |
| img1 = self.transform(img1) |
| img2 = self.transform(img2) |
|
|
| label = torch.tensor(1.0 if is_positive else 0.0) |
| return img1, img2, label |
|
|
| def dl_data(): |
| basic = transforms.ToTensor() |
|
|
| bg = Omniglot(root=root, background=True, download=True, transform=basic) |
| eval = Omniglot(root=root, background=False, download=True, transform=basic) |
|
|
| print(f"Background split : {len(bg)} images") |
| print(f"Evaluation split : {len(eval)} images") |
|
|
| |
| fig, axes = plt.subplots(2, 10, figsize=(16, 4)) |
| for i, ax in enumerate(axes.flat): |
| img, label = bg[i * 20] |
| ax.imshow(img.squeeze(), cmap="gray") |
| ax.axis("off") |
| plt.tight_layout() |
| plt.savefig("../logs/sample_grid.png", dpi=100) |
| plt.show() |
| |
| class_split(bg) |
|
|
| def class_split(bg): |
|
|
| all_classes = list(set([label for _, label in bg])) |
| random.seed(42) |
| random.shuffle(all_classes) |
|
|
| n = len(all_classes) |
| train_classes = all_classes[:int(n * 0.7)] |
| val_classes = all_classes[int(n * 0.7):int(n * 0.85)] |
| test_classes = all_classes[int(n * 0.85):] |
|
|
| split = {"train": train_classes, "val": val_classes, "test": test_classes} |
| with open(os.path.join(root, "class_split.json"), "w") as f: |
| json.dump(split, f, indent=4) |
|
|
| print(f"Train: {len(train_classes)} | Val: {len(val_classes)} | Test: {len(test_classes)}") |
|
|
| def validate_dataloader(): |
| import json |
| from torch.utils.data import DataLoader |
| bg = Omniglot(root=root, background=True, download=True, transform=None) |
| with open(os.path.join(root, "class_split.json")) as f: |
| split = json.load(f) |
|
|
| train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, num_pairs=10000) |
| val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, num_pairs=2000) |
|
|
| train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True) |
|
|
| |
| img1, img2, labels = next(iter(train_loader)) |
| print(f"img1 shape : {img1.shape}") |
| print(f"img2 shape : {img2.shape}") |
| print(f"labels : {labels[:8]}") |
| print(f"Positive % : {labels.mean().item()*100:.1f}%") |
| assert img1.shape == img2.shape == torch.Size([32, 1, 105, 105]) |
| print("All assertions passed — DataLoader is ready") |
|
|
|
|
| if __name__ == "__main__": |
| root = "../data" |
| if os.listdir(root) == []: |
| dl_data() |
| |
|
|
| MEAN, STD = [0.9220], [0.2256] |
|
|
| train_transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.Resize((105, 105)), |
| transforms.RandomCrop(105, padding=8), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), |
| transforms.ToTensor(), |
| transforms.Normalize(MEAN, STD), |
| ]) |
|
|
| eval_transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.Resize((105, 105)), |
| transforms.ToTensor(), |
| transforms.Normalize(MEAN, STD), |
| ]) |
| validate_dataloader() |
|
|