''' - This file contains definition for a custom dataset class inherits from torchvision.datasets.ImageFolder. The only reason why we have to do this is because we need to filter out child drawing and fossil image of dinosaurs using CLIP. - We will modify the __getitem__ method to return images, labels, and paths. By doing this, we can utilize DataLoader for batch processing, and clean our data much faster. ''' import os from torchvision.datasets import ImageFolder class DinoDataset(ImageFolder): ''' Custom dataset class inherits from torch.datasets.ImageFolder ''' def __init__(self, root, transform=None): super().__init__(root, transform) self.paths = [os.path.join(root, p[0]) for p in self.samples] def __getitem__(self, idx): img, label = super().__getitem__(idx) return img, label, self.paths[idx]