Spaces:
Runtime error
Runtime error
| # python3.7 | |
| """Contains the class of dataset.""" | |
| import os | |
| from PIL import Image | |
| from .process_image import ImageProcessor | |
| from torch.utils.data import Dataset | |
| class InferenceDataset(Dataset): | |
| def __init__(self, | |
| root_dir, | |
| resolution=256, | |
| aligner_path=None | |
| ): | |
| """Initializes the dataset. | |
| Args: | |
| root_dir: Root directory containing the dataset. | |
| resolution: The resolution of the returned image. | |
| transform: The transform function for pre-processing. | |
| (default: `datasets.transforms.normalize_image()`) | |
| """ | |
| self.root_dir = root_dir | |
| self.resolution = resolution | |
| self.image_paths = sorted(os.listdir(self.root_dir)) | |
| self.num_samples = len(self.image_paths) | |
| self.processor = ImageProcessor(aligner_path) | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, idx): | |
| data = dict() | |
| image_path = self.image_paths[idx] | |
| image = Image.open(os.path.join(self.root_dir, image_path)) | |
| image = self.processor.align_face(image) | |
| image = self.processor.preprocess_image(image) | |
| # image = image.resize( (self.resolution, self.resolution)) | |
| # image = np.asarray(image).transpose(2, 0, 1).astype(np.float32) # C,H,W -> H,W,C | |
| # image = torch.FloatTensor(image.copy()) | |
| # image = (image - 127.5) / 127.5 # Normalize | |
| data.update({'image': image}) | |
| data.update({'name': image_path}) | |
| return data | |