blee's picture
Upload 53 files
6670ec8 verified
import torch
import torch.nn.functional as F
import numpy as np
from torch.distributions.poisson import Poisson
import random
def crop_to_bounding_box(image, offset_height, offset_width, target_height,
target_width, is_batch):
# BHWC -> BHWC
cropped = image[:, offset_height: offset_height + target_height, offset_width: offset_width + target_width, :]
if not is_batch:
cropped = cropped[0]
return cropped
def crop_to_bounding_box_list(image, offset_height, offset_width, target_height,
target_width):
# HWC
cropped = [_image[offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] for _image in image]
return cropped
def pad_to_bounding_box(image, offset_height, offset_width, target_height,
target_width, is_batch):
_,height,width,_ = image.shape
after_padding_width = target_width - offset_width - width
after_padding_height = target_height - offset_height - height
paddings = (0, 0, offset_width, after_padding_width, offset_height, after_padding_height, 0, 0)
padded = torch.nn.functional.pad(image, paddings)
if not is_batch:
padded = padded[0]
return padded
def resize_with_crop_or_pad_torch(image, target_height, target_width):
# BHWC -> BHWC
is_batch = True
if image.ndim == 3:
is_batch = False
image = image[None] # 1HWC
def max_(x, y):
return max(x, y)
def min_(x, y):
return min(x, y)
def equal_(x, y):
return x == y
_, height, width, _ = image.shape
width_diff = target_width - width
offset_crop_width = max_(-width_diff // 2, 0)
offset_pad_width = max_(width_diff // 2, 0)
height_diff = target_height - height
offset_crop_height = max_(-height_diff // 2, 0)
offset_pad_height = max_(height_diff // 2, 0)
# Maybe crop if needed.
cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
min_(target_height, height),
min_(target_width, width), is_batch)
# Maybe pad if needed.
if not is_batch and cropped.ndim == 3:
cropped = cropped[None]
resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width,
target_height, target_width, is_batch)
return resized
def psf2otf(psf, h=None, w=None, permute=False):
'''
psf = (b) h,w,c
'''
if h is not None:
psf = resize_with_crop_or_pad_torch(psf, h, w)
if permute:
if psf.ndim == 3:
psf = psf.permute(2,0,1) # HWC -> CHW
else:
psf = psf.permute(0,3,1,2) # HWC -> CHW
psf = psf.to(torch.complex64)
psf = torch.fft.fftshift(psf, dim=(-1,-2))
otf = torch.fft.fft2(psf)
return otf
def fft(img): # CHW
img = img.to(torch.complex64)
Fimg = torch.fft.fft2(img)
return Fimg
def ifft(Fimg):
img = torch.abs(torch.fft.ifft2(Fimg)).to(torch.float32)
return img
def create_contrast_mask(image):
return 1 - torch.mean(image, dim=(-1,-2), keepdim=True) # (B), C,1,1
def apply_tikhonov(lr_img, psf, K, norm=True, otf=None):
h,w = lr_img.shape[-2:]
if otf is None:
psf_norm = resize_with_crop_or_pad_torch(psf, h, w)
if norm:
psf_norm = psf_norm / psf_norm.sum((0, 1))
otf = psf2otf(psf_norm, h, w, permute=True)
otf = otf[:,None,...] # B,1,C,H,W
contrast_mask = create_contrast_mask(lr_img)[:,None,...] # B,1,C,1,1
K_adjusted = K * contrast_mask # B,M,C,1,1
tikhonov_filter = torch.conj(otf) / (torch.abs(otf) ** 2 + K_adjusted) # B,M,C,H,W
lr_fft = fft(lr_img)[:,None,...] # B,1,C,H,W
deconvolved_fft = lr_fft * tikhonov_filter
deconvolved_image = torch.fft.ifft2(deconvolved_fft).real
deconvolved_image = torch.clamp(deconvolved_image, min=0.0, max=1.0)
return deconvolved_image # B,M,C,H,W
def add_noise_all_new(image, poss=4e-5, gaus=1e-5):
p = Poisson(image / poss)
sampled = p.sample((1,))[0]
poss_img = sampled * poss
gauss_noise = torch.randn_like(image) * gaus
noised_img = poss_img + gauss_noise
noised_img = torch.clamp(noised_img, 0.0, 1.0)
return noised_img
def apply_convolution(image, psf, pad):
'''
input: hr img (b,c,h,w, [0,1])
output: noised lr img (b,c,h+P,w+P, [0,1])
'''
# metalens simulation
image = F.pad(image, (pad, pad, pad, pad))
h,w = image.shape[-2:]
psf_norm = resize_with_crop_or_pad_torch(psf, h, w)
otf = psf2otf(psf_norm, h, w, permute=True)
lr_img = fft(image) * otf
lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0)
# noise addition
noised_img = add_noise_all_new(lr_img)
return noised_img, otf
def apply_conv_n_deconv(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=True, conv=True):
'''
input: hr img (b,c,h,w)
otf: 1,N,C,H,W
output: noised lr img (N,c,h,w)
'''
b,_,_,_ = image.shape
if conv:
img_patch = F.unfold(image, kernel_size=ph*3, stride=ph).view(b,3,ph*3,ph*3,num_psf**2).permute(0,4,1,2,3).contiguous() # B,N,C,H,W
# metalens simulation
lr_img = fft(img_patch) * otf
lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0)
# noise addtion
lr_img = add_noise_all_new(lr_img)
else: # load convolved image for validation
b = 1
lr_img = image
# apply deconvolution
if ks is not None:
lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,405,405
lr_img = lr_img[..., ph:-ph, ph:-ph] # BMNCHW
lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h)
else:
lr_img = lr_img[..., ph:-ph, ph:-ph] # BNCHW
lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h)
lq_patches = []
gt_patches = []
for i in range(b):
cur = lr_img[i] # (M),C,H,W
cur_gt = image[i]
# remove padding for lq and gt
pt,pb,pl,pr = padding[i]
if pb and pt:
cur = cur[...,pt: -pb, :]
cur_gt = cur_gt[...,pt+ph: -(pb+ph), ph:-ph]
elif pl and pr:
cur = cur[...,pl:-pr]
cur_gt = cur_gt[...,ph:-ph, pl+ph: -(pr+ph)]
else:
cur_gt = cur_gt[...,ph:-ph, ph: -ph]
h,w = cur.shape[-2:]
# randomly crop patch for training
if crop: # train
top = random.randint(0, h - psize)
left = random.randint(0, w - psize)
lq_patches.append(cur[..., top:top + psize, left:left + psize])
gt_patches.append(cur_gt[..., top:top + psize, left:left + psize])
if crop: # training
lq_patches = torch.stack(lq_patches)
gt_patches = torch.stack(gt_patches)
else: # validation
return cur, cur_gt
return lq_patches, gt_patches # B,(M),C,H,W
def apply_convolution_square_val(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=False):
'''
merge to above one.
image = lr_image
'''
lr_img = image
b = 1
if M: # apply deconvolution
lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,H,W
lr_img = lr_img[..., ph:-ph, ph:-ph] # B,M,N,C,H,W
lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h)
else:
lr_img = lr_img[..., ph:-ph, ph:-ph] # B,N,C,H,W
lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h)
for i in range(b):
cur = lr_img[i] # (M),C,H,W
# remove padding for lq and gt
pt,pb,pl,pr = padding[i]
if pb and pt:
cur = cur[...,pt: -pb, :]
elif pl and pr:
cur = cur[...,pl:-pr]
return cur