Spaces:
Runtime error
Runtime error
- app.py +29 -46
- uhdm_checkpoint.pth +0 -3
app.py
CHANGED
|
@@ -24,16 +24,6 @@ load_path1 = "./mix.pth"
|
|
| 24 |
model_state_dict1 = torch.load(load_path1, map_location=device)
|
| 25 |
model1.load_state_dict(model_state_dict1)
|
| 26 |
|
| 27 |
-
model2 = my_model(en_feature_num=48,
|
| 28 |
-
en_inter_num=32,
|
| 29 |
-
de_feature_num=64,
|
| 30 |
-
de_inter_num=32,
|
| 31 |
-
sam_number=1,
|
| 32 |
-
).to(device)
|
| 33 |
-
|
| 34 |
-
load_path2 = "./uhdm_checkpoint.pth"
|
| 35 |
-
model_state_dict2 = torch.load(load_path2, map_location=device)
|
| 36 |
-
model2.load_state_dict(model_state_dict2)
|
| 37 |
|
| 38 |
def default_toTensor(img):
|
| 39 |
t_list = [transforms.ToTensor()]
|
|
@@ -59,25 +49,6 @@ def predict1(img):
|
|
| 59 |
|
| 60 |
return out_1
|
| 61 |
|
| 62 |
-
def predict2(img):
|
| 63 |
-
in_img = transforms.ToTensor()(img).to(device).unsqueeze(0)
|
| 64 |
-
b, c, h, w = in_img.size()
|
| 65 |
-
# pad image such that the resolution is a multiple of 32
|
| 66 |
-
w_pad = (math.ceil(w / 32) * 32 - w) // 2
|
| 67 |
-
h_pad = (math.ceil(h / 32) * 32 - h) // 2
|
| 68 |
-
in_img = img_pad(in_img, w_r=w_pad, h_r=h_pad)
|
| 69 |
-
with torch.no_grad():
|
| 70 |
-
out_1, out_2, out_3 = model2(in_img)
|
| 71 |
-
if h_pad != 0:
|
| 72 |
-
out_1 = out_1[:, :, h_pad:-h_pad, :]
|
| 73 |
-
if w_pad != 0:
|
| 74 |
-
out_1 = out_1[:, :, :, w_pad:-w_pad]
|
| 75 |
-
out_1 = out_1.squeeze(0)
|
| 76 |
-
out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255
|
| 77 |
-
).byte().permute(1, 2, 0).cpu().numpy())
|
| 78 |
-
|
| 79 |
-
return out_1
|
| 80 |
-
|
| 81 |
def img_pad(x, h_r=0, w_r=0):
|
| 82 |
'''
|
| 83 |
Here the padding values are determined by the average r,g,b values across the training set
|
|
@@ -93,21 +64,33 @@ def img_pad(x, h_r=0, w_r=0):
|
|
| 93 |
return y
|
| 94 |
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
iface1 = gr.Interface(fn=predict1,
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
iface_all.launch()
|
|
|
|
| 24 |
model_state_dict1 = torch.load(load_path1, map_location=device)
|
| 25 |
model1.load_state_dict(model_state_dict1)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def default_toTensor(img):
|
| 29 |
t_list = [transforms.ToTensor()]
|
|
|
|
| 49 |
|
| 50 |
return out_1
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def img_pad(x, h_r=0, w_r=0):
|
| 53 |
'''
|
| 54 |
Here the padding values are determined by the average r,g,b values across the training set
|
|
|
|
| 64 |
return y
|
| 65 |
|
| 66 |
|
| 67 |
+
title = "Clean Your Moire Images!"
|
| 68 |
+
description = """
|
| 69 |
+
|
| 70 |
+
The model was trained to remove the moire patterns from your captured screen images! Specially, this model is capable of tackling
|
| 71 |
+
images up to 4K resolution, which adapts to most of the modern mobile phones.
|
| 72 |
+
(Note: It may cost 80s per 4K image (e.g., iPhone's resolution: 4032x3024) since this demo runs on the CPU. The model can run
|
| 73 |
+
on a NVIDIA 3090 GPU 17ms per standard 4K image)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
article = "Check out the [ECCV 2022 paper](https://arxiv.org/abs/2207.09935) and the \
|
| 79 |
+
[official training code](https://github.com/CVMI-Lab/UHDM) which the demo is based on."
|
| 80 |
+
|
| 81 |
+
|
| 82 |
iface1 = gr.Interface(fn=predict1,
|
| 83 |
+
inputs=gr.inputs.Image(type="pil"),
|
| 84 |
+
outputs=gr.inputs.Image(type="pil"),
|
| 85 |
+
examples=['001.jpg',
|
| 86 |
+
'002.jpg',
|
| 87 |
+
'003.jpg',
|
| 88 |
+
'004.jpg',
|
| 89 |
+
'005.jpg'],
|
| 90 |
+
title = title,
|
| 91 |
+
description = description,
|
| 92 |
+
article = article
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
iface1.launch()
|
|
|
|
|
|
|
|
|
uhdm_checkpoint.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:254235cd25f90a3f1785885385dc6cb3f2178e053291ab53d1943bd7c2f7de65
|
| 3 |
-
size 23895301
|
|
|
|
|
|
|
|
|
|
|
|