Spaces:
Sleeping
Sleeping
File size: 7,356 Bytes
e3fd7d4 f0ea2eb e3fd7d4 f0ea2eb e3fd7d4 f0ea2eb e3fd7d4 d4a7e1c e3fd7d4 3800d7b e3fd7d4 f0ea2eb e3fd7d4 f0ea2eb 3800d7b f0ea2eb d4a7e1c f0ea2eb 3800d7b f0ea2eb 0867bd7 f0ea2eb e3fd7d4 1f0e31a e3fd7d4 d4a7e1c 3800d7b e3fd7d4 f0ea2eb d4a7e1c f0ea2eb 3800d7b f0ea2eb 0867bd7 f0ea2eb 0867bd7 f0ea2eb e3fd7d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from PIL import Image
import torchvision.transforms.functional as TF
# --- 1. MODEL ARCHITECTURE ---
class PureResBlock(nn.Module):
def __init__(self, channels):
super(PureResBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.act = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.res_scale = 1.0
def forward(self, x):
res = self.conv1(x)
res = self.act(res)
res = self.conv2(res)
return x + (res * self.res_scale)
class FastEDSR(nn.Module):
def __init__(self, scale_factor=2, num_blocks=8, channels=64):
super(FastEDSR, self).__init__()
self.scale_factor = scale_factor
self.head = nn.Conv2d(3, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.body = nn.Sequential(*[PureResBlock(channels) for _ in range(num_blocks)])
self.tail = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.sub_pixel = nn.Sequential(
nn.Conv2d(channels, 3 * (scale_factor ** 2), kernel_size=3, padding=1, padding_mode='replicate'),
nn.PixelShuffle(scale_factor)
)
def forward(self, x):
base_upscaled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
f0 = self.head(x)
f_body = self.body(f0)
f_body = self.tail(f_body)
f_out = f0 + f_body
details = self.sub_pixel(f_out)
return base_upscaled + details
# --- 2. INITIALIZATION ---
device = torch.device('cpu')
model = FastEDSR(scale_factor=2, num_blocks=8, channels=64)
# Load the weights
model_path = "FastEDSR_x2_31dB.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
def calc_psnr(pred, target):
mse = torch.mean((pred - target) ** 2)
if mse == 0:
return 100.0
return 10 * torch.log10(1.0 / mse).item()
# --- 3. INFERENCE FUNCTIONS ---
def standard_upscale(img):
if img is None: return None, ""
max_input_dim = 2048
w, h = img.size
if w > max_input_dim or h > max_input_dim:
gr.Warning(f"Input image exceeded the 2K ({max_input_dim}px) limit. It has been proportionally downscaled to ensure the 4K output fits in server memory and constraints.")
scale = max_input_dim / max(w, h)
w, h = int(w * scale), int(h * scale)
img = img.resize((w, h), Image.BICUBIC)
img = img.convert('RGB')
input_tensor = TF.to_tensor(img).unsqueeze(0).to(device)
with torch.no_grad():
output_tensor = model(input_tensor)
output_tensor = output_tensor.squeeze(0).clamp(0, 1)
output_img = TF.to_pil_image(output_tensor)
new_w, new_h = output_img.size
details = (
f"### Resolution Details\n"
f"- **Before:** {w} x {h} ({w * h:,} pixels)\n\n"
f"- **After:** {new_w} x {new_h} ({new_w * new_h:,} pixels)"
)
return output_img, details
def benchmark_upscale(hr_img):
if hr_img is None: return "", None, None
hr_img = hr_img.convert('RGB')
w, h = hr_img.size
w = w - (w % 2)
h = h - (h % 2)
hr_img = hr_img.crop((0, 0, w, h))
max_hr_dim = 4096
if w > max_hr_dim or h > max_hr_dim:
gr.Warning(f"Ground truth image exceeded the 4K ({max_hr_dim}px) limit. It has been proportionally downscaled.")
scale = max_hr_dim / max(w, h)
w, h = int(w * scale), int(h * scale)
w = w - (w % 2)
h = h - (h % 2)
hr_img = hr_img.resize((w, h), Image.BICUBIC)
lr_w, lr_h = w // 2, h // 2
lr_img = hr_img.resize((lr_w, lr_h), Image.BICUBIC)
lr_tensor = TF.to_tensor(lr_img).unsqueeze(0).to(device)
hr_tensor = TF.to_tensor(hr_img).unsqueeze(0).to(device)
with torch.no_grad():
pred_tensor = model(lr_tensor).clamp(0, 1)
psnr = calc_psnr(pred_tensor, hr_tensor)
pred_img = TF.to_pil_image(pred_tensor.squeeze(0))
lr_slider_img = lr_img.resize((w, h), Image.NEAREST)
details = (
f"### Benchmark Results\n"
f"- **PSNR:** {psnr:.2f} dB\n\n"
f"- **Low-Res Input:** {lr_w} x {lr_h} ({lr_w * lr_h:,} pixels)\n\n"
f"- **Model Output & Ground Truth:** {w} x {h} ({w * h:,} pixels)"
)
# Gradio's native ImageSlider expects a tuple of (image1, image2)
return details, (lr_slider_img, pred_img), (hr_img, pred_img)
# --- 4. GRADIO UI ---
with gr.Blocks() as app:
gr.Markdown(
"""
# ⚡ FastEDSR 2x Image Upscaler
Upload an image to enhance and upscale it by 2x. Supports up to 4K resolution output.
For more information on the model, training, and use, and for a local demo, visit our [website](https://infinitode.netlify.app).
This model was trained on DIV2K:
```
@inproceedings{Agustsson2017,
title={NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study},
author={Agustsson, Eirikur and Timofte, Radu},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition Workshops},
year={2017}
}
```
"""
)
with gr.Tabs():
# TAB 1: STANDARD
with gr.TabItem("⚡ Standard Upscaling"):
gr.Markdown("Directly upscale any low-resolution image. Inputs over 2K (2048px) will be scaled down to prevent memory limits.")
with gr.Row():
with gr.Column():
std_input = gr.Image(type="pil", label="Low Resolution Input")
std_btn = gr.Button("Upscale Image", variant="primary")
with gr.Column():
std_output = gr.Image(type="pil", label="2x High Resolution Output")
std_details = gr.Markdown()
std_btn.click(fn=standard_upscale, inputs=std_input, outputs=[std_output, std_details])
# TAB 2: BENCHMARK
with gr.TabItem("📊 Benchmark Mode"):
gr.Markdown("Upload a high-quality image (up to 4K). The app will compress it to 2x lower resolution, upscale it using FastEDSR, and measure the PSNR quality against the original. It will also generate side-by-side comparisons for you to view.")
with gr.Row():
with gr.Column():
bm_input = gr.Image(type="pil", label="Ground Truth (High Res) Image")
bm_btn = gr.Button("Run Benchmark", variant="primary")
bm_details = gr.Markdown()
with gr.Column():
gr.Markdown("### Low-Res vs. Model Prediction")
slider_lr_pred = gr.ImageSlider(label="Left: Pixelated Low-Res | Right: FastEDSR")
gr.Markdown("### Ground Truth vs. Model Prediction")
slider_hr_pred = gr.ImageSlider(label="Left: Original HR | Right: FastEDSR")
bm_btn.click(fn=benchmark_upscale, inputs=bm_input, outputs=[bm_details, slider_lr_pred, slider_hr_pred])
if __name__ == "__main__":
app.launch() |