Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import time | |
| import cv2 | |
| import numpy as np | |
| # model part | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms as tr | |
| from torchvision.transforms import v2 | |
| from sklearn.preprocessing import minmax_scale | |
| from collections import OrderedDict | |
| st.session_state.image = None | |
| st.session_state.calls = 0 | |
| def get_transforms(mean, std): | |
| val_transform = tr.Compose([ | |
| tr.ToPILImage(), | |
| v2.Resize(size=256), | |
| tr.ToTensor(), | |
| #..., | |
| tr.Normalize(mean=mean, std=std) | |
| ]) | |
| def de_normalize(img): | |
| if isinstance(img, torch.Tensor): | |
| image = img.cpu() | |
| else: | |
| image = img | |
| return minmax_scale( | |
| (image.reshape(3, -1) + mean[:, None]) * std[:, None], | |
| feature_range=(0., 1.), | |
| axis=1, | |
| ).reshape(*img.shape).transpose(1, 2, 0) | |
| return val_transform, de_normalize | |
| class Conv7Stride1(nn.Module): | |
| def __init__(self, in_channels, out_channels, use_norm=True): | |
| super(Conv7Stride1, self).__init__() | |
| if use_norm: | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('pad', nn.ReflectionPad2d(3)), | |
| ('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)), | |
| ('norm', nn.InstanceNorm2d(out_channels)), | |
| ('relu', nn.ReLU()) | |
| ])) | |
| else: | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('pad', nn.ReflectionPad2d(3)), | |
| ('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)), | |
| ('tanh', nn.Tanh()) | |
| ])) | |
| def forward(self, x): | |
| return self.model(x) | |
| class Down(nn.Module): | |
| def __init__(self, k): | |
| super(Down, self).__init__() | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('conv', torch.nn.Conv2d(k//2, k, kernel_size=3, stride=2, padding=1)), | |
| ('norm', nn.InstanceNorm2d(k)), | |
| ('relu', nn.ReLU()) | |
| ])) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ResBlock(nn.Module): | |
| def __init__(self, k, use_dropout=False): | |
| super(ResBlock, self).__init__() | |
| self.blocks = [] | |
| for _ in range(2): | |
| self.blocks += [nn.Sequential(OrderedDict([ | |
| ('pad', nn.ReflectionPad2d(1)), | |
| ('conv', torch.nn.Conv2d(k, k, kernel_size=3)), | |
| ('dropout', nn.BatchNorm2d(k)), | |
| ('relu', nn.ReLU()) | |
| ]))] | |
| if use_dropout: | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('block1', self.blocks[0]), | |
| ('dropout', nn.Dropout(0.5)), | |
| ('block2', self.blocks[1]) | |
| ])) | |
| else: | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('block1', self.blocks[0]), | |
| ('block2', self.blocks[1]) | |
| ])) | |
| def forward(self, x): | |
| return (x + self.model(x)) | |
| class Up(nn.Module): | |
| def __init__(self, k): | |
| super(Up, self).__init__() | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('conv_transpose', nn.ConvTranspose2d(2*k, k, kernel_size=3, padding=1, output_padding=1, stride=2)), | |
| ('norm', nn.InstanceNorm2d(k)), | |
| ('relu', nn.ReLU()) | |
| ])) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ResGenerator(nn.Module): | |
| def __init__(self, res_blocks=9, use_dropout=False): | |
| super(ResGenerator, self).__init__() | |
| self.residual_blocks = nn.Sequential(OrderedDict([ | |
| (f'R256_{i+1}', ResBlock(256, use_dropout=use_dropout)) for i in range(res_blocks) | |
| ])) | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('c7s1-64', Conv7Stride1(3, 64)), | |
| ('d128', Down(128)), | |
| ('d256', Down(256)), | |
| ('res_blocks', self.residual_blocks), | |
| ('u128', Up(128)), | |
| ('u64', Up(64)), | |
| ('c7s1-3', Conv7Stride1(64, 3, use_norm=False)) | |
| ])) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ConvForDisc(nn.Module): | |
| def __init__(self, *channels, stride=2, use_norm=True): | |
| super(ConvForDisc, self).__init__() | |
| if len(channels) == 1: | |
| channels = (channels[0] // 2, channels[0]) | |
| if use_norm: | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)), | |
| ('norm', nn.InstanceNorm2d(channels[1])), | |
| ('relu', nn.LeakyReLU(0.2, True)) | |
| ])) | |
| else: | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)), | |
| ('relu', nn.LeakyReLU(0.2, True)) | |
| ])) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ConvDiscriminator(nn.Module): | |
| def __init__(self): | |
| super(ConvDiscriminator, self).__init__() | |
| self.model = nn.Sequential(OrderedDict([ | |
| ('C64', ConvForDisc(3, 64, use_norm=False)), | |
| ('C128', ConvForDisc(128)), | |
| ('C256', ConvForDisc(256)), | |
| ('C512', ConvForDisc(512, stride=1)), | |
| ('conv1channel', nn.Conv2d(512, 1, kernel_size=4, padding=1)) | |
| ])) | |
| def forward(self, x): | |
| # predicts logits | |
| return torch.flatten(self.model(x), start_dim=1) | |
| class CycleGAN(nn.Module): | |
| def __init__(self, res_blocks=9, use_dropout=False): | |
| super(CycleGAN, self).__init__() | |
| self.a2b_generator = ResGenerator(res_blocks=9, use_dropout=False) | |
| self.a_discriminator = ConvDiscriminator() | |
| self.b2a_generator = ResGenerator(res_blocks=9, use_dropout=False) | |
| self.b_discriminator = ConvDiscriminator() | |
| def load_model(): | |
| checkpoint = torch.load('cycle_gan#21.pt', weights_only=False, | |
| map_location=torch.device('cpu')) | |
| model = CycleGAN() | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| return model | |
| mean_night = np.array([0.46207718, 0.52259593, 0.54372674]) | |
| mean_day = np.array([0.18620284, 0.18614635, 0.20172116]) | |
| std_night = np.array([0.21945059, 0.20839803, 0.2328357 ]) | |
| std_day = np.array([0.16982935, 0.14963816, 0.14965146]) | |
| # front part | |
| st.markdown("<h1 style='text-align: center;'>Change daytime!</h1>", unsafe_allow_html=True) | |
| def add_calls(): | |
| st.session_state.calls += 1 | |
| st.write(f'{st.session_state.calls=}') | |
| def convert_day2night(): | |
| image = st.session_state.image | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("Left Column") | |
| st.image(opencv_image, channels="BGR", use_container_width=True) | |
| with col2: | |
| st.write("Center Column") | |
| model = load_model() | |
| with torch.no_grad(): | |
| channel_mean = (image / 255.).mean() | |
| transform, de_norm = get_transforms(mean_day, std_day) | |
| batch = transform(image)[None, :, :, :] | |
| batch_tr = model.a2b_generator(batch) | |
| img_tr = de_norm(batch_tr[0, :, :, :]) | |
| st.write(img_tr.shape) | |
| st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True) | |
| def convert_night2day(): | |
| image = st.session_state.image | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("Left Column") | |
| st.image(opencv_image, channels="BGR", use_container_width=True) | |
| with col2: | |
| st.write("Center Column") | |
| model = load_model() | |
| with torch.no_grad(): | |
| transform, de_norm = get_transforms(mean_night, std_night) | |
| batch = transform(image)[None, :, :, :] | |
| batch_tr = model.b2a_generator(batch) | |
| img_tr = de_norm(batch_tr[0, :, :, :]) | |
| st.write(img_tr.shape) | |
| st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True) | |
| def zero_calls(): | |
| st.session_state.calls = 0 | |
| st.session_state.option = st.selectbox('day2night OR night2day', ['day2night', 'night2day']) | |
| uploaded_file = st.file_uploader("Choose a image file", type="jpg") | |
| if uploaded_file is not None: | |
| # Convert the file to an opencv image. | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| opencv_image = cv2.imdecode(file_bytes, 1) | |
| st.session_state.image = np.asarray(opencv_image) | |
| image = st.session_state.image | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("Original") | |
| st.image(opencv_image, channels="BGR", use_container_width=True) | |
| with col2: | |
| st.write("Transformed") | |
| model = load_model() | |
| with torch.no_grad(): | |
| if st.session_state.option == 'day2night': | |
| channel_mean = (image / 255.).mean() | |
| transform, de_norm = get_transforms(mean_day, std_day) | |
| batch = transform(image)[None, :, :, :] | |
| batch_tr = model.a2b_generator(batch) | |
| img_tr = de_norm(batch_tr[0, :, :, :]) | |
| st.image(img_tr, channels="BGR", use_container_width=True, clamp=True) | |
| else: | |
| transform, de_norm = get_transforms(mean_night, std_night) | |
| batch = transform(image)[None, :, :, :] | |
| batch_tr = model.b2a_generator(batch) | |
| img_tr = de_norm(batch_tr[0, :, :, :]) | |
| st.image(img_tr, channels="BGR", use_container_width=True, clamp=True) |