Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import config | |
| import os | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision.utils import save_image | |
| class MapDataset(Dataset): | |
| def __init__(self, root_dir_comics, root_dir_face): | |
| self.root_dir_face = root_dir_face | |
| self.list_files_face = os.listdir(self.root_dir_face) | |
| self.root_dir_comics = root_dir_comics | |
| self.list_files_comics = os.listdir(self.root_dir_comics) | |
| def __len__(self): | |
| return len(self.list_files_face) | |
| def __getitem__(self, index): | |
| img_file_input = self.list_files_face[index] | |
| img_file_target = self.list_files_comics[index] | |
| img_path_input = os.path.join(self.root_dir_face, img_file_input) | |
| img_path_target = os.path.join(self.root_dir_comics, img_file_target) | |
| image_input = np.array(Image.open(img_path_input)) | |
| image_target = np.array(Image.open(img_path_target)) | |
| input_image = image_input | |
| target_image = image_target | |
| augmentations = config.both_transform(image=input_image, image0=target_image) | |
| input_image = augmentations["image"] | |
| target_image = augmentations["image0"] | |
| input_image = config.transform_only_input(image=input_image)["image"] | |
| target_image = config.transform_only_mask(image=target_image)["image"] | |
| return input_image, target_image | |
| if __name__ == "__main__": | |
| dataset = MapDataset("data/train/") | |
| loader = DataLoader(dataset, batch_size=5) | |
| for x, y in loader: | |
| print(x.shape) | |
| save_image(x, "x.png") | |
| save_image(y, "y.png") | |
| import sys | |
| sys.exit() |