Spaces:
Runtime error
Runtime error
| from abc import abstractmethod | |
| from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset | |
| from PIL import Image, ImageFile | |
| from pathlib import Path | |
| from functools import partial | |
| from torchvision import transforms as T, utils | |
| from torch import nn | |
| def exists(val): | |
| return val is not None | |
| def cycle(dl): | |
| while True: | |
| for data in dl: | |
| yield data | |
| def convert_image_to(img_type, image): | |
| if image.mode != img_type: | |
| return image.convert(img_type) | |
| return image | |
| class Txt2ImgIterableBaseDataset(IterableDataset): | |
| ''' | |
| Define an interface to make the IterableDatasets for text2img data chainable | |
| ''' | |
| def __init__(self, num_records=0, valid_ids=None, size=256): | |
| super().__init__() | |
| self.num_records = num_records | |
| self.valid_ids = valid_ids | |
| self.sample_ids = valid_ids | |
| self.size = size | |
| # print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') | |
| # def __len__(self): | |
| # return self.num_records | |
| def __iter__(self): | |
| pass | |
| class BaseDataset(Dataset): | |
| def __init__( | |
| self, | |
| folder, | |
| image_size, | |
| exts = ['jpg', 'jpeg', 'png', 'tiff'], | |
| convert_image_to_type = None | |
| ): | |
| super().__init__() | |
| self.folder = folder | |
| self.image_size = image_size | |
| self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] | |
| convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity() | |
| self.transform = T.Compose([ | |
| T.Lambda(convert_fn), | |
| T.Resize(image_size), | |
| T.RandomHorizontalFlip(), | |
| T.CenterCrop(image_size), | |
| T.ToTensor() | |
| ]) | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, index): | |
| path = self.paths[index] | |
| img = Image.open(path) | |
| return self.transform(img) |