| import os |
| import pandas as pd |
| from PIL import Image |
| import torch |
| from torch.utils.data import ( |
| DataLoader, |
| Dataset, |
| IterableDataset, |
| SubsetRandomSampler, |
| get_worker_info, |
| ) |
| import clip.clip as clip |
|
|
|
|
| class CsvDataset(Dataset): |
| def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): |
| df = pd.read_csv(input_filename, sep=sep) |
|
|
| self.location = os.path.dirname(input_filename) |
| self.images = df[img_key].tolist() |
| self.captions = df[caption_key].tolist() |
| self.transforms = transforms |
|
|
| def __len__(self): |
| return len(self.captions) |
|
|
| def __getitem__(self, idx): |
| image_path = os.path.join(self.location, str(self.images[idx])) |
| images = self.transforms(Image.open(image_path)) |
| texts = clip.tokenize([str(self.captions[idx])])[0] |
| return images, texts |
|
|
|
|
| class conceptual_captions(Dataset): |
| def __init__( |
| self, transforms, location, batch_size, *args, num_workers=16, **kwargs |
| ): |
| file_name = "Validation_GCC-1.1.0-Validation_output.csv" |
| file_path = os.path.join(location, file_name) |
| self.template = lambda c: f"a photo of a {c}." |
| self.train_dataset = CsvDataset( |
| input_filename=file_path, |
| transforms=transforms, |
| img_key="filepath", |
| caption_key="title", |
| ) |
| |
| self.train_loader = torch.utils.data.DataLoader( |
| self.train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| ) |
|
|