import torch import os from typing import List, Optional from PIL import Image import matplotlib.pyplot as plt from torchvision import transforms import albumentations as A import numpy as np import albumentations.pytorch as al_pytorch from typing import Dict, Tuple from app import config import pytorch_lightning as pl torch.__version__ class AnimeDataset(torch.utils.data.Dataset): """ Sketchs and Colored Image dataset """ def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None: """ Set the transforms and file path """ self.list_files = imgs_path self.transform = transforms def __len__(self) -> int: """ Should return number of files """ return len(self.list_files) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get image and mask by index """ # read image file img_file = self.list_files[index] # img_path = os.path.join(self.root_dir, img_file) image = np.array(Image.open(img_file)) # divide image into sketchs and colored_imgs, right is sketch and left is colored images sketchs = image[:, image.shape[1] // 2:, :] colored_imgs = image[:, :image.shape[1] // 2, :] # data augmentation on both sketchs and colored_imgs augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs) sketchs, colored_imgs = augmentations['image'], augmentations['image0'] # conduct data augmentation respectively sketchs = self.transform.transform_only_input(image=sketchs)['image'] colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image'] return sketchs, colored_imgs # Data Augmentation class Transforms: def __init__(self): # use on both sketchs and colored images self.both_transform = A.Compose([ A.Resize(width=256, height=256), A.HorizontalFlip(p=.5) ], additional_targets={'image0': 'image'}) # use on sketchs only self.transform_only_input = A.Compose([ A.ColorJitter(p=.1), A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) # use on colored images self.transform_only_mask = A.Compose([ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) class Transforms_v1: """ Class to hold transforms """ def __init__(self): # use on both sketchs and colored images self.resize_572 = A.Compose([ A.Resize(width=572, height=572) ]) self.resize_388 = A.Compose([ A.Resize(width=388, height=388) ]) self.resize_256 = A.Compose([ A.Resize(width=256, height=256) ]) # use on sketchs only self.transform_only_input = A.Compose([ # A.ColorJitter(p=.1), A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) # use on colored images self.transform_only_mask = A.Compose([ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) class AnimeSketchDataModule(pl.LightningDataModule): """ Class to hold the Anime sketch Data""" def __init__( self, data_dir: str, train_folder_name: str = "train/", val_folder_name: str = "val/", train_batch_size: int = config.train_batch_size, val_batch_size: int = config.val_batch_size, train_num_images: int = 0, val_num_images: int = 0, ): super().__init__() self.val_dataset = None self.train_dataset = None self.data_dir: str = data_dir # Set train and val images folder train_path: str = f"{self.data_dir}{train_folder_name}/" train_images: List[str] = [f"{train_path}{x}" for x in os.listdir(train_path)] val_path: str = f"{self.data_dir}{val_folder_name}" val_images: List[str] = [f"{val_path}{x}" for x in os.listdir(val_path)] # self.train_images = train_images[:train_num_images] if train_num_images else train_images self.val_images = val_images[:val_num_images] if val_num_images else val_images # self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size def set_datasets(self) -> None: """ Get the train and test datasets """ self.train_dataset = AnimeDataset( imgs_path=self.train_images, transforms=Transforms() ) self.val_dataset = AnimeDataset( imgs_path=self.val_images, transforms=Transforms() ) print("The train test dataset lengths are : ", len(self.train_dataset), len(self.val_dataset)) return None def setup(self, stage: Optional[str] = None) -> None: self.set_datasets() def train_dataloader(self): return torch.utils.data.DataLoader( self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=2, pin_memory=True ) def val_dataloader(self): return torch.utils.data.DataLoader( self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=2, pin_memory=True )