Colorization / src /streamlit_app.py
Fiixq's picture
Update src/streamlit_app.py
10661c4 verified
import math
import numpy as np
from PIL import Image
import streamlit as st
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import os
def rgb2lab2(r0, g0, b0):
r = r0 / 255
g = g0 / 255
b = b0 / 255
y = 0.299 * r + 0.587 * g + 0.114 * b
x = 0.449 * r + 0.353 * g + 0.198 * b
z = 0.012 * r + 0.089 * g + 0.899 * b
l = y
a = (x - y) / 0.234
b = (y - z) / 0.785
return l, a, b
def lab22rgb(l, a, b):
a11 = 0.299
a12 = 0.587
a13 = 0.114
a21 = (0.15 / 0.234)
a22 = (-0.234 / 0.234)
a23 = (0.084 / 0.234)
a31 = (0.287 / 0.785)
a32 = (0.498 / 0.785)
a33 = (-0.785 / 0.785)
aa = np.array([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
c0 = np.zeros((l.shape[0], 3))
c0[:, 0] = l[:, 0]
c0[:, 1] = a[:, 0]
c0[:, 2] = b[:, 0]
c = np.transpose(c0)
x = np.linalg.inv(aa).dot(c)
x1_d = np.reshape(x, (x.shape[0] * x.shape[1], 1))
p0 = np.where(x1_d < 0)
x1_d[p0[0]] = 0
p1 = np.where(x1_d > 1)
x1_d[p1[0]] = 1
xr = np.reshape(x1_d, (x.shape[0], x.shape[1]))
Rr = xr[0][:]
Gr = xr[1][:]
Br = xr[2][:]
R = np.uint8(np.round(Rr * 255))
G = np.uint8(np.round(Gr * 255))
B = np.uint8(np.round(Br * 255))
return R, G, B
def psnr(img1, img2):
mse = np.mean((img1.astype("float") - img2.astype("float")) ** 2)
if mse == 0:
return 100
PIXEL_MAX = 255.0
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
def mse(imageA, imageB, bands):
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
err /= float(imageA.shape[0] * imageA.shape[1] * bands)
return err
def mae(imageA, imageB, bands):
err = np.sum(np.abs((imageA.astype("float") - imageB.astype("float"))))
err /= float(imageA.shape[0] * imageA.shape[1] * bands)
return err
def rmse(imageA, imageB, bands):
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
err /= float(imageA.shape[0] * imageA.shape[1] * bands)
err = np.sqrt(err)
return err
class DoubleConv(nn.Module):
"""Double Convolution Block"""
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class TripleConv(nn.Module):
"""Triple Convolution Block"""
def __init__(self, in_channels, out_channels):
super(TripleConv, self).__init__()
self.triple_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.triple_conv(x)
class UNet1(nn.Module):
def __init__(self, in_channels=1, out_channels=2):
super(UNet1, self).__init__()
# Encoder
self.conv1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = TripleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = TripleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = TripleConv(512, 512)
self.pool5 = nn.MaxPool2d(2)
# Bottleneck
self.conv55 = TripleConv(512, 512)
# Decoder
self.up66 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.conv66 = DoubleConv(1024, 512) # 512 + 512 from skip connection
self.up6 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.conv6 = DoubleConv(1024, 512) # 512 + 512 from skip connection
self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv7 = DoubleConv(512, 256) # 256 + 256 from skip connection
self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv8 = DoubleConv(256, 128) # 128 + 128 from skip connection
self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv9 = DoubleConv(128, 64) # 64 + 64 from skip connection
# Multi-scale feature fusion
self.up_f02 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up_f12 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# Final layers
self.conv11 = nn.Conv2d(384, 128, kernel_size=3, padding=1) # 64+64+128+128
self.relu11 = nn.ReLU(inplace=True)
self.conv12 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.relu12 = nn.ReLU(inplace=True)
self.conv13 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.relu13 = nn.ReLU(inplace=True)
self.conv14 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
self.tanh = nn.Tanh() # I've changed last activation to tanh because ab channels should be between -1 and 1. And tanh is used for that.
def forward(self, x):
# Encoder
conv1 = self.conv1(x)
x1 = self.pool1(conv1)
conv2 = self.conv2(x1)
x2 = self.pool2(conv2)
conv3 = self.conv3(x2)
x3 = self.pool3(conv3)
conv4 = self.conv4(x3)
x4 = self.pool4(conv4)
conv5 = self.conv5(x4)
x5 = self.pool5(conv5)
# Bottleneck
conv55 = self.conv55(x5)
# Decoder
up66 = self.up66(conv55)
if up66.size()[2:] != conv5.size()[2:]:
up66 = F.interpolate(up66, size=conv5.size()[2:], mode="bilinear", align_corners=True)
merge66 = torch.cat([conv5, up66], dim=1)
conv66 = self.conv66(merge66)
up6 = self.up6(conv66)
if up6.size()[2:] != conv4.size()[2:]:
up6 = F.interpolate(up6, size=conv4.size()[2:], mode="bilinear", align_corners=True)
merge6 = torch.cat([conv4, up6], dim=1)
conv6 = self.conv6(merge6)
up7 = self.up7(conv6)
if up7.size()[2:] != conv3.size()[2:]:
up7 = F.interpolate(up7, size=conv3.size()[2:], mode="bilinear", align_corners=True)
merge7 = torch.cat([conv3, up7], dim=1)
conv7 = self.conv7(merge7)
up8 = self.up8(conv7)
if up8.size()[2:] != conv2.size()[2:]:
up8 = F.interpolate(up8, size=conv2.size()[2:], mode="bilinear", align_corners=True)
merge8 = torch.cat([conv2, up8], dim=1)
conv8 = self.conv8(merge8)
up9 = self.up9(conv8)
if up9.size()[2:] != conv1.size()[2:]:
up9 = F.interpolate(up9, size=conv1.size()[2:], mode="bilinear", align_corners=True)
merge9 = torch.cat([conv1, up9], dim=1)
conv9 = self.conv9(merge9)
# Multi-scale feature fusion
up_f01 = conv1
up_f11 = conv9
up_f02 = self.up_f02(conv2)
up_f12 = self.up_f12(conv8)
merge11 = torch.cat([up_f01, up_f11, up_f02, up_f12], dim=1) # Concatenate multi-scale features
# Final processing
conv11 = self.relu11(self.conv11(merge11))
conv12 = self.relu12(self.conv12(conv11))
conv13 = self.relu13(self.conv13(conv12))
output = self.tanh(self.conv14(conv13))
return output
def load_vgg16_weights(model):
"""Load pretrained VGG16 weights to U-Net encoder"""
vgg16 = models.vgg16(pretrained=True).to(device)
vgg_features = vgg16.features
with torch.no_grad():
rgb_weights = vgg_features[0].weight
gray_weights = rgb_weights.mean(dim=1, keepdim=True)
model.conv1.double_conv[0].weight.data = gray_weights
model.conv1.double_conv[0].bias.data = vgg_features[0].bias.data
model.conv1.double_conv[2].weight.data = vgg_features[2].weight.data
model.conv1.double_conv[2].bias.data = vgg_features[2].bias.data
model.conv2.double_conv[0].weight.data = vgg_features[5].weight.data
model.conv2.double_conv[0].bias.data = vgg_features[5].bias.data
model.conv2.double_conv[2].weight.data = vgg_features[7].weight.data
model.conv2.double_conv[2].bias.data = vgg_features[7].bias.data
model.conv3.triple_conv[0].weight.data = vgg_features[10].weight.data
model.conv3.triple_conv[0].bias.data = vgg_features[10].bias.data
model.conv3.triple_conv[2].weight.data = vgg_features[12].weight.data
model.conv3.triple_conv[2].bias.data = vgg_features[12].bias.data
model.conv3.triple_conv[4].weight.data = vgg_features[14].weight.data
model.conv3.triple_conv[4].bias.data = vgg_features[14].bias.data
model.conv4.triple_conv[0].weight.data = vgg_features[17].weight.data
model.conv4.triple_conv[0].bias.data = vgg_features[17].bias.data
model.conv4.triple_conv[2].weight.data = vgg_features[19].weight.data
model.conv4.triple_conv[2].bias.data = vgg_features[19].bias.data
model.conv4.triple_conv[4].weight.data = vgg_features[21].weight.data
model.conv4.triple_conv[4].bias.data = vgg_features[21].bias.data
model.conv5.triple_conv[0].weight.data = vgg_features[24].weight.data
model.conv5.triple_conv[0].bias.data = vgg_features[24].bias.data
model.conv5.triple_conv[2].weight.data = vgg_features[26].weight.data
model.conv5.triple_conv[2].bias.data = vgg_features[26].bias.data
model.conv5.triple_conv[4].weight.data = vgg_features[28].weight.data
model.conv5.triple_conv[4].bias.data = vgg_features[28].bias.data
def load_model_for_inference(model_path, device):
model = UNet1(in_channels=1, out_channels=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
def inference(model, l_channel):
model.eval()
with torch.no_grad():
if len(l_channel.shape) == 3:
l_channel = l_channel.unsqueeze(0) # Add batch dimension
l_tensor = torch.FloatTensor(l_channel).to(device)
ab_pred = model(l_tensor)
return ab_pred.cpu().numpy()
def prepare_test_image(img, dim=150):
if isinstance(img, Image.Image):
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.resize(img, (dim, dim))
sz0, sz1 = img.shape[:2]
R1 = img[:, :, 2].reshape(-1, 1)
G1 = img[:, :, 1].reshape(-1, 1)
B1 = img[:, :, 0].reshape(-1, 1)
L, A, B = rgb2lab2(R1, G1, B1) # LAB2'ye çevir
L = L.reshape(sz0, sz1, 1)
L_tensor = torch.FloatTensor(L).permute(2, 0, 1)
return L_tensor, A.reshape(sz0, sz1), B.reshape(sz0, sz1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model_path = os.path.join(os.getcwd(), 'Hyper_U_NET_pytorch-MAE-30Epoch.pth')
# model_path = "Hyper_U_NET_pytorch-MAE-30Epoch.pth"
model_path = '/app/src/Hyper_U_NET_pytorch-MAE-30Epoch.pth'
test_model = load_model_for_inference(model_path, device)
st.markdown("<h1 style='text-align: center; color: #4CAF50;'>Image Colorization Demo</h1>", unsafe_allow_html=True)
st.markdown(
"<p style='text-align: center; color: gray;'>Grayscale bir görüntü yükleyin, model sizin için renklendirsin.</p>",
unsafe_allow_html=True)
st.markdown(
"""
<style>
.css-18e3th9 {padding-top: 2rem;}
div.stButton > button:first-child {
color: white;
border-radius: 10px;
height: 3em;
width: 100%;
font-size: 16px;
border: none;
transition: 0.3s;
}
div.stButton > button:hover {
background-color: #45a049;
color: white;
}
div.stButton > button:active {
background-color: #3e8e41 !important;
color: white !important;
}
div.stButton > button:focus {
box-shadow: none !important;
outline: none !important;
color: white !important;
}
</style>
""",
unsafe_allow_html=True
)
with st.container():
st.markdown("#### 📂 Grayscale Görüntü Yükle")
uploaded_file = st.file_uploader("Yüklemek için sürükleyip bırakın", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
img = Image.open(uploaded_file).convert("RGB")
l_tensor, A_true, B_true = prepare_test_image(img, dim=150)
ab_pred = inference(test_model, l_tensor)
ab_pred = ab_pred.squeeze(0)
A_pred, B_pred = ab_pred[0], ab_pred[1]
sz0, sz1 = A_pred.shape
L = l_tensor.squeeze().numpy().reshape(-1, 1)
A = A_pred.reshape(-1, 1)
B = B_pred.reshape(-1, 1)
R, G, B = lab22rgb(L, A, B)
R = R.reshape(sz0, sz1)
G = G.reshape(sz0, sz1)
B = B.reshape(sz0, sz1)
rgb_pred = cv2.merge([B, G, R])
new_image = cv2.cvtColor(rgb_pred, cv2.COLOR_BGR2RGB)
new_image2 = cv2.resize(new_image, (img.width, img.height), interpolation=cv2.INTER_LANCZOS4)
if st.button("🎨 Renklendir"):
with st.spinner("Model çalışıyor, lütfen bekleyin..."):
col1, col2 = st.columns(2)
with col1:
st.markdown("**Girdi (Grayscale)**")
st.image(img)
with col2:
st.markdown("**Model Çıkışı (Renkli)**")
st.image(np.array(new_image2))
st.success("Tamamlandı!")