LETTER / src /dataset.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
from torchvision.datasets import Omniglot
from torchvision import transforms
import matplotlib.pyplot as plt
import random, os
import json, random
import torch
from torch.utils.data import Dataset
from PIL import Image
class SiamesePairDataset(Dataset):
def __init__(self, dataset, allowed_classes, transform=None, num_pairs=10000):
self.transform = transform
self.num_pairs = num_pairs
# Group image indices by class
self.class_to_indices = {}
for idx, (_, label) in enumerate(dataset):
if label not in allowed_classes:
continue
self.class_to_indices.setdefault(label, []).append(idx)
self.classes = list(self.class_to_indices.keys())
self.dataset = dataset
def __len__(self):
return self.num_pairs
def __getitem__(self, _):
is_positive = random.random() > 0.5 # 50/50 split
if is_positive:
cls = random.choice(self.classes)
i1, i2 = random.sample(self.class_to_indices[cls], 2)
else:
cls1, cls2 = random.sample(self.classes, 2)
i1 = random.choice(self.class_to_indices[cls1])
i2 = random.choice(self.class_to_indices[cls2])
img1, _ = self.dataset[i1]
img2, _ = self.dataset[i2]
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
label = torch.tensor(1.0 if is_positive else 0.0)
return img1, img2, label
def dl_data():
basic = transforms.ToTensor()
bg = Omniglot(root=root, background=True, download=True, transform=basic)
eval = Omniglot(root=root, background=False, download=True, transform=basic)
print(f"Background split : {len(bg)} images")
print(f"Evaluation split : {len(eval)} images")
# Quick grid of sample images
fig, axes = plt.subplots(2, 10, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
img, label = bg[i * 20]
ax.imshow(img.squeeze(), cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.savefig("../logs/sample_grid.png", dpi=100)
plt.show()
# Split test, train and eval
class_split(bg)
def class_split(bg):
all_classes = list(set([label for _, label in bg]))
random.seed(42)
random.shuffle(all_classes)
n = len(all_classes)
train_classes = all_classes[:int(n * 0.7)]
val_classes = all_classes[int(n * 0.7):int(n * 0.85)]
test_classes = all_classes[int(n * 0.85):] # NEVER touch until Day 5
split = {"train": train_classes, "val": val_classes, "test": test_classes}
with open(os.path.join(root, "class_split.json"), "w") as f:
json.dump(split, f, indent=4)
print(f"Train: {len(train_classes)} | Val: {len(val_classes)} | Test: {len(test_classes)}")
def validate_dataloader():
import json
from torch.utils.data import DataLoader
bg = Omniglot(root=root, background=True, download=True, transform=None)
with open(os.path.join(root, "class_split.json")) as f:
split = json.load(f)
train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, num_pairs=10000)
val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, num_pairs=2000)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
# Sanity check
img1, img2, labels = next(iter(train_loader))
print(f"img1 shape : {img1.shape}") # [32, 1, 105, 105]
print(f"img2 shape : {img2.shape}") # [32, 1, 105, 105]
print(f"labels : {labels[:8]}")
print(f"Positive % : {labels.mean().item()*100:.1f}%") # should be ~50%
assert img1.shape == img2.shape == torch.Size([32, 1, 105, 105])
print("All assertions passed — DataLoader is ready")
if __name__ == "__main__":
root = "../data"
if os.listdir(root) == []:
dl_data()
MEAN, STD = [0.9220], [0.2256] # Omniglot stats (grayscale)
train_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.RandomCrop(105, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
eval_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
validate_dataloader()