import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image from src import constants # DATASET EXAMPLE class OlfactionVisionDataset(Dataset): def __init__(self, image_paths, olfaction_vectors, labels): self.image_paths = image_paths self.olfaction_vectors = olfaction_vectors self.labels = labels self.transform = transforms.Compose([ transforms.Resize((constants.IMG_DIM, constants.IMG_DIM)), transforms.ToTensor(), ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = self.transform(Image.open(img_path).convert('RGB')) olf_vec = self.olfaction_vectors[idx] label = self.labels[idx] return image, torch.tensor(olf_vec, dtype=torch.float32), label