import torch.nn as nn import torchvision.transforms.functional as TF import torch import numpy as np def pad_linear_extrapolation(x): # x: (..., H, W) h, w = x.shape[-2:] # Pad H if h > 1: top = x[..., 0:1, :] second = x[..., 1:2, :] top_pad = 2 * top - second bot = x[..., -1:, :] second_last = x[..., -2:-1, :] bot_pad = 2 * bot - second_last else: top_pad = x[..., 0:1, :] bot_pad = x[..., -1:, :] x = torch.cat([top_pad, x, bot_pad], dim=-2) # Pad W if w > 1: left = x[..., :, 0:1] second_w = x[..., :, 1:2] left_pad = 2 * left - second_w right = x[..., :, -1:] second_last_w = x[..., :, -2:-1] right_pad = 2 * right - second_last_w else: left_pad = x[..., :, 0:1] right_pad = x[..., :, -1:] x = torch.cat([left_pad, x, right_pad], dim=-1) return x def resize_extrapolated(x, size, interpolation=TF.InterpolationMode.BILINEAR, **kwargs): if not isinstance(size, (tuple, list)): return TF.resize(x, size, interpolation=interpolation, **kwargs) target_h, target_w = size h, w = x.shape[-2:] scale_h = target_h / h scale_w = target_w / w x_padded = pad_linear_extrapolation(x) new_h = int(round(target_h + 2 * scale_h)) new_w = int(round(target_w + 2 * scale_w)) out = TF.resize(x_padded, (new_h, new_w), interpolation=interpolation, **kwargs) pad_h = int(round(scale_h)) pad_w = int(round(scale_w)) return out[..., pad_h:pad_h+target_h, pad_w:pad_w+target_w] def laplacian_encode(x, downsample_size, sigma, interp_mode=TF.InterpolationMode.BILINEAR, extrapolate=False): is_numpy = isinstance(x, np.ndarray) if is_numpy: x = torch.from_numpy(x) # Unsqueeze to 4 dimensions if needed squeeze_count = 0 while x.ndim < 4: x = x.unsqueeze(0) squeeze_count += 1 lowres = TF.resize(x, downsample_size, interpolation=interp_mode) lowres = TF.gaussian_blur(lowres, kernel_size=int(sigma*2)//2*2 + 1, sigma=sigma) if not extrapolate: lowres_up = TF.resize(lowres, x.shape[-2:], interpolation=interp_mode) else: lowres_up = resize_extrapolated(lowres, x.shape[-2:], interpolation=interp_mode) residual = x - lowres_up # Squeeze back to original dimensions while squeeze_count > 0: residual = residual.squeeze(0) lowres = lowres.squeeze(0) squeeze_count -= 1 if is_numpy: residual = residual.numpy() lowres = lowres.numpy() return residual, lowres def laplacian_decode(residual, lowres, interp_mode=TF.InterpolationMode.BILINEAR, extrapolate=False, pre_padded=False): is_numpy = isinstance(residual, np.ndarray) # Convert to torch first if numpy if is_numpy: residual = torch.from_numpy(residual) lowres = torch.from_numpy(lowres) # Unsqueeze to 4 dimensions if needed squeeze_count = 0 while residual.ndim < 4: residual = residual.unsqueeze(0) lowres = lowres.unsqueeze(0) squeeze_count += 1 resize_shape = residual.shape[-2:] if pre_padded: pad_pixels = residual.shape[-1] // (lowres.shape[-1] - 2) resize_shape = (resize_shape[-2] + 2 * pad_pixels, resize_shape[-1] + 2 * pad_pixels) else: resize_shape = residual.shape[-2:] if not extrapolate: lowres_up = TF.resize(lowres, resize_shape, interpolation=interp_mode) else: lowres_up = resize_extrapolated(lowres, resize_shape, interpolation=interp_mode) if pre_padded: lowres_up = lowres_up[..., pad_pixels:-pad_pixels, pad_pixels:-pad_pixels] # Squeeze back to original dimensions while squeeze_count > 0: residual = residual.squeeze(0) lowres = lowres.squeeze(0) lowres_up = lowres_up.squeeze(0) squeeze_count -= 1 if is_numpy: residual = residual.numpy() lowres_up = lowres_up.numpy() return residual + lowres_up def laplacian_denoise(residual, lowres, sigma, interp_mode=TF.InterpolationMode.BILINEAR): decoded = laplacian_decode(residual, lowres, interp_mode, extrapolate=True) _, new_lowres = laplacian_encode(decoded, lowres.shape[-1], sigma, interp_mode) return residual, new_lowres