Fiixq commited on
Commit
d1242b0
·
verified ·
1 Parent(s): e6369fe

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +411 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,414 @@
1
- import altair as alt
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
  import numpy as np
3
+ from PIL import Image
4
  import streamlit as st
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.models as models
10
 
11
+
12
+ def rgb2lab2(r0, g0, b0):
13
+ r = r0 / 255
14
+ g = g0 / 255
15
+ b = b0 / 255
16
+
17
+ y = 0.299 * r + 0.587 * g + 0.114 * b
18
+ x = 0.449 * r + 0.353 * g + 0.198 * b
19
+ z = 0.012 * r + 0.089 * g + 0.899 * b
20
+
21
+ l = y
22
+ a = (x - y) / 0.234
23
+ b = (y - z) / 0.785
24
+
25
+ return l, a, b
26
+
27
+
28
+ def lab22rgb(l, a, b):
29
+ a11 = 0.299
30
+ a12 = 0.587
31
+ a13 = 0.114
32
+ a21 = (0.15 / 0.234)
33
+ a22 = (-0.234 / 0.234)
34
+ a23 = (0.084 / 0.234)
35
+ a31 = (0.287 / 0.785)
36
+ a32 = (0.498 / 0.785)
37
+ a33 = (-0.785 / 0.785)
38
+
39
+ aa = np.array([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
40
+ c0 = np.zeros((l.shape[0], 3))
41
+ c0[:, 0] = l[:, 0]
42
+ c0[:, 1] = a[:, 0]
43
+ c0[:, 2] = b[:, 0]
44
+ c = np.transpose(c0)
45
+
46
+ x = np.linalg.inv(aa).dot(c)
47
+ x1_d = np.reshape(x, (x.shape[0] * x.shape[1], 1))
48
+ p0 = np.where(x1_d < 0)
49
+ x1_d[p0[0]] = 0
50
+ p1 = np.where(x1_d > 1)
51
+ x1_d[p1[0]] = 1
52
+ xr = np.reshape(x1_d, (x.shape[0], x.shape[1]))
53
+
54
+ Rr = xr[0][:]
55
+ Gr = xr[1][:]
56
+ Br = xr[2][:]
57
+
58
+ R = np.uint8(np.round(Rr * 255))
59
+ G = np.uint8(np.round(Gr * 255))
60
+ B = np.uint8(np.round(Br * 255))
61
+ return R, G, B
62
+
63
+
64
+ def psnr(img1, img2):
65
+ mse = np.mean((img1.astype("float") - img2.astype("float")) ** 2)
66
+ if mse == 0:
67
+ return 100
68
+ PIXEL_MAX = 255.0
69
+ return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
70
+
71
+
72
+ def mse(imageA, imageB, bands):
73
+ err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
74
+ err /= float(imageA.shape[0] * imageA.shape[1] * bands)
75
+ return err
76
+
77
+
78
+ def mae(imageA, imageB, bands):
79
+ err = np.sum(np.abs((imageA.astype("float") - imageB.astype("float"))))
80
+ err /= float(imageA.shape[0] * imageA.shape[1] * bands)
81
+ return err
82
+
83
+
84
+ def rmse(imageA, imageB, bands):
85
+ err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
86
+ err /= float(imageA.shape[0] * imageA.shape[1] * bands)
87
+ err = np.sqrt(err)
88
+ return err
89
+
90
+
91
+ class DoubleConv(nn.Module):
92
+ """Double Convolution Block"""
93
+
94
+ def __init__(self, in_channels, out_channels):
95
+ super(DoubleConv, self).__init__()
96
+ self.double_conv = nn.Sequential(
97
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
98
+ nn.ReLU(inplace=True),
99
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
100
+ nn.ReLU(inplace=True)
101
+ )
102
+
103
+ def forward(self, x):
104
+ return self.double_conv(x)
105
+
106
+
107
+ class TripleConv(nn.Module):
108
+ """Triple Convolution Block"""
109
+
110
+ def __init__(self, in_channels, out_channels):
111
+ super(TripleConv, self).__init__()
112
+ self.triple_conv = nn.Sequential(
113
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
114
+ nn.ReLU(inplace=True),
115
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
116
+ nn.ReLU(inplace=True),
117
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
118
+ nn.ReLU(inplace=True)
119
+ )
120
+
121
+ def forward(self, x):
122
+ return self.triple_conv(x)
123
+
124
+
125
+ class UNet1(nn.Module):
126
+ def __init__(self, in_channels=1, out_channels=2):
127
+ super(UNet1, self).__init__()
128
+
129
+ # Encoder
130
+ self.conv1 = DoubleConv(in_channels, 64)
131
+ self.pool1 = nn.MaxPool2d(2)
132
+
133
+ self.conv2 = DoubleConv(64, 128)
134
+ self.pool2 = nn.MaxPool2d(2)
135
+
136
+ self.conv3 = TripleConv(128, 256)
137
+ self.pool3 = nn.MaxPool2d(2)
138
+
139
+ self.conv4 = TripleConv(256, 512)
140
+ self.pool4 = nn.MaxPool2d(2)
141
+
142
+ self.conv5 = TripleConv(512, 512)
143
+ self.pool5 = nn.MaxPool2d(2)
144
+
145
+ # Bottleneck
146
+ self.conv55 = TripleConv(512, 512)
147
+
148
+ # Decoder
149
+ self.up66 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
150
+ self.conv66 = DoubleConv(1024, 512) # 512 + 512 from skip connection
151
+
152
+ self.up6 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
153
+ self.conv6 = DoubleConv(1024, 512) # 512 + 512 from skip connection
154
+
155
+ self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
156
+ self.conv7 = DoubleConv(512, 256) # 256 + 256 from skip connection
157
+
158
+ self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
159
+ self.conv8 = DoubleConv(256, 128) # 128 + 128 from skip connection
160
+
161
+ self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
162
+ self.conv9 = DoubleConv(128, 64) # 64 + 64 from skip connection
163
+
164
+ # Multi-scale feature fusion
165
+ self.up_f02 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
166
+ self.up_f12 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
167
+
168
+ # Final layers
169
+ self.conv11 = nn.Conv2d(384, 128, kernel_size=3, padding=1) # 64+64+128+128
170
+ self.relu11 = nn.ReLU(inplace=True)
171
+
172
+ self.conv12 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
173
+ self.relu12 = nn.ReLU(inplace=True)
174
+
175
+ self.conv13 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
176
+ self.relu13 = nn.ReLU(inplace=True)
177
+
178
+ self.conv14 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
179
+ self.tanh = nn.Tanh() # I've changed last activation to tanh because ab channels should be between -1 and 1. And tanh is used for that.
180
+
181
+ def forward(self, x):
182
+ # Encoder
183
+ conv1 = self.conv1(x)
184
+ x1 = self.pool1(conv1)
185
+
186
+ conv2 = self.conv2(x1)
187
+ x2 = self.pool2(conv2)
188
+
189
+ conv3 = self.conv3(x2)
190
+ x3 = self.pool3(conv3)
191
+
192
+ conv4 = self.conv4(x3)
193
+ x4 = self.pool4(conv4)
194
+
195
+ conv5 = self.conv5(x4)
196
+ x5 = self.pool5(conv5)
197
+
198
+ # Bottleneck
199
+ conv55 = self.conv55(x5)
200
+
201
+ # Decoder
202
+ up66 = self.up66(conv55)
203
+ if up66.size()[2:] != conv5.size()[2:]:
204
+ up66 = F.interpolate(up66, size=conv5.size()[2:], mode="bilinear", align_corners=True)
205
+ merge66 = torch.cat([conv5, up66], dim=1)
206
+ conv66 = self.conv66(merge66)
207
+
208
+ up6 = self.up6(conv66)
209
+ if up6.size()[2:] != conv4.size()[2:]:
210
+ up6 = F.interpolate(up6, size=conv4.size()[2:], mode="bilinear", align_corners=True)
211
+ merge6 = torch.cat([conv4, up6], dim=1)
212
+ conv6 = self.conv6(merge6)
213
+
214
+ up7 = self.up7(conv6)
215
+ if up7.size()[2:] != conv3.size()[2:]:
216
+ up7 = F.interpolate(up7, size=conv3.size()[2:], mode="bilinear", align_corners=True)
217
+ merge7 = torch.cat([conv3, up7], dim=1)
218
+ conv7 = self.conv7(merge7)
219
+
220
+ up8 = self.up8(conv7)
221
+ if up8.size()[2:] != conv2.size()[2:]:
222
+ up8 = F.interpolate(up8, size=conv2.size()[2:], mode="bilinear", align_corners=True)
223
+ merge8 = torch.cat([conv2, up8], dim=1)
224
+ conv8 = self.conv8(merge8)
225
+
226
+ up9 = self.up9(conv8)
227
+ if up9.size()[2:] != conv1.size()[2:]:
228
+ up9 = F.interpolate(up9, size=conv1.size()[2:], mode="bilinear", align_corners=True)
229
+ merge9 = torch.cat([conv1, up9], dim=1)
230
+ conv9 = self.conv9(merge9)
231
+
232
+ # Multi-scale feature fusion
233
+ up_f01 = conv1
234
+ up_f11 = conv9
235
+ up_f02 = self.up_f02(conv2)
236
+ up_f12 = self.up_f12(conv8)
237
+
238
+ merge11 = torch.cat([up_f01, up_f11, up_f02, up_f12], dim=1) # Concatenate multi-scale features
239
+
240
+ # Final processing
241
+ conv11 = self.relu11(self.conv11(merge11))
242
+ conv12 = self.relu12(self.conv12(conv11))
243
+ conv13 = self.relu13(self.conv13(conv12))
244
+ output = self.tanh(self.conv14(conv13))
245
+
246
+ return output
247
+
248
+
249
+ def load_vgg16_weights(model):
250
+ """Load pretrained VGG16 weights to U-Net encoder"""
251
+ vgg16 = models.vgg16(pretrained=True).to(device)
252
+ vgg_features = vgg16.features
253
+
254
+ with torch.no_grad():
255
+ rgb_weights = vgg_features[0].weight
256
+ gray_weights = rgb_weights.mean(dim=1, keepdim=True)
257
+
258
+ model.conv1.double_conv[0].weight.data = gray_weights
259
+ model.conv1.double_conv[0].bias.data = vgg_features[0].bias.data
260
+
261
+ model.conv1.double_conv[2].weight.data = vgg_features[2].weight.data
262
+ model.conv1.double_conv[2].bias.data = vgg_features[2].bias.data
263
+
264
+ model.conv2.double_conv[0].weight.data = vgg_features[5].weight.data
265
+ model.conv2.double_conv[0].bias.data = vgg_features[5].bias.data
266
+ model.conv2.double_conv[2].weight.data = vgg_features[7].weight.data
267
+ model.conv2.double_conv[2].bias.data = vgg_features[7].bias.data
268
+
269
+ model.conv3.triple_conv[0].weight.data = vgg_features[10].weight.data
270
+ model.conv3.triple_conv[0].bias.data = vgg_features[10].bias.data
271
+ model.conv3.triple_conv[2].weight.data = vgg_features[12].weight.data
272
+ model.conv3.triple_conv[2].bias.data = vgg_features[12].bias.data
273
+ model.conv3.triple_conv[4].weight.data = vgg_features[14].weight.data
274
+ model.conv3.triple_conv[4].bias.data = vgg_features[14].bias.data
275
+
276
+ model.conv4.triple_conv[0].weight.data = vgg_features[17].weight.data
277
+ model.conv4.triple_conv[0].bias.data = vgg_features[17].bias.data
278
+ model.conv4.triple_conv[2].weight.data = vgg_features[19].weight.data
279
+ model.conv4.triple_conv[2].bias.data = vgg_features[19].bias.data
280
+ model.conv4.triple_conv[4].weight.data = vgg_features[21].weight.data
281
+ model.conv4.triple_conv[4].bias.data = vgg_features[21].bias.data
282
+
283
+ model.conv5.triple_conv[0].weight.data = vgg_features[24].weight.data
284
+ model.conv5.triple_conv[0].bias.data = vgg_features[24].bias.data
285
+ model.conv5.triple_conv[2].weight.data = vgg_features[26].weight.data
286
+ model.conv5.triple_conv[2].bias.data = vgg_features[26].bias.data
287
+ model.conv5.triple_conv[4].weight.data = vgg_features[28].weight.data
288
+ model.conv5.triple_conv[4].bias.data = vgg_features[28].bias.data
289
+
290
+
291
+ def load_model_for_inference(model_path, device):
292
+ model = UNet1(in_channels=1, out_channels=2).to(device)
293
+ model.load_state_dict(torch.load(model_path, map_location=device))
294
+ model.eval()
295
+ return model
296
+
297
+
298
+ def inference(model, l_channel):
299
+ model.eval()
300
+ with torch.no_grad():
301
+ if len(l_channel.shape) == 3:
302
+ l_channel = l_channel.unsqueeze(0) # Add batch dimension
303
+
304
+ l_tensor = torch.FloatTensor(l_channel).to(device)
305
+ ab_pred = model(l_tensor)
306
+
307
+ return ab_pred.cpu().numpy()
308
+
309
+
310
+ def prepare_test_image(img, dim=150):
311
+ if isinstance(img, Image.Image):
312
+ img = np.array(img)
313
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
314
+
315
+ img = cv2.resize(img, (dim, dim))
316
+
317
+ sz0, sz1 = img.shape[:2]
318
+ R1 = img[:, :, 2].reshape(-1, 1)
319
+ G1 = img[:, :, 1].reshape(-1, 1)
320
+ B1 = img[:, :, 0].reshape(-1, 1)
321
+
322
+ L, A, B = rgb2lab2(R1, G1, B1) # LAB2'ye çevir
323
+ L = L.reshape(sz0, sz1, 1)
324
+
325
+ L_tensor = torch.FloatTensor(L).permute(2, 0, 1)
326
+
327
+ return L_tensor, A.reshape(sz0, sz1), B.reshape(sz0, sz1)
328
+
329
+
330
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
331
+
332
+ model_path = "Hyper_U_NET_pytorch-MAE-30Epoch.pth"
333
+
334
+ test_model = load_model_for_inference(model_path, device)
335
+
336
+ st.markdown("<h1 style='text-align: center; color: #4CAF50;'>Image Colorization Demo</h1>", unsafe_allow_html=True)
337
+ st.markdown(
338
+ "<p style='text-align: center; color: gray;'>Grayscale bir görüntü yükleyin, model sizin için renklendirsin.</p>",
339
+ unsafe_allow_html=True)
340
+
341
+ st.markdown(
342
+ """
343
+ <style>
344
+ .css-18e3th9 {padding-top: 2rem;}
345
+ div.stButton > button:first-child {
346
+
347
+ color: white;
348
+ border-radius: 10px;
349
+ height: 3em;
350
+ width: 100%;
351
+ font-size: 16px;
352
+ border: none;
353
+ transition: 0.3s;
354
+ }
355
+ div.stButton > button:hover {
356
+ background-color: #45a049;
357
+ color: white;
358
+ }
359
+ div.stButton > button:active {
360
+ background-color: #3e8e41 !important;
361
+ color: white !important;
362
+ }
363
+ div.stButton > button:focus {
364
+ box-shadow: none !important;
365
+ outline: none !important;
366
+ color: white !important;
367
+ }
368
+
369
+ </style>
370
+ """,
371
+ unsafe_allow_html=True
372
+ )
373
+
374
+ with st.container():
375
+ st.markdown("#### 📂 Grayscale Görüntü Yükle")
376
+ uploaded_file = st.file_uploader("Yüklemek için sürükleyip bırakın", type=["jpg", "jpeg", "png"])
377
+
378
+ if uploaded_file is not None:
379
+ img = Image.open(uploaded_file).convert("RGB")
380
+
381
+ l_tensor, A_true, B_true = prepare_test_image(img, dim=150)
382
+
383
+ ab_pred = inference(test_model, l_tensor)
384
+ ab_pred = ab_pred.squeeze(0)
385
+ A_pred, B_pred = ab_pred[0], ab_pred[1]
386
+
387
+ sz0, sz1 = A_pred.shape
388
+ L = l_tensor.squeeze().numpy().reshape(-1, 1)
389
+ A = A_pred.reshape(-1, 1)
390
+ B = B_pred.reshape(-1, 1)
391
+
392
+ R, G, B = lab22rgb(L, A, B)
393
+ R = R.reshape(sz0, sz1)
394
+ G = G.reshape(sz0, sz1)
395
+ B = B.reshape(sz0, sz1)
396
+
397
+ rgb_pred = cv2.merge([B, G, R])
398
+
399
+ new_image = cv2.cvtColor(rgb_pred, cv2.COLOR_BGR2RGB)
400
+
401
+ new_image2 = cv2.resize(new_image, (img.width, img.height), interpolation=cv2.INTER_LANCZOS4)
402
+
403
+ if st.button("🎨 Renklendir"):
404
+ with st.spinner("Model çalışıyor, lütfen bekleyin..."):
405
+ col1, col2 = st.columns(2)
406
+ with col1:
407
+ st.markdown("**Girdi (Grayscale)**")
408
+ st.image(img)
409
+
410
+ with col2:
411
+ st.markdown("**Model Çıkışı (Renkli)**")
412
+ st.image(np.array(new_image2))
413
+
414
+ st.success("Tamamlandı!")