Spaces:
Runtime error
Runtime error
| import json | |
| import os.path | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from transformers import CLIPProcessor | |
| from torchvision.transforms import transforms | |
| import pytorch_lightning as pl | |
| class WikiArtDataset(): | |
| def __init__(self, meta_file): | |
| super(WikiArtDataset, self).__init__() | |
| self.files = [] | |
| with open(meta_file, 'r') as f: | |
| js = json.load(f) | |
| for img_path in js: | |
| img_name = os.path.splitext(os.path.basename(img_path))[0] | |
| caption = img_name.split('_')[-1] | |
| caption = caption.split('-') | |
| j = len(caption) - 1 | |
| while j >= 0: | |
| if not caption[j].isdigit(): | |
| break | |
| j -= 1 | |
| if j < 0: | |
| continue | |
| sentence = ' '.join(caption[:j + 1]) | |
| self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence}) | |
| version = 'openai/clip-vit-large-patch14' | |
| self.processor = CLIPProcessor.from_pretrained(version) | |
| self.jpg_transform = transforms.Compose([ | |
| transforms.Resize(512), | |
| transforms.RandomCrop(512), | |
| transforms.ToTensor(), | |
| ]) | |
| def __getitem__(self, idx): | |
| file = self.files[idx] | |
| im = Image.open(file['img_path']) | |
| im_tensor = self.jpg_transform(im) | |
| clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0] | |
| return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']} | |
| def __len__(self): | |
| return len(self.files) | |
| class WikiArtDataModule(pl.LightningDataModule): | |
| def __init__(self, meta_file, batch_size, num_workers): | |
| super(WikiArtDataModule, self).__init__() | |
| self.train_dataset = WikiArtDataset(meta_file) | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, | |
| pin_memory=True) | |