Spaces:
Running
Running
| # Copyright (c) 2020 Huawei Technologies Co., Ltd. | |
| # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode | |
| # | |
| # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE | |
| import os | |
| import subprocess | |
| import torch.utils.data as data | |
| import numpy as np | |
| import time | |
| import torch | |
| import pickle | |
| class LRHR_PKLDataset(data.Dataset): | |
| def __init__(self, opt): | |
| super(LRHR_PKLDataset, self).__init__() | |
| self.opt = opt | |
| self.crop_size = opt.get("GT_size", None) | |
| self.scale = None | |
| self.random_scale_list = [1] | |
| hr_file_path = opt["dataroot_GT"] | |
| lr_file_path = opt["dataroot_LQ"] | |
| y_labels_file_path = opt['dataroot_y_labels'] | |
| gpu = True | |
| augment = True | |
| self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False | |
| self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False | |
| self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False | |
| self.center_crop_hr_size = opt.get("center_crop_hr_size", None) | |
| n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8) | |
| t = time.time() | |
| self.lr_images = self.load_pkls(lr_file_path, n_max) | |
| self.hr_images = self.load_pkls(hr_file_path, n_max) | |
| min_val_hr = np.min([i.min() for i in self.hr_images[:20]]) | |
| max_val_hr = np.max([i.max() for i in self.hr_images[:20]]) | |
| min_val_lr = np.min([i.min() for i in self.lr_images[:20]]) | |
| max_val_lr = np.max([i.max() for i in self.lr_images[:20]]) | |
| t = time.time() - t | |
| print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". | |
| format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path)) | |
| print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". | |
| format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path)) | |
| self.gpu = gpu | |
| self.augment = augment | |
| self.measures = None | |
| def load_pkls(self, path, n_max): | |
| assert os.path.isfile(path), path | |
| images = [] | |
| with open(path, "rb") as f: | |
| images += pickle.load(f) | |
| assert len(images) > 0, path | |
| images = images[:n_max] | |
| images = [np.transpose(image, [2, 0, 1]) for image in images] | |
| return images | |
| def __len__(self): | |
| return len(self.hr_images) | |
| def __getitem__(self, item): | |
| hr = self.hr_images[item] | |
| lr = self.lr_images[item] | |
| if self.scale == None: | |
| self.scale = hr.shape[1] // lr.shape[1] | |
| assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape) | |
| if self.use_crop: | |
| hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop) | |
| if self.center_crop_hr_size: | |
| hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale) | |
| if self.use_flip: | |
| hr, lr = random_flip(hr, lr) | |
| if self.use_rot: | |
| hr, lr = random_rotation(hr, lr) | |
| hr = hr / 255.0 | |
| lr = lr / 255.0 | |
| if self.measures is None or np.random.random() < 0.05: | |
| if self.measures is None: | |
| self.measures = {} | |
| self.measures['hr_means'] = np.mean(hr) | |
| self.measures['hr_stds'] = np.std(hr) | |
| self.measures['lr_means'] = np.mean(lr) | |
| self.measures['lr_stds'] = np.std(lr) | |
| hr = torch.Tensor(hr) | |
| lr = torch.Tensor(lr) | |
| # if self.gpu: | |
| # hr = hr.cuda() | |
| # lr = lr.cuda() | |
| return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)} | |
| def print_and_reset(self, tag): | |
| m = self.measures | |
| kvs = [] | |
| for k in sorted(m.keys()): | |
| kvs.append("{}={:.2f}".format(k, m[k])) | |
| print("[KPI] " + tag + ": " + ", ".join(kvs)) | |
| self.measures = None | |
| def random_flip(img, seg): | |
| random_choice = np.random.choice([True, False]) | |
| img = img if random_choice else np.flip(img, 2).copy() | |
| seg = seg if random_choice else np.flip(seg, 2).copy() | |
| return img, seg | |
| def random_rotation(img, seg): | |
| random_choice = np.random.choice([0, 1, 3]) | |
| img = np.rot90(img, random_choice, axes=(1, 2)).copy() | |
| seg = np.rot90(seg, random_choice, axes=(1, 2)).copy() | |
| return img, seg | |
| def random_crop(hr, lr, size_hr, scale, random): | |
| size_lr = size_hr // scale | |
| size_lr_x = lr.shape[1] | |
| size_lr_y = lr.shape[2] | |
| start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0 | |
| start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0 | |
| # LR Patch | |
| lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr] | |
| # HR Patch | |
| start_x_hr = start_x_lr * scale | |
| start_y_hr = start_y_lr * scale | |
| hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr] | |
| return hr_patch, lr_patch | |
| def center_crop(img, size): | |
| assert img.shape[1] == img.shape[2], img.shape | |
| border_double = img.shape[1] - size | |
| assert border_double % 2 == 0, (img.shape, size) | |
| border = border_double // 2 | |
| return img[:, border:-border, border:-border] | |
| def center_crop_tensor(img, size): | |
| assert img.shape[2] == img.shape[3], img.shape | |
| border_double = img.shape[2] - size | |
| assert border_double % 2 == 0, (img.shape, size) | |
| border = border_double // 2 | |
| return img[:, :, border:-border, border:-border] | |