| from glob import glob | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| def make_transform( | |
| smaller_edge_size: int, patch_size, center_crop=False, max_edge_size=812 | |
| ) -> transforms.Compose: | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| interpolation_mode = transforms.InterpolationMode.BICUBIC | |
| assert smaller_edge_size > 0 | |
| if center_crop: | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| size=smaller_edge_size, | |
| interpolation=interpolation_mode, | |
| antialias=True, | |
| ), | |
| transforms.CenterCrop(smaller_edge_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD | |
| ), | |
| transforms.Lambda( | |
| lambda img: img[ | |
| :, | |
| : min( | |
| max_edge_size, | |
| (img.shape[1] - img.shape[1] % patch_size), | |
| ), | |
| : min( | |
| max_edge_size, | |
| (img.shape[2] - img.shape[2] % patch_size), | |
| ), | |
| ] | |
| ), | |
| ] | |
| ) | |
| else: | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| size=smaller_edge_size, | |
| interpolation=interpolation_mode, | |
| antialias=True, | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD | |
| ), | |
| transforms.Lambda( | |
| lambda img: img[ | |
| :, | |
| : min( | |
| max_edge_size, | |
| (img.shape[1] - img.shape[1] % patch_size), | |
| ), | |
| : min( | |
| max_edge_size, | |
| (img.shape[2] - img.shape[2] % patch_size), | |
| ), | |
| ] | |
| ), | |
| ] | |
| ) | |
| class VisualDataset(Dataset): | |
| def __init__(self, transform, imgs=None): | |
| self.transform = transform | |
| if imgs is None: | |
| self.files = [ | |
| 'resources/example.jpg', | |
| 'resources/villa.png', | |
| 'resources/000000037740.jpg', | |
| 'resources/000000064359.jpg', | |
| 'resources/000000066635.jpg', | |
| 'resources/000000078420.jpg', | |
| ] | |
| else: | |
| self.files = imgs | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, index): | |
| img = self.files[index] | |
| img = Image.open(img).convert('RGB') | |
| if self.transform: | |
| img = self.transform(img) | |
| return img | |
| class ImageNetDataset(Dataset): | |
| def __init__(self, transform, num_train_max=1000000): | |
| self.transform = transform | |
| self.files = glob('data/imagenet/train/*/*.JPEG') | |
| step = len(self.files) // num_train_max | |
| self.files = self.files[::step] | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, index): | |
| img = Image.open(self.files[index]).convert('RGB') | |
| img = self.transform(img) | |
| return img | |
| def load_data(args, model): | |
| transform = make_transform( | |
| args.resolution, model.patch_size, center_crop=True | |
| ) | |
| dataset = ImageNetDataset( | |
| transform=transform, num_train_max=args.num_train_max | |
| ) | |
| return dataset | |
| def load_visual_data(args, model): | |
| transform = make_transform( | |
| args.visual_size, model.patch_size, max_edge_size=1792 | |
| ) | |
| dataset = VisualDataset(transform=transform, imgs=vars(args).get('imgs')) | |
| return dataset | |