JohanBeytell commited on
Commit
f0ea2eb
·
verified ·
1 Parent(s): c6cf678

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -24
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import gradio as gr
 
5
  from PIL import Image
6
  import torchvision.transforms.functional as TF
7
 
@@ -43,61 +44,138 @@ class FastEDSR(nn.Module):
43
  return base_upscaled + details
44
 
45
  # --- 2. INITIALIZATION ---
46
- device = torch.device('cpu') # Hugging Face Free Tier runs on CPU
47
  model = FastEDSR(scale_factor=2, num_blocks=8, channels=64)
48
 
49
- # Load the weights (Update this string if your file is named differently in the HF root)
50
  model_path = "FastEDSR_x2_31dB.pth"
51
  model.load_state_dict(torch.load(model_path, map_location=device))
52
  model.eval()
53
 
54
- # --- 3. INFERENCE FUNCTION ---
55
- def upscale_image(img):
56
- if img is None:
57
- return None
 
 
 
 
 
 
58
 
59
- # Enforce constraints to prevent CPU OOM timeouts
60
- # Max input 1024px -> Max output 2048px (2K)
61
  max_input_dim = 1024
62
  w, h = img.size
63
 
64
  if w > max_input_dim or h > max_input_dim:
65
  scale = max_input_dim / max(w, h)
66
- new_w, new_h = int(w * scale), int(h * scale)
67
- img = img.resize((new_w, new_h), Image.BICUBIC)
68
 
69
- # Preprocess
70
  img = img.convert('RGB')
71
  input_tensor = TF.to_tensor(img).unsqueeze(0).to(device)
72
 
73
- # Forward Pass
74
  with torch.no_grad():
75
  output_tensor = model(input_tensor)
76
 
77
- # Postprocess
78
  output_tensor = output_tensor.squeeze(0).clamp(0, 1)
79
  output_img = TF.to_pil_image(output_tensor)
80
 
81
- return output_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # --- 4. GRADIO UI ---
84
  with gr.Blocks(theme=gr.themes.Soft()) as app:
85
  gr.Markdown(
86
  """
87
  # ⚡ FastEDSR 2x Image Upscaler
88
- Upload an image to enhance and upscale it by 2x.
89
- *Note: To ensure stability on CPU infrastructure, input images larger than 1024px are proportionally downscaled before processing to guarantee a maximum 2K output.*
90
  """
91
  )
92
 
93
- with gr.Row():
94
- with gr.Column():
95
- input_image = gr.Image(type="pil", label="Low Resolution Input")
96
- upscale_btn = gr.Button("Upscale Image", variant="primary")
97
- with gr.Column():
98
- output_image = gr.Image(type="pil", label="2x High Resolution Output")
99
-
100
- upscale_btn.click(fn=upscale_image, inputs=input_image, outputs=output_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
  app.launch()
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import gradio as gr
5
+ from gradio_imageslider import ImageSlider
6
  from PIL import Image
7
  import torchvision.transforms.functional as TF
8
 
 
44
  return base_upscaled + details
45
 
46
  # --- 2. INITIALIZATION ---
47
+ device = torch.device('cpu')
48
  model = FastEDSR(scale_factor=2, num_blocks=8, channels=64)
49
 
50
+ # Load the weights
51
  model_path = "FastEDSR_x2_31dB.pth"
52
  model.load_state_dict(torch.load(model_path, map_location=device))
53
  model.eval()
54
 
55
+ def calc_psnr(pred, target):
56
+ mse = torch.mean((pred - target) ** 2)
57
+ if mse == 0:
58
+ return 100.0
59
+ return 10 * torch.log10(1.0 / mse).item()
60
+
61
+ # --- 3. INFERENCE FUNCTIONS ---
62
+
63
+ def standard_upscale(img):
64
+ if img is None: return None, ""
65
 
 
 
66
  max_input_dim = 1024
67
  w, h = img.size
68
 
69
  if w > max_input_dim or h > max_input_dim:
70
  scale = max_input_dim / max(w, h)
71
+ w, h = int(w * scale), int(h * scale)
72
+ img = img.resize((w, h), Image.BICUBIC)
73
 
 
74
  img = img.convert('RGB')
75
  input_tensor = TF.to_tensor(img).unsqueeze(0).to(device)
76
 
 
77
  with torch.no_grad():
78
  output_tensor = model(input_tensor)
79
 
 
80
  output_tensor = output_tensor.squeeze(0).clamp(0, 1)
81
  output_img = TF.to_pil_image(output_tensor)
82
 
83
+ new_w, new_h = output_img.size
84
+
85
+ details = (
86
+ f"### Resolution Details\n"
87
+ f"**Before:** {w} x {h} ({w * h:,} pixels)\n\n"
88
+ f"**After:** {new_w} x {new_h} ({new_w * new_h:,} pixels)"
89
+ )
90
+
91
+ return output_img, details
92
+
93
+ def benchmark_upscale(hr_img):
94
+ if hr_img is None: return "", None, None
95
+
96
+ hr_img = hr_img.convert('RGB')
97
+ w, h = hr_img.size
98
+
99
+ # Enforce even dimensions so 2x scaling mathematically matches
100
+ w = w - (w % 2)
101
+ h = h - (h % 2)
102
+ hr_img = hr_img.crop((0, 0, w, h))
103
+
104
+ max_input_dim = 2048 # HR can be 2048 because LR will be 1024
105
+ if w > max_input_dim or h > max_input_dim:
106
+ scale = max_input_dim / max(w, h)
107
+ w, h = int(w * scale), int(h * scale)
108
+ # Ensure even dimensions again after resize
109
+ w = w - (w % 2)
110
+ h = h - (h % 2)
111
+ hr_img = hr_img.resize((w, h), Image.BICUBIC)
112
+
113
+ # Create the simulated Low-Res image
114
+ lr_w, lr_h = w // 2, h // 2
115
+ lr_img = hr_img.resize((lr_w, lr_h), Image.BICUBIC)
116
+
117
+ # Run Inference
118
+ lr_tensor = TF.to_tensor(lr_img).unsqueeze(0).to(device)
119
+ hr_tensor = TF.to_tensor(hr_img).unsqueeze(0).to(device)
120
+
121
+ with torch.no_grad():
122
+ pred_tensor = model(lr_tensor).clamp(0, 1)
123
+
124
+ # Calculate PSNR
125
+ psnr = calc_psnr(pred_tensor, hr_tensor)
126
+ pred_img = TF.to_pil_image(pred_tensor.squeeze(0))
127
+
128
+ # Resize LR using NEAREST so it looks accurately pixelated in the slider comparison
129
+ lr_slider_img = lr_img.resize((w, h), Image.NEAREST)
130
+
131
+ details = (
132
+ f"### Benchmark Results\n"
133
+ f"**PSNR:** {psnr:.2f} dB\n\n"
134
+ f"**Low-Res Input:** {lr_w} x {lr_h} ({lr_w * lr_h:,} pixels)\n\n"
135
+ f"**Model Output & Ground Truth:** {w} x {h} ({w * h:,} pixels)"
136
+ )
137
+
138
+ return details, (lr_slider_img, pred_img), (hr_img, pred_img)
139
 
140
  # --- 4. GRADIO UI ---
141
  with gr.Blocks(theme=gr.themes.Soft()) as app:
142
  gr.Markdown(
143
  """
144
  # ⚡ FastEDSR 2x Image Upscaler
145
+ Upload an image to enhance and upscale it by 2x.
 
146
  """
147
  )
148
 
149
+ with gr.Tabs():
150
+ # TAB 1: STANDARD
151
+ with gr.TabItem(" Standard Upscaling"):
152
+ gr.Markdown("Directly upscale any low-resolution image.")
153
+ with gr.Row():
154
+ with gr.Column():
155
+ std_input = gr.Image(type="pil", label="Low Resolution Input")
156
+ std_btn = gr.Button("Upscale Image", variant="primary")
157
+ with gr.Column():
158
+ std_output = gr.Image(type="pil", label="2x High Resolution Output")
159
+ std_details = gr.Markdown()
160
+
161
+ std_btn.click(fn=standard_upscale, inputs=std_input, outputs=[std_output, std_details])
162
+
163
+ # TAB 2: BENCHMARK
164
+ with gr.TabItem("📊 Benchmark Mode"):
165
+ gr.Markdown("Upload a high-quality image. The app will compress it, upscale it, and measure the PSNR quality against the original.")
166
+ with gr.Row():
167
+ with gr.Column():
168
+ bm_input = gr.Image(type="pil", label="Ground Truth (High Res) Image")
169
+ bm_btn = gr.Button("Run Benchmark", variant="primary")
170
+ bm_details = gr.Markdown()
171
+ with gr.Column():
172
+ gr.Markdown("### Low-Res vs. Model Prediction")
173
+ slider_lr_pred = ImageSlider(label="Left: Pixelated Low-Res | Right: FastEDSR")
174
+
175
+ gr.Markdown("### Ground Truth vs. Model Prediction")
176
+ slider_hr_pred = ImageSlider(label="Left: Original HR | Right: FastEDSR")
177
+
178
+ bm_btn.click(fn=benchmark_upscale, inputs=bm_input, outputs=[bm_details, slider_lr_pred, slider_hr_pred])
179
 
180
  if __name__ == "__main__":
181
  app.launch()