import streamlit as st import torch from PIL import Image import torchvision.transforms as tr import os import torch.nn as nn from torchvision.transforms import v2 from sklearn.preprocessing import minmax_scale import itertools import functools device = torch.device('cpu') import numpy as np from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import StepLR # Дискриминаторы обучаются быстро.. почему бы и не использовать? class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) class DeepGenerator(nn.Module): def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=6, use_dropout=False): super(DeepGenerator, self).__init__() # Энкодер self.enc1 = nn.Sequential( nn.Conv2d(input_nc, ngf, 4, 2, 1), nn.LeakyReLU(0.2) ) self.enc2 = nn.Sequential( nn.Conv2d(ngf, ngf*2, 4, 2, 1), nn.InstanceNorm2d(ngf*2), nn.LeakyReLU(0.2) ) self.enc3 = nn.Sequential( nn.Conv2d(ngf*2, ngf*4, 4, 2, 1), nn.InstanceNorm2d(ngf*4), nn.LeakyReLU(0.2) ) # Промежуточные residual-блоки self.res_blocks = nn.Sequential( *[ResidualBlock(ngf*4, use_dropout=use_dropout) for _ in range(n_blocks)] ) # Декодер с skip-connections (U-Net like) self.dec1 = nn.Sequential( nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1), nn.InstanceNorm2d(ngf*2), nn.ReLU(True) ) self.dec2 = nn.Sequential( nn.ConvTranspose2d(ngf*2*2, ngf, 4, 2, 1), # x2 из-за skip-connection nn.InstanceNorm2d(ngf), nn.ReLU(True) ) self.dec3 = nn.Sequential( nn.ConvTranspose2d(ngf*2, output_nc, 4, 2, 1), # x2 из-за skip-connection nn.Tanh() ) # Дополнительные слои self.dropout = nn.Dropout(0.5) if use_dropout else None def forward(self, x): # Энкодер e1 = self.enc1(x) # [batch, ngf, H/2, W/2] e2 = self.enc2(e1) # [batch, ngf*2, H/4, W/4] e3 = self.enc3(e2) # [batch, ngf*4, H/8, W/8] # Residual-блоки m = self.res_blocks(e3) if self.dropout: m = self.dropout(m) # Декодер с skip-connections d1 = self.dec1(m) # [batch, ngf*2, H/4, W/4] d1 = torch.cat([d1, e2], dim=1) # Skip-connection d2 = self.dec2(d1) # [batch, ngf, H/2, W/2] d2 = torch.cat([d2, e1], dim=1) # Skip-connection output = self.dec3(d2) # [batch, output_nc, H, W] return output class ResidualBlock(nn.Module): def __init__(self, channels, use_dropout=False): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.Conv2d(channels, channels, 3, 1, 1), nn.InstanceNorm2d(channels), nn.ReLU(True), nn.Dropout(0.5) if use_dropout else nn.Identity(), nn.Conv2d(channels, channels, 3, 1, 1), nn.InstanceNorm2d(channels) ) def forward(self, x): return x + self.block(x) # Skip-connection class CycleGAN(nn.Module): def __init__(self): super(CycleGAN, self).__init__() # Волевое решение, что все изображения трёхканальные, культ RGB все дела self.D_A = NLayerDiscriminator(input_nc=3, n_layers= 3) self.D_B = NLayerDiscriminator(input_nc=3, n_layers= 3) # self.G_A = ResnetGenerator(input_nc=3, output_nc=3) # self.G_B = ResnetGenerator(input_nc=3, output_nc=3) # self.G_A = SimpleGenerator(input_nc=3, output_nc=3) # self.G_B = SimpleGenerator(input_nc=3, output_nc=3) self.G_A = DeepGenerator(input_nc= 3, output_nc= 3) self.G_B = DeepGenerator(input_nc= 3, output_nc= 3) def create_model_and_optimizer(model_class, model_params, lr_d=1e-4, lr_g=1.e-3, weight_decay=1.e-5, device=device): model = model_class(**model_params) model = model.to(device) optimizer_d = torch.optim.Adam( itertools.chain(model.D_A.parameters(), model.D_B.parameters()), lr = lr_d, weight_decay = weight_decay, ) optimizer_g = torch.optim.Adam( itertools.chain(model.G_A.parameters(), model.G_B.parameters()), lr = lr_g, weight_decay=weight_decay, ) return model, optimizer_d, optimizer_g model_name = "cycle_gan_200e_fallout#1" checkpoint = torch.load(os.path.join("./model", f"{model_name}.pt"), weights_only = False, map_location=torch.device('cpu')) # Создаём те же классы, что и внутри чекпоинта total_epochs = 200 half_epochs = total_epochs/2 def get_linear_scheduler(optimizer, start_epoch, n_epochs): def lr_lambda(epoch): return 1.0 - max(0, epoch - start_epoch) / (n_epochs - start_epoch) return LambdaLR(optimizer, lr_lambda) model_params = dict() model, optimizer_d, optimizer_g = create_model_and_optimizer( model_class = CycleGAN, model_params = model_params, lr_d = 1e-3, lr_g= 1e-3, device = device, ) scheduler_d = get_linear_scheduler(optimizer_d, start_epoch=half_epochs, n_epochs=total_epochs) scheduler_g = get_linear_scheduler(optimizer_g, start_epoch=half_epochs, n_epochs=total_epochs) # Загружаем состояния из чекпоинта def load_model(): model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model load_model() # Функция для получения train и val transform-ов, а так же функции для де-нормализации изображения def get_transforms(mean, std, random_crop_size, p): train_transform = tr.Compose([ tr.ToPILImage(), tr.RandomApply([ tr.ColorJitter(0.2, 0.2, 0.2, 0.01) # Изменение яркости, контраста и т.д. ], p=0.8), v2.RandomCrop(size=random_crop_size), v2.RandomHorizontalFlip(p=p), tr.ToTensor(), tr.Normalize(mean=mean, std=std), ]) val_transform = tr.Compose([ # tr.ToPILImage(), tr.Resize((random_crop_size, random_crop_size)), tr.ToTensor(), tr.Normalize(mean=mean, std=std), ]) def de_normalize(img, normalized = True): with torch.no_grad(): if normalized: img = img.cpu().permute(1,2,0) return minmax_scale( (img.reshape(3, -1) + mean[:, None]) * std[:, None], feature_range=(0., 1.), axis=1, ).reshape(*img.shape) return img # def denormalize(tensor): # tensor = tensor.clone().squeeze() # for t, m, s in zip(tensor, mean, std): # t.mul_(s).add_(m) # Обратное преобразование: (t * std) + mean # tensor = tensor.permute(1, 2, 0) * 255 # return tensor.cpu().numpy().astype(np.uint8) # def denormalize(tensor): # return tensor.cpu().numpy() return train_transform, val_transform, de_normalize hyperparams = dict( random_crop_size = 256, p = 0.5 ) channel_mean_a =np.array([0.2999074, 0.34460328, 0.27964595]) channel_std_a = np.array([0.26192185, 0.27246983, 0.239288 ]) channel_mean_b =np.array( [0.40723574, 0.35794826, 0.27471713]) channel_std_b = np.array( [0.25011656, 0.23407436, 0.21894803]) train_transform_a, val_transform_a, de_normalize_a = get_transforms(channel_mean_a, channel_std_a, **hyperparams) train_transform_b, val_transform_b, de_normalize_b = get_transforms(channel_mean_b, channel_std_b, **hyperparams) # Обработка изображения def process_image(model : CycleGAN, image, transform): img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model.G_A(img_tensor) return output # Интерфейс st.title("CycleGAN for Fallout 3 to Fallout: New Vegas Task") upload_A = st.file_uploader("Upload Fallout 3 image", type=["jpg", "png"]) if upload_A is not None: image = Image.open(upload_A) col1, col2 = st.columns(2) with col1: st.image(image, caption="Original") with col2: result = process_image(model, image, val_transform_a) st.image(de_normalize_b(result[0]), caption="Fallout: New Vegas style", clamp=True) upload_B = st.file_uploader("Upload Fallout: New Vegas image", type=["jpg", "png"]) if upload_B is not None: image = Image.open(upload_B) col1, col2 = st.columns(2) with col1: st.image(image, caption="Original") with col2: result = process_image(model, image, val_transform_b) st.image(de_normalize_a(result[0]), caption="Fallout 3 style", clamp=True)