|
|
import torch
|
|
|
import torch_dct as dct
|
|
|
|
|
|
from torchvision.transforms import functional as F
|
|
|
|
|
|
def flip_odd_lines(matrix):
|
|
|
"""
|
|
|
Flip odd lines of a matrix
|
|
|
"""
|
|
|
matrix = matrix.clone()
|
|
|
|
|
|
matrix[..., 1::2, :] = matrix[..., 1::2, :].flip(-1)
|
|
|
|
|
|
return matrix
|
|
|
|
|
|
rotate = lambda m, r: torch.rot90(m, r, [-2, -1])
|
|
|
sequency_vec = lambda m: flip_odd_lines(rotate(m, 0)).flatten(start_dim= m.dim()-2)
|
|
|
sequency_mat = lambda v, s: rotate(flip_odd_lines(v.unflatten(-1, s)), 0)
|
|
|
|
|
|
|
|
|
def modulo(x, L):
|
|
|
positive = x > 0
|
|
|
x = x % L
|
|
|
x = torch.where( ( x == 0) & positive, L, x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def center_modulo(x, L):
|
|
|
return modulo(x + L/2, L) - L/2
|
|
|
|
|
|
|
|
|
def unmodulo(psi):
|
|
|
|
|
|
psi = torch.nn.functional.pad(psi, (1,1), mode='constant', value=0)
|
|
|
psi = torch.diff(psi, 1)
|
|
|
psi = dct.dct(psi, norm='ortho')
|
|
|
N = psi.shape[-1]
|
|
|
k = torch.arange(0, N)
|
|
|
denom = 2*( torch.cos( torch.pi * k / N ) - 1 )
|
|
|
denom[0] = 1.0
|
|
|
denom = denom.unsqueeze(0).unsqueeze(0) + 1e-7
|
|
|
psi = psi / denom
|
|
|
psi[..., 0] = 0.0
|
|
|
psi = dct.idct(psi, norm='ortho')
|
|
|
return psi
|
|
|
|
|
|
RD = lambda x, L: torch.round( x / L) * L
|
|
|
|
|
|
|
|
|
def hard_thresholding(x, t):
|
|
|
return x * (torch.abs(x) > t)
|
|
|
|
|
|
|
|
|
def stripe_estimation(psi, t=0.15):
|
|
|
|
|
|
dx = torch.diff(psi, 1, dim=-1)
|
|
|
dy = torch.diff(psi, 1, dim=-2)
|
|
|
|
|
|
dx = hard_thresholding(dx, t)
|
|
|
dy = hard_thresholding(dy, t)
|
|
|
|
|
|
dx = F.pad(dx, (1, 0, 1, 0))
|
|
|
dy = F.pad(dy, (0, 1, 0, 1))
|
|
|
|
|
|
|
|
|
rho = torch.diff(dx, 1, dim=-1) + torch.diff(dy, 1, dim=-2)
|
|
|
dct_rho = dct.dct_2d(rho, norm='ortho')
|
|
|
|
|
|
|
|
|
MX = rho.shape[-2]
|
|
|
NX = rho.shape[-1]
|
|
|
|
|
|
I, J = torch.meshgrid(torch.arange(0, MX), torch.arange(0, NX), indexing="ij")
|
|
|
I = I.to(rho.device)
|
|
|
J = J.to(rho.device)
|
|
|
denom = 2 * (torch.cos(torch.pi * I / MX ) + torch.cos(torch.pi * J / NX ) - 2)
|
|
|
denom = denom.unsqueeze(0).unsqueeze(0)
|
|
|
denom = denom.to(rho.device)
|
|
|
dct_phi = dct_rho / denom
|
|
|
dct_phi[..., 0, 0] = 0
|
|
|
phi = dct.idct_2d(dct_phi, norm='ortho')
|
|
|
phi = phi - torch.min(phi)
|
|
|
|
|
|
|
|
|
return phi
|
|
|
|
|
|
|
|
|
def recons(m_t, DO=1, L=1.0, vertical=False, t=0.3):
|
|
|
|
|
|
if vertical:
|
|
|
m_t = m_t.permute(0, 1, 3, 2)
|
|
|
|
|
|
shape = m_t.shape[-2:]
|
|
|
|
|
|
modulo_vec = sequency_vec(m_t)
|
|
|
res = center_modulo( torch.diff(modulo_vec, n=DO), L) - torch.diff(modulo_vec, n=DO)
|
|
|
bl = res
|
|
|
|
|
|
for i in range(DO):
|
|
|
bl = unmodulo(bl)
|
|
|
bl = RD(bl, L)
|
|
|
|
|
|
x_est = bl
|
|
|
|
|
|
x_est = sequency_mat(x_est, shape)
|
|
|
x_est = x_est + m_t
|
|
|
|
|
|
if vertical:
|
|
|
x_est = x_est.permute(0, 1, 3, 2)
|
|
|
|
|
|
stripes = stripe_estimation(x_est, t=t)
|
|
|
x_est = x_est - stripes
|
|
|
|
|
|
|
|
|
|
|
|
x_est = x_est - x_est.min()
|
|
|
x_est = x_est / x_est.max()
|
|
|
return x_est |