YashNagraj75's picture
Add the dataset and the training script
31677e7
import io
import os
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class ParquetImageDataset(Dataset):
def __init__(
self, parquet_files, transform=None, im_size=256, condition_config=None
):
self.data = pd.concat(
[pd.read_parquet(file) for file in parquet_files], ignore_index=True
)
self.transform = transform
self.im_size = im_size
self.condition_types = (
[] if condition_config is None else condition_config["condition_types"]
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = Image.open(io.BytesIO(self.data.iloc[idx]["image"]["bytes"]))
caption = self.data.iloc[idx]["text"]
im_tensor = transforms.Compose(
[
transforms.Resize(self.im_size),
transforms.CenterCrop(self.im_size),
transforms.ToTensor(),
]
)(image)
image.close()
im_tensor = (2 * im_tensor) - 1 # type: ignore
if len(self.condition_types) == 0:
return im_tensor
else:
return im_tensor, caption
def create_dataloader(parquet_dir, batch_size=32, shuffle=True, num_workers=4):
parquet_files = [
os.path.join(parquet_dir, f)
for f in os.listdir(parquet_dir)
if f.endswith(".parquet")
]
dataset = ParquetImageDataset(parquet_files)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
)
return dataloader
# Test the dataloader
# dataloader = create_dataloader('dataset')
# for i, batch in enumerate(dataloader):
# for img_tensor in batch:
# if i == 0 :
# print(img_tensor)
# else:
# break