JohanBeytell commited on
Commit
0867bd7
·
verified ·
1 Parent(s): 3a79fec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -10
app.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import gradio as gr
5
- from gradio_imageslider import ImageSlider
6
  from PIL import Image
7
  import torchvision.transforms.functional as TF
8
 
@@ -63,7 +62,6 @@ def calc_psnr(pred, target):
63
  def standard_upscale(img):
64
  if img is None: return None, ""
65
 
66
- # Bumped max input to 2K (2048px) for a 4K output
67
  max_input_dim = 2048
68
  w, h = img.size
69
 
@@ -98,12 +96,10 @@ def benchmark_upscale(hr_img):
98
  hr_img = hr_img.convert('RGB')
99
  w, h = hr_img.size
100
 
101
- # Enforce even dimensions so 2x scaling mathematically matches
102
  w = w - (w % 2)
103
  h = h - (h % 2)
104
  hr_img = hr_img.crop((0, 0, w, h))
105
 
106
- # HR can now be 4K (4096px) because the LR input to the model will be max 2K
107
  max_hr_dim = 4096
108
  if w > max_hr_dim or h > max_hr_dim:
109
  gr.Warning(f"Ground truth image exceeded the 4K ({max_hr_dim}px) limit. It has been proportionally downscaled.")
@@ -113,22 +109,18 @@ def benchmark_upscale(hr_img):
113
  h = h - (h % 2)
114
  hr_img = hr_img.resize((w, h), Image.BICUBIC)
115
 
116
- # Create the simulated Low-Res image
117
  lr_w, lr_h = w // 2, h // 2
118
  lr_img = hr_img.resize((lr_w, lr_h), Image.BICUBIC)
119
 
120
- # Run Inference
121
  lr_tensor = TF.to_tensor(lr_img).unsqueeze(0).to(device)
122
  hr_tensor = TF.to_tensor(hr_img).unsqueeze(0).to(device)
123
 
124
  with torch.no_grad():
125
  pred_tensor = model(lr_tensor).clamp(0, 1)
126
 
127
- # Calculate PSNR
128
  psnr = calc_psnr(pred_tensor, hr_tensor)
129
  pred_img = TF.to_pil_image(pred_tensor.squeeze(0))
130
 
131
- # Resize LR using NEAREST so it looks accurately pixelated in the slider comparison
132
  lr_slider_img = lr_img.resize((w, h), Image.NEAREST)
133
 
134
  details = (
@@ -138,6 +130,7 @@ def benchmark_upscale(hr_img):
138
  f"**Model Output & Ground Truth:** {w} x {h} ({w * h:,} pixels)"
139
  )
140
 
 
141
  return details, (lr_slider_img, pred_img), (hr_img, pred_img)
142
 
143
  # --- 4. GRADIO UI ---
@@ -173,10 +166,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
173
  bm_details = gr.Markdown()
174
  with gr.Column():
175
  gr.Markdown("### Low-Res vs. Model Prediction")
176
- slider_lr_pred = ImageSlider(label="Left: Pixelated Low-Res | Right: FastEDSR")
177
 
178
  gr.Markdown("### Ground Truth vs. Model Prediction")
179
- slider_hr_pred = ImageSlider(label="Left: Original HR | Right: FastEDSR")
180
 
181
  bm_btn.click(fn=benchmark_upscale, inputs=bm_input, outputs=[bm_details, slider_lr_pred, slider_hr_pred])
182
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import gradio as gr
 
5
  from PIL import Image
6
  import torchvision.transforms.functional as TF
7
 
 
62
  def standard_upscale(img):
63
  if img is None: return None, ""
64
 
 
65
  max_input_dim = 2048
66
  w, h = img.size
67
 
 
96
  hr_img = hr_img.convert('RGB')
97
  w, h = hr_img.size
98
 
 
99
  w = w - (w % 2)
100
  h = h - (h % 2)
101
  hr_img = hr_img.crop((0, 0, w, h))
102
 
 
103
  max_hr_dim = 4096
104
  if w > max_hr_dim or h > max_hr_dim:
105
  gr.Warning(f"Ground truth image exceeded the 4K ({max_hr_dim}px) limit. It has been proportionally downscaled.")
 
109
  h = h - (h % 2)
110
  hr_img = hr_img.resize((w, h), Image.BICUBIC)
111
 
 
112
  lr_w, lr_h = w // 2, h // 2
113
  lr_img = hr_img.resize((lr_w, lr_h), Image.BICUBIC)
114
 
 
115
  lr_tensor = TF.to_tensor(lr_img).unsqueeze(0).to(device)
116
  hr_tensor = TF.to_tensor(hr_img).unsqueeze(0).to(device)
117
 
118
  with torch.no_grad():
119
  pred_tensor = model(lr_tensor).clamp(0, 1)
120
 
 
121
  psnr = calc_psnr(pred_tensor, hr_tensor)
122
  pred_img = TF.to_pil_image(pred_tensor.squeeze(0))
123
 
 
124
  lr_slider_img = lr_img.resize((w, h), Image.NEAREST)
125
 
126
  details = (
 
130
  f"**Model Output & Ground Truth:** {w} x {h} ({w * h:,} pixels)"
131
  )
132
 
133
+ # Gradio's native ImageSlider expects a tuple of (image1, image2)
134
  return details, (lr_slider_img, pred_img), (hr_img, pred_img)
135
 
136
  # --- 4. GRADIO UI ---
 
166
  bm_details = gr.Markdown()
167
  with gr.Column():
168
  gr.Markdown("### Low-Res vs. Model Prediction")
169
+ slider_lr_pred = gr.ImageSlider(label="Left: Pixelated Low-Res | Right: FastEDSR")
170
 
171
  gr.Markdown("### Ground Truth vs. Model Prediction")
172
+ slider_hr_pred = gr.ImageSlider(label="Left: Original HR | Right: FastEDSR")
173
 
174
  bm_btn.click(fn=benchmark_upscale, inputs=bm_input, outputs=[bm_details, slider_lr_pred, slider_hr_pred])
175