import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from PIL import Image import torchvision.transforms.functional as TF # --- 1. MODEL ARCHITECTURE --- class PureResBlock(nn.Module): def __init__(self, channels): super(PureResBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate') self.act = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate') self.res_scale = 1.0 def forward(self, x): res = self.conv1(x) res = self.act(res) res = self.conv2(res) return x + (res * self.res_scale) class FastEDSR(nn.Module): def __init__(self, scale_factor=2, num_blocks=8, channels=64): super(FastEDSR, self).__init__() self.scale_factor = scale_factor self.head = nn.Conv2d(3, channels, kernel_size=3, padding=1, padding_mode='replicate') self.body = nn.Sequential(*[PureResBlock(channels) for _ in range(num_blocks)]) self.tail = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate') self.sub_pixel = nn.Sequential( nn.Conv2d(channels, 3 * (scale_factor ** 2), kernel_size=3, padding=1, padding_mode='replicate'), nn.PixelShuffle(scale_factor) ) def forward(self, x): base_upscaled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False) f0 = self.head(x) f_body = self.body(f0) f_body = self.tail(f_body) f_out = f0 + f_body details = self.sub_pixel(f_out) return base_upscaled + details # --- 2. INITIALIZATION --- device = torch.device('cpu') # Hugging Face Free Tier runs on CPU model = FastEDSR(scale_factor=2, num_blocks=8, channels=64) # Load the weights (Update this string if your file is named differently in the HF root) model_path = "FastEDSR_x2_31dB.pth" model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # --- 3. INFERENCE FUNCTION --- def upscale_image(img): if img is None: return None # Enforce constraints to prevent CPU OOM timeouts # Max input 1024px -> Max output 2048px (2K) max_input_dim = 1024 w, h = img.size if w > max_input_dim or h > max_input_dim: scale = max_input_dim / max(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.BICUBIC) # Preprocess img = img.convert('RGB') input_tensor = TF.to_tensor(img).unsqueeze(0).to(device) # Forward Pass with torch.no_grad(): output_tensor = model(input_tensor) # Postprocess output_tensor = output_tensor.squeeze(0).clamp(0, 1) output_img = TF.to_pil_image(output_tensor) return output_img # --- 4. GRADIO UI --- with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown( """ # ⚡ FastEDSR 2x Image Upscaler Upload an image to enhance and upscale it by 2x. *Note: To ensure stability on CPU infrastructure, input images larger than 1024px are proportionally downscaled before processing to guarantee a maximum 2K output.* """ ) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Low Resolution Input") upscale_btn = gr.Button("Upscale Image", variant="primary") with gr.Column(): output_image = gr.Image(type="pil", label="2x High Resolution Output") upscale_btn.click(fn=upscale_image, inputs=input_image, outputs=output_image) if __name__ == "__main__": app.launch()