import numpy as np import config import os from PIL import Image from torch.utils.data import Dataset, DataLoader from torchvision.utils import save_image class MapDataset(Dataset): def __init__(self, root_dir_comics, root_dir_face): self.root_dir_face = root_dir_face self.list_files_face = os.listdir(self.root_dir_face) self.root_dir_comics = root_dir_comics self.list_files_comics = os.listdir(self.root_dir_comics) def __len__(self): return len(self.list_files_face) def __getitem__(self, index): img_file_input = self.list_files_face[index] img_file_target = self.list_files_comics[index] img_path_input = os.path.join(self.root_dir_face, img_file_input) img_path_target = os.path.join(self.root_dir_comics, img_file_target) image_input = np.array(Image.open(img_path_input)) image_target = np.array(Image.open(img_path_target)) input_image = image_input target_image = image_target augmentations = config.both_transform(image=input_image, image0=target_image) input_image = augmentations["image"] target_image = augmentations["image0"] input_image = config.transform_only_input(image=input_image)["image"] target_image = config.transform_only_mask(image=target_image)["image"] return input_image, target_image if __name__ == "__main__": dataset = MapDataset("data/train/") loader = DataLoader(dataset, batch_size=5) for x, y in loader: print(x.shape) save_image(x, "x.png") save_image(y, "y.png") import sys sys.exit()