Colorization / src /streamlit_app.py
Fiixq's picture
Update src/streamlit_app.py
d1242b0 verified
raw
history blame
13.5 kB
import math
import numpy as np
from PIL import Image
import streamlit as st
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
def rgb2lab2(r0, g0, b0):
r = r0 / 255
g = g0 / 255
b = b0 / 255
y = 0.299 * r + 0.587 * g + 0.114 * b
x = 0.449 * r + 0.353 * g + 0.198 * b
z = 0.012 * r + 0.089 * g + 0.899 * b
l = y
a = (x - y) / 0.234
b = (y - z) / 0.785
return l, a, b
def lab22rgb(l, a, b):
a11 = 0.299
a12 = 0.587
a13 = 0.114
a21 = (0.15 / 0.234)
a22 = (-0.234 / 0.234)
a23 = (0.084 / 0.234)
a31 = (0.287 / 0.785)
a32 = (0.498 / 0.785)
a33 = (-0.785 / 0.785)
aa = np.array([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
c0 = np.zeros((l.shape[0], 3))
c0[:, 0] = l[:, 0]
c0[:, 1] = a[:, 0]
c0[:, 2] = b[:, 0]
c = np.transpose(c0)
x = np.linalg.inv(aa).dot(c)
x1_d = np.reshape(x, (x.shape[0] * x.shape[1], 1))
p0 = np.where(x1_d < 0)
x1_d[p0[0]] = 0
p1 = np.where(x1_d > 1)
x1_d[p1[0]] = 1
xr = np.reshape(x1_d, (x.shape[0], x.shape[1]))
Rr = xr[0][:]
Gr = xr[1][:]
Br = xr[2][:]
R = np.uint8(np.round(Rr * 255))
G = np.uint8(np.round(Gr * 255))
B = np.uint8(np.round(Br * 255))
return R, G, B
def psnr(img1, img2):
mse = np.mean((img1.astype("float") - img2.astype("float")) ** 2)
if mse == 0:
return 100
PIXEL_MAX = 255.0
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
def mse(imageA, imageB, bands):
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
err /= float(imageA.shape[0] * imageA.shape[1] * bands)
return err
def mae(imageA, imageB, bands):
err = np.sum(np.abs((imageA.astype("float") - imageB.astype("float"))))
err /= float(imageA.shape[0] * imageA.shape[1] * bands)
return err
def rmse(imageA, imageB, bands):
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
err /= float(imageA.shape[0] * imageA.shape[1] * bands)
err = np.sqrt(err)
return err
class DoubleConv(nn.Module):
"""Double Convolution Block"""
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class TripleConv(nn.Module):
"""Triple Convolution Block"""
def __init__(self, in_channels, out_channels):
super(TripleConv, self).__init__()
self.triple_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.triple_conv(x)
class UNet1(nn.Module):
def __init__(self, in_channels=1, out_channels=2):
super(UNet1, self).__init__()
# Encoder
self.conv1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = TripleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = TripleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = TripleConv(512, 512)
self.pool5 = nn.MaxPool2d(2)
# Bottleneck
self.conv55 = TripleConv(512, 512)
# Decoder
self.up66 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.conv66 = DoubleConv(1024, 512) # 512 + 512 from skip connection
self.up6 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.conv6 = DoubleConv(1024, 512) # 512 + 512 from skip connection
self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv7 = DoubleConv(512, 256) # 256 + 256 from skip connection
self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv8 = DoubleConv(256, 128) # 128 + 128 from skip connection
self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv9 = DoubleConv(128, 64) # 64 + 64 from skip connection
# Multi-scale feature fusion
self.up_f02 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up_f12 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# Final layers
self.conv11 = nn.Conv2d(384, 128, kernel_size=3, padding=1) # 64+64+128+128
self.relu11 = nn.ReLU(inplace=True)
self.conv12 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.relu12 = nn.ReLU(inplace=True)
self.conv13 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.relu13 = nn.ReLU(inplace=True)
self.conv14 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
self.tanh = nn.Tanh() # I've changed last activation to tanh because ab channels should be between -1 and 1. And tanh is used for that.
def forward(self, x):
# Encoder
conv1 = self.conv1(x)
x1 = self.pool1(conv1)
conv2 = self.conv2(x1)
x2 = self.pool2(conv2)
conv3 = self.conv3(x2)
x3 = self.pool3(conv3)
conv4 = self.conv4(x3)
x4 = self.pool4(conv4)
conv5 = self.conv5(x4)
x5 = self.pool5(conv5)
# Bottleneck
conv55 = self.conv55(x5)
# Decoder
up66 = self.up66(conv55)
if up66.size()[2:] != conv5.size()[2:]:
up66 = F.interpolate(up66, size=conv5.size()[2:], mode="bilinear", align_corners=True)
merge66 = torch.cat([conv5, up66], dim=1)
conv66 = self.conv66(merge66)
up6 = self.up6(conv66)
if up6.size()[2:] != conv4.size()[2:]:
up6 = F.interpolate(up6, size=conv4.size()[2:], mode="bilinear", align_corners=True)
merge6 = torch.cat([conv4, up6], dim=1)
conv6 = self.conv6(merge6)
up7 = self.up7(conv6)
if up7.size()[2:] != conv3.size()[2:]:
up7 = F.interpolate(up7, size=conv3.size()[2:], mode="bilinear", align_corners=True)
merge7 = torch.cat([conv3, up7], dim=1)
conv7 = self.conv7(merge7)
up8 = self.up8(conv7)
if up8.size()[2:] != conv2.size()[2:]:
up8 = F.interpolate(up8, size=conv2.size()[2:], mode="bilinear", align_corners=True)
merge8 = torch.cat([conv2, up8], dim=1)
conv8 = self.conv8(merge8)
up9 = self.up9(conv8)
if up9.size()[2:] != conv1.size()[2:]:
up9 = F.interpolate(up9, size=conv1.size()[2:], mode="bilinear", align_corners=True)
merge9 = torch.cat([conv1, up9], dim=1)
conv9 = self.conv9(merge9)
# Multi-scale feature fusion
up_f01 = conv1
up_f11 = conv9
up_f02 = self.up_f02(conv2)
up_f12 = self.up_f12(conv8)
merge11 = torch.cat([up_f01, up_f11, up_f02, up_f12], dim=1) # Concatenate multi-scale features
# Final processing
conv11 = self.relu11(self.conv11(merge11))
conv12 = self.relu12(self.conv12(conv11))
conv13 = self.relu13(self.conv13(conv12))
output = self.tanh(self.conv14(conv13))
return output
def load_vgg16_weights(model):
"""Load pretrained VGG16 weights to U-Net encoder"""
vgg16 = models.vgg16(pretrained=True).to(device)
vgg_features = vgg16.features
with torch.no_grad():
rgb_weights = vgg_features[0].weight
gray_weights = rgb_weights.mean(dim=1, keepdim=True)
model.conv1.double_conv[0].weight.data = gray_weights
model.conv1.double_conv[0].bias.data = vgg_features[0].bias.data
model.conv1.double_conv[2].weight.data = vgg_features[2].weight.data
model.conv1.double_conv[2].bias.data = vgg_features[2].bias.data
model.conv2.double_conv[0].weight.data = vgg_features[5].weight.data
model.conv2.double_conv[0].bias.data = vgg_features[5].bias.data
model.conv2.double_conv[2].weight.data = vgg_features[7].weight.data
model.conv2.double_conv[2].bias.data = vgg_features[7].bias.data
model.conv3.triple_conv[0].weight.data = vgg_features[10].weight.data
model.conv3.triple_conv[0].bias.data = vgg_features[10].bias.data
model.conv3.triple_conv[2].weight.data = vgg_features[12].weight.data
model.conv3.triple_conv[2].bias.data = vgg_features[12].bias.data
model.conv3.triple_conv[4].weight.data = vgg_features[14].weight.data
model.conv3.triple_conv[4].bias.data = vgg_features[14].bias.data
model.conv4.triple_conv[0].weight.data = vgg_features[17].weight.data
model.conv4.triple_conv[0].bias.data = vgg_features[17].bias.data
model.conv4.triple_conv[2].weight.data = vgg_features[19].weight.data
model.conv4.triple_conv[2].bias.data = vgg_features[19].bias.data
model.conv4.triple_conv[4].weight.data = vgg_features[21].weight.data
model.conv4.triple_conv[4].bias.data = vgg_features[21].bias.data
model.conv5.triple_conv[0].weight.data = vgg_features[24].weight.data
model.conv5.triple_conv[0].bias.data = vgg_features[24].bias.data
model.conv5.triple_conv[2].weight.data = vgg_features[26].weight.data
model.conv5.triple_conv[2].bias.data = vgg_features[26].bias.data
model.conv5.triple_conv[4].weight.data = vgg_features[28].weight.data
model.conv5.triple_conv[4].bias.data = vgg_features[28].bias.data
def load_model_for_inference(model_path, device):
model = UNet1(in_channels=1, out_channels=2).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
def inference(model, l_channel):
model.eval()
with torch.no_grad():
if len(l_channel.shape) == 3:
l_channel = l_channel.unsqueeze(0) # Add batch dimension
l_tensor = torch.FloatTensor(l_channel).to(device)
ab_pred = model(l_tensor)
return ab_pred.cpu().numpy()
def prepare_test_image(img, dim=150):
if isinstance(img, Image.Image):
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.resize(img, (dim, dim))
sz0, sz1 = img.shape[:2]
R1 = img[:, :, 2].reshape(-1, 1)
G1 = img[:, :, 1].reshape(-1, 1)
B1 = img[:, :, 0].reshape(-1, 1)
L, A, B = rgb2lab2(R1, G1, B1) # LAB2'ye çevir
L = L.reshape(sz0, sz1, 1)
L_tensor = torch.FloatTensor(L).permute(2, 0, 1)
return L_tensor, A.reshape(sz0, sz1), B.reshape(sz0, sz1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "Hyper_U_NET_pytorch-MAE-30Epoch.pth"
test_model = load_model_for_inference(model_path, device)
st.markdown("<h1 style='text-align: center; color: #4CAF50;'>Image Colorization Demo</h1>", unsafe_allow_html=True)
st.markdown(
"<p style='text-align: center; color: gray;'>Grayscale bir görüntü yükleyin, model sizin için renklendirsin.</p>",
unsafe_allow_html=True)
st.markdown(
"""
<style>
.css-18e3th9 {padding-top: 2rem;}
div.stButton > button:first-child {
color: white;
border-radius: 10px;
height: 3em;
width: 100%;
font-size: 16px;
border: none;
transition: 0.3s;
}
div.stButton > button:hover {
background-color: #45a049;
color: white;
}
div.stButton > button:active {
background-color: #3e8e41 !important;
color: white !important;
}
div.stButton > button:focus {
box-shadow: none !important;
outline: none !important;
color: white !important;
}
</style>
""",
unsafe_allow_html=True
)
with st.container():
st.markdown("#### 📂 Grayscale Görüntü Yükle")
uploaded_file = st.file_uploader("Yüklemek için sürükleyip bırakın", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
img = Image.open(uploaded_file).convert("RGB")
l_tensor, A_true, B_true = prepare_test_image(img, dim=150)
ab_pred = inference(test_model, l_tensor)
ab_pred = ab_pred.squeeze(0)
A_pred, B_pred = ab_pred[0], ab_pred[1]
sz0, sz1 = A_pred.shape
L = l_tensor.squeeze().numpy().reshape(-1, 1)
A = A_pred.reshape(-1, 1)
B = B_pred.reshape(-1, 1)
R, G, B = lab22rgb(L, A, B)
R = R.reshape(sz0, sz1)
G = G.reshape(sz0, sz1)
B = B.reshape(sz0, sz1)
rgb_pred = cv2.merge([B, G, R])
new_image = cv2.cvtColor(rgb_pred, cv2.COLOR_BGR2RGB)
new_image2 = cv2.resize(new_image, (img.width, img.height), interpolation=cv2.INTER_LANCZOS4)
if st.button("🎨 Renklendir"):
with st.spinner("Model çalışıyor, lütfen bekleyin..."):
col1, col2 = st.columns(2)
with col1:
st.markdown("**Girdi (Grayscale)**")
st.image(img)
with col2:
st.markdown("**Model Çıkışı (Renkli)**")
st.image(np.array(new_image2))
st.success("Tamamlandı!")