kebincontreras commited on
Commit
9ace6fe
·
verified ·
1 Parent(s): 77d2f88

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -118
utils.py DELETED
@@ -1,118 +0,0 @@
1
- import torch
2
- import torch_dct as dct
3
-
4
- from torchvision.transforms import functional as F
5
-
6
- def flip_odd_lines(matrix):
7
- """
8
- Flip odd lines of a matrix
9
- """
10
- matrix = matrix.clone()
11
-
12
- matrix[..., 1::2, :] = matrix[..., 1::2, :].flip(-1)
13
-
14
- return matrix
15
-
16
- rotate = lambda m, r: torch.rot90(m, r, [-2, -1])
17
- sequency_vec = lambda m: flip_odd_lines(rotate(m, 0)).flatten(start_dim= m.dim()-2)
18
- sequency_mat = lambda v, s: rotate(flip_odd_lines(v.unflatten(-1, s)), 0)
19
-
20
-
21
- def modulo(x, L):
22
- positive = x > 0
23
- x = x % L
24
- x = torch.where( ( x == 0) & positive, L, x)
25
- return x
26
-
27
-
28
- def center_modulo(x, L):
29
- return modulo(x + L/2, L) - L/2
30
-
31
-
32
- def unmodulo(psi):
33
-
34
- psi = torch.nn.functional.pad(psi, (1,1), mode='constant', value=0)
35
- psi = torch.diff(psi, 1)
36
- psi = dct.dct(psi, norm='ortho')
37
- N = psi.shape[-1]
38
- k = torch.arange(0, N)
39
- denom = 2*( torch.cos( torch.pi * k / N ) - 1 )
40
- denom[0] = 1.0
41
- denom = denom.unsqueeze(0).unsqueeze(0) + 1e-7
42
- psi = psi / denom
43
- psi[..., 0] = 0.0
44
- psi = dct.idct(psi, norm='ortho')
45
- return psi
46
-
47
- RD = lambda x, L: torch.round( x / L) * L
48
-
49
-
50
- def hard_thresholding(x, t):
51
- return x * (torch.abs(x) > t)
52
-
53
-
54
- def stripe_estimation(psi, t=0.15):
55
-
56
- dx = torch.diff(psi, 1, dim=-1)
57
- dy = torch.diff(psi, 1, dim=-2)
58
-
59
- dx = hard_thresholding(dx, t)
60
- dy = hard_thresholding(dy, t)
61
-
62
- dx = F.pad(dx, (1, 0, 1, 0))
63
- dy = F.pad(dy, (0, 1, 0, 1))
64
-
65
-
66
- rho = torch.diff(dx, 1, dim=-1) + torch.diff(dy, 1, dim=-2)
67
- dct_rho = dct.dct_2d(rho, norm='ortho')
68
-
69
-
70
- MX = rho.shape[-2]
71
- NX = rho.shape[-1]
72
-
73
- I, J = torch.meshgrid(torch.arange(0, MX), torch.arange(0, NX), indexing="ij")
74
- I = I.to(rho.device)
75
- J = J.to(rho.device)
76
- denom = 2 * (torch.cos(torch.pi * I / MX ) + torch.cos(torch.pi * J / NX ) - 2)
77
- denom = denom.unsqueeze(0).unsqueeze(0)
78
- denom = denom.to(rho.device)
79
- dct_phi = dct_rho / denom
80
- dct_phi[..., 0, 0] = 0
81
- phi = dct.idct_2d(dct_phi, norm='ortho')
82
- phi = phi - torch.min(phi)
83
- # phi = phi - torch.amin(phi, dim=(-1, -2), keepdim=True)
84
- # phi = RD(phi, 1.0)
85
- return phi
86
-
87
-
88
- def recons(m_t, DO=1, L=1.0, vertical=False, t=0.3):
89
-
90
- if vertical:
91
- m_t = m_t.permute(0, 1, 3, 2)
92
-
93
- shape = m_t.shape[-2:]
94
-
95
- modulo_vec = sequency_vec(m_t)
96
- res = center_modulo( torch.diff(modulo_vec, n=DO), L) - torch.diff(modulo_vec, n=DO)
97
- bl = res
98
-
99
- for i in range(DO):
100
- bl = unmodulo(bl)
101
- bl = RD(bl, L)
102
-
103
- x_est = bl
104
-
105
- x_est = sequency_mat(x_est, shape)
106
- x_est = x_est + m_t
107
-
108
- if vertical:
109
- x_est = x_est.permute(0, 1, 3, 2)
110
-
111
- stripes = stripe_estimation(x_est, t=t)
112
- x_est = x_est - stripes
113
-
114
- # if vertical:
115
- # x_est = x_est.permute(0, 1, 3, 2)
116
- x_est = x_est - x_est.min()
117
- x_est = x_est / x_est.max()
118
- return x_est