Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image, ImageFilter | |
| import os | |
| import time | |
| import gc | |
| from RRDBNet_arch import RRDBNet # Ensure this file is in the same directory | |
| # ------------------------- | |
| # Load ESRGAN Model (from root directory) | |
| # ------------------------- | |
| def load_model(): | |
| model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23) | |
| model_path = "RRDB_ESRGAN_x4.pth" | |
| model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ------------------------- | |
| # Utility Functions | |
| # ------------------------- | |
| def preprocess(img_pil): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| return transform(img_pil).unsqueeze(0) | |
| def postprocess(tensor): | |
| tensor = tensor.squeeze().detach().cpu() | |
| tensor = (tensor * 0.5 + 0.5).clamp(0, 1) | |
| return transforms.ToPILImage()(tensor) | |
| def fuse_images(img1, img2): | |
| img1 = img1.resize((384, 384), Image.LANCZOS) | |
| img2 = img2.resize((384, 384), Image.LANCZOS) | |
| return Image.blend(img1, img2, alpha=0.5) | |
| def sharpen_image(image: Image.Image) -> Image.Image: | |
| return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1)) | |
| def upscale_to_8k(img: Image.Image) -> Image.Image: | |
| return img.resize((8000, 8000), Image.LANCZOS) | |
| # ------------------------- | |
| # Inference Pipeline | |
| # ------------------------- | |
| def esrgan_pipeline(img1, img2, _): | |
| if not img1 or not img2: | |
| return None, None, "Please upload two valid images." | |
| img1 = img1.convert("RGB") | |
| img2 = img2.convert("RGB") | |
| fused_img = fuse_images(img1, img2) | |
| start = time.time() | |
| with torch.no_grad(): | |
| input_tensor = preprocess(fused_img) | |
| sr_output = model(input_tensor) | |
| base_output = postprocess(sr_output) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| upscaled_img = upscale_to_8k(base_output) | |
| final_img = sharpen_image(upscaled_img) | |
| elapsed = time.time() - start | |
| sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item() | |
| msg = f"β Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}" | |
| return base_output, final_img, msg | |
| # ------------------------- | |
| # Gradio UI | |
| # ------------------------- | |
| with gr.Blocks(title="8000x8000 ESRGAN Ultra-HD Super-Resolution") as demo: | |
| gr.Markdown("## π§ ESRGAN Ultra-HD Image Upscaler (8000 Γ 8000 Output)") | |
| gr.Markdown("Upload **two low-res images** β Fuse β ESRGAN β Final **8000 Γ 8000** enhanced image with sharpening.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input1 = gr.Image(type="pil", label="Low-Res Image 1") | |
| img_input2 = gr.Image(type="pil", label="Low-Res Image 2") | |
| dummy_resolution = gr.Radio(["8000x8000"], value="8000x8000", label="Output Resolution (Fixed)") | |
| run_button = gr.Button("π Run ESRGAN") | |
| with gr.Column(): | |
| output_esrgan = gr.Image(label="π§ ESRGAN Output") | |
| output_final = gr.Image(label="ποΈ Final Enhanced Output (8000 Γ 8000)") | |
| result_text = gr.Textbox(label="π Output Log") | |
| gr.Markdown("---") | |
| gr.HTML( | |
| "<div style='text-align: center; font-size: 16px;'>" | |
| "Made with β€οΈ by <b>CodeKarma</b> as a part of <b>Bharatiya Antariksh Hackathon 2025</b>" | |
| "</div>" | |
| ) | |
| run_button.click(fn=esrgan_pipeline, | |
| inputs=[img_input1, img_input2, dummy_resolution], | |
| outputs=[output_esrgan, output_final, result_text]) | |
| # ------------------------- | |
| # Launch | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() | |