| import random |
| import json |
| import numpy as np |
| from pathlib import Path |
| from typing import Iterable |
| from omegaconf import ListConfig |
|
|
| import cv2 |
| import torch |
| from functools import partial |
| import torchvision as thv |
| from torch.utils.data import Dataset |
|
|
| from utils import util_sisr |
| from utils import util_image |
| from utils import util_common |
|
|
| from basicsr.data.transforms import augment |
| from basicsr.data.realesrgan_dataset import RealESRGANDataset |
|
|
| def get_transforms(transform_type, kwargs): |
| ''' |
| Accepted optins in kwargs. |
| mean: scaler or sequence, for nornmalization |
| std: scaler or sequence, for nornmalization |
| crop_size: int or sequence, random or center cropping |
| scale, out_shape: for Bicubic |
| min_max: tuple or list with length 2, for cliping |
| ''' |
| if transform_type == 'default': |
| transform = thv.transforms.Compose([ |
| thv.transforms.ToTensor(), |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), |
| ]) |
| elif transform_type == 'resize_ccrop_norm': |
| transform = thv.transforms.Compose([ |
| util_image.SmallestMaxSize( |
| max_size=kwargs.get('size'), |
| interpolation=kwargs.get('interpolation'), |
| ), |
| thv.transforms.ToTensor(), |
| thv.transforms.CenterCrop(size=kwargs.get('size', None)), |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), |
| ]) |
| elif transform_type == 'ccrop_norm': |
| transform = thv.transforms.Compose([ |
| thv.transforms.ToTensor(), |
| thv.transforms.CenterCrop(size=kwargs.get('size', None)), |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), |
| ]) |
| elif transform_type == 'rcrop_aug_norm': |
| transform = thv.transforms.Compose([ |
| util_image.RandomCrop(pch_size=kwargs.get('pch_size', 256)), |
| util_image.SpatialAug( |
| only_hflip=kwargs.get('only_hflip', False), |
| only_vflip=kwargs.get('only_vflip', False), |
| only_hvflip=kwargs.get('only_hvflip', False), |
| ), |
| util_image.ToTensor(max_value=kwargs.get('max_value')), |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), |
| ]) |
| elif transform_type == 'aug_norm': |
| transform = thv.transforms.Compose([ |
| util_image.SpatialAug( |
| only_hflip=kwargs.get('only_hflip', False), |
| only_vflip=kwargs.get('only_vflip', False), |
| only_hvflip=kwargs.get('only_hvflip', False), |
| ), |
| util_image.ToTensor(), |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), |
| ]) |
| else: |
| raise ValueError(f'Unexpected transform_variant {transform_variant}') |
| return transform |
|
|
| def create_dataset(dataset_config): |
| if dataset_config['type'] == 'base': |
| dataset = BaseData(**dataset_config['params']) |
| elif dataset_config['type'] == 'base_meta': |
| dataset = BaseDataMetaCond(**dataset_config['params']) |
| elif dataset_config['type'] == 'realesrgan': |
| dataset = RealESRGANDataset(dataset_config['params']) |
| else: |
| raise NotImplementedError(f"{dataset_config['type']}") |
|
|
| return dataset |
|
|
| class BaseData(Dataset): |
| def __init__( |
| self, |
| dir_path, |
| txt_path=None, |
| transform_type='default', |
| transform_kwargs={'mean':0.0, 'std':1.0}, |
| extra_dir_path=None, |
| extra_transform_type=None, |
| extra_transform_kwargs=None, |
| length=None, |
| need_path=False, |
| im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'], |
| recursive=False, |
| ): |
| super().__init__() |
|
|
| file_paths_all = [] |
| if dir_path is not None: |
| file_paths_all.extend(util_common.scan_files_from_folder(dir_path, im_exts, recursive)) |
| if txt_path is not None: |
| file_paths_all.extend(util_common.readline_txt(txt_path)) |
|
|
| self.file_paths = file_paths_all if length is None else random.sample(file_paths_all, length) |
| self.file_paths_all = file_paths_all |
|
|
| self.length = length |
| self.need_path = need_path |
| self.transform = get_transforms(transform_type, transform_kwargs) |
|
|
| self.extra_dir_path = extra_dir_path |
| if extra_dir_path is not None: |
| assert extra_transform_type is not None |
| self.extra_transform = get_transforms(extra_transform_type, extra_transform_kwargs) |
|
|
| def __len__(self): |
| return len(self.file_paths) |
|
|
| def __getitem__(self, index): |
| im_path_base = self.file_paths[index] |
| im_base = util_image.imread(im_path_base, chn='rgb', dtype='float32') |
|
|
| im_target = self.transform(im_base) |
| out = {'image':im_target, 'lq':im_target} |
|
|
| if self.extra_dir_path is not None: |
| im_path_extra = Path(self.extra_dir_path) / Path(im_path_base).name |
| im_extra = util_image.imread(im_path_extra, chn='rgb', dtype='float32') |
| im_extra = self.extra_transform(im_extra) |
| out['gt'] = im_extra |
|
|
| if self.need_path: |
| out['path'] = im_path_base |
|
|
| return out |
|
|
| def reset_dataset(self): |
| self.file_paths = random.sample(self.file_paths_all, self.length) |
|
|
| class BaseDataMetaCond(Dataset): |
| def __init__( |
| self, |
| meta_dir, |
| transform_type='default', |
| transform_kwargs={'mean':0.5, 'std':0.5}, |
| length=None, |
| need_path=False, |
| cond_key='canny', |
| cond_transform_type='default', |
| cond_transform_kwargs={'mean':0.5, 'std':0.5}, |
| ): |
| super().__init__() |
| if not isinstance(meta_dir, ListConfig): |
| meta_dir = [meta_dir,] |
| meta_list = [] |
| |
| |
| |
| |
| |
| for current_dir in meta_dir: |
| meta_list.extend(sorted([str(x) for x in Path(current_dir).glob("*.json")])) |
| self.meta_list = meta_list if length is None else meta_list[:length] |
|
|
| self.cond_key = cond_key |
| self.length = length |
| self.need_path = need_path |
| self.transform = get_transforms(transform_type, transform_kwargs) |
| self.cond_trasform = get_transforms(cond_transform_type, cond_transform_kwargs) |
|
|
| def __len__(self): |
| return len(self.meta_list) |
|
|
| def __getitem__(self, index): |
| |
| json_path = self.meta_list[index] |
| with open(json_path, 'r') as json_file: |
| meta_info = json.load(json_file) |
|
|
| |
| im_path = meta_info['source'] |
| im_source = util_image.imread(im_path, chn='rgb', dtype='uint8') |
| im_source = self.transform(im_source) |
| out = {'image': im_source,} |
| if self.need_path: |
| out['path'] = im_path |
|
|
| |
| if 'latent' in meta_info: |
| latent_path = meta_info['latent'] |
| out['latent'] = np.load(latent_path) |
|
|
| |
| out['txt'] = meta_info['prompt'] |
|
|
| |
| cond_key = self.cond_key |
| cond_path = meta_info[cond_key] |
| if cond_key == 'canny': |
| cond = util_image.imread(cond_path, chn='gray', dtype='uint8')[:, :, None] |
| elif cond_key == 'seg': |
| cond = util_image.imread(cond_path, chn='rgb', dtype='uint8') |
| else: |
| raise ValueError(f"Unexpected cond key: {cond_key}") |
| cond = self.cond_trasform(cond) |
| out['cond'] = cond |
|
|
| return out |
|
|