FastEDSR / app.py
JohanBeytell's picture
Create app.py
e3fd7d4 verified
raw
history blame
3.75 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from PIL import Image
import torchvision.transforms.functional as TF
# --- 1. MODEL ARCHITECTURE ---
class PureResBlock(nn.Module):
def __init__(self, channels):
super(PureResBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.act = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.res_scale = 1.0
def forward(self, x):
res = self.conv1(x)
res = self.act(res)
res = self.conv2(res)
return x + (res * self.res_scale)
class FastEDSR(nn.Module):
def __init__(self, scale_factor=2, num_blocks=8, channels=64):
super(FastEDSR, self).__init__()
self.scale_factor = scale_factor
self.head = nn.Conv2d(3, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.body = nn.Sequential(*[PureResBlock(channels) for _ in range(num_blocks)])
self.tail = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
self.sub_pixel = nn.Sequential(
nn.Conv2d(channels, 3 * (scale_factor ** 2), kernel_size=3, padding=1, padding_mode='replicate'),
nn.PixelShuffle(scale_factor)
)
def forward(self, x):
base_upscaled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
f0 = self.head(x)
f_body = self.body(f0)
f_body = self.tail(f_body)
f_out = f0 + f_body
details = self.sub_pixel(f_out)
return base_upscaled + details
# --- 2. INITIALIZATION ---
device = torch.device('cpu') # Hugging Face Free Tier runs on CPU
model = FastEDSR(scale_factor=2, num_blocks=8, channels=64)
# Load the weights (Update this string if your file is named differently in the HF root)
model_path = "FastEDSR_x2_31dB.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# --- 3. INFERENCE FUNCTION ---
def upscale_image(img):
if img is None:
return None
# Enforce constraints to prevent CPU OOM timeouts
# Max input 1024px -> Max output 2048px (2K)
max_input_dim = 1024
w, h = img.size
if w > max_input_dim or h > max_input_dim:
scale = max_input_dim / max(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.BICUBIC)
# Preprocess
img = img.convert('RGB')
input_tensor = TF.to_tensor(img).unsqueeze(0).to(device)
# Forward Pass
with torch.no_grad():
output_tensor = model(input_tensor)
# Postprocess
output_tensor = output_tensor.squeeze(0).clamp(0, 1)
output_img = TF.to_pil_image(output_tensor)
return output_img
# --- 4. GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# ⚡ FastEDSR 2x Image Upscaler
Upload an image to enhance and upscale it by 2x.
*Note: To ensure stability on CPU infrastructure, input images larger than 1024px are proportionally downscaled before processing to guarantee a maximum 2K output.*
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Low Resolution Input")
upscale_btn = gr.Button("Upscale Image", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="2x High Resolution Output")
upscale_btn.click(fn=upscale_image, inputs=input_image, outputs=output_image)
if __name__ == "__main__":
app.launch()