T3FiO's picture
Update app.py
465abbc verified
import streamlit as st
import torch
from PIL import Image
import torchvision.transforms as tr
import os
import torch.nn as nn
from torchvision.transforms import v2
from sklearn.preprocessing import minmax_scale
import itertools
import functools
device = torch.device('cpu')
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR
# Дискриминаторы обучаются быстро.. почему бы и не использовать?
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class DeepGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=6, use_dropout=False):
super(DeepGenerator, self).__init__()
# Энкодер
self.enc1 = nn.Sequential(
nn.Conv2d(input_nc, ngf, 4, 2, 1),
nn.LeakyReLU(0.2)
)
self.enc2 = nn.Sequential(
nn.Conv2d(ngf, ngf*2, 4, 2, 1),
nn.InstanceNorm2d(ngf*2),
nn.LeakyReLU(0.2)
)
self.enc3 = nn.Sequential(
nn.Conv2d(ngf*2, ngf*4, 4, 2, 1),
nn.InstanceNorm2d(ngf*4),
nn.LeakyReLU(0.2)
)
# Промежуточные residual-блоки
self.res_blocks = nn.Sequential(
*[ResidualBlock(ngf*4, use_dropout=use_dropout) for _ in range(n_blocks)]
)
# Декодер с skip-connections (U-Net like)
self.dec1 = nn.Sequential(
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1),
nn.InstanceNorm2d(ngf*2),
nn.ReLU(True)
)
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(ngf*2*2, ngf, 4, 2, 1), # x2 из-за skip-connection
nn.InstanceNorm2d(ngf),
nn.ReLU(True)
)
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(ngf*2, output_nc, 4, 2, 1), # x2 из-за skip-connection
nn.Tanh()
)
# Дополнительные слои
self.dropout = nn.Dropout(0.5) if use_dropout else None
def forward(self, x):
# Энкодер
e1 = self.enc1(x) # [batch, ngf, H/2, W/2]
e2 = self.enc2(e1) # [batch, ngf*2, H/4, W/4]
e3 = self.enc3(e2) # [batch, ngf*4, H/8, W/8]
# Residual-блоки
m = self.res_blocks(e3)
if self.dropout:
m = self.dropout(m)
# Декодер с skip-connections
d1 = self.dec1(m) # [batch, ngf*2, H/4, W/4]
d1 = torch.cat([d1, e2], dim=1) # Skip-connection
d2 = self.dec2(d1) # [batch, ngf, H/2, W/2]
d2 = torch.cat([d2, e1], dim=1) # Skip-connection
output = self.dec3(d2) # [batch, output_nc, H, W]
return output
class ResidualBlock(nn.Module):
def __init__(self, channels, use_dropout=False):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.InstanceNorm2d(channels),
nn.ReLU(True),
nn.Dropout(0.5) if use_dropout else nn.Identity(),
nn.Conv2d(channels, channels, 3, 1, 1),
nn.InstanceNorm2d(channels)
)
def forward(self, x):
return x + self.block(x) # Skip-connection
class CycleGAN(nn.Module):
def __init__(self):
super(CycleGAN, self).__init__()
# Волевое решение, что все изображения трёхканальные, культ RGB все дела
self.D_A = NLayerDiscriminator(input_nc=3, n_layers= 3)
self.D_B = NLayerDiscriminator(input_nc=3, n_layers= 3)
# self.G_A = ResnetGenerator(input_nc=3, output_nc=3)
# self.G_B = ResnetGenerator(input_nc=3, output_nc=3)
# self.G_A = SimpleGenerator(input_nc=3, output_nc=3)
# self.G_B = SimpleGenerator(input_nc=3, output_nc=3)
self.G_A = DeepGenerator(input_nc= 3, output_nc= 3)
self.G_B = DeepGenerator(input_nc= 3, output_nc= 3)
def create_model_and_optimizer(model_class, model_params, lr_d=1e-4, lr_g=1.e-3, weight_decay=1.e-5, device=device):
model = model_class(**model_params)
model = model.to(device)
optimizer_d = torch.optim.Adam(
itertools.chain(model.D_A.parameters(), model.D_B.parameters()),
lr = lr_d,
weight_decay = weight_decay,
)
optimizer_g = torch.optim.Adam(
itertools.chain(model.G_A.parameters(), model.G_B.parameters()),
lr = lr_g,
weight_decay=weight_decay,
)
return model, optimizer_d, optimizer_g
model_name = "cycle_gan_200e_fallout#1"
checkpoint = torch.load(os.path.join("./model", f"{model_name}.pt"), weights_only = False, map_location=torch.device('cpu'))
# Создаём те же классы, что и внутри чекпоинта
total_epochs = 200
half_epochs = total_epochs/2
def get_linear_scheduler(optimizer, start_epoch, n_epochs):
def lr_lambda(epoch):
return 1.0 - max(0, epoch - start_epoch) / (n_epochs - start_epoch)
return LambdaLR(optimizer, lr_lambda)
model_params = dict()
model, optimizer_d, optimizer_g = create_model_and_optimizer(
model_class = CycleGAN,
model_params = model_params,
lr_d = 1e-3,
lr_g= 1e-3,
device = device,
)
scheduler_d = get_linear_scheduler(optimizer_d, start_epoch=half_epochs, n_epochs=total_epochs)
scheduler_g = get_linear_scheduler(optimizer_g, start_epoch=half_epochs, n_epochs=total_epochs)
# Загружаем состояния из чекпоинта
def load_model():
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
load_model()
# Функция для получения train и val transform-ов, а так же функции для де-нормализации изображения
def get_transforms(mean, std, random_crop_size, p):
train_transform = tr.Compose([
tr.ToPILImage(),
tr.RandomApply([
tr.ColorJitter(0.2, 0.2, 0.2, 0.01) # Изменение яркости, контраста и т.д.
], p=0.8),
v2.RandomCrop(size=random_crop_size),
v2.RandomHorizontalFlip(p=p),
tr.ToTensor(),
tr.Normalize(mean=mean, std=std),
])
val_transform = tr.Compose([
# tr.ToPILImage(),
tr.Resize((random_crop_size, random_crop_size)),
tr.ToTensor(),
tr.Normalize(mean=mean, std=std),
])
def de_normalize(img, normalized = True):
with torch.no_grad():
if normalized:
img = img.cpu().permute(1,2,0)
return minmax_scale(
(img.reshape(3, -1) + mean[:, None]) * std[:, None],
feature_range=(0., 1.),
axis=1,
).reshape(*img.shape)
return img
# def denormalize(tensor):
# tensor = tensor.clone().squeeze()
# for t, m, s in zip(tensor, mean, std):
# t.mul_(s).add_(m) # Обратное преобразование: (t * std) + mean
# tensor = tensor.permute(1, 2, 0) * 255
# return tensor.cpu().numpy().astype(np.uint8)
# def denormalize(tensor):
# return tensor.cpu().numpy()
return train_transform, val_transform, de_normalize
hyperparams = dict(
random_crop_size = 256,
p = 0.5
)
channel_mean_a =np.array([0.2999074, 0.34460328, 0.27964595])
channel_std_a = np.array([0.26192185, 0.27246983, 0.239288 ])
channel_mean_b =np.array( [0.40723574, 0.35794826, 0.27471713])
channel_std_b = np.array( [0.25011656, 0.23407436, 0.21894803])
train_transform_a, val_transform_a, de_normalize_a = get_transforms(channel_mean_a, channel_std_a, **hyperparams)
train_transform_b, val_transform_b, de_normalize_b = get_transforms(channel_mean_b, channel_std_b, **hyperparams)
# Обработка изображения
def process_image(model : CycleGAN, image, transform):
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model.G_A(img_tensor)
return output
# Интерфейс
st.title("CycleGAN for Fallout 3 to Fallout: New Vegas Task")
upload_A = st.file_uploader("Upload Fallout 3 image", type=["jpg", "png"])
if upload_A is not None:
image = Image.open(upload_A)
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original")
with col2:
result = process_image(model, image, val_transform_a)
st.image(de_normalize_b(result[0]), caption="Fallout: New Vegas style", clamp=True)
upload_B = st.file_uploader("Upload Fallout: New Vegas image", type=["jpg", "png"])
if upload_B is not None:
image = Image.open(upload_B)
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original")
with col2:
result = process_image(model, image, val_transform_b)
st.image(de_normalize_a(result[0]), caption="Fallout 3 style", clamp=True)