import gradio as gr import torch from torchvision import transforms from PIL import Image, ImageFilter import os import time import gc from RRDBNet_arch import RRDBNet # Ensure this file is in the same directory # ------------------------- # Load ESRGAN Model (from root directory) # ------------------------- @torch.no_grad() def load_model(): model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23) model_path = "RRDB_ESRGAN_x4.pth" model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True) model.eval() return model model = load_model() # ------------------------- # Utility Functions # ------------------------- def preprocess(img_pil): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) return transform(img_pil).unsqueeze(0) def postprocess(tensor): tensor = tensor.squeeze().detach().cpu() tensor = (tensor * 0.5 + 0.5).clamp(0, 1) return transforms.ToPILImage()(tensor) def fuse_images(img1, img2): img1 = img1.resize((384, 384), Image.LANCZOS) img2 = img2.resize((384, 384), Image.LANCZOS) return Image.blend(img1, img2, alpha=0.5) def sharpen_image(image: Image.Image) -> Image.Image: return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1)) def upscale_to_8k(img: Image.Image) -> Image.Image: return img.resize((8000, 8000), Image.LANCZOS) # ------------------------- # Inference Pipeline # ------------------------- def esrgan_pipeline(img1, img2, _): if not img1 or not img2: return None, None, "Please upload two valid images." img1 = img1.convert("RGB") img2 = img2.convert("RGB") fused_img = fuse_images(img1, img2) start = time.time() with torch.no_grad(): input_tensor = preprocess(fused_img) sr_output = model(input_tensor) base_output = postprocess(sr_output) gc.collect() torch.cuda.empty_cache() upscaled_img = upscale_to_8k(base_output) final_img = sharpen_image(upscaled_img) elapsed = time.time() - start sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item() msg = f"✅ Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}" return base_output, final_img, msg # ------------------------- # Gradio UI # ------------------------- with gr.Blocks(title="8000x8000 ESRGAN Ultra-HD Super-Resolution") as demo: gr.Markdown("## 🧠 ESRGAN Ultra-HD Image Upscaler (8000 × 8000 Output)") gr.Markdown("Upload **two low-res images** → Fuse → ESRGAN → Final **8000 × 8000** enhanced image with sharpening.") with gr.Row(): with gr.Column(): img_input1 = gr.Image(type="pil", label="Low-Res Image 1") img_input2 = gr.Image(type="pil", label="Low-Res Image 2") dummy_resolution = gr.Radio(["8000x8000"], value="8000x8000", label="Output Resolution (Fixed)") run_button = gr.Button("🚀 Run ESRGAN") with gr.Column(): output_esrgan = gr.Image(label="🧠 ESRGAN Output") output_final = gr.Image(label="🏞️ Final Enhanced Output (8000 × 8000)") result_text = gr.Textbox(label="📊 Output Log") gr.Markdown("---") gr.HTML( "