leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
"""
Inspired by: https://onlinelibrary.wiley.com/doi/pdf/10.1002/smll.202002878?casa_token=OP1n_oLqe4kAAAAA%3Aiovq39gdeNfEIR8Vyi_FRd3Ec9lz8cDm3m9MtmCoOXbg6w1ohs5YPom5x9uVK9S3wsqmssIPFzfsCIBM9w
"""
import pywt
import torch
class SSTEM(torch.nn.Module):
"""
Implemented from: https://github.com/ziatdinovmax/Notebooks-for-papers/blob/master/GP_spiral_scans_GP.ipynb
"""
def __init__(
self,
):
super(SSTEM, self).__init__()
@staticmethod
def SSTEM(
yspar: torch.Tensor,
mask: torch.Tensor,
itern: int = 20,
levels: int = 2,
lambd: float = 0.8,
) -> torch.Tensor:
"""
Parameters
---
yspar: sparse image as an array, to reduce iteration numer, better rescaling the value to [0,1]
mask: binary array, 1 indicationg sampled pixel locations
itern: iteration number, usually 20 is enough
levels: wavelet level, common choice 2,3,4, larger value for larger feature size, if too blur, change to smaller one
lambd: threshold value, usually 0.8 is fine
Output: Inpaited image
"""
# bool -> float
mask = mask.float()
# -> numpy()
yspar = yspar.detach().cpu().numpy()
mask = mask.detach().cpu().numpy()
fSpars = yspar
W_thr = [0] * levels
ProjC = lambda f, Omega: (1 - Omega) * f + Omega * yspar
for i in range(itern):
fSpars = ProjC(fSpars, mask)
W_pro = pywt.swt2(fSpars, "db2", levels)
for j in range(levels):
sA = W_pro[j][0]
sH = W_pro[j][1][0]
sV = W_pro[j][1][1]
sD = W_pro[j][1][2]
W_thr[j] = (pywt.threshold(sA, 0, "soft")), (
pywt.threshold(sH, lambd, "soft"),
pywt.threshold(sV, lambd, "soft"),
pywt.threshold(sD, lambd, "soft"),
)
fSpars = pywt.iswt2(W_thr, "db2")
return fSpars
def forward(self, y_sparse: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
x = SSTEM.SSTEM(y_sparse, y_mask)
x = torch.Tensor(x).cuda()
return x
@staticmethod
def get(weights=None):
return SSTEM()
if __name__ == "__main__":
sstem = SSTEM()