Update app.py
Browse files
app.py
CHANGED
|
@@ -127,13 +127,13 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
| 127 |
print(f'input size: {img.shape}')
|
| 128 |
|
| 129 |
# img2tensor
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
|
| 134 |
# inference
|
| 135 |
if large_input_flag:
|
| 136 |
-
patches, idx, size = img2patch(
|
| 137 |
with torch.no_grad():
|
| 138 |
n = len(patches)
|
| 139 |
outs = []
|
|
@@ -153,24 +153,26 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
| 153 |
output = patch2img(output, idx, size, scale=upscale)
|
| 154 |
else:
|
| 155 |
with torch.no_grad():
|
| 156 |
-
output = model(
|
| 157 |
|
| 158 |
# color fix
|
| 159 |
if color_fix:
|
| 160 |
-
|
| 161 |
-
output = wavelet_reconstruction(output,
|
| 162 |
# tensor2img
|
| 163 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 164 |
if output.ndim == 3:
|
| 165 |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
| 166 |
output = (output * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
# save restored img
|
| 169 |
-
save_path = f'results/out.png'
|
| 170 |
-
cv2.imwrite(save_path, output)
|
| 171 |
|
| 172 |
-
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
| 173 |
-
return output, save_path
|
| 174 |
|
| 175 |
|
| 176 |
|
|
@@ -223,10 +225,11 @@ demo = gr.Interface(
|
|
| 223 |
fn=inference,
|
| 224 |
inputs=[
|
| 225 |
gr.Image(value="real_testdata/004.png", type="pil", label="Input"),
|
| 226 |
-
gr.Number(minimum=2, maximum=4,
|
| 227 |
gr.Checkbox(value=False, label="Memory-efficient inference"),
|
| 228 |
gr.Checkbox(value=False, label="Color correction"),
|
| 229 |
],
|
|
|
|
| 230 |
outputs=ImageSlider(label="Super-Resolved Image",
|
| 231 |
type="pil",
|
| 232 |
show_download_button=True,
|
|
|
|
| 127 |
print(f'input size: {img.shape}')
|
| 128 |
|
| 129 |
# img2tensor
|
| 130 |
+
y = y.astype(np.float32) / 255.
|
| 131 |
+
y = torch.from_numpy(np.transpose(y[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
| 132 |
+
y = y.unsqueeze(0).to(device)
|
| 133 |
|
| 134 |
# inference
|
| 135 |
if large_input_flag:
|
| 136 |
+
patches, idx, size = img2patch(y, scale=upscale)
|
| 137 |
with torch.no_grad():
|
| 138 |
n = len(patches)
|
| 139 |
outs = []
|
|
|
|
| 153 |
output = patch2img(output, idx, size, scale=upscale)
|
| 154 |
else:
|
| 155 |
with torch.no_grad():
|
| 156 |
+
output = model(y)
|
| 157 |
|
| 158 |
# color fix
|
| 159 |
if color_fix:
|
| 160 |
+
y = F.interpolate(y, scale_factor=upscale, mode='bilinear')
|
| 161 |
+
output = wavelet_reconstruction(output, y)
|
| 162 |
# tensor2img
|
| 163 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 164 |
if output.ndim == 3:
|
| 165 |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
| 166 |
output = (output * 255.0).round().astype(np.uint8)
|
| 167 |
+
|
| 168 |
+
return (img, Image.fromarray(output))
|
| 169 |
|
| 170 |
+
# # save restored img
|
| 171 |
+
# save_path = f'results/out.png'
|
| 172 |
+
# cv2.imwrite(save_path, output)
|
| 173 |
|
| 174 |
+
# output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
| 175 |
+
# return output, save_path
|
| 176 |
|
| 177 |
|
| 178 |
|
|
|
|
| 225 |
fn=inference,
|
| 226 |
inputs=[
|
| 227 |
gr.Image(value="real_testdata/004.png", type="pil", label="Input"),
|
| 228 |
+
gr.Number(minimum=2, maximum=4, default_value=2, label="Upscaling factor (up to 4)"),
|
| 229 |
gr.Checkbox(value=False, label="Memory-efficient inference"),
|
| 230 |
gr.Checkbox(value=False, label="Color correction"),
|
| 231 |
],
|
| 232 |
+
|
| 233 |
outputs=ImageSlider(label="Super-Resolved Image",
|
| 234 |
type="pil",
|
| 235 |
show_download_button=True,
|