bumchik2's picture
better errors handling
12a7f67
import torch
import torch.nn as nn
import streamlit as st
import os
import numpy as np
from PIL import Image
from torchvision import transforms as tr
from enum import Enum
from sklearn.preprocessing import minmax_scale
from PIL import Image
from io import BytesIO
from huggingface_hub import PyTorchModelHubMixin
class Generator(nn.Module):
def __init__(
self,
channels_multiplier: int = 32
):
super(Generator, self).__init__()
# we just want to use convolutional layers, decrease height x width, then increase it back to the original size
# use instancenorm2d everywhere, as it was described in the article
self.conv1 = nn.Conv2d(3, channels_multiplier, kernel_size=7, stride=1, padding=3) # 256 -> (256 + 6 - 6 - 1) / 1 + 1 = 256
self.norm1 = nn.InstanceNorm2d(channels_multiplier)
self.conv2 = nn.Conv2d(channels_multiplier, channels_multiplier * 2, kernel_size=3, stride=2, padding=1) # 256 -> (256 + 2 - 2 - 1) / 2 + 1 = 128
self.norm2 = nn.InstanceNorm2d(channels_multiplier * 2)
self.conv3 = nn.Conv2d(channels_multiplier * 2, channels_multiplier * 4, kernel_size=3, stride=2, padding=1) # 128 -> (128 + 2 - 2 - 1) / 2 + 1 = 64
self.norm3 = nn.InstanceNorm2d(channels_multiplier * 4)
self.conv4 = nn.Conv2d(channels_multiplier * 4, channels_multiplier * 4, kernel_size=3, stride=1, padding=1) # (64 + 2 - 2 - 1) / 1 + 1 = 64 - we don't change size here
self.norm4 = nn.InstanceNorm2d(channels_multiplier * 4)
self.conv5 = nn.Conv2d(channels_multiplier * 4, channels_multiplier * 4, kernel_size=3, stride=1, padding=1)
self.norm5 = nn.InstanceNorm2d(channels_multiplier * 4)
self.deconv1 = nn.ConvTranspose2d(channels_multiplier * 4, channels_multiplier * 2, kernel_size=3, stride=2, padding=1, output_padding=1) # 64 -> 128
self.denorm1 = nn.InstanceNorm2d(channels_multiplier * 2)
self.deconv2 = nn.ConvTranspose2d(channels_multiplier * 2, channels_multiplier, kernel_size=3, stride=2, padding=1, output_padding=1) # 128 -> 256
self.denorm2 = nn.InstanceNorm2d(channels_multiplier)
self.convlast = nn.Conv2d(channels_multiplier, 3, kernel_size=7, stride=1, padding=3) # 256 -> 256, so nothing changes in the end
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = nn.LeakyReLU()(x)
x = self.conv2(x)
x = self.norm2(x)
x = nn.LeakyReLU()(x)
x = self.conv3(x)
x = self.norm3(x)
x = nn.LeakyReLU()(x)
x = self.conv4(x)
x = self.norm4(x)
x = nn.LeakyReLU()(x)
x = self.conv5(x)
x = self.norm5(x)
x = nn.LeakyReLU()(x)
x = self.deconv1(x)
x = self.denorm1(x)
x = nn.LeakyReLU()(x)
x = self.deconv2(x)
x = self.denorm2(x)
x = nn.LeakyReLU()(x)
x = self.convlast(x)
x = nn.LeakyReLU()(x)
return x
class Discriminator(nn.Module):
def __init__(
self,
channels_multiplier: int = 32
):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=channels_multiplier, kernel_size=4, stride=4, padding=1) # 256 -> (256 + 2 - 3 - 1) / 4 + 1 = 64
self.conv2 = nn.Conv2d(in_channels=channels_multiplier, out_channels=channels_multiplier * 2, kernel_size=4, stride=4, padding=1) # 128 -> 16
self.conv3 = nn.Conv2d(in_channels=channels_multiplier * 2, out_channels=channels_multiplier * 4, kernel_size=4, stride=4, padding=1) # 16 -> 4
self.conv4 = nn.Conv2d(in_channels=channels_multiplier * 4, out_channels=channels_multiplier * 8, kernel_size=4, stride=4, padding=1) # (4 + 2 - 3 - 1) / 4 + 1 = 1
self.conv5 = nn.Conv2d(in_channels=channels_multiplier * 8, out_channels=1, kernel_size=1, stride=1, padding=0) # height, width don't change here: (1 - 0 - 1) / 1 + 1
def forward(self, x):
x = self.conv1(x)
x = nn.LeakyReLU()(x)
x = self.conv2(x)
x = nn.LeakyReLU()(x)
x = self.conv3(x)
x = nn.LeakyReLU()(x)
x = self.conv4(x)
x = nn.LeakyReLU()(x)
x = self.conv5(x)
x = nn.Flatten()(x)
x = nn.Sigmoid()(x)
return x
class CycleGAN(
nn.Module,
PyTorchModelHubMixin
):
def __init__(
self,
channels_multiplier_generator: int = 32,
channels_multiplier_discriminator: int = 64
):
super(CycleGAN, self).__init__()
self.generator_X_to_Y = Generator(channels_multiplier=channels_multiplier_generator)
self.generator_Y_to_X = Generator(channels_multiplier=channels_multiplier_generator)
self.discriminator_X = Discriminator(channels_multiplier=channels_multiplier_discriminator)
self.discriminator_Y = Discriminator(channels_multiplier=channels_multiplier_discriminator)
def forward(
self,
x
):
fake = self.generator_X_to_Y(x)
return self.generator_Y_to_X(fake)
def forward_Y_to_X(
self,
x
):
fake = self.generator_Y_to_X(x)
return self.generator_X_to_Y(fake)
@st.cache_resource # кэширование
def load_model():
model = CycleGAN(channels_multiplier_discriminator=64, channels_multiplier_generator=32).from_pretrained(
"bumchik2/summer-to-winter-model"
)
model.eval()
return model
class Space(Enum):
A = 'A'
B = 'B'
SPACE_TO_MEAN = {
Space.A: np.array([0.40429478, 0.40832175, 0.3835889]),
Space.B: np.array([0.45099882, 0.42138782, 0.40148178])
}
SPACE_TO_STD = {
Space.A: np.array([0.29130578, 0.25078464, 0.26218044]),
Space.B: np.array([0.29352425, 0.26508255, 0.27024732])
}
def get_transform(space: Space):
test_transform = tr.Compose([
tr.ToPILImage(),
tr.Resize((256, 256)),
tr.ToTensor(),
tr.Normalize(mean=SPACE_TO_MEAN[space], std=SPACE_TO_STD[space])
])
return test_transform
def de_normalize(
img,
space: Space
):
return (minmax_scale(
(img.reshape(3, -1) + SPACE_TO_MEAN[space][:, None]) * SPACE_TO_STD[space][:, None],
feature_range=(0., 1.),
axis=1,
).reshape(*img.shape).transpose(1, 2, 0) * 255).astype(int)
model: CycleGAN = load_model()
mode = st.radio(
'Выберите, как вы хотите преобразовать изображение',
['summer to winter', 'winter to summer'],
captions=[
'summer to winter',
'winter to summer'
],
)
uploaded_file = st.file_uploader('Загрузите картинку', accept_multiple_files=False)
if uploaded_file is not None:
space_source, space_target = Space.A, Space.B
if mode == 'winter to summer':
space_source, space_target = Space.B, Space.A
try:
bytes_data = uploaded_file.read()
image = Image.open(BytesIO(bytes_data))
st.write('Исходное изображение:')
st.image(image)
except Exception:
st.write('Не удалось корректно распознать изображение')
else:
transform = get_transform(space=space_source)
image_array = np.array(image)
if image_array.shape[-1] == 4:
image_array = image_array[:,:,:3]
image_transformed = transform(image_array)
with torch.no_grad():
if mode == 'summer to winter':
result = model.generator_X_to_Y(image_transformed).numpy()
else:
result = model.generator_Y_to_X(image_transformed).numpy()
result_image = de_normalize(
result,
space_target
)
st.write(f'Результат {mode}:')
st.image(result_image)