Nikhil Mudhalwadkar
added other files
c6d5483
raw
history blame
2.43 kB
import torch
import os
from typing import List
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
import numpy as np
import albumentations.pytorch as al_pytorch
from typing import Dict, Tuple
class AnimeDataset(torch.utils.data.Dataset):
""" Sketchs and Colored Image dataset """
def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
""" Set the transforms and file path """
self.list_files = imgs_path
self.transform = transforms
def __len__(self) -> int:
""" Should return number of files """
return len(self.list_files)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
""" Get image and mask by index """
# read image file
img_path = img_file = self.list_files[index]
image = np.array(Image.open(img_path))
# divide image into sketchs and colored_imgs, right is sketch and left is colored images
# as according to the dataset
sketchs = image[:, image.shape[1] // 2:, :]
colored_imgs = image[:, :image.shape[1] // 2, :]
# data augmentation on both sketchs and colored_imgs
augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
sketchs, colored_imgs = augmentations['image'], augmentations['image0']
# conduct data augmentation respectively
sketchs = self.transform.transform_only_input(image=sketchs)['image']
colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
return sketchs, colored_imgs
class Transforms:
""" Class to hold transforms """
def __init__(self):
# use on both sketchs and colored images
self.both_transform = A.Compose([
A.Resize(width=1024, height=1024),
A.HorizontalFlip(p=.5)
],
additional_targets={'image0': 'image'}
)
# use on sketchs only
self.transform_only_input = A.Compose([
# A.ColorJitter(p=.1),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
# use on colored images
self.transform_only_mask = A.Compose([
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])