Spaces:
Running
Running
| """ | |
| Inspired by: https://onlinelibrary.wiley.com/doi/pdf/10.1002/smll.202002878?casa_token=OP1n_oLqe4kAAAAA%3Aiovq39gdeNfEIR8Vyi_FRd3Ec9lz8cDm3m9MtmCoOXbg6w1ohs5YPom5x9uVK9S3wsqmssIPFzfsCIBM9w | |
| """ | |
| import pywt | |
| import torch | |
| class SSTEM(torch.nn.Module): | |
| """ | |
| Implemented from: https://github.com/ziatdinovmax/Notebooks-for-papers/blob/master/GP_spiral_scans_GP.ipynb | |
| """ | |
| def __init__( | |
| self, | |
| ): | |
| super(SSTEM, self).__init__() | |
| def SSTEM( | |
| yspar: torch.Tensor, | |
| mask: torch.Tensor, | |
| itern: int = 20, | |
| levels: int = 2, | |
| lambd: float = 0.8, | |
| ) -> torch.Tensor: | |
| """ | |
| Parameters | |
| --- | |
| yspar: sparse image as an array, to reduce iteration numer, better rescaling the value to [0,1] | |
| mask: binary array, 1 indicationg sampled pixel locations | |
| itern: iteration number, usually 20 is enough | |
| levels: wavelet level, common choice 2,3,4, larger value for larger feature size, if too blur, change to smaller one | |
| lambd: threshold value, usually 0.8 is fine | |
| Output: Inpaited image | |
| """ | |
| # bool -> float | |
| mask = mask.float() | |
| # -> numpy() | |
| yspar = yspar.detach().cpu().numpy() | |
| mask = mask.detach().cpu().numpy() | |
| fSpars = yspar | |
| W_thr = [0] * levels | |
| ProjC = lambda f, Omega: (1 - Omega) * f + Omega * yspar | |
| for i in range(itern): | |
| fSpars = ProjC(fSpars, mask) | |
| W_pro = pywt.swt2(fSpars, "db2", levels) | |
| for j in range(levels): | |
| sA = W_pro[j][0] | |
| sH = W_pro[j][1][0] | |
| sV = W_pro[j][1][1] | |
| sD = W_pro[j][1][2] | |
| W_thr[j] = (pywt.threshold(sA, 0, "soft")), ( | |
| pywt.threshold(sH, lambd, "soft"), | |
| pywt.threshold(sV, lambd, "soft"), | |
| pywt.threshold(sD, lambd, "soft"), | |
| ) | |
| fSpars = pywt.iswt2(W_thr, "db2") | |
| return fSpars | |
| def forward(self, y_sparse: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor: | |
| x = SSTEM.SSTEM(y_sparse, y_mask) | |
| x = torch.Tensor(x).cuda() | |
| return x | |
| def get(weights=None): | |
| return SSTEM() | |
| if __name__ == "__main__": | |
| sstem = SSTEM() | |