Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import streamlit as st | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms as tr | |
| from enum import Enum | |
| from sklearn.preprocessing import minmax_scale | |
| from PIL import Image | |
| from io import BytesIO | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class Generator(nn.Module): | |
| def __init__( | |
| self, | |
| channels_multiplier: int = 32 | |
| ): | |
| super(Generator, self).__init__() | |
| # we just want to use convolutional layers, decrease height x width, then increase it back to the original size | |
| # use instancenorm2d everywhere, as it was described in the article | |
| self.conv1 = nn.Conv2d(3, channels_multiplier, kernel_size=7, stride=1, padding=3) # 256 -> (256 + 6 - 6 - 1) / 1 + 1 = 256 | |
| self.norm1 = nn.InstanceNorm2d(channels_multiplier) | |
| self.conv2 = nn.Conv2d(channels_multiplier, channels_multiplier * 2, kernel_size=3, stride=2, padding=1) # 256 -> (256 + 2 - 2 - 1) / 2 + 1 = 128 | |
| self.norm2 = nn.InstanceNorm2d(channels_multiplier * 2) | |
| self.conv3 = nn.Conv2d(channels_multiplier * 2, channels_multiplier * 4, kernel_size=3, stride=2, padding=1) # 128 -> (128 + 2 - 2 - 1) / 2 + 1 = 64 | |
| self.norm3 = nn.InstanceNorm2d(channels_multiplier * 4) | |
| self.conv4 = nn.Conv2d(channels_multiplier * 4, channels_multiplier * 4, kernel_size=3, stride=1, padding=1) # (64 + 2 - 2 - 1) / 1 + 1 = 64 - we don't change size here | |
| self.norm4 = nn.InstanceNorm2d(channels_multiplier * 4) | |
| self.conv5 = nn.Conv2d(channels_multiplier * 4, channels_multiplier * 4, kernel_size=3, stride=1, padding=1) | |
| self.norm5 = nn.InstanceNorm2d(channels_multiplier * 4) | |
| self.deconv1 = nn.ConvTranspose2d(channels_multiplier * 4, channels_multiplier * 2, kernel_size=3, stride=2, padding=1, output_padding=1) # 64 -> 128 | |
| self.denorm1 = nn.InstanceNorm2d(channels_multiplier * 2) | |
| self.deconv2 = nn.ConvTranspose2d(channels_multiplier * 2, channels_multiplier, kernel_size=3, stride=2, padding=1, output_padding=1) # 128 -> 256 | |
| self.denorm2 = nn.InstanceNorm2d(channels_multiplier) | |
| self.convlast = nn.Conv2d(channels_multiplier, 3, kernel_size=7, stride=1, padding=3) # 256 -> 256, so nothing changes in the end | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.norm1(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv2(x) | |
| x = self.norm2(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv3(x) | |
| x = self.norm3(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv4(x) | |
| x = self.norm4(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv5(x) | |
| x = self.norm5(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.deconv1(x) | |
| x = self.denorm1(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.deconv2(x) | |
| x = self.denorm2(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.convlast(x) | |
| x = nn.LeakyReLU()(x) | |
| return x | |
| class Discriminator(nn.Module): | |
| def __init__( | |
| self, | |
| channels_multiplier: int = 32 | |
| ): | |
| super(Discriminator, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=channels_multiplier, kernel_size=4, stride=4, padding=1) # 256 -> (256 + 2 - 3 - 1) / 4 + 1 = 64 | |
| self.conv2 = nn.Conv2d(in_channels=channels_multiplier, out_channels=channels_multiplier * 2, kernel_size=4, stride=4, padding=1) # 128 -> 16 | |
| self.conv3 = nn.Conv2d(in_channels=channels_multiplier * 2, out_channels=channels_multiplier * 4, kernel_size=4, stride=4, padding=1) # 16 -> 4 | |
| self.conv4 = nn.Conv2d(in_channels=channels_multiplier * 4, out_channels=channels_multiplier * 8, kernel_size=4, stride=4, padding=1) # (4 + 2 - 3 - 1) / 4 + 1 = 1 | |
| self.conv5 = nn.Conv2d(in_channels=channels_multiplier * 8, out_channels=1, kernel_size=1, stride=1, padding=0) # height, width don't change here: (1 - 0 - 1) / 1 + 1 | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv2(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv3(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv4(x) | |
| x = nn.LeakyReLU()(x) | |
| x = self.conv5(x) | |
| x = nn.Flatten()(x) | |
| x = nn.Sigmoid()(x) | |
| return x | |
| class CycleGAN( | |
| nn.Module, | |
| PyTorchModelHubMixin | |
| ): | |
| def __init__( | |
| self, | |
| channels_multiplier_generator: int = 32, | |
| channels_multiplier_discriminator: int = 64 | |
| ): | |
| super(CycleGAN, self).__init__() | |
| self.generator_X_to_Y = Generator(channels_multiplier=channels_multiplier_generator) | |
| self.generator_Y_to_X = Generator(channels_multiplier=channels_multiplier_generator) | |
| self.discriminator_X = Discriminator(channels_multiplier=channels_multiplier_discriminator) | |
| self.discriminator_Y = Discriminator(channels_multiplier=channels_multiplier_discriminator) | |
| def forward( | |
| self, | |
| x | |
| ): | |
| fake = self.generator_X_to_Y(x) | |
| return self.generator_Y_to_X(fake) | |
| def forward_Y_to_X( | |
| self, | |
| x | |
| ): | |
| fake = self.generator_Y_to_X(x) | |
| return self.generator_X_to_Y(fake) | |
| # кэширование | |
| def load_model(): | |
| model = CycleGAN(channels_multiplier_discriminator=64, channels_multiplier_generator=32).from_pretrained( | |
| "bumchik2/summer-to-winter-model" | |
| ) | |
| model.eval() | |
| return model | |
| class Space(Enum): | |
| A = 'A' | |
| B = 'B' | |
| SPACE_TO_MEAN = { | |
| Space.A: np.array([0.40429478, 0.40832175, 0.3835889]), | |
| Space.B: np.array([0.45099882, 0.42138782, 0.40148178]) | |
| } | |
| SPACE_TO_STD = { | |
| Space.A: np.array([0.29130578, 0.25078464, 0.26218044]), | |
| Space.B: np.array([0.29352425, 0.26508255, 0.27024732]) | |
| } | |
| def get_transform(space: Space): | |
| test_transform = tr.Compose([ | |
| tr.ToPILImage(), | |
| tr.Resize((256, 256)), | |
| tr.ToTensor(), | |
| tr.Normalize(mean=SPACE_TO_MEAN[space], std=SPACE_TO_STD[space]) | |
| ]) | |
| return test_transform | |
| def de_normalize( | |
| img, | |
| space: Space | |
| ): | |
| return (minmax_scale( | |
| (img.reshape(3, -1) + SPACE_TO_MEAN[space][:, None]) * SPACE_TO_STD[space][:, None], | |
| feature_range=(0., 1.), | |
| axis=1, | |
| ).reshape(*img.shape).transpose(1, 2, 0) * 255).astype(int) | |
| model: CycleGAN = load_model() | |
| mode = st.radio( | |
| 'Выберите, как вы хотите преобразовать изображение', | |
| ['summer to winter', 'winter to summer'], | |
| captions=[ | |
| 'summer to winter', | |
| 'winter to summer' | |
| ], | |
| ) | |
| uploaded_file = st.file_uploader('Загрузите картинку', accept_multiple_files=False) | |
| if uploaded_file is not None: | |
| space_source, space_target = Space.A, Space.B | |
| if mode == 'winter to summer': | |
| space_source, space_target = Space.B, Space.A | |
| try: | |
| bytes_data = uploaded_file.read() | |
| image = Image.open(BytesIO(bytes_data)) | |
| st.write('Исходное изображение:') | |
| st.image(image) | |
| except Exception: | |
| st.write('Не удалось корректно распознать изображение') | |
| else: | |
| transform = get_transform(space=space_source) | |
| image_array = np.array(image) | |
| if image_array.shape[-1] == 4: | |
| image_array = image_array[:,:,:3] | |
| image_transformed = transform(image_array) | |
| with torch.no_grad(): | |
| if mode == 'summer to winter': | |
| result = model.generator_X_to_Y(image_transformed).numpy() | |
| else: | |
| result = model.generator_Y_to_X(image_transformed).numpy() | |
| result_image = de_normalize( | |
| result, | |
| space_target | |
| ) | |
| st.write(f'Результат {mode}:') | |
| st.image(result_image) | |