File size: 2,336 Bytes
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Inspired by: https://onlinelibrary.wiley.com/doi/pdf/10.1002/smll.202002878?casa_token=OP1n_oLqe4kAAAAA%3Aiovq39gdeNfEIR8Vyi_FRd3Ec9lz8cDm3m9MtmCoOXbg6w1ohs5YPom5x9uVK9S3wsqmssIPFzfsCIBM9w
"""

import pywt
import torch


class SSTEM(torch.nn.Module):
    """
    Implemented from: https://github.com/ziatdinovmax/Notebooks-for-papers/blob/master/GP_spiral_scans_GP.ipynb
    """

    def __init__(
        self,
    ):
        super(SSTEM, self).__init__()

    @staticmethod
    def SSTEM(
        yspar: torch.Tensor,
        mask: torch.Tensor,
        itern: int = 20,
        levels: int = 2,
        lambd: float = 0.8,
    ) -> torch.Tensor:
        """
        Parameters
        ---
        yspar: sparse image as an array, to reduce iteration numer, better rescaling the value to [0,1]
        mask: binary array, 1 indicationg sampled pixel locations
        itern: iteration number, usually 20 is enough
        levels: wavelet level, common choice 2,3,4, larger value for larger feature size, if too blur, change to smaller one
        lambd: threshold value, usually 0.8 is fine
        Output: Inpaited image
        """

        # bool -> float
        mask = mask.float()

        # -> numpy()
        yspar = yspar.detach().cpu().numpy()
        mask = mask.detach().cpu().numpy()

        fSpars = yspar
        W_thr = [0] * levels

        ProjC = lambda f, Omega: (1 - Omega) * f + Omega * yspar

        for i in range(itern):
            fSpars = ProjC(fSpars, mask)
            W_pro = pywt.swt2(fSpars, "db2", levels)
            for j in range(levels):
                sA = W_pro[j][0]
                sH = W_pro[j][1][0]
                sV = W_pro[j][1][1]
                sD = W_pro[j][1][2]
                W_thr[j] = (pywt.threshold(sA, 0, "soft")), (
                    pywt.threshold(sH, lambd, "soft"),
                    pywt.threshold(sV, lambd, "soft"),
                    pywt.threshold(sD, lambd, "soft"),
                )
            fSpars = pywt.iswt2(W_thr, "db2")
        return fSpars

    def forward(self, y_sparse: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
        x = SSTEM.SSTEM(y_sparse, y_mask)
        x = torch.Tensor(x).cuda()
        return x

    @staticmethod
    def get(weights=None):
        return SSTEM()


if __name__ == "__main__":
    sstem = SSTEM()