dev1461 commited on
Commit
3dbf5d4
·
verified ·
1 Parent(s): 7958db9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ # ---------------------------
9
+ # MODEL ARCHITECTURE
10
+ # ---------------------------
11
+
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, channels):
14
+ super().__init__()
15
+ self.block = nn.Sequential(
16
+ nn.Conv2d(channels, channels, 3, 1, 1),
17
+ nn.ReLU(),
18
+ nn.Conv2d(channels, channels, 3, 1, 1)
19
+ )
20
+
21
+ def forward(self, x):
22
+ return x + self.block(x)
23
+
24
+
25
+ class Generator(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ self.entry = nn.Conv2d(3, 64, 3, 1, 1)
30
+
31
+ self.res_blocks = nn.Sequential(
32
+ ResidualBlock(64),
33
+ ResidualBlock(64),
34
+ ResidualBlock(64)
35
+ )
36
+
37
+ self.exit = nn.Sequential(
38
+ nn.Conv2d(64, 3, 3, 1, 1),
39
+ nn.Sigmoid()
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.entry(x)
44
+ x = self.res_blocks(x)
45
+ return self.exit(x)
46
+
47
+ # ---------------------------
48
+ # LOAD MODEL
49
+ # ---------------------------
50
+
51
+ device = torch.device("cpu")
52
+
53
+ model = Generator().to(device)
54
+ checkpoint = torch.load("final_sr_model_v3.pth", map_location=device)
55
+ model.load_state_dict(checkpoint['generator'])
56
+ model.eval()
57
+
58
+ # ---------------------------
59
+ # TRANSFORM
60
+ # ---------------------------
61
+
62
+ transform = transforms.Compose([
63
+ transforms.Resize((128,128)),
64
+ transforms.ToTensor()
65
+ ])
66
+
67
+ # ---------------------------
68
+ # INFERENCE FUNCTION
69
+ # ---------------------------
70
+
71
+ def enhance_image(input_image):
72
+ img = input_image.convert("RGB")
73
+
74
+ input_tensor = transform(img).unsqueeze(0).to(device)
75
+
76
+ with torch.no_grad():
77
+ output = model(input_tensor)
78
+
79
+ output_img = output.squeeze().permute(1,2,0).cpu().numpy()
80
+ output_img = (output_img * 255).astype(np.uint8)
81
+
82
+ return output_img
83
+
84
+
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("# 🔍 Image Super Resolution")
87
+
88
+ input_img = gr.Image(type="pil", label="Upload Image")
89
+ output_img = gr.Image(type="numpy", label="Enhanced Image")
90
+
91
+ btn = gr.Button("Enhance Image")
92
+
93
+ btn.click(fn=enhance_image, inputs=input_img, outputs=output_img)
94
+
95
+ gr.DownloadButton(label="Download Enhanced Image", data=output_img)
96
+
97
+ demo.launch()
98
+
99
+ # ---------------------------
100
+ # GRADIO UI
101
+ # ---------------------------
102
+
103
+ interface = gr.Interface(
104
+ fn=enhance_image,
105
+ inputs=gr.Image(type="pil", label="Upload Image"),
106
+ outputs=gr.Image(type="numpy", label="Enhanced Image"),
107
+ title="🔍 Super Resolution App",
108
+ description="Upload a low-quality image and enhance it using deep learning",
109
+ allow_flagging="never"
110
+ )
111
+
112
+ interface.launch()