import os import pandas as pd from PIL import Image import torch from torch.utils.data import ( DataLoader, Dataset, IterableDataset, SubsetRandomSampler, get_worker_info, ) import clip.clip as clip class CsvDataset(Dataset): def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): df = pd.read_csv(input_filename, sep=sep) self.location = os.path.dirname(input_filename) self.images = df[img_key].tolist() self.captions = df[caption_key].tolist() self.transforms = transforms def __len__(self): return len(self.captions) def __getitem__(self, idx): image_path = os.path.join(self.location, str(self.images[idx])) images = self.transforms(Image.open(image_path)) texts = clip.tokenize([str(self.captions[idx])])[0] return images, texts class conceptual_captions(Dataset): def __init__( self, transforms, location, batch_size, *args, num_workers=16, **kwargs ): file_name = "Validation_GCC-1.1.0-Validation_output.csv" file_path = os.path.join(location, file_name) self.template = lambda c: f"a photo of a {c}." self.train_dataset = CsvDataset( input_filename=file_path, transforms=transforms, img_key="filepath", caption_key="title", ) # breakpoint() self.train_loader = torch.utils.data.DataLoader( self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, )