File size: 4,651 Bytes
02ac88d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | from torchvision.datasets import Omniglot
from torchvision import transforms
import matplotlib.pyplot as plt
import random, os
import json, random
import torch
from torch.utils.data import Dataset
from PIL import Image
class SiamesePairDataset(Dataset):
def __init__(self, dataset, allowed_classes, transform=None, num_pairs=10000):
self.transform = transform
self.num_pairs = num_pairs
# Group image indices by class
self.class_to_indices = {}
for idx, (_, label) in enumerate(dataset):
if label not in allowed_classes:
continue
self.class_to_indices.setdefault(label, []).append(idx)
self.classes = list(self.class_to_indices.keys())
self.dataset = dataset
def __len__(self):
return self.num_pairs
def __getitem__(self, _):
is_positive = random.random() > 0.5 # 50/50 split
if is_positive:
cls = random.choice(self.classes)
i1, i2 = random.sample(self.class_to_indices[cls], 2)
else:
cls1, cls2 = random.sample(self.classes, 2)
i1 = random.choice(self.class_to_indices[cls1])
i2 = random.choice(self.class_to_indices[cls2])
img1, _ = self.dataset[i1]
img2, _ = self.dataset[i2]
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
label = torch.tensor(1.0 if is_positive else 0.0)
return img1, img2, label
def dl_data():
basic = transforms.ToTensor()
bg = Omniglot(root=root, background=True, download=True, transform=basic)
eval = Omniglot(root=root, background=False, download=True, transform=basic)
print(f"Background split : {len(bg)} images")
print(f"Evaluation split : {len(eval)} images")
# Quick grid of sample images
fig, axes = plt.subplots(2, 10, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
img, label = bg[i * 20]
ax.imshow(img.squeeze(), cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.savefig("../logs/sample_grid.png", dpi=100)
plt.show()
# Split test, train and eval
class_split(bg)
def class_split(bg):
all_classes = list(set([label for _, label in bg]))
random.seed(42)
random.shuffle(all_classes)
n = len(all_classes)
train_classes = all_classes[:int(n * 0.7)]
val_classes = all_classes[int(n * 0.7):int(n * 0.85)]
test_classes = all_classes[int(n * 0.85):] # NEVER touch until Day 5
split = {"train": train_classes, "val": val_classes, "test": test_classes}
with open(os.path.join(root, "class_split.json"), "w") as f:
json.dump(split, f, indent=4)
print(f"Train: {len(train_classes)} | Val: {len(val_classes)} | Test: {len(test_classes)}")
def validate_dataloader():
import json
from torch.utils.data import DataLoader
bg = Omniglot(root=root, background=True, download=True, transform=None)
with open(os.path.join(root, "class_split.json")) as f:
split = json.load(f)
train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, num_pairs=10000)
val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, num_pairs=2000)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
# Sanity check
img1, img2, labels = next(iter(train_loader))
print(f"img1 shape : {img1.shape}") # [32, 1, 105, 105]
print(f"img2 shape : {img2.shape}") # [32, 1, 105, 105]
print(f"labels : {labels[:8]}")
print(f"Positive % : {labels.mean().item()*100:.1f}%") # should be ~50%
assert img1.shape == img2.shape == torch.Size([32, 1, 105, 105])
print("All assertions passed — DataLoader is ready")
if __name__ == "__main__":
root = "../data"
if os.listdir(root) == []:
dl_data()
MEAN, STD = [0.9220], [0.2256] # Omniglot stats (grayscale)
train_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.RandomCrop(105, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
eval_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
validate_dataloader()
|