Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| from typing import List | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from torchvision import transforms | |
| import albumentations as A | |
| import numpy as np | |
| import albumentations.pytorch as al_pytorch | |
| from typing import Dict, Tuple | |
| class AnimeDataset(torch.utils.data.Dataset): | |
| """ Sketchs and Colored Image dataset """ | |
| def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None: | |
| """ Set the transforms and file path """ | |
| self.list_files = imgs_path | |
| self.transform = transforms | |
| def __len__(self) -> int: | |
| """ Should return number of files """ | |
| return len(self.list_files) | |
| def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ Get image and mask by index """ | |
| # read image file | |
| img_path = img_file = self.list_files[index] | |
| image = np.array(Image.open(img_path)) | |
| # divide image into sketchs and colored_imgs, right is sketch and left is colored images | |
| # as according to the dataset | |
| sketchs = image[:, image.shape[1] // 2:, :] | |
| colored_imgs = image[:, :image.shape[1] // 2, :] | |
| # data augmentation on both sketchs and colored_imgs | |
| augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs) | |
| sketchs, colored_imgs = augmentations['image'], augmentations['image0'] | |
| # conduct data augmentation respectively | |
| sketchs = self.transform.transform_only_input(image=sketchs)['image'] | |
| colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image'] | |
| return sketchs, colored_imgs | |
| class Transforms: | |
| """ Class to hold transforms """ | |
| def __init__(self): | |
| # use on both sketchs and colored images | |
| self.both_transform = A.Compose([ | |
| A.Resize(width=1024, height=1024), | |
| A.HorizontalFlip(p=.5) | |
| ], | |
| additional_targets={'image0': 'image'} | |
| ) | |
| # use on sketchs only | |
| self.transform_only_input = A.Compose([ | |
| # A.ColorJitter(p=.1), | |
| A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), | |
| al_pytorch.ToTensorV2(), | |
| ]) | |
| # use on colored images | |
| self.transform_only_mask = A.Compose([ | |
| A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), | |
| al_pytorch.ToTensorV2(), | |
| ]) |