Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as tr | |
| import os | |
| import torch.nn as nn | |
| from torchvision.transforms import v2 | |
| from sklearn.preprocessing import minmax_scale | |
| import itertools | |
| import functools | |
| device = torch.device('cpu') | |
| import numpy as np | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.optim.lr_scheduler import StepLR | |
| # Дискриминаторы обучаются быстро.. почему бы и не использовать? | |
| class NLayerDiscriminator(nn.Module): | |
| """Defines a PatchGAN discriminator""" | |
| def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): | |
| """Construct a PatchGAN discriminator | |
| Parameters: | |
| input_nc (int) -- the number of channels in input images | |
| ndf (int) -- the number of filters in the last conv layer | |
| n_layers (int) -- the number of conv layers in the discriminator | |
| norm_layer -- normalization layer | |
| """ | |
| super(NLayerDiscriminator, self).__init__() | |
| if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters | |
| use_bias = norm_layer.func == nn.InstanceNorm2d | |
| else: | |
| use_bias = norm_layer == nn.InstanceNorm2d | |
| kw = 4 | |
| padw = 1 | |
| sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] | |
| nf_mult = 1 | |
| nf_mult_prev = 1 | |
| for n in range(1, n_layers): # gradually increase the number of filters | |
| nf_mult_prev = nf_mult | |
| nf_mult = min(2 ** n, 8) | |
| sequence += [ | |
| nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), | |
| norm_layer(ndf * nf_mult), | |
| nn.LeakyReLU(0.2, True) | |
| ] | |
| nf_mult_prev = nf_mult | |
| nf_mult = min(2 ** n_layers, 8) | |
| sequence += [ | |
| nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), | |
| norm_layer(ndf * nf_mult), | |
| nn.LeakyReLU(0.2, True) | |
| ] | |
| sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map | |
| self.model = nn.Sequential(*sequence) | |
| def forward(self, input): | |
| """Standard forward.""" | |
| return self.model(input) | |
| class DeepGenerator(nn.Module): | |
| def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=6, use_dropout=False): | |
| super(DeepGenerator, self).__init__() | |
| # Энкодер | |
| self.enc1 = nn.Sequential( | |
| nn.Conv2d(input_nc, ngf, 4, 2, 1), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| self.enc2 = nn.Sequential( | |
| nn.Conv2d(ngf, ngf*2, 4, 2, 1), | |
| nn.InstanceNorm2d(ngf*2), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| self.enc3 = nn.Sequential( | |
| nn.Conv2d(ngf*2, ngf*4, 4, 2, 1), | |
| nn.InstanceNorm2d(ngf*4), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| # Промежуточные residual-блоки | |
| self.res_blocks = nn.Sequential( | |
| *[ResidualBlock(ngf*4, use_dropout=use_dropout) for _ in range(n_blocks)] | |
| ) | |
| # Декодер с skip-connections (U-Net like) | |
| self.dec1 = nn.Sequential( | |
| nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1), | |
| nn.InstanceNorm2d(ngf*2), | |
| nn.ReLU(True) | |
| ) | |
| self.dec2 = nn.Sequential( | |
| nn.ConvTranspose2d(ngf*2*2, ngf, 4, 2, 1), # x2 из-за skip-connection | |
| nn.InstanceNorm2d(ngf), | |
| nn.ReLU(True) | |
| ) | |
| self.dec3 = nn.Sequential( | |
| nn.ConvTranspose2d(ngf*2, output_nc, 4, 2, 1), # x2 из-за skip-connection | |
| nn.Tanh() | |
| ) | |
| # Дополнительные слои | |
| self.dropout = nn.Dropout(0.5) if use_dropout else None | |
| def forward(self, x): | |
| # Энкодер | |
| e1 = self.enc1(x) # [batch, ngf, H/2, W/2] | |
| e2 = self.enc2(e1) # [batch, ngf*2, H/4, W/4] | |
| e3 = self.enc3(e2) # [batch, ngf*4, H/8, W/8] | |
| # Residual-блоки | |
| m = self.res_blocks(e3) | |
| if self.dropout: | |
| m = self.dropout(m) | |
| # Декодер с skip-connections | |
| d1 = self.dec1(m) # [batch, ngf*2, H/4, W/4] | |
| d1 = torch.cat([d1, e2], dim=1) # Skip-connection | |
| d2 = self.dec2(d1) # [batch, ngf, H/2, W/2] | |
| d2 = torch.cat([d2, e1], dim=1) # Skip-connection | |
| output = self.dec3(d2) # [batch, output_nc, H, W] | |
| return output | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, channels, use_dropout=False): | |
| super(ResidualBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(channels, channels, 3, 1, 1), | |
| nn.InstanceNorm2d(channels), | |
| nn.ReLU(True), | |
| nn.Dropout(0.5) if use_dropout else nn.Identity(), | |
| nn.Conv2d(channels, channels, 3, 1, 1), | |
| nn.InstanceNorm2d(channels) | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) # Skip-connection | |
| class CycleGAN(nn.Module): | |
| def __init__(self): | |
| super(CycleGAN, self).__init__() | |
| # Волевое решение, что все изображения трёхканальные, культ RGB все дела | |
| self.D_A = NLayerDiscriminator(input_nc=3, n_layers= 3) | |
| self.D_B = NLayerDiscriminator(input_nc=3, n_layers= 3) | |
| # self.G_A = ResnetGenerator(input_nc=3, output_nc=3) | |
| # self.G_B = ResnetGenerator(input_nc=3, output_nc=3) | |
| # self.G_A = SimpleGenerator(input_nc=3, output_nc=3) | |
| # self.G_B = SimpleGenerator(input_nc=3, output_nc=3) | |
| self.G_A = DeepGenerator(input_nc= 3, output_nc= 3) | |
| self.G_B = DeepGenerator(input_nc= 3, output_nc= 3) | |
| def create_model_and_optimizer(model_class, model_params, lr_d=1e-4, lr_g=1.e-3, weight_decay=1.e-5, device=device): | |
| model = model_class(**model_params) | |
| model = model.to(device) | |
| optimizer_d = torch.optim.Adam( | |
| itertools.chain(model.D_A.parameters(), model.D_B.parameters()), | |
| lr = lr_d, | |
| weight_decay = weight_decay, | |
| ) | |
| optimizer_g = torch.optim.Adam( | |
| itertools.chain(model.G_A.parameters(), model.G_B.parameters()), | |
| lr = lr_g, | |
| weight_decay=weight_decay, | |
| ) | |
| return model, optimizer_d, optimizer_g | |
| model_name = "cycle_gan_200e_fallout#1" | |
| checkpoint = torch.load(os.path.join("./model", f"{model_name}.pt"), weights_only = False, map_location=torch.device('cpu')) | |
| # Создаём те же классы, что и внутри чекпоинта | |
| total_epochs = 200 | |
| half_epochs = total_epochs/2 | |
| def get_linear_scheduler(optimizer, start_epoch, n_epochs): | |
| def lr_lambda(epoch): | |
| return 1.0 - max(0, epoch - start_epoch) / (n_epochs - start_epoch) | |
| return LambdaLR(optimizer, lr_lambda) | |
| model_params = dict() | |
| model, optimizer_d, optimizer_g = create_model_and_optimizer( | |
| model_class = CycleGAN, | |
| model_params = model_params, | |
| lr_d = 1e-3, | |
| lr_g= 1e-3, | |
| device = device, | |
| ) | |
| scheduler_d = get_linear_scheduler(optimizer_d, start_epoch=half_epochs, n_epochs=total_epochs) | |
| scheduler_g = get_linear_scheduler(optimizer_g, start_epoch=half_epochs, n_epochs=total_epochs) | |
| # Загружаем состояния из чекпоинта | |
| def load_model(): | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model | |
| load_model() | |
| # Функция для получения train и val transform-ов, а так же функции для де-нормализации изображения | |
| def get_transforms(mean, std, random_crop_size, p): | |
| train_transform = tr.Compose([ | |
| tr.ToPILImage(), | |
| tr.RandomApply([ | |
| tr.ColorJitter(0.2, 0.2, 0.2, 0.01) # Изменение яркости, контраста и т.д. | |
| ], p=0.8), | |
| v2.RandomCrop(size=random_crop_size), | |
| v2.RandomHorizontalFlip(p=p), | |
| tr.ToTensor(), | |
| tr.Normalize(mean=mean, std=std), | |
| ]) | |
| val_transform = tr.Compose([ | |
| # tr.ToPILImage(), | |
| tr.Resize((random_crop_size, random_crop_size)), | |
| tr.ToTensor(), | |
| tr.Normalize(mean=mean, std=std), | |
| ]) | |
| def de_normalize(img, normalized = True): | |
| with torch.no_grad(): | |
| if normalized: | |
| img = img.cpu().permute(1,2,0) | |
| return minmax_scale( | |
| (img.reshape(3, -1) + mean[:, None]) * std[:, None], | |
| feature_range=(0., 1.), | |
| axis=1, | |
| ).reshape(*img.shape) | |
| return img | |
| # def denormalize(tensor): | |
| # tensor = tensor.clone().squeeze() | |
| # for t, m, s in zip(tensor, mean, std): | |
| # t.mul_(s).add_(m) # Обратное преобразование: (t * std) + mean | |
| # tensor = tensor.permute(1, 2, 0) * 255 | |
| # return tensor.cpu().numpy().astype(np.uint8) | |
| # def denormalize(tensor): | |
| # return tensor.cpu().numpy() | |
| return train_transform, val_transform, de_normalize | |
| hyperparams = dict( | |
| random_crop_size = 256, | |
| p = 0.5 | |
| ) | |
| channel_mean_a =np.array([0.2999074, 0.34460328, 0.27964595]) | |
| channel_std_a = np.array([0.26192185, 0.27246983, 0.239288 ]) | |
| channel_mean_b =np.array( [0.40723574, 0.35794826, 0.27471713]) | |
| channel_std_b = np.array( [0.25011656, 0.23407436, 0.21894803]) | |
| train_transform_a, val_transform_a, de_normalize_a = get_transforms(channel_mean_a, channel_std_a, **hyperparams) | |
| train_transform_b, val_transform_b, de_normalize_b = get_transforms(channel_mean_b, channel_std_b, **hyperparams) | |
| # Обработка изображения | |
| def process_image(model : CycleGAN, image, transform): | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model.G_A(img_tensor) | |
| return output | |
| # Интерфейс | |
| st.title("CycleGAN for Fallout 3 to Fallout: New Vegas Task") | |
| upload_A = st.file_uploader("Upload Fallout 3 image", type=["jpg", "png"]) | |
| if upload_A is not None: | |
| image = Image.open(upload_A) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Original") | |
| with col2: | |
| result = process_image(model, image, val_transform_a) | |
| st.image(de_normalize_b(result[0]), caption="Fallout: New Vegas style", clamp=True) | |
| upload_B = st.file_uploader("Upload Fallout: New Vegas image", type=["jpg", "png"]) | |
| if upload_B is not None: | |
| image = Image.open(upload_B) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Original") | |
| with col2: | |
| result = process_image(model, image, val_transform_b) | |
| st.image(de_normalize_a(result[0]), caption="Fallout 3 style", clamp=True) |