File size: 3,748 Bytes
1545518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image, ImageFilter
import os
import time
import gc

from RRDBNet_arch import RRDBNet  # Ensure this file is in the same directory

# -------------------------
# Load ESRGAN Model (from root directory)
# -------------------------
@torch.no_grad()
def load_model():
    model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
    model_path = "RRDB_ESRGAN_x4.pth"
    model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
    model.eval()
    return model

model = load_model()

# -------------------------
# Utility Functions
# -------------------------
def preprocess(img_pil):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    return transform(img_pil).unsqueeze(0)

def postprocess(tensor):
    tensor = tensor.squeeze().detach().cpu()
    tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
    return transforms.ToPILImage()(tensor)

def fuse_images(img1, img2):
    img1 = img1.resize((384, 384), Image.LANCZOS)
    img2 = img2.resize((384, 384), Image.LANCZOS)
    return Image.blend(img1, img2, alpha=0.5)

def sharpen_image(image: Image.Image) -> Image.Image:
    return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1))

def upscale_to_8k(img: Image.Image) -> Image.Image:
    return img.resize((8000, 8000), Image.LANCZOS)

# -------------------------
# Inference Pipeline
# -------------------------
def esrgan_pipeline(img1, img2, _):
    if not img1 or not img2:
        return None, None, "Please upload two valid images."

    img1 = img1.convert("RGB")
    img2 = img2.convert("RGB")
    fused_img = fuse_images(img1, img2)

    start = time.time()

    with torch.no_grad():
        input_tensor = preprocess(fused_img)
        sr_output = model(input_tensor)

    base_output = postprocess(sr_output)

    gc.collect()
    torch.cuda.empty_cache()

    upscaled_img = upscale_to_8k(base_output)
    final_img = sharpen_image(upscaled_img)

    elapsed = time.time() - start
    sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item()
    msg = f"βœ… Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}"

    return base_output, final_img, msg

# -------------------------
# Gradio UI
# -------------------------
with gr.Blocks(title="8000x8000 ESRGAN Ultra-HD Super-Resolution") as demo:
    gr.Markdown("## 🧠 ESRGAN Ultra-HD Image Upscaler (8000 Γ— 8000 Output)")
    gr.Markdown("Upload **two low-res images** β†’ Fuse β†’ ESRGAN β†’ Final **8000 Γ— 8000** enhanced image with sharpening.")

    with gr.Row():
        with gr.Column():
            img_input1 = gr.Image(type="pil", label="Low-Res Image 1")
            img_input2 = gr.Image(type="pil", label="Low-Res Image 2")
            dummy_resolution = gr.Radio(["8000x8000"], value="8000x8000", label="Output Resolution (Fixed)")

            run_button = gr.Button("πŸš€ Run ESRGAN")

        with gr.Column():
            output_esrgan = gr.Image(label="🧠 ESRGAN Output")
            output_final = gr.Image(label="🏞️ Final Enhanced Output (8000 Γ— 8000)")
            result_text = gr.Textbox(label="πŸ“Š Output Log")

    gr.Markdown("---")
    gr.HTML(
        "<div style='text-align: center; font-size: 16px;'>"
        "Made with ❀️ by <b>CodeKarma</b> as a part of <b>Bharatiya Antariksh Hackathon 2025</b>"
        "</div>"
    )

    run_button.click(fn=esrgan_pipeline,
                     inputs=[img_input1, img_input2, dummy_resolution],
                     outputs=[output_esrgan, output_final, result_text])

# -------------------------
# Launch
# -------------------------
if __name__ == "__main__":
    demo.launch()