| import io |
| import os |
|
|
| import pandas as pd |
| import torch |
| from PIL import Image |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
|
|
|
|
| class ParquetImageDataset(Dataset): |
| def __init__( |
| self, parquet_files, transform=None, im_size=256, condition_config=None |
| ): |
| self.data = pd.concat( |
| [pd.read_parquet(file) for file in parquet_files], ignore_index=True |
| ) |
| self.transform = transform |
| self.im_size = im_size |
| self.condition_types = ( |
| [] if condition_config is None else condition_config["condition_types"] |
| ) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| image = Image.open(io.BytesIO(self.data.iloc[idx]["image"]["bytes"])) |
| caption = self.data.iloc[idx]["text"] |
| im_tensor = transforms.Compose( |
| [ |
| transforms.Resize(self.im_size), |
| transforms.CenterCrop(self.im_size), |
| transforms.ToTensor(), |
| ] |
| )(image) |
| image.close() |
| im_tensor = (2 * im_tensor) - 1 |
| if len(self.condition_types) == 0: |
| return im_tensor |
| else: |
| return im_tensor, caption |
|
|
|
|
| def create_dataloader(parquet_dir, batch_size=32, shuffle=True, num_workers=4): |
| parquet_files = [ |
| os.path.join(parquet_dir, f) |
| for f in os.listdir(parquet_dir) |
| if f.endswith(".parquet") |
| ] |
| dataset = ParquetImageDataset(parquet_files) |
| dataloader = DataLoader( |
| dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers |
| ) |
| return dataloader |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|