KuunVo commited on
Commit
b77fd1a
·
1 Parent(s): ea2173c

First Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/*
2
+ temp/*
3
+ models/__pycache__/*
4
+ ui/__pycache__/*
main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from ui import upscaler_ui, enhancer_ui
3
+
4
+ st.set_page_config(layout="wide")
5
+
6
+ # st.title("Image Upscaler and Enhancer")
7
+ tab1, tab2 = st.tabs(["Upscaler", "Enhancer"])
8
+
9
+ with tab1:
10
+ upscaler_ui.ui()
11
+
12
+ with tab2:
13
+ enhancer_ui.ui()
14
+
models/base_model.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class AttentionBlock(nn.Module):
6
+ def __init__(self, in_ch):
7
+ super().__init__()
8
+ self.group_norm = nn.GroupNorm(32, in_ch)
9
+ self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
10
+ self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
11
+ self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
12
+ self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
13
+
14
+ def forward(self, x):
15
+ B, C, H, W = x.shape
16
+ h = self.group_norm(x)
17
+ q = self.proj_q(h)
18
+ k = self.proj_k(h)
19
+ v = self.proj_v(h)
20
+
21
+ q = q.permute(0, 2, 3, 1).view(B, H * W, C)
22
+ k = k.view(B, C, H * W)
23
+ w = torch.bmm(q, k) * (int(C) ** (-0.5))
24
+
25
+ w = F.softmax(w, dim=-1)
26
+
27
+ v = v.permute(0, 2, 3, 1).view(B, H * W, C)
28
+ h = torch.bmm(w, v)
29
+ assert list(h.shape) == [B, H * W, C]
30
+ h = h.view(B, H, W, C).permute(0, 3, 1, 2)
31
+ h = self.proj(h)
32
+
33
+ return x + h
34
+
35
+
36
+ class ResidualBlock(nn.Module):
37
+ def __init__(self,
38
+ in_channels: int,
39
+ out_channels: int,
40
+ dropout: float,
41
+ n_groups: int = 32,
42
+ has_attn: bool = False):
43
+ super().__init__()
44
+
45
+ self.norm1 = nn.GroupNorm(n_groups, in_channels)
46
+ self.act1 = nn.SiLU()
47
+ self.conv1 = nn.Conv2d(in_channels, out_channels,
48
+ kernel_size=(3, 3), padding=(1, 1))
49
+
50
+ self.norm2 = nn.GroupNorm(n_groups, out_channels)
51
+ self.act2 = nn.SiLU()
52
+ self.conv2 = nn.Conv2d(out_channels, out_channels,
53
+ kernel_size=(3, 3), padding=(1, 1))
54
+
55
+ if in_channels != out_channels:
56
+ self.shortcut = nn.Conv2d(
57
+ in_channels, out_channels, kernel_size=(1, 1))
58
+ else:
59
+ self.shortcut = nn.Identity()
60
+
61
+ if has_attn:
62
+ self.attn = AttentionBlock(out_channels)
63
+ else:
64
+ self.attn = nn.Identity()
65
+
66
+ self.dropout = nn.Dropout(dropout)
67
+
68
+ def forward(self, x: torch.Tensor):
69
+ h = self.conv1(self.act1(self.norm1(x)))
70
+ h = self.conv2(self.dropout(self.act2(self.norm2(h))))
71
+ return self.attn(h + self.shortcut(x))
72
+
73
+
74
+ class DownBlock(nn.Module):
75
+ def __init__(self, in_channels: int, out_channels: int, has_attn: bool, dropout: int):
76
+ super().__init__()
77
+ self.res = ResidualBlock(
78
+ in_channels, out_channels, dropout=dropout, has_attn=has_attn)
79
+
80
+ def forward(self, x: torch.Tensor):
81
+ return self.res(x)
82
+
83
+
84
+ class UpBlock(nn.Module):
85
+ def __init__(self, in_channels: int, out_channels: int, has_attn: bool, dropout: int):
86
+ super().__init__()
87
+ self.res = ResidualBlock(
88
+ in_channels, out_channels, dropout=dropout, has_attn=has_attn)
89
+
90
+ def forward(self, x: torch.Tensor):
91
+ return self.res(x)
92
+
93
+
94
+ class MiddleBlock(nn.Module):
95
+ def __init__(self, n_channels: int, dropout: int):
96
+ super().__init__()
97
+ self.res1 = ResidualBlock(
98
+ n_channels, n_channels, dropout=dropout, has_attn=True)
99
+ self.res2 = ResidualBlock(n_channels, n_channels, dropout=dropout)
100
+
101
+ def forward(self, x: torch.Tensor):
102
+ x = self.res1(x)
103
+ x = self.res2(x)
104
+ return x
105
+
106
+
107
+ class Downsample(nn.Module):
108
+ def __init__(self, n_channels):
109
+ super().__init__()
110
+ self.conv = nn.Conv2d(n_channels, n_channels,
111
+ kernel_size=3, stride=2, padding=1)
112
+
113
+ def forward(self, x: torch.Tensor):
114
+ return self.conv(x)
115
+
116
+
117
+ class Upsample(nn.Module):
118
+ def __init__(self, n_channels):
119
+ super().__init__()
120
+ self.convT = nn.ConvTranspose2d(
121
+ n_channels, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
122
+ self.conv = nn.Conv2d(n_channels, n_channels,
123
+ kernel_size=3, stride=1, padding=1)
124
+
125
+ def forward(self, x: torch.Tensor):
126
+ # Bx, Cx, Hx, Wx = x.size()
127
+ # x = F.interpolate(x, size=(2*Hx, 2*Wx), mode='bicubic', align_corners=False)
128
+ return self.conv(self.convT(x))
129
+
130
+
131
+ class MeanShift(nn.Conv2d):
132
+ def __init__(
133
+ self, rgb_range,
134
+ rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
135
+
136
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
137
+ std = torch.Tensor(rgb_std)
138
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
139
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
140
+ for p in self.parameters():
141
+ p.requires_grad = False
142
+
143
+
144
+ class UNET(nn.Module):
145
+ def __init__(self,
146
+ in_channels: int = 3,
147
+ out_channels: int = 3,
148
+ n_features: int = 64,
149
+ dropout: int = 0.1,
150
+ block_out_channels=[64, 128, 128, 256],
151
+ layers_per_block=4,
152
+ is_attn_layers=(False, False, True, False),
153
+ ):
154
+ super().__init__()
155
+
156
+ self.in_channels = in_channels
157
+ self.out_channels = out_channels
158
+ self.n_features = n_features
159
+ self.dropout = dropout
160
+ self.block_out_channels = block_out_channels
161
+ self.layers_per_block = layers_per_block
162
+ self.is_attn_layers = is_attn_layers
163
+
164
+ self.sub_mean = MeanShift(255)
165
+ self.add_mean = MeanShift(255, sign=1)
166
+
167
+ self.shallow_feature_extraction = nn.Conv2d(
168
+ in_channels, n_features, kernel_size=3, padding=1)
169
+ self.image_rescontruction = nn.Conv2d(
170
+ n_features, in_channels, kernel_size=3, padding=1)
171
+
172
+ self.left_model = self.left_unet()
173
+ self.middle_model = MiddleBlock(
174
+ block_out_channels[-1], dropout=self.dropout)
175
+ self.right_model = self.right_unet()
176
+
177
+ def left_unet(self):
178
+ left_model = []
179
+
180
+ in_channel = out_channel = self.n_features
181
+ for i in range(len(self.block_out_channels)):
182
+ out_channel = self.block_out_channels[i]
183
+
184
+ down_block = [DownBlock(in_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[i])] \
185
+ + [DownBlock(out_channel, out_channel, dropout=self.dropout,
186
+ has_attn=self.is_attn_layers[i])] * (self.layers_per_block - 1)
187
+ in_channel = out_channel
188
+ left_model.append(nn.Sequential(*down_block))
189
+ if i < len(self.block_out_channels):
190
+ left_model.append(Downsample(out_channel))
191
+
192
+ return nn.ModuleList(left_model)
193
+
194
+ def right_unet(self):
195
+ right_unet = []
196
+
197
+ in_channel = out_channel = self.block_out_channels[-1]
198
+ for i in reversed(range(len(self.block_out_channels))):
199
+
200
+ out_channel = self.block_out_channels[i]
201
+
202
+ up_block = [UpBlock(in_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[i - 1])] \
203
+ + [UpBlock(out_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[i - 1])
204
+ ] * (self.layers_per_block - 1)
205
+
206
+ in_channel = out_channel * 2
207
+ right_unet.append(nn.Sequential(*up_block))
208
+ right_unet.append(Upsample(out_channel))
209
+
210
+ in_channel, out_channel = self.block_out_channels[0] * \
211
+ 2, self.n_features
212
+ up_block = [UpBlock(in_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[0])] \
213
+ + [UpBlock(out_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[0])
214
+ ] * (self.layers_per_block - 1)
215
+ right_unet.append(nn.Sequential(*up_block))
216
+ return nn.ModuleList(right_unet)
217
+
218
+ def forward(self, x):
219
+ x = x * 255
220
+ x = self.sub_mean(x)
221
+
222
+ feature_maps = self.shallow_feature_extraction(x)
223
+ feature_x = [feature_maps]
224
+ # print(feature_maps.shape)
225
+ feature_block = feature_maps
226
+ for block in self.left_model:
227
+ feature_block = block(feature_block)
228
+ if not isinstance(block, Downsample):
229
+ # print(feature_block.shape)
230
+ feature_x.append(feature_block)
231
+
232
+ bottleneck = self.middle_model(feature_block)
233
+
234
+ feature_x.reverse()
235
+ # print('Middle::: ', feature_maps.shape)
236
+
237
+ recover = bottleneck
238
+ d = 0
239
+ for block in self.right_model:
240
+ if isinstance(block, Upsample):
241
+ # print('UP-CAT::: ', recover.shape)
242
+ recover = block(recover)
243
+ # print('UP-CAT-END::: ', recover.shape, feature_x[d].shape)
244
+ recover = torch.cat([recover, feature_x[d]], 1)
245
+ # print('UP-CAT-END::: ', recover.shape, feature_x[d].shape)
246
+ d += 1
247
+ else:
248
+ recover = block(recover)
249
+ # print('UP-RES::: ', recover.shape)
250
+
251
+ recover = self.image_rescontruction(recover)
252
+ recover = self.add_mean(recover) / 255
253
+ return recover
pretrained/SRUNET_scale_x2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5868beb1c314915e26a20aa21d6a3b4583c1bd30f9a761998e5bce1deea67f40
3
+ size 56710930
pretrained/SRUNET_scale_x234.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ced8f065681ddf24272fbfce70ae55b0c3c3afe601bdeaf6012ed01f4ddb907
3
+ size 56711450
pretrained/SRUNET_scale_x3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cb821b0b33d54dc092b18b7a03fbcf889213c02230a70459bd6d2b2d5877acb
3
+ size 56710930
pretrained/SRUNET_scale_x4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee39dc4142754c6f883e417cda5e68c0a56de0bfe81d20619496bbdd44fc6ae
3
+ size 56710930
requirement.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
ui/enhancer_ui.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def ui():
2
+ pass
ui/upscaler_ui.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ from streamlit_image_comparison import image_comparison
6
+ from utils import upscale_image
7
+
8
+
9
+ def ui():
10
+ image = None
11
+ input_text = None
12
+ uploaded_file = None
13
+
14
+ input_area = st.columns([2, 1, 1])
15
+
16
+ with input_area[0]:
17
+ option = st.selectbox(
18
+ "How do you want to provide the image?",
19
+ ("Fetch from URL", "Upload from local machine")
20
+ )
21
+
22
+ with input_area[2]:
23
+ option_scale = st.selectbox(
24
+ "Which factors do you want to upscale?",
25
+ (2, 3, 4)
26
+ )
27
+
28
+ with input_area[1]:
29
+ # , 'SRUNET_interpolation', 'SRUNET_x234_interpolation'
30
+ option_model = st.selectbox(
31
+ "Which model do you want to use?",
32
+ ('SRUNET_x2', 'SRUNET_x3', 'SRUNET_x4', 'SRUNET_x234')
33
+ )
34
+
35
+ picture_url_area = st.columns([2, 2], vertical_alignment="top")
36
+ with picture_url_area[0]:
37
+ if option == "Upload from local machine":
38
+ uploaded_file = st.file_uploader(
39
+ "Choose an image...", type=["jpg", "jpeg", "png"])
40
+ elif option == "Fetch from URL":
41
+ input_text = st.text_input("Enter the image URL")
42
+
43
+ if st.button("Submit"):
44
+ if option == "Upload from local machine" and uploaded_file is not None:
45
+ try:
46
+ image = Image.open(uploaded_file)
47
+ # st.image(image, caption="Uploaded Image", use_column_width=True)
48
+ except Exception as e:
49
+ st.error(f"Error opening image: {e}")
50
+ elif option == "Fetch from URL" and input_text:
51
+ try:
52
+ response = requests.get(input_text)
53
+ response.raise_for_status()
54
+ image = Image.open(BytesIO(response.content))
55
+ # st.image(image, caption="Image from URL", use_column_width=True)
56
+ except requests.exceptions.RequestException as e:
57
+ st.error(f"Error fetching image: {e}")
58
+
59
+ if image:
60
+ width, height = image.size
61
+ if width * int(option_scale) > 1000 or height * int(option_scale) > 1000:
62
+ st.error(
63
+ "Unable to upscale. The size of upscaled image should be less than 1000x1000")
64
+ image = None
65
+ # pass
66
+
67
+ if image:
68
+ st.header('Results')
69
+ # image_resize = image.resize((1024, 1024))
70
+ width, height = image.size
71
+ picture_url_area[1].text(
72
+ f"Image size: {width}x{height} --> {width*int(option_scale)}x{height*int(option_scale)}")
73
+ picture_url_area[1].image(
74
+ image, caption="Original", use_column_width=True)
75
+
76
+ img_1 = image.resize(
77
+ (width*int(option_scale), height*int(option_scale)), Image.BICUBIC)
78
+ img_2 = upscale_image(image, option_model, int(option_scale))
79
+
80
+ image_comparison(
81
+ img1=img_1,
82
+ img2=img_2,
83
+ )
utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms.v2 as transforms
5
+ from models.base_model import UNET
6
+
7
+ def find_padding(img, depth=2**4):
8
+ B, C, H, W = img.shape
9
+
10
+ h_pad = (depth - H % depth) % depth
11
+ w_pad = (depth - W % depth) % depth
12
+ return h_pad, w_pad
13
+
14
+ def get_pretrained_path(model_name):
15
+ # 'SRUNET_x2', 'SRUNET_x3', 'SRUNET_x4', 'SRUNET_x234', 'SRUNET_interpolation', 'SRUNET_x234_interpolation'
16
+
17
+ current_path = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
18
+ if model_name == 'SRUNET_x2':
19
+ return current_path + '/pretrained/SRUNET_scale_x2.pt'
20
+ elif model_name == 'SRUNET_x3':
21
+ return current_path + '/pretrained/SRUNET_scale_x3.pt'
22
+ elif model_name == 'SRUNET_x4':
23
+ return current_path + '/pretrained/SRUNET_scale_x4.pt'
24
+ elif model_name == 'SRUNET_x234':
25
+ return current_path + '/pretrained/SRUNET_scale_x234.pt'
26
+ # elif model_name == 'SRUNET_interpolation':
27
+ # return current_path + '/pretrained/SRUNET_x3.pt'
28
+ # elif model_name == 'SRUNET_x234_interpolation':
29
+ # return current_path + '/pretrained/SRUNET_x3.pt'
30
+ else:
31
+ raise Exception('Model not found')
32
+
33
+
34
+ def upscale_image(img, model_name, scale_factor):
35
+ # get img width height
36
+ width, height = img.size
37
+ img_mode = img.mode
38
+ if img.mode != "RGB":
39
+ img = img.convert("RGB")
40
+
41
+ transform = transforms.Compose([
42
+ transforms.Resize((height * scale_factor, width * scale_factor),
43
+ interpolation=transforms.InterpolationMode.BICUBIC),
44
+ transforms.ToImage(),
45
+ transforms.ToDtype(torch.float32, scale=True),
46
+ ])
47
+
48
+ #Load Model
49
+ checkpoint = torch.load(get_pretrained_path(
50
+ model_name), map_location=torch.device('cpu'))
51
+ model = UNET()
52
+ model.load_state_dict(checkpoint['best_model_state_dict'])
53
+ model.eval()
54
+
55
+ data = transform(img).clamp(0, 1).unsqueeze(0)
56
+ # print(data.shape, img.mode)
57
+ # return img
58
+ h_pad, w_pad = find_padding(data)
59
+ data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect')
60
+
61
+
62
+ with torch.no_grad():
63
+ img_scale_pred = model(data).clamp(0, 1)
64
+ if h_pad > 0 and w_pad > 0:
65
+ img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad]
66
+ elif h_pad > 0:
67
+ img_scale_pred = img_scale_pred[..., :-h_pad, :]
68
+ elif w_pad > 0:
69
+ img_scale_pred = img_scale_pred[..., :, :-w_pad]
70
+ else:
71
+ img_scale_pred = img_scale_pred
72
+
73
+ img_scale_pred = img_scale_pred.squeeze(0)
74
+ return transforms.ToPILImage()(img_scale_pred).convert(img_mode)
75
+