| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from torch.distributions.poisson import Poisson |
| import random |
|
|
|
|
| def crop_to_bounding_box(image, offset_height, offset_width, target_height, |
| target_width, is_batch): |
| |
| cropped = image[:, offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] |
|
|
| if not is_batch: |
| cropped = cropped[0] |
|
|
| return cropped |
|
|
| def crop_to_bounding_box_list(image, offset_height, offset_width, target_height, |
| target_width): |
| |
| cropped = [_image[offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] for _image in image] |
|
|
| return cropped |
|
|
| def pad_to_bounding_box(image, offset_height, offset_width, target_height, |
| target_width, is_batch): |
| _,height,width,_ = image.shape |
| after_padding_width = target_width - offset_width - width |
| after_padding_height = target_height - offset_height - height |
|
|
| paddings = (0, 0, offset_width, after_padding_width, offset_height, after_padding_height, 0, 0) |
| |
| padded = torch.nn.functional.pad(image, paddings) |
| if not is_batch: |
| padded = padded[0] |
|
|
| return padded |
|
|
| def resize_with_crop_or_pad_torch(image, target_height, target_width): |
| |
| |
| is_batch = True |
| if image.ndim == 3: |
| is_batch = False |
| image = image[None] |
|
|
| def max_(x, y): |
| return max(x, y) |
|
|
| def min_(x, y): |
| return min(x, y) |
|
|
| def equal_(x, y): |
| return x == y |
|
|
| _, height, width, _ = image.shape |
| width_diff = target_width - width |
| offset_crop_width = max_(-width_diff // 2, 0) |
| offset_pad_width = max_(width_diff // 2, 0) |
|
|
| height_diff = target_height - height |
| offset_crop_height = max_(-height_diff // 2, 0) |
| offset_pad_height = max_(height_diff // 2, 0) |
|
|
| |
| cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width, |
| min_(target_height, height), |
| min_(target_width, width), is_batch) |
|
|
| |
| if not is_batch and cropped.ndim == 3: |
| cropped = cropped[None] |
| resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, |
| target_height, target_width, is_batch) |
|
|
| return resized |
|
|
|
|
|
|
| def psf2otf(psf, h=None, w=None, permute=False): |
| ''' |
| psf = (b) h,w,c |
| ''' |
| if h is not None: |
| psf = resize_with_crop_or_pad_torch(psf, h, w) |
| if permute: |
| if psf.ndim == 3: |
| psf = psf.permute(2,0,1) |
| else: |
| psf = psf.permute(0,3,1,2) |
| psf = psf.to(torch.complex64) |
| psf = torch.fft.fftshift(psf, dim=(-1,-2)) |
| otf = torch.fft.fft2(psf) |
| return otf |
|
|
| def fft(img): |
| img = img.to(torch.complex64) |
| Fimg = torch.fft.fft2(img) |
| return Fimg |
|
|
| def ifft(Fimg): |
| img = torch.abs(torch.fft.ifft2(Fimg)).to(torch.float32) |
| return img |
|
|
|
|
| def create_contrast_mask(image): |
| return 1 - torch.mean(image, dim=(-1,-2), keepdim=True) |
|
|
| def apply_tikhonov(lr_img, psf, K, norm=True, otf=None): |
| h,w = lr_img.shape[-2:] |
| if otf is None: |
| psf_norm = resize_with_crop_or_pad_torch(psf, h, w) |
| if norm: |
| psf_norm = psf_norm / psf_norm.sum((0, 1)) |
| otf = psf2otf(psf_norm, h, w, permute=True) |
|
|
| otf = otf[:,None,...] |
| contrast_mask = create_contrast_mask(lr_img)[:,None,...] |
| K_adjusted = K * contrast_mask |
| tikhonov_filter = torch.conj(otf) / (torch.abs(otf) ** 2 + K_adjusted) |
| lr_fft = fft(lr_img)[:,None,...] |
| deconvolved_fft = lr_fft * tikhonov_filter |
| deconvolved_image = torch.fft.ifft2(deconvolved_fft).real |
| deconvolved_image = torch.clamp(deconvolved_image, min=0.0, max=1.0) |
|
|
| return deconvolved_image |
| |
|
|
| def add_noise_all_new(image, poss=4e-5, gaus=1e-5): |
| p = Poisson(image / poss) |
| sampled = p.sample((1,))[0] |
| poss_img = sampled * poss |
| gauss_noise = torch.randn_like(image) * gaus |
| noised_img = poss_img + gauss_noise |
|
|
| noised_img = torch.clamp(noised_img, 0.0, 1.0) |
|
|
| return noised_img |
|
|
|
|
| def apply_convolution(image, psf, pad): |
| ''' |
| input: hr img (b,c,h,w, [0,1]) |
| output: noised lr img (b,c,h+P,w+P, [0,1]) |
| ''' |
|
|
| |
| image = F.pad(image, (pad, pad, pad, pad)) |
| h,w = image.shape[-2:] |
| psf_norm = resize_with_crop_or_pad_torch(psf, h, w) |
| otf = psf2otf(psf_norm, h, w, permute=True) |
| lr_img = fft(image) * otf |
| lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0) |
|
|
| |
| noised_img = add_noise_all_new(lr_img) |
|
|
| return noised_img, otf |
|
|
| def apply_conv_n_deconv(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=True, conv=True): |
| ''' |
| input: hr img (b,c,h,w) |
| otf: 1,N,C,H,W |
| output: noised lr img (N,c,h,w) |
| ''' |
|
|
| b,_,_,_ = image.shape |
| if conv: |
| img_patch = F.unfold(image, kernel_size=ph*3, stride=ph).view(b,3,ph*3,ph*3,num_psf**2).permute(0,4,1,2,3).contiguous() |
|
|
| |
| lr_img = fft(img_patch) * otf |
| lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0) |
|
|
| |
| lr_img = add_noise_all_new(lr_img) |
|
|
| else: |
| b = 1 |
| lr_img = image |
|
|
| |
| if ks is not None: |
| lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) |
| lr_img = lr_img[..., ph:-ph, ph:-ph] |
| lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h) |
| else: |
| lr_img = lr_img[..., ph:-ph, ph:-ph] |
| lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h) |
|
|
| lq_patches = [] |
| gt_patches = [] |
| for i in range(b): |
| cur = lr_img[i] |
| cur_gt = image[i] |
|
|
| |
| pt,pb,pl,pr = padding[i] |
| if pb and pt: |
| cur = cur[...,pt: -pb, :] |
| cur_gt = cur_gt[...,pt+ph: -(pb+ph), ph:-ph] |
| elif pl and pr: |
| cur = cur[...,pl:-pr] |
| cur_gt = cur_gt[...,ph:-ph, pl+ph: -(pr+ph)] |
| else: |
| cur_gt = cur_gt[...,ph:-ph, ph: -ph] |
| h,w = cur.shape[-2:] |
|
|
| |
| if crop: |
| top = random.randint(0, h - psize) |
| left = random.randint(0, w - psize) |
| lq_patches.append(cur[..., top:top + psize, left:left + psize]) |
| gt_patches.append(cur_gt[..., top:top + psize, left:left + psize]) |
| if crop: |
| lq_patches = torch.stack(lq_patches) |
| gt_patches = torch.stack(gt_patches) |
| else: |
| return cur, cur_gt |
|
|
| return lq_patches, gt_patches |
|
|
|
|
| def apply_convolution_square_val(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=False): |
| ''' |
| merge to above one. |
| image = lr_image |
| ''' |
| lr_img = image |
| b = 1 |
| if M: |
| lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) |
| lr_img = lr_img[..., ph:-ph, ph:-ph] |
| lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h) |
| else: |
| lr_img = lr_img[..., ph:-ph, ph:-ph] |
| lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h) |
| |
|
|
| for i in range(b): |
| cur = lr_img[i] |
|
|
| |
| pt,pb,pl,pr = padding[i] |
| if pb and pt: |
| cur = cur[...,pt: -pb, :] |
| elif pl and pr: |
| cur = cur[...,pl:-pr] |
|
|
| return cur |