File size: 2,983 Bytes
91487c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch_dct as dct
from torchvision.transforms import functional as F
def flip_odd_lines(matrix):
"""
Flip odd lines of a matrix
"""
matrix = matrix.clone()
matrix[..., 1::2, :] = matrix[..., 1::2, :].flip(-1)
return matrix
rotate = lambda m, r: torch.rot90(m, r, [-2, -1])
sequency_vec = lambda m: flip_odd_lines(rotate(m, 0)).flatten(start_dim= m.dim()-2)
sequency_mat = lambda v, s: rotate(flip_odd_lines(v.unflatten(-1, s)), 0)
def modulo(x, L):
positive = x > 0
x = x % L
x = torch.where( ( x == 0) & positive, L, x)
return x
def center_modulo(x, L):
return modulo(x + L/2, L) - L/2
def unmodulo(psi):
psi = torch.nn.functional.pad(psi, (1,1), mode='constant', value=0)
psi = torch.diff(psi, 1)
psi = dct.dct(psi, norm='ortho')
N = psi.shape[-1]
k = torch.arange(0, N)
denom = 2*( torch.cos( torch.pi * k / N ) - 1 )
denom[0] = 1.0
denom = denom.unsqueeze(0).unsqueeze(0) + 1e-7
psi = psi / denom
psi[..., 0] = 0.0
psi = dct.idct(psi, norm='ortho')
return psi
RD = lambda x, L: torch.round( x / L) * L
def hard_thresholding(x, t):
return x * (torch.abs(x) > t)
def stripe_estimation(psi, t=0.15):
dx = torch.diff(psi, 1, dim=-1)
dy = torch.diff(psi, 1, dim=-2)
dx = hard_thresholding(dx, t)
dy = hard_thresholding(dy, t)
dx = F.pad(dx, (1, 0, 1, 0))
dy = F.pad(dy, (0, 1, 0, 1))
rho = torch.diff(dx, 1, dim=-1) + torch.diff(dy, 1, dim=-2)
dct_rho = dct.dct_2d(rho, norm='ortho')
MX = rho.shape[-2]
NX = rho.shape[-1]
I, J = torch.meshgrid(torch.arange(0, MX), torch.arange(0, NX), indexing="ij")
I = I.to(rho.device)
J = J.to(rho.device)
denom = 2 * (torch.cos(torch.pi * I / MX ) + torch.cos(torch.pi * J / NX ) - 2)
denom = denom.unsqueeze(0).unsqueeze(0)
denom = denom.to(rho.device)
dct_phi = dct_rho / denom
dct_phi[..., 0, 0] = 0
phi = dct.idct_2d(dct_phi, norm='ortho')
phi = phi - torch.min(phi)
# phi = phi - torch.amin(phi, dim=(-1, -2), keepdim=True)
# phi = RD(phi, 1.0)
return phi
def recons(m_t, DO=1, L=1.0, vertical=False, t=0.3):
if vertical:
m_t = m_t.permute(0, 1, 3, 2)
shape = m_t.shape[-2:]
modulo_vec = sequency_vec(m_t)
res = center_modulo( torch.diff(modulo_vec, n=DO), L) - torch.diff(modulo_vec, n=DO)
bl = res
for i in range(DO):
bl = unmodulo(bl)
bl = RD(bl, L)
x_est = bl
x_est = sequency_mat(x_est, shape)
x_est = x_est + m_t
if vertical:
x_est = x_est.permute(0, 1, 3, 2)
stripes = stripe_estimation(x_est, t=t)
x_est = x_est - stripes
# if vertical:
# x_est = x_est.permute(0, 1, 3, 2)
x_est = x_est - x_est.min()
x_est = x_est / x_est.max()
return x_est |