| import os |
| import numpy as np |
| import cv2 |
| import glob |
| import math |
| import yaml |
| import random |
| from collections import OrderedDict |
| import torch |
| import torch.nn.functional as F |
|
|
| from basicsr.data.transforms import augment |
| from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels |
| from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img |
| from basicsr.utils.img_process_util import filter2D |
| from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt |
| from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, |
| normalize, rgb_to_grayscale) |
|
|
| cur_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
| def ordered_yaml(): |
| """Support OrderedDict for yaml. |
| |
| Returns: |
| yaml Loader and Dumper. |
| """ |
| try: |
| from yaml import CDumper as Dumper |
| from yaml import CLoader as Loader |
| except ImportError: |
| from yaml import Dumper, Loader |
|
|
| _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG |
|
|
| def dict_representer(dumper, data): |
| return dumper.represent_dict(data.items()) |
|
|
| def dict_constructor(loader, node): |
| return OrderedDict(loader.construct_pairs(node)) |
|
|
| Dumper.add_representer(OrderedDict, dict_representer) |
| Loader.add_constructor(_mapping_tag, dict_constructor) |
| return Loader, Dumper |
|
|
| def opt_parse(opt_path): |
| with open(opt_path, mode='r') as f: |
| Loader, _ = ordered_yaml() |
| opt = yaml.load(f, Loader=Loader) |
|
|
| return opt |
|
|
| class RealESRGAN_degradation(object): |
| def __init__(self, opt_name='params_realesrgan.yml', device='cpu'): |
| opt_path = f'{cur_path}/{opt_name}' |
| self.opt = opt_parse(opt_path) |
| self.device = device |
| optk = self.opt['kernel_info'] |
|
|
| |
| self.blur_kernel_size = optk['blur_kernel_size'] |
| self.kernel_list = optk['kernel_list'] |
| self.kernel_prob = optk['kernel_prob'] |
| self.blur_sigma = optk['blur_sigma'] |
| self.betag_range = optk['betag_range'] |
| self.betap_range = optk['betap_range'] |
| self.sinc_prob = optk['sinc_prob'] |
|
|
| |
| self.blur_kernel_size2 = optk['blur_kernel_size2'] |
| self.kernel_list2 = optk['kernel_list2'] |
| self.kernel_prob2 = optk['kernel_prob2'] |
| self.blur_sigma2 = optk['blur_sigma2'] |
| self.betag_range2 = optk['betag_range2'] |
| self.betap_range2 = optk['betap_range2'] |
| self.sinc_prob2 = optk['sinc_prob2'] |
|
|
| |
| self.final_sinc_prob = optk['final_sinc_prob'] |
|
|
| self.kernel_range = [2 * v + 1 for v in range(3, 11)] |
| self.pulse_tensor = torch.zeros(21, 21).float() |
| self.pulse_tensor[10, 10] = 1 |
|
|
| self.jpeger = DiffJPEG(differentiable=False).to(self.device) |
| self.usm_shaper = USMSharp().to(self.device) |
| |
| def color_jitter_pt(self, img, brightness, contrast, saturation, hue): |
| fn_idx = torch.randperm(4) |
| for fn_id in fn_idx: |
| if fn_id == 0 and brightness is not None: |
| brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() |
| img = adjust_brightness(img, brightness_factor) |
|
|
| if fn_id == 1 and contrast is not None: |
| contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() |
| img = adjust_contrast(img, contrast_factor) |
|
|
| if fn_id == 2 and saturation is not None: |
| saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() |
| img = adjust_saturation(img, saturation_factor) |
|
|
| if fn_id == 3 and hue is not None: |
| hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() |
| img = adjust_hue(img, hue_factor) |
| return img |
|
|
| def random_augment(self, img_gt): |
| |
| img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True) |
| """ |
| # random color jitter |
| if np.random.uniform() < self.opt['color_jitter_prob']: |
| jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) |
| img_gt = img_gt + jitter_val |
| img_gt = np.clip(img_gt, 0, 1) |
| |
| # random grayscale |
| if np.random.uniform() < self.opt['gray_prob']: |
| #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) |
| img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY) |
| img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) |
| """ |
| |
| img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0) |
| return img_gt |
|
|
| def random_kernels(self): |
| |
| kernel_size = random.choice(self.kernel_range) |
| if np.random.uniform() < self.sinc_prob: |
| |
| if kernel_size < 13: |
| omega_c = np.random.uniform(np.pi / 3, np.pi) |
| else: |
| omega_c = np.random.uniform(np.pi / 5, np.pi) |
| kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) |
| else: |
| kernel = random_mixed_kernels( |
| self.kernel_list, |
| self.kernel_prob, |
| kernel_size, |
| self.blur_sigma, |
| self.blur_sigma, [-math.pi, math.pi], |
| self.betag_range, |
| self.betap_range, |
| noise_range=None) |
| |
| pad_size = (21 - kernel_size) // 2 |
| kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) |
|
|
| |
| kernel_size = random.choice(self.kernel_range) |
| if np.random.uniform() < self.sinc_prob2: |
| if kernel_size < 13: |
| omega_c = np.random.uniform(np.pi / 3, np.pi) |
| else: |
| omega_c = np.random.uniform(np.pi / 5, np.pi) |
| kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) |
| else: |
| kernel2 = random_mixed_kernels( |
| self.kernel_list2, |
| self.kernel_prob2, |
| kernel_size, |
| self.blur_sigma2, |
| self.blur_sigma2, [-math.pi, math.pi], |
| self.betag_range2, |
| self.betap_range2, |
| noise_range=None) |
|
|
| |
| pad_size = (21 - kernel_size) // 2 |
| kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) |
|
|
| |
| if np.random.uniform() < self.final_sinc_prob: |
| kernel_size = random.choice(self.kernel_range) |
| omega_c = np.random.uniform(np.pi / 3, np.pi) |
| sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) |
| sinc_kernel = torch.FloatTensor(sinc_kernel) |
| else: |
| sinc_kernel = self.pulse_tensor |
|
|
| kernel = torch.FloatTensor(kernel) |
| kernel2 = torch.FloatTensor(kernel2) |
|
|
| return kernel, kernel2, sinc_kernel |
|
|
| @torch.no_grad() |
| def degrade_process(self, img_gt, resize_bak=False): |
| img_gt = self.random_augment(img_gt) |
| kernel1, kernel2, sinc_kernel = self.random_kernels() |
| img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device) |
| |
| ori_h, ori_w = img_gt.size()[2:4] |
|
|
| |
| scale_final = 4 |
|
|
| |
| |
| out = filter2D(img_gt, kernel1) |
| |
| updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] |
| if updown_type == 'up': |
| scale = np.random.uniform(1, self.opt['resize_range'][1]) |
| elif updown_type == 'down': |
| scale = np.random.uniform(self.opt['resize_range'][0], 1) |
| else: |
| scale = 1 |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate(out, scale_factor=scale, mode=mode) |
| |
| gray_noise_prob = self.opt['gray_noise_prob'] |
| if np.random.uniform() < self.opt['gaussian_noise_prob']: |
| out = random_add_gaussian_noise_pt( |
| out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) |
| else: |
| out = random_add_poisson_noise_pt( |
| out, |
| scale_range=self.opt['poisson_scale_range'], |
| gray_prob=gray_noise_prob, |
| clip=True, |
| rounds=False) |
| |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) |
| out = torch.clamp(out, 0, 1) |
| out = self.jpeger(out, quality=jpeg_p) |
|
|
| |
| |
| if self.opt['second_phase_prob'] > random.random(): |
| |
| if np.random.uniform() < self.opt['second_blur_prob']: |
| out = filter2D(out, kernel2) |
| |
| updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] |
| if updown_type == 'up': |
| scale = np.random.uniform(1, self.opt['resize_range2'][1]) |
| elif updown_type == 'down': |
| scale = np.random.uniform(self.opt['resize_range2'][0], 1) |
| else: |
| scale = 1 |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate( |
| out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode) |
| |
| gray_noise_prob = self.opt['gray_noise_prob2'] |
| if np.random.uniform() < self.opt['gaussian_noise_prob2']: |
| out = random_add_gaussian_noise_pt( |
| out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) |
| else: |
| out = random_add_poisson_noise_pt( |
| out, |
| scale_range=self.opt['poisson_scale_range2'], |
| gray_prob=gray_noise_prob, |
| clip=True, |
| rounds=False) |
|
|
| |
| |
| |
| |
| |
| |
| |
| if np.random.uniform() < 0.5: |
| |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) |
| out = filter2D(out, sinc_kernel) |
| |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) |
| out = torch.clamp(out, 0, 1) |
| out = self.jpeger(out, quality=jpeg_p) |
| else: |
| |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) |
| out = torch.clamp(out, 0, 1) |
| out = self.jpeger(out, quality=jpeg_p) |
| |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) |
| out = filter2D(out, sinc_kernel) |
|
|
| if np.random.uniform() < self.opt['gray_prob']: |
| out = rgb_to_grayscale(out, num_output_channels=1) |
|
|
| if np.random.uniform() < self.opt['color_jitter_prob']: |
| brightness = self.opt.get('brightness', (0.5, 1.5)) |
| contrast = self.opt.get('contrast', (0.5, 1.5)) |
| saturation = self.opt.get('saturation', (0, 1.5)) |
| hue = self.opt.get('hue', (-0.1, 0.1)) |
| out = self.color_jitter_pt(out, brightness, contrast, saturation, hue) |
| |
| out1 = out |
| img_lq_noresize = torch.clamp((out1 * 255.0).round(), 0, 255) / 255. |
|
|
| if resize_bak: |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate(out, size=(ori_h, ori_w), mode=mode) |
| |
| img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. |
|
|
| return img_gt, img_lq, img_lq_noresize |