| | import cv2 |
| | import numpy as np |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| |
|
| | def filter2D(img, kernel): |
| | """PyTorch version of cv2.filter2D |
| | |
| | Args: |
| | img (Tensor): (b, c, h, w) |
| | kernel (Tensor): (b, k, k) |
| | """ |
| | k = kernel.size(-1) |
| | b, c, h, w = img.size() |
| | if k % 2 == 1: |
| | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') |
| | else: |
| | raise ValueError('Wrong kernel size') |
| |
|
| | ph, pw = img.size()[-2:] |
| |
|
| | if kernel.size(0) == 1: |
| | |
| | img = img.view(b * c, 1, ph, pw) |
| | kernel = kernel.view(1, 1, k, k) |
| | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) |
| | else: |
| | img = img.view(1, b * c, ph, pw) |
| | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) |
| | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) |
| |
|
| |
|
| | def usm_sharp(img, weight=0.5, radius=50, threshold=10): |
| | """USM sharpening. |
| | |
| | Input image: I; Blurry image: B. |
| | 1. sharp = I + weight * (I - B) |
| | 2. Mask = 1 if abs(I - B) > threshold, else: 0 |
| | 3. Blur mask: |
| | 4. Out = Mask * sharp + (1 - Mask) * I |
| | |
| | |
| | Args: |
| | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. |
| | weight (float): Sharp weight. Default: 1. |
| | radius (float): Kernel size of Gaussian blur. Default: 50. |
| | threshold (int): |
| | """ |
| | if radius % 2 == 0: |
| | radius += 1 |
| | blur = cv2.GaussianBlur(img, (radius, radius), 0) |
| | residual = img - blur |
| | mask = np.abs(residual) * 255 > threshold |
| | mask = mask.astype('float32') |
| | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) |
| |
|
| | sharp = img + weight * residual |
| | sharp = np.clip(sharp, 0, 1) |
| | return soft_mask * sharp + (1 - soft_mask) * img |
| |
|
| |
|
| | class USMSharp(torch.nn.Module): |
| |
|
| | def __init__(self, radius=50, sigma=0): |
| | super(USMSharp, self).__init__() |
| | if radius % 2 == 0: |
| | radius += 1 |
| | self.radius = radius |
| | kernel = cv2.getGaussianKernel(radius, sigma) |
| | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) |
| | self.register_buffer('kernel', kernel) |
| |
|
| | def forward(self, img, weight=0.5, threshold=10): |
| | blur = filter2D(img, self.kernel) |
| | residual = img - blur |
| |
|
| | mask = torch.abs(residual) * 255 > threshold |
| | mask = mask.float() |
| | soft_mask = filter2D(mask, self.kernel) |
| | sharp = img + weight * residual |
| | sharp = torch.clip(sharp, 0, 1) |
| | return soft_mask * sharp + (1 - soft_mask) * img |
| |
|