Sourishdey05 commited on
Commit
78cfd22
Β·
verified Β·
1 Parent(s): 6d731d9

Upload 3 files

Browse files
Files changed (3) hide show
  1. RRDBNet_arch.py +78 -0
  2. app.py +131 -131
  3. requirements.txt +6 -0
RRDBNet_arch.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def make_layer(block, n_layers):
8
+ layers = []
9
+ for _ in range(n_layers):
10
+ layers.append(block())
11
+ return nn.Sequential(*layers)
12
+
13
+
14
+ class ResidualDenseBlock_5C(nn.Module):
15
+ def __init__(self, nf=64, gc=32, bias=True):
16
+ super(ResidualDenseBlock_5C, self).__init__()
17
+ # gc: growth channel, i.e. intermediate channels
18
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
19
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
20
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
21
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
22
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
23
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
24
+
25
+ # initialization
26
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
27
+
28
+ def forward(self, x):
29
+ x1 = self.lrelu(self.conv1(x))
30
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
31
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
32
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
33
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
34
+ return x5 * 0.2 + x
35
+
36
+
37
+ class RRDB(nn.Module):
38
+ '''Residual in Residual Dense Block'''
39
+
40
+ def __init__(self, nf, gc=32):
41
+ super(RRDB, self).__init__()
42
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
43
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
44
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
45
+
46
+ def forward(self, x):
47
+ out = self.RDB1(x)
48
+ out = self.RDB2(out)
49
+ out = self.RDB3(out)
50
+ return out * 0.2 + x
51
+
52
+
53
+ class RRDBNet(nn.Module):
54
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
55
+ super(RRDBNet, self).__init__()
56
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
57
+
58
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
59
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
60
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
61
+ #### upsampling
62
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
65
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
66
+
67
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
68
+
69
+ def forward(self, x):
70
+ fea = self.conv_first(x)
71
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
72
+ fea = fea + trunk
73
+
74
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
75
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
76
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
77
+
78
+ return out
app.py CHANGED
@@ -1,131 +1,131 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision import transforms
4
- from PIL import Image, ImageFilter
5
- import os
6
- import time
7
- import gc
8
- import gdown
9
-
10
- from RRDBNet_arch import RRDBNet
11
-
12
- # -------------------------
13
- # Download from Google Drive if not present
14
- # -------------------------
15
- def ensure_model_downloaded():
16
- model_path = "models/RRDB_ESRGAN_x4.pth"
17
- if not os.path.exists(model_path):
18
- os.makedirs("models", exist_ok=True)
19
- file_id = "1P3Hbr51ZNsbNJIiWxrsHgl-D3I9n5ItN"
20
- gdown.download(f"https://drive.google.com/uc?id={file_id}", model_path, quiet=False)
21
-
22
- # -------------------------
23
- # Load ESRGAN Model
24
- # -------------------------
25
- @torch.no_grad()
26
- def load_model():
27
- ensure_model_downloaded()
28
- model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
29
- model_path = os.path.join("models", "RRDB_ESRGAN_x4.pth")
30
- model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
31
- model.eval()
32
- return model
33
-
34
- model = load_model()
35
-
36
- # -------------------------
37
- # Utility Functions
38
- # -------------------------
39
- def preprocess(img_pil):
40
- transform = transforms.Compose([
41
- transforms.ToTensor(),
42
- transforms.Normalize((0.5,), (0.5,))
43
- ])
44
- return transform(img_pil).unsqueeze(0)
45
-
46
- def postprocess(tensor):
47
- tensor = tensor.squeeze().detach().cpu()
48
- tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
49
- return transforms.ToPILImage()(tensor)
50
-
51
- def fuse_images(img1, img2):
52
- img1 = img1.resize((384, 384), Image.LANCZOS)
53
- img2 = img2.resize((384, 384), Image.LANCZOS)
54
- return Image.blend(img1, img2, alpha=0.5)
55
-
56
- def sharpen_image(image: Image.Image) -> Image.Image:
57
- return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1))
58
-
59
- def upscale_to_resolution(img: Image.Image, resolution: str = "4K") -> Image.Image:
60
- target_size = (3840, 2160) if resolution == "4K" else (7680, 4320)
61
- return img.resize(target_size, Image.LANCZOS)
62
-
63
- # -------------------------
64
- # Inference Pipeline
65
- # -------------------------
66
- def esrgan_pipeline(img1, img2, resolution):
67
- if not img1 or not img2:
68
- return None, None, "Please upload two valid images."
69
-
70
- img1 = img1.convert("RGB")
71
- img2 = img2.convert("RGB")
72
- fused_img = fuse_images(img1, img2)
73
-
74
- start = time.time()
75
-
76
- with torch.no_grad():
77
- input_tensor = preprocess(fused_img)
78
- sr1 = model(input_tensor)
79
- sr2 = model(sr1)
80
- sr3 = model(sr2)
81
-
82
- base_output = postprocess(sr3)
83
-
84
- gc.collect()
85
- torch.cuda.empty_cache()
86
-
87
- upscaled_img = upscale_to_resolution(base_output, resolution)
88
- final_img = sharpen_image(upscaled_img)
89
-
90
- elapsed = time.time() - start
91
- sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item()
92
- msg = f"βœ… Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}"
93
-
94
- return base_output, final_img, msg
95
-
96
- # -------------------------
97
- # Gradio UI
98
- # -------------------------
99
- with gr.Blocks(title="Triple-Pass ESRGAN Super-Resolution") as demo:
100
- gr.Markdown("## 🧠 Triple-Pass ESRGAN Ultra-HD Upscaler")
101
- gr.Markdown("Upload **two low-res images** β†’ ESRGAN (3 passes) β†’ Final **4K/8K** enhanced image with sharpening.")
102
-
103
- with gr.Row():
104
- with gr.Column():
105
- img_input1 = gr.Image(type="pil", label="Low-Res Image 1")
106
- img_input2 = gr.Image(type="pil", label="Low-Res Image 2")
107
- resolution_choice = gr.Radio(["4K", "8K"], value="4K", label="Select Output Resolution")
108
- run_button = gr.Button("πŸš€ Run ESRGAN")
109
-
110
- with gr.Column():
111
- output_esrgan = gr.Image(label="🧠 ESRGAN 3x Output")
112
- output_final = gr.Image(label="🏞️ Final Enhanced Output")
113
- result_text = gr.Textbox(label="πŸ“Š Output Log")
114
-
115
- gr.Markdown("---")
116
- gr.Markdown(
117
- "<div style='text-align: center; font-size: 16px;'>"
118
- "Made with ❀️ by <b>CodeKarma</b> as a part of <b>Bharatiya Antariksh Hackathon 2025</b>"
119
- "</div>",
120
- unsafe_allow_html=True
121
- )
122
-
123
- run_button.click(fn=esrgan_pipeline,
124
- inputs=[img_input1, img_input2, resolution_choice],
125
- outputs=[output_esrgan, output_final, result_text])
126
-
127
- # -------------------------
128
- # Launch
129
- # -------------------------
130
- if __name__ == "__main__":
131
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image, ImageFilter
5
+ import os
6
+ import time
7
+ import gc
8
+ import gdown
9
+
10
+ from RRDBNet_arch import RRDBNet
11
+
12
+ # -------------------------
13
+ # Download from Google Drive if not present
14
+ # -------------------------
15
+ def ensure_model_downloaded():
16
+ model_path = "models/RRDB_ESRGAN_x4.pth"
17
+ if not os.path.exists(model_path):
18
+ os.makedirs("models", exist_ok=True)
19
+ file_id = "1P3Hbr51ZNsbNJIiWxrsHgl-D3I9n5ItN"
20
+ gdown.download(f"https://drive.google.com/uc?id={file_id}", model_path, quiet=False)
21
+
22
+ # -------------------------
23
+ # Load ESRGAN Model
24
+ # -------------------------
25
+ @torch.no_grad()
26
+ def load_model():
27
+ ensure_model_downloaded()
28
+ model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
29
+ model_path = os.path.join("models", "RRDB_ESRGAN_x4.pth")
30
+ model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
31
+ model.eval()
32
+ return model
33
+
34
+ model = load_model()
35
+
36
+ # -------------------------
37
+ # Utility Functions
38
+ # -------------------------
39
+ def preprocess(img_pil):
40
+ transform = transforms.Compose([
41
+ transforms.ToTensor(),
42
+ transforms.Normalize((0.5,), (0.5,))
43
+ ])
44
+ return transform(img_pil).unsqueeze(0)
45
+
46
+ def postprocess(tensor):
47
+ tensor = tensor.squeeze().detach().cpu()
48
+ tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
49
+ return transforms.ToPILImage()(tensor)
50
+
51
+ def fuse_images(img1, img2):
52
+ img1 = img1.resize((384, 384), Image.LANCZOS)
53
+ img2 = img2.resize((384, 384), Image.LANCZOS)
54
+ return Image.blend(img1, img2, alpha=0.5)
55
+
56
+ def sharpen_image(image: Image.Image) -> Image.Image:
57
+ return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1))
58
+
59
+ def upscale_to_resolution(img: Image.Image, resolution: str = "4K") -> Image.Image:
60
+ target_size = (3840, 2160) if resolution == "4K" else (7680, 4320)
61
+ return img.resize(target_size, Image.LANCZOS)
62
+
63
+ # -------------------------
64
+ # Inference Pipeline
65
+ # -------------------------
66
+ def esrgan_pipeline(img1, img2, resolution):
67
+ if not img1 or not img2:
68
+ return None, None, "Please upload two valid images."
69
+
70
+ img1 = img1.convert("RGB")
71
+ img2 = img2.convert("RGB")
72
+ fused_img = fuse_images(img1, img2)
73
+
74
+ start = time.time()
75
+
76
+ with torch.no_grad():
77
+ input_tensor = preprocess(fused_img)
78
+ sr1 = model(input_tensor)
79
+ sr2 = model(sr1)
80
+ sr3 = model(sr2)
81
+
82
+ base_output = postprocess(sr3)
83
+
84
+ gc.collect()
85
+ torch.cuda.empty_cache()
86
+
87
+ upscaled_img = upscale_to_resolution(base_output, resolution)
88
+ final_img = sharpen_image(upscaled_img)
89
+
90
+ elapsed = time.time() - start
91
+ sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item()
92
+ msg = f"βœ… Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}"
93
+
94
+ return base_output, final_img, msg
95
+
96
+ # -------------------------
97
+ # Gradio UI
98
+ # -------------------------
99
+ with gr.Blocks(title="Triple-Pass ESRGAN Super-Resolution") as demo:
100
+ gr.Markdown("## 🧠 Triple-Pass ESRGAN Ultra-HD Upscaler")
101
+ gr.Markdown("Upload **two low-res images** β†’ ESRGAN (3 passes) β†’ Final **4K/8K** enhanced image with sharpening.")
102
+
103
+ with gr.Row():
104
+ with gr.Column():
105
+ img_input1 = gr.Image(type="pil", label="Low-Res Image 1")
106
+ img_input2 = gr.Image(type="pil", label="Low-Res Image 2")
107
+ resolution_choice = gr.Radio(["4K", "8K"], value="4K", label="Select Output Resolution")
108
+ run_button = gr.Button("πŸš€ Run ESRGAN")
109
+
110
+ with gr.Column():
111
+ output_esrgan = gr.Image(label="🧠 ESRGAN 3x Output")
112
+ output_final = gr.Image(label="🏞️ Final Enhanced Output")
113
+ result_text = gr.Textbox(label="πŸ“Š Output Log")
114
+
115
+ gr.Markdown("---")
116
+ gr.Markdown(
117
+ "<div style='text-align: center; font-size: 16px;'>"
118
+ "Made with ❀️ by <b>CodeKarma</b> as a part of <b>Bharatiya Antariksh Hackathon 2025</b>"
119
+ "</div>",
120
+ unsafe_allow_html=True
121
+ )
122
+
123
+ run_button.click(fn=esrgan_pipeline,
124
+ inputs=[img_input1, img_input2, resolution_choice],
125
+ outputs=[output_esrgan, output_final, result_text])
126
+
127
+ # -------------------------
128
+ # Launch
129
+ # -------------------------
130
+ if __name__ == "__main__":
131
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ Pillow>=9.5.0
5
+ numpy>=1.24.0
6
+ gdown