Sourishdey05 commited on
Commit
1545518
·
verified ·
1 Parent(s): a25dbda

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image, ImageFilter
5
+ import os
6
+ import time
7
+ import gc
8
+
9
+ from RRDBNet_arch import RRDBNet # Ensure this file is in the same directory
10
+
11
+ # -------------------------
12
+ # Load ESRGAN Model (from root directory)
13
+ # -------------------------
14
+ @torch.no_grad()
15
+ def load_model():
16
+ model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
17
+ model_path = "RRDB_ESRGAN_x4.pth"
18
+ model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
19
+ model.eval()
20
+ return model
21
+
22
+ model = load_model()
23
+
24
+ # -------------------------
25
+ # Utility Functions
26
+ # -------------------------
27
+ def preprocess(img_pil):
28
+ transform = transforms.Compose([
29
+ transforms.ToTensor(),
30
+ transforms.Normalize((0.5,), (0.5,))
31
+ ])
32
+ return transform(img_pil).unsqueeze(0)
33
+
34
+ def postprocess(tensor):
35
+ tensor = tensor.squeeze().detach().cpu()
36
+ tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
37
+ return transforms.ToPILImage()(tensor)
38
+
39
+ def fuse_images(img1, img2):
40
+ img1 = img1.resize((384, 384), Image.LANCZOS)
41
+ img2 = img2.resize((384, 384), Image.LANCZOS)
42
+ return Image.blend(img1, img2, alpha=0.5)
43
+
44
+ def sharpen_image(image: Image.Image) -> Image.Image:
45
+ return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1))
46
+
47
+ def upscale_to_8k(img: Image.Image) -> Image.Image:
48
+ return img.resize((8000, 8000), Image.LANCZOS)
49
+
50
+ # -------------------------
51
+ # Inference Pipeline
52
+ # -------------------------
53
+ def esrgan_pipeline(img1, img2, _):
54
+ if not img1 or not img2:
55
+ return None, None, "Please upload two valid images."
56
+
57
+ img1 = img1.convert("RGB")
58
+ img2 = img2.convert("RGB")
59
+ fused_img = fuse_images(img1, img2)
60
+
61
+ start = time.time()
62
+
63
+ with torch.no_grad():
64
+ input_tensor = preprocess(fused_img)
65
+ sr_output = model(input_tensor)
66
+
67
+ base_output = postprocess(sr_output)
68
+
69
+ gc.collect()
70
+ torch.cuda.empty_cache()
71
+
72
+ upscaled_img = upscale_to_8k(base_output)
73
+ final_img = sharpen_image(upscaled_img)
74
+
75
+ elapsed = time.time() - start
76
+ sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item()
77
+ msg = f"✅ Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}"
78
+
79
+ return base_output, final_img, msg
80
+
81
+ # -------------------------
82
+ # Gradio UI
83
+ # -------------------------
84
+ with gr.Blocks(title="8000x8000 ESRGAN Ultra-HD Super-Resolution") as demo:
85
+ gr.Markdown("## 🧠 ESRGAN Ultra-HD Image Upscaler (8000 × 8000 Output)")
86
+ gr.Markdown("Upload **two low-res images** → Fuse → ESRGAN → Final **8000 × 8000** enhanced image with sharpening.")
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ img_input1 = gr.Image(type="pil", label="Low-Res Image 1")
91
+ img_input2 = gr.Image(type="pil", label="Low-Res Image 2")
92
+ dummy_resolution = gr.Radio(["8000x8000"], value="8000x8000", label="Output Resolution (Fixed)")
93
+
94
+ run_button = gr.Button("🚀 Run ESRGAN")
95
+
96
+ with gr.Column():
97
+ output_esrgan = gr.Image(label="🧠 ESRGAN Output")
98
+ output_final = gr.Image(label="🏞️ Final Enhanced Output (8000 × 8000)")
99
+ result_text = gr.Textbox(label="📊 Output Log")
100
+
101
+ gr.Markdown("---")
102
+ gr.HTML(
103
+ "<div style='text-align: center; font-size: 16px;'>"
104
+ "Made with ❤️ by <b>CodeKarma</b> as a part of <b>Bharatiya Antariksh Hackathon 2025</b>"
105
+ "</div>"
106
+ )
107
+
108
+ run_button.click(fn=esrgan_pipeline,
109
+ inputs=[img_input1, img_input2, dummy_resolution],
110
+ outputs=[output_esrgan, output_final, result_text])
111
+
112
+ # -------------------------
113
+ # Launch
114
+ # -------------------------
115
+ if __name__ == "__main__":
116
+ demo.launch()