|
|
'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.''' |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from functools import partial |
|
|
import yaml |
|
|
from torch.nn import functional as F |
|
|
from torchvision import torch |
|
|
|
|
|
from util.resizer import Resizer |
|
|
from util.img_utils import Blurkernel, fft2_m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__OPERATOR__ = {} |
|
|
|
|
|
def register_operator(name: str): |
|
|
def wrapper(cls): |
|
|
if __OPERATOR__.get(name, None): |
|
|
raise NameError(f"Name {name} is already registered!") |
|
|
__OPERATOR__[name] = cls |
|
|
return cls |
|
|
return wrapper |
|
|
|
|
|
|
|
|
def get_operator(name: str, **kwargs): |
|
|
if __OPERATOR__.get(name, None) is None: |
|
|
raise NameError(f"Name {name} is not defined.") |
|
|
return __OPERATOR__[name](**kwargs) |
|
|
|
|
|
|
|
|
class LinearOperator(ABC): |
|
|
@abstractmethod |
|
|
def forward(self, data, **kwargs): |
|
|
|
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def transpose(self, data, **kwargs): |
|
|
|
|
|
pass |
|
|
|
|
|
def ortho_project(self, data, **kwargs): |
|
|
|
|
|
return data - self.transpose(self.forward(data, **kwargs), **kwargs) |
|
|
|
|
|
def project(self, data, measurement, **kwargs): |
|
|
|
|
|
return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs) |
|
|
|
|
|
|
|
|
@register_operator(name='noise') |
|
|
class DenoiseOperator(LinearOperator): |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
|
|
|
def forward(self, data): |
|
|
return data |
|
|
|
|
|
def transpose(self, data): |
|
|
return data |
|
|
|
|
|
def ortho_project(self, data): |
|
|
return data |
|
|
|
|
|
def project(self, data): |
|
|
return data |
|
|
|
|
|
|
|
|
@register_operator(name='super_resolution') |
|
|
class SuperResolutionOperator(LinearOperator): |
|
|
def __init__(self, in_shape, scale_factor, device): |
|
|
self.device = device |
|
|
self.up_sample = partial(F.interpolate, scale_factor=scale_factor) |
|
|
self.down_sample = Resizer(in_shape, 1/scale_factor).to(device) |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
return self.down_sample(data) |
|
|
|
|
|
def transpose(self, data, **kwargs): |
|
|
return self.up_sample(data) |
|
|
|
|
|
def project(self, data, measurement, **kwargs): |
|
|
return data - self.transpose(self.forward(data)) + self.transpose(measurement) |
|
|
|
|
|
|
|
|
|
|
|
@register_operator(name='motion_blur') |
|
|
class MotionBlurOperator(LinearOperator): |
|
|
def __init__(self, kernel_size, intensity, device): |
|
|
self.device = device |
|
|
self.kernel_size = kernel_size |
|
|
self.conv = Blurkernel(blur_type='motion', |
|
|
kernel_size=kernel_size, |
|
|
std=intensity, |
|
|
device=device).to(device) |
|
|
|
|
|
self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity) |
|
|
kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) |
|
|
self.conv.update_weights(kernel) |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
|
|
|
return self.conv(data) |
|
|
|
|
|
def transpose(self, data, **kwargs): |
|
|
return data |
|
|
|
|
|
def get_kernel(self): |
|
|
kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device) |
|
|
return kernel.view(1, 1, self.kernel_size, self.kernel_size) |
|
|
|
|
|
|
|
|
@register_operator(name='colorization') |
|
|
class ColorizationOperator(LinearOperator): |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
return (1/3) * torch.sum(data, dim=1, keepdim=True) |
|
|
|
|
|
def transpose(self, data, **kwargs): |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
@register_operator(name='gaussian_blur') |
|
|
class GaussialBlurOperator(LinearOperator): |
|
|
def __init__(self, kernel_size, intensity, device): |
|
|
self.device = device |
|
|
self.kernel_size = kernel_size |
|
|
self.conv = Blurkernel(blur_type='gaussian', |
|
|
kernel_size=kernel_size, |
|
|
std=intensity, |
|
|
device=device).to(device) |
|
|
self.kernel = self.conv.get_kernel() |
|
|
self.conv.update_weights(self.kernel.type(torch.float32)) |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
return self.conv(data) |
|
|
|
|
|
def transpose(self, data, **kwargs): |
|
|
return data |
|
|
|
|
|
def get_kernel(self): |
|
|
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) |
|
|
|
|
|
def project(self, data, measurement, **kwargs): |
|
|
|
|
|
return data - self.forward(data, **kwargs) + measurement |
|
|
|
|
|
@register_operator(name='inpainting') |
|
|
class InpaintingOperator(LinearOperator): |
|
|
'''This operator get pre-defined mask and return masked image.''' |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
|
|
|
def set_mask(self, mask): |
|
|
self.mask = mask |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
try: |
|
|
return data * self.mask.to(self.device) |
|
|
except: |
|
|
raise ValueError("Require mask") |
|
|
|
|
|
def transpose(self, data, **kwargs): |
|
|
return data |
|
|
|
|
|
def ortho_project(self, data, **kwargs): |
|
|
return data - self.forward(data, **kwargs) |
|
|
|
|
|
def project(self, data, measurement, **kwargs): |
|
|
return data - self.forward(data, **kwargs) + measurement |
|
|
|
|
|
|
|
|
class NonLinearOperator(ABC): |
|
|
@abstractmethod |
|
|
def forward(self, data, **kwargs): |
|
|
pass |
|
|
|
|
|
def project(self, data, measurement, **kwargs): |
|
|
return data + measurement - self.forward(data) |
|
|
|
|
|
@register_operator(name='phase_retrieval') |
|
|
class PhaseRetrievalOperator(NonLinearOperator): |
|
|
def __init__(self, oversample, device): |
|
|
self.pad = int((oversample / 8.0) * 256) |
|
|
self.device = device |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad)) |
|
|
amplitude = fft2_m(padded).abs() |
|
|
return amplitude |
|
|
|
|
|
@register_operator(name='nonlinear_blur') |
|
|
class NonlinearBlurOperator(NonLinearOperator): |
|
|
def __init__(self, opt_yml_path, device): |
|
|
self.device = device |
|
|
self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path) |
|
|
|
|
|
def prepare_nonlinear_blur_model(self, opt_yml_path): |
|
|
''' |
|
|
Nonlinear deblur requires external codes (bkse). |
|
|
''' |
|
|
from bkse.models.kernel_encoding.kernel_wizard import KernelWizard |
|
|
|
|
|
with open(opt_yml_path, "r") as f: |
|
|
opt = yaml.safe_load(f)["KernelWizard"] |
|
|
model_path = opt["pretrained"] |
|
|
blur_model = KernelWizard(opt) |
|
|
blur_model.eval() |
|
|
blur_model.load_state_dict(torch.load(model_path)) |
|
|
blur_model = blur_model.to(self.device) |
|
|
return blur_model |
|
|
|
|
|
def forward(self, data, **kwargs): |
|
|
random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2 |
|
|
data = (data + 1.0) / 2.0 |
|
|
blurred = self.blur_model.adaptKernel(data, kernel=random_kernel) |
|
|
blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) |
|
|
return blurred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__NOISE__ = {} |
|
|
|
|
|
def register_noise(name: str): |
|
|
def wrapper(cls): |
|
|
if __NOISE__.get(name, None): |
|
|
raise NameError(f"Name {name} is already defined!") |
|
|
__NOISE__[name] = cls |
|
|
return cls |
|
|
return wrapper |
|
|
|
|
|
def get_noise(name: str, **kwargs): |
|
|
if __NOISE__.get(name, None) is None: |
|
|
raise NameError(f"Name {name} is not defined.") |
|
|
noiser = __NOISE__[name](**kwargs) |
|
|
noiser.__name__ = name |
|
|
return noiser |
|
|
|
|
|
class Noise(ABC): |
|
|
def __call__(self, data): |
|
|
return self.forward(data) |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, data): |
|
|
pass |
|
|
|
|
|
@register_noise(name='clean') |
|
|
class Clean(Noise): |
|
|
def forward(self, data): |
|
|
return data |
|
|
|
|
|
@register_noise(name='gaussian') |
|
|
class GaussianNoise(Noise): |
|
|
def __init__(self, sigma): |
|
|
self.sigma = sigma |
|
|
|
|
|
def forward(self, data): |
|
|
return data + torch.randn_like(data, device=data.device) * self.sigma * 2 |
|
|
|
|
|
|
|
|
@register_noise(name='poisson') |
|
|
class PoissonNoise(Noise): |
|
|
def __init__(self, rate): |
|
|
self.rate = rate |
|
|
|
|
|
def forward(self, data): |
|
|
''' |
|
|
Follow skimage.util.random_noise. |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
data = (data + 1.0) / 2.0 |
|
|
data = data.clamp(0, 1) |
|
|
device = data.device |
|
|
data = data.detach().cpu() |
|
|
data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate) |
|
|
data = data * 2.0 - 1.0 |
|
|
data = data.clamp(-1, 1) |
|
|
return data.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|