import torch import torch.nn as nn import streamlit as st import os import numpy as np from PIL import Image from torchvision import transforms as tr from enum import Enum from sklearn.preprocessing import minmax_scale from PIL import Image from io import BytesIO from huggingface_hub import PyTorchModelHubMixin class Generator(nn.Module): def __init__( self, channels_multiplier: int = 32 ): super(Generator, self).__init__() # we just want to use convolutional layers, decrease height x width, then increase it back to the original size # use instancenorm2d everywhere, as it was described in the article self.conv1 = nn.Conv2d(3, channels_multiplier, kernel_size=7, stride=1, padding=3) # 256 -> (256 + 6 - 6 - 1) / 1 + 1 = 256 self.norm1 = nn.InstanceNorm2d(channels_multiplier) self.conv2 = nn.Conv2d(channels_multiplier, channels_multiplier * 2, kernel_size=3, stride=2, padding=1) # 256 -> (256 + 2 - 2 - 1) / 2 + 1 = 128 self.norm2 = nn.InstanceNorm2d(channels_multiplier * 2) self.conv3 = nn.Conv2d(channels_multiplier * 2, channels_multiplier * 4, kernel_size=3, stride=2, padding=1) # 128 -> (128 + 2 - 2 - 1) / 2 + 1 = 64 self.norm3 = nn.InstanceNorm2d(channels_multiplier * 4) self.conv4 = nn.Conv2d(channels_multiplier * 4, channels_multiplier * 4, kernel_size=3, stride=1, padding=1) # (64 + 2 - 2 - 1) / 1 + 1 = 64 - we don't change size here self.norm4 = nn.InstanceNorm2d(channels_multiplier * 4) self.conv5 = nn.Conv2d(channels_multiplier * 4, channels_multiplier * 4, kernel_size=3, stride=1, padding=1) self.norm5 = nn.InstanceNorm2d(channels_multiplier * 4) self.deconv1 = nn.ConvTranspose2d(channels_multiplier * 4, channels_multiplier * 2, kernel_size=3, stride=2, padding=1, output_padding=1) # 64 -> 128 self.denorm1 = nn.InstanceNorm2d(channels_multiplier * 2) self.deconv2 = nn.ConvTranspose2d(channels_multiplier * 2, channels_multiplier, kernel_size=3, stride=2, padding=1, output_padding=1) # 128 -> 256 self.denorm2 = nn.InstanceNorm2d(channels_multiplier) self.convlast = nn.Conv2d(channels_multiplier, 3, kernel_size=7, stride=1, padding=3) # 256 -> 256, so nothing changes in the end def forward(self, x): x = self.conv1(x) x = self.norm1(x) x = nn.LeakyReLU()(x) x = self.conv2(x) x = self.norm2(x) x = nn.LeakyReLU()(x) x = self.conv3(x) x = self.norm3(x) x = nn.LeakyReLU()(x) x = self.conv4(x) x = self.norm4(x) x = nn.LeakyReLU()(x) x = self.conv5(x) x = self.norm5(x) x = nn.LeakyReLU()(x) x = self.deconv1(x) x = self.denorm1(x) x = nn.LeakyReLU()(x) x = self.deconv2(x) x = self.denorm2(x) x = nn.LeakyReLU()(x) x = self.convlast(x) x = nn.LeakyReLU()(x) return x class Discriminator(nn.Module): def __init__( self, channels_multiplier: int = 32 ): super(Discriminator, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=channels_multiplier, kernel_size=4, stride=4, padding=1) # 256 -> (256 + 2 - 3 - 1) / 4 + 1 = 64 self.conv2 = nn.Conv2d(in_channels=channels_multiplier, out_channels=channels_multiplier * 2, kernel_size=4, stride=4, padding=1) # 128 -> 16 self.conv3 = nn.Conv2d(in_channels=channels_multiplier * 2, out_channels=channels_multiplier * 4, kernel_size=4, stride=4, padding=1) # 16 -> 4 self.conv4 = nn.Conv2d(in_channels=channels_multiplier * 4, out_channels=channels_multiplier * 8, kernel_size=4, stride=4, padding=1) # (4 + 2 - 3 - 1) / 4 + 1 = 1 self.conv5 = nn.Conv2d(in_channels=channels_multiplier * 8, out_channels=1, kernel_size=1, stride=1, padding=0) # height, width don't change here: (1 - 0 - 1) / 1 + 1 def forward(self, x): x = self.conv1(x) x = nn.LeakyReLU()(x) x = self.conv2(x) x = nn.LeakyReLU()(x) x = self.conv3(x) x = nn.LeakyReLU()(x) x = self.conv4(x) x = nn.LeakyReLU()(x) x = self.conv5(x) x = nn.Flatten()(x) x = nn.Sigmoid()(x) return x class CycleGAN( nn.Module, PyTorchModelHubMixin ): def __init__( self, channels_multiplier_generator: int = 32, channels_multiplier_discriminator: int = 64 ): super(CycleGAN, self).__init__() self.generator_X_to_Y = Generator(channels_multiplier=channels_multiplier_generator) self.generator_Y_to_X = Generator(channels_multiplier=channels_multiplier_generator) self.discriminator_X = Discriminator(channels_multiplier=channels_multiplier_discriminator) self.discriminator_Y = Discriminator(channels_multiplier=channels_multiplier_discriminator) def forward( self, x ): fake = self.generator_X_to_Y(x) return self.generator_Y_to_X(fake) def forward_Y_to_X( self, x ): fake = self.generator_Y_to_X(x) return self.generator_X_to_Y(fake) @st.cache_resource # кэширование def load_model(): model = CycleGAN(channels_multiplier_discriminator=64, channels_multiplier_generator=32).from_pretrained( "bumchik2/summer-to-winter-model" ) model.eval() return model class Space(Enum): A = 'A' B = 'B' SPACE_TO_MEAN = { Space.A: np.array([0.40429478, 0.40832175, 0.3835889]), Space.B: np.array([0.45099882, 0.42138782, 0.40148178]) } SPACE_TO_STD = { Space.A: np.array([0.29130578, 0.25078464, 0.26218044]), Space.B: np.array([0.29352425, 0.26508255, 0.27024732]) } def get_transform(space: Space): test_transform = tr.Compose([ tr.ToPILImage(), tr.Resize((256, 256)), tr.ToTensor(), tr.Normalize(mean=SPACE_TO_MEAN[space], std=SPACE_TO_STD[space]) ]) return test_transform def de_normalize( img, space: Space ): return (minmax_scale( (img.reshape(3, -1) + SPACE_TO_MEAN[space][:, None]) * SPACE_TO_STD[space][:, None], feature_range=(0., 1.), axis=1, ).reshape(*img.shape).transpose(1, 2, 0) * 255).astype(int) model: CycleGAN = load_model() mode = st.radio( 'Выберите, как вы хотите преобразовать изображение', ['summer to winter', 'winter to summer'], captions=[ 'summer to winter', 'winter to summer' ], ) uploaded_file = st.file_uploader('Загрузите картинку', accept_multiple_files=False) if uploaded_file is not None: space_source, space_target = Space.A, Space.B if mode == 'winter to summer': space_source, space_target = Space.B, Space.A try: bytes_data = uploaded_file.read() image = Image.open(BytesIO(bytes_data)) st.write('Исходное изображение:') st.image(image) except Exception: st.write('Не удалось корректно распознать изображение') else: transform = get_transform(space=space_source) image_array = np.array(image) if image_array.shape[-1] == 4: image_array = image_array[:,:,:3] image_transformed = transform(image_array) with torch.no_grad(): if mode == 'summer to winter': result = model.generator_X_to_Y(image_transformed).numpy() else: result = model.generator_Y_to_X(image_transformed).numpy() result_image = de_normalize( result, space_target ) st.write(f'Результат {mode}:') st.image(result_image)