Spaces:
Runtime error
Runtime error
File size: 5,622 Bytes
c6d5483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import torch
import os
from typing import List, Optional
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
import numpy as np
import albumentations.pytorch as al_pytorch
from typing import Dict, Tuple
from app import config
import pytorch_lightning as pl
torch.__version__
class AnimeDataset(torch.utils.data.Dataset):
""" Sketchs and Colored Image dataset """
def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
""" Set the transforms and file path """
self.list_files = imgs_path
self.transform = transforms
def __len__(self) -> int:
""" Should return number of files """
return len(self.list_files)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
""" Get image and mask by index """
# read image file
img_file = self.list_files[index]
# img_path = os.path.join(self.root_dir, img_file)
image = np.array(Image.open(img_file))
# divide image into sketchs and colored_imgs, right is sketch and left is colored images
sketchs = image[:, image.shape[1] // 2:, :]
colored_imgs = image[:, :image.shape[1] // 2, :]
# data augmentation on both sketchs and colored_imgs
augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
sketchs, colored_imgs = augmentations['image'], augmentations['image0']
# conduct data augmentation respectively
sketchs = self.transform.transform_only_input(image=sketchs)['image']
colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
return sketchs, colored_imgs
# Data Augmentation
class Transforms:
def __init__(self):
# use on both sketchs and colored images
self.both_transform = A.Compose([
A.Resize(width=256, height=256),
A.HorizontalFlip(p=.5)
], additional_targets={'image0': 'image'})
# use on sketchs only
self.transform_only_input = A.Compose([
A.ColorJitter(p=.1),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
# use on colored images
self.transform_only_mask = A.Compose([
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
class Transforms_v1:
""" Class to hold transforms """
def __init__(self):
# use on both sketchs and colored images
self.resize_572 = A.Compose([
A.Resize(width=572, height=572)
])
self.resize_388 = A.Compose([
A.Resize(width=388, height=388)
])
self.resize_256 = A.Compose([
A.Resize(width=256, height=256)
])
# use on sketchs only
self.transform_only_input = A.Compose([
# A.ColorJitter(p=.1),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
# use on colored images
self.transform_only_mask = A.Compose([
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
class AnimeSketchDataModule(pl.LightningDataModule):
""" Class to hold the Anime sketch Data"""
def __init__(
self,
data_dir: str,
train_folder_name: str = "train/",
val_folder_name: str = "val/",
train_batch_size: int = config.train_batch_size,
val_batch_size: int = config.val_batch_size,
train_num_images: int = 0,
val_num_images: int = 0,
):
super().__init__()
self.val_dataset = None
self.train_dataset = None
self.data_dir: str = data_dir
# Set train and val images folder
train_path: str = f"{self.data_dir}{train_folder_name}/"
train_images: List[str] = [f"{train_path}{x}" for x in os.listdir(train_path)]
val_path: str = f"{self.data_dir}{val_folder_name}"
val_images: List[str] = [f"{val_path}{x}" for x in os.listdir(val_path)]
#
self.train_images = train_images[:train_num_images] if train_num_images else train_images
self.val_images = val_images[:val_num_images] if val_num_images else val_images
#
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
def set_datasets(self) -> None:
""" Get the train and test datasets """
self.train_dataset = AnimeDataset(
imgs_path=self.train_images,
transforms=Transforms()
)
self.val_dataset = AnimeDataset(
imgs_path=self.val_images,
transforms=Transforms()
)
print("The train test dataset lengths are : ", len(self.train_dataset), len(self.val_dataset))
return None
def setup(self, stage: Optional[str] = None) -> None:
self.set_datasets()
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
shuffle=False,
num_workers=2,
pin_memory=True
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.val_batch_size,
shuffle=False,
num_workers=2,
pin_memory=True
)
|