Delete utils.py
Browse files
utils.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch_dct as dct
|
| 3 |
-
|
| 4 |
-
from torchvision.transforms import functional as F
|
| 5 |
-
|
| 6 |
-
def flip_odd_lines(matrix):
|
| 7 |
-
"""
|
| 8 |
-
Flip odd lines of a matrix
|
| 9 |
-
"""
|
| 10 |
-
matrix = matrix.clone()
|
| 11 |
-
|
| 12 |
-
matrix[..., 1::2, :] = matrix[..., 1::2, :].flip(-1)
|
| 13 |
-
|
| 14 |
-
return matrix
|
| 15 |
-
|
| 16 |
-
rotate = lambda m, r: torch.rot90(m, r, [-2, -1])
|
| 17 |
-
sequency_vec = lambda m: flip_odd_lines(rotate(m, 0)).flatten(start_dim= m.dim()-2)
|
| 18 |
-
sequency_mat = lambda v, s: rotate(flip_odd_lines(v.unflatten(-1, s)), 0)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def modulo(x, L):
|
| 22 |
-
positive = x > 0
|
| 23 |
-
x = x % L
|
| 24 |
-
x = torch.where( ( x == 0) & positive, L, x)
|
| 25 |
-
return x
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def center_modulo(x, L):
|
| 29 |
-
return modulo(x + L/2, L) - L/2
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def unmodulo(psi):
|
| 33 |
-
|
| 34 |
-
psi = torch.nn.functional.pad(psi, (1,1), mode='constant', value=0)
|
| 35 |
-
psi = torch.diff(psi, 1)
|
| 36 |
-
psi = dct.dct(psi, norm='ortho')
|
| 37 |
-
N = psi.shape[-1]
|
| 38 |
-
k = torch.arange(0, N)
|
| 39 |
-
denom = 2*( torch.cos( torch.pi * k / N ) - 1 )
|
| 40 |
-
denom[0] = 1.0
|
| 41 |
-
denom = denom.unsqueeze(0).unsqueeze(0) + 1e-7
|
| 42 |
-
psi = psi / denom
|
| 43 |
-
psi[..., 0] = 0.0
|
| 44 |
-
psi = dct.idct(psi, norm='ortho')
|
| 45 |
-
return psi
|
| 46 |
-
|
| 47 |
-
RD = lambda x, L: torch.round( x / L) * L
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def hard_thresholding(x, t):
|
| 51 |
-
return x * (torch.abs(x) > t)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def stripe_estimation(psi, t=0.15):
|
| 55 |
-
|
| 56 |
-
dx = torch.diff(psi, 1, dim=-1)
|
| 57 |
-
dy = torch.diff(psi, 1, dim=-2)
|
| 58 |
-
|
| 59 |
-
dx = hard_thresholding(dx, t)
|
| 60 |
-
dy = hard_thresholding(dy, t)
|
| 61 |
-
|
| 62 |
-
dx = F.pad(dx, (1, 0, 1, 0))
|
| 63 |
-
dy = F.pad(dy, (0, 1, 0, 1))
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
rho = torch.diff(dx, 1, dim=-1) + torch.diff(dy, 1, dim=-2)
|
| 67 |
-
dct_rho = dct.dct_2d(rho, norm='ortho')
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
MX = rho.shape[-2]
|
| 71 |
-
NX = rho.shape[-1]
|
| 72 |
-
|
| 73 |
-
I, J = torch.meshgrid(torch.arange(0, MX), torch.arange(0, NX), indexing="ij")
|
| 74 |
-
I = I.to(rho.device)
|
| 75 |
-
J = J.to(rho.device)
|
| 76 |
-
denom = 2 * (torch.cos(torch.pi * I / MX ) + torch.cos(torch.pi * J / NX ) - 2)
|
| 77 |
-
denom = denom.unsqueeze(0).unsqueeze(0)
|
| 78 |
-
denom = denom.to(rho.device)
|
| 79 |
-
dct_phi = dct_rho / denom
|
| 80 |
-
dct_phi[..., 0, 0] = 0
|
| 81 |
-
phi = dct.idct_2d(dct_phi, norm='ortho')
|
| 82 |
-
phi = phi - torch.min(phi)
|
| 83 |
-
# phi = phi - torch.amin(phi, dim=(-1, -2), keepdim=True)
|
| 84 |
-
# phi = RD(phi, 1.0)
|
| 85 |
-
return phi
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def recons(m_t, DO=1, L=1.0, vertical=False, t=0.3):
|
| 89 |
-
|
| 90 |
-
if vertical:
|
| 91 |
-
m_t = m_t.permute(0, 1, 3, 2)
|
| 92 |
-
|
| 93 |
-
shape = m_t.shape[-2:]
|
| 94 |
-
|
| 95 |
-
modulo_vec = sequency_vec(m_t)
|
| 96 |
-
res = center_modulo( torch.diff(modulo_vec, n=DO), L) - torch.diff(modulo_vec, n=DO)
|
| 97 |
-
bl = res
|
| 98 |
-
|
| 99 |
-
for i in range(DO):
|
| 100 |
-
bl = unmodulo(bl)
|
| 101 |
-
bl = RD(bl, L)
|
| 102 |
-
|
| 103 |
-
x_est = bl
|
| 104 |
-
|
| 105 |
-
x_est = sequency_mat(x_est, shape)
|
| 106 |
-
x_est = x_est + m_t
|
| 107 |
-
|
| 108 |
-
if vertical:
|
| 109 |
-
x_est = x_est.permute(0, 1, 3, 2)
|
| 110 |
-
|
| 111 |
-
stripes = stripe_estimation(x_est, t=t)
|
| 112 |
-
x_est = x_est - stripes
|
| 113 |
-
|
| 114 |
-
# if vertical:
|
| 115 |
-
# x_est = x_est.permute(0, 1, 3, 2)
|
| 116 |
-
x_est = x_est - x_est.min()
|
| 117 |
-
x_est = x_est / x_est.max()
|
| 118 |
-
return x_est
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|