Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| # --- 1. MODEL ARCHITECTURE --- | |
| class PureResBlock(nn.Module): | |
| def __init__(self, channels): | |
| super(PureResBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate') | |
| self.act = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate') | |
| self.res_scale = 1.0 | |
| def forward(self, x): | |
| res = self.conv1(x) | |
| res = self.act(res) | |
| res = self.conv2(res) | |
| return x + (res * self.res_scale) | |
| class FastEDSR(nn.Module): | |
| def __init__(self, scale_factor=2, num_blocks=8, channels=64): | |
| super(FastEDSR, self).__init__() | |
| self.scale_factor = scale_factor | |
| self.head = nn.Conv2d(3, channels, kernel_size=3, padding=1, padding_mode='replicate') | |
| self.body = nn.Sequential(*[PureResBlock(channels) for _ in range(num_blocks)]) | |
| self.tail = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate') | |
| self.sub_pixel = nn.Sequential( | |
| nn.Conv2d(channels, 3 * (scale_factor ** 2), kernel_size=3, padding=1, padding_mode='replicate'), | |
| nn.PixelShuffle(scale_factor) | |
| ) | |
| def forward(self, x): | |
| base_upscaled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False) | |
| f0 = self.head(x) | |
| f_body = self.body(f0) | |
| f_body = self.tail(f_body) | |
| f_out = f0 + f_body | |
| details = self.sub_pixel(f_out) | |
| return base_upscaled + details | |
| # --- 2. INITIALIZATION --- | |
| device = torch.device('cpu') # Hugging Face Free Tier runs on CPU | |
| model = FastEDSR(scale_factor=2, num_blocks=8, channels=64) | |
| # Load the weights (Update this string if your file is named differently in the HF root) | |
| model_path = "FastEDSR_x2_31dB.pth" | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| # --- 3. INFERENCE FUNCTION --- | |
| def upscale_image(img): | |
| if img is None: | |
| return None | |
| # Enforce constraints to prevent CPU OOM timeouts | |
| # Max input 1024px -> Max output 2048px (2K) | |
| max_input_dim = 1024 | |
| w, h = img.size | |
| if w > max_input_dim or h > max_input_dim: | |
| scale = max_input_dim / max(w, h) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| img = img.resize((new_w, new_h), Image.BICUBIC) | |
| # Preprocess | |
| img = img.convert('RGB') | |
| input_tensor = TF.to_tensor(img).unsqueeze(0).to(device) | |
| # Forward Pass | |
| with torch.no_grad(): | |
| output_tensor = model(input_tensor) | |
| # Postprocess | |
| output_tensor = output_tensor.squeeze(0).clamp(0, 1) | |
| output_img = TF.to_pil_image(output_tensor) | |
| return output_img | |
| # --- 4. GRADIO UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown( | |
| """ | |
| # ⚡ FastEDSR 2x Image Upscaler | |
| Upload an image to enhance and upscale it by 2x. | |
| *Note: To ensure stability on CPU infrastructure, input images larger than 1024px are proportionally downscaled before processing to guarantee a maximum 2K output.* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Low Resolution Input") | |
| upscale_btn = gr.Button("Upscale Image", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="2x High Resolution Output") | |
| upscale_btn.click(fn=upscale_image, inputs=input_image, outputs=output_image) | |
| if __name__ == "__main__": | |
| app.launch() |