Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| from types import SimpleNamespace | |
| import random | |
| from torchvision.utils import save_image | |
| import gradio as gr | |
| import numpy as np | |
| import io | |
| import tempfile # Importar tempfile | |
| import math | |
| # Aseg煤rate de que las funciones necesarias est茅n definidas (si no lo est谩n ya) | |
| def resize(img, size): | |
| return F.interpolate(img, size=size, mode='bilinear', align_corners=False) | |
| def denormalize(x): | |
| return (x + 1) / 2 # Valores en [0, 1] | |
| # Definici贸n de las clases de los modelos (Generator, StyleEncoder, MappingNetwork, ResBlk, AdaIN, AdainResBlk) | |
| class ResBlk(nn.Module): | |
| def __init__(self, dim_in, dim_out, normalize=False, downsample=False): | |
| super().__init__() | |
| self.normalize = normalize | |
| self.downsample = downsample | |
| self.main = nn.Sequential( | |
| nn.Conv2d(dim_in, dim_out, 3, 1, 1), | |
| nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(dim_out, dim_out, 3, 1, 1), | |
| nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity() | |
| ) | |
| self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity() | |
| self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) | |
| def forward(self, x): | |
| out = self.main(x) | |
| out = self.downsample_layer(out) | |
| skip = self.skip(x) | |
| skip = self.downsample_layer(skip) | |
| return (out + skip) / math.sqrt(2) | |
| class AdaIN(nn.Module): | |
| def __init__(self, num_features, style_dim): | |
| super(AdaIN, self).__init__() | |
| self.fc = nn.Linear(style_dim, num_features * 2) | |
| def forward(self, x, s): | |
| h = self.fc(s) | |
| gamma, beta = torch.chunk(h, chunks=2, dim=1) | |
| gamma = gamma.unsqueeze(2).unsqueeze(3) | |
| beta = beta.unsqueeze(2).unsqueeze(3) | |
| return (1 + gamma) * x + beta | |
| class AdainResBlk(nn.Module): | |
| def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False): | |
| super().__init__() | |
| self.upsample = upsample | |
| self.w_hpf = w_hpf | |
| self.norm1 = AdaIN(dim_in, style_dim) | |
| self.norm2 = AdaIN(dim_out, style_dim) | |
| self.actv = nn.LeakyReLU(0.2) | |
| self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) | |
| self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) | |
| if dim_in != dim_out: | |
| self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0) | |
| else: | |
| self.skip = nn.Identity() | |
| def forward(self, x, s): | |
| x_orig = x | |
| if self.upsample: | |
| x = F.interpolate(x, scale_factor=2, mode='nearest') | |
| x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest') | |
| h = self.norm1(x, s) | |
| h = self.actv(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h, s) | |
| h = self.actv(h) | |
| h = self.conv2(h) | |
| skip = self.skip(x_orig) | |
| out = (h + skip) / math.sqrt(2) | |
| return out | |
| class Generator(nn.Module): | |
| def __init__(self, img_size=256, style_dim=64, max_conv_dim=512): | |
| super().__init__() | |
| dim_in = 64 | |
| blocks = [] | |
| blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] | |
| repeat_num = int(np.log2(img_size)) - 4 | |
| for _ in range(repeat_num): | |
| dim_out = min(dim_in*2, max_conv_dim) | |
| blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)] | |
| dim_in = dim_out | |
| self.encode = nn.Sequential(*blocks) | |
| self.decode = nn.ModuleList() | |
| for _ in range(repeat_num): | |
| dim_out = dim_in // 2 | |
| self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)] | |
| dim_in = dim_out | |
| self.to_rgb = nn.Sequential( | |
| nn.InstanceNorm2d(dim_in, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(dim_in, 3, 1, 1, 0) | |
| ) | |
| def forward(self, x, s): | |
| x = self.encode(x) | |
| for block in self.decode: | |
| x = block(x, s) | |
| out = self.to_rgb(x) | |
| return out | |
| class MappingNetwork(nn.Module): | |
| def __init__(self, latent_dim=16, style_dim=64, num_domains=2, hidden_dim=512): | |
| super(MappingNetwork, self).__init__() | |
| layers = [ | |
| nn.Linear(latent_dim, hidden_dim), | |
| nn.ReLU() | |
| ] | |
| for _ in range(3): | |
| layers += [ | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU() | |
| ] | |
| self.shared = nn.Sequential(*layers) | |
| self.unshared = nn.ModuleList() | |
| for _ in range(num_domains): | |
| self.unshared.append(nn.Linear(hidden_dim, style_dim)) | |
| def forward(self, z, y): | |
| h = self.shared(z) | |
| out = [] | |
| for layer in self.unshared: | |
| out.append(layer(h)) | |
| out = torch.stack(out, dim=1) | |
| idx = torch.arange(y.size(0)).to(y.device) | |
| s = out[idx, y] | |
| return s | |
| class StyleEncoder(nn.Module): | |
| def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512): | |
| super().__init__() | |
| dim_in = 64 | |
| blocks = [] | |
| blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] | |
| repeat_num = int(np.log2(img_size)) - 2 | |
| for _ in range(repeat_num): | |
| dim_out = min(dim_in*2, max_conv_dim) | |
| blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)] | |
| dim_in = dim_out | |
| blocks += [nn.LeakyReLU(0.2)] | |
| self.shared = nn.Sequential(*blocks) | |
| self.unshared = nn.ModuleList() | |
| for _ in range(num_domains): | |
| self.unshared += [nn.Linear(dim_in, style_dim)] | |
| def forward(self, x, y): | |
| h = self.shared(x) | |
| h = F.adaptive_avg_pool2d(h, (1,1)) | |
| h = h.view(h.size(0), -1) | |
| out = [] | |
| for layer in self.unshared: | |
| out += [layer(h)] | |
| out = torch.stack(out, dim=1) | |
| idx = torch.arange(y.size(0)).to(y.device) | |
| s = out[idx, y] | |
| return s | |
| # Clase para cargar imagenes | |
| class ImageFolder(Dataset): | |
| def __init__(self, root, transform, mode, which='source'): | |
| self.transform = transform | |
| self.paths = [] | |
| domains = sorted(os.listdir(root)) | |
| for domain in domains: | |
| if os.path.isdir(os.path.join(root, domain)): | |
| files = os.listdir(os.path.join(root, domain)) | |
| files = [os.path.join(root, domain, f) for f in files] | |
| self.paths += [(f, domains.index(domain)) for f in files] | |
| if mode == 'train' and which == 'reference': | |
| random.shuffle(self.paths) | |
| def __getitem__(self, index): | |
| path, label = self.paths[index] | |
| img = Image.open(path).convert('RGB') | |
| return self.transform(img), label | |
| def __len__(self): | |
| return len(self.paths) | |
| # Funciones para obtener los data loaders | |
| def get_transform(img_size, mode='train', prob=0.5): | |
| transform = [] | |
| transform.append(transforms.Resize((img_size, img_size))) | |
| if mode == 'train': | |
| transform.append(transforms.RandomHorizontalFlip()) | |
| transform.append(transforms.RandomApply([ | |
| transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)) | |
| ], p=prob)) | |
| transform.append(transforms.ToTensor()) | |
| transform.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5])) | |
| return transforms.Compose(transform) | |
| def get_train_loader(root, which='source', img_size=256, batch_size=8, prob=0.5, num_workers=4): | |
| transform = transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.RandomHorizontalFlip(p=prob), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| dataset = ImageFolder(root=root, transform=transform, mode=which) | |
| loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) | |
| return loader | |
| def get_test_loader(root, img_size=256, batch_size=8, shuffle=False, num_workers=4, mode='reference'): | |
| transform = transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| dataset = ImageFolder(root=root, transform=transform, mode=mode) | |
| loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False) | |
| return loader | |
| # Clase Solver (adaptada para la inferencia) | |
| class Solver(object): | |
| def __init__(self, args): | |
| self.args = args | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Definir los modelos | |
| self.G = Generator(args.img_size, args.style_dim).to(self.device) | |
| self.M = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains).to(self.device) | |
| self.S = StyleEncoder(args.img_size, args.style_dim, args.num_domains).to(self.device) | |
| def load_checkpoint(self, checkpoint_path): | |
| try: | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.G.load_state_dict(checkpoint['generator']) | |
| self.M.load_state_dict(checkpoint['mapping_network']) | |
| self.S.load_state_dict(checkpoint['style_encoder']) | |
| print(f"Checkpoint cargado exitosamente desde {checkpoint_path}.") | |
| except FileNotFoundError: | |
| print(f"Error: No se encontr贸 el checkpoint en {checkpoint_path}.") | |
| raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}") | |
| except Exception as e: | |
| print(f"Error al cargar el checkpoint: {e}.") | |
| raise Exception(f"Error al cargar el checkpoint: {e}") | |
| def transfer_style(self, source_image, reference_image): | |
| # Aseg煤rate de que los modelos est茅n en modo de evaluaci贸n | |
| self.G.eval() | |
| self.S.eval() | |
| with torch.no_grad(): | |
| # Preprocesar las im谩genes de entrada | |
| transform = transforms.Compose([ | |
| transforms.Resize((self.args.img_size, self.args.img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # Convertir a PIL image antes de la transformaci贸n | |
| source_image_pil = Image.fromarray(source_image) | |
| reference_image_pil = Image.fromarray(reference_image) | |
| source_image = transform(source_image_pil).unsqueeze(0).to(self.device) | |
| reference_image = transform(reference_image_pil).unsqueeze(0).to(self.device) | |
| # Codificar el estilo de la imagen de referencia | |
| s_ref = self.S(reference_image, torch.tensor([0]).to(self.device)) | |
| # Generar la imagen con el estilo transferido | |
| generated_image = self.G(source_image, s_ref) | |
| # Denormalizar la imagen para mostrarla en la interfaz | |
| generated_image = denormalize(generated_image.squeeze(0)).cpu() | |
| return (generated_image * 255).clamp(0, 255).byte().permute(1, 2, 0).numpy() # Convertir a NumPy y a rango v谩lido | |
| # Funci贸n principal para la inferencia | |
| def main(source_image, reference_image, checkpoint_path, args): | |
| if source_image is None or reference_image is None: | |
| raise gr.Error("Por favor, proporciona ambas im谩genes (fuente y referencia).") | |
| # Crear el solver | |
| solver = Solver(args) | |
| # Cargar el checkpoint | |
| solver.load_checkpoint(checkpoint_path) | |
| # Realizar la transferencia de estilo | |
| generated_image = solver.transfer_style(source_image, reference_image) | |
| return generated_image | |
| def gradio_interface(): | |
| # Definir los argumentos (ajustados para la inferencia) | |
| args = SimpleNamespace( | |
| img_size=128, | |
| num_domains=3, | |
| latent_dim=16, | |
| style_dim=64, | |
| num_workers=0, | |
| seed=8365, | |
| ) | |
| # Ruta al checkpoint | |
| checkpoint_path = "iter/27000_nets_ema.ckpt" | |
| # Crear la interfaz de Gradio | |
| inputs = [ | |
| gr.Image(label="Source Image (Car to change style)"), | |
| gr.Image(label="Reference Image (Style to transfer)"), | |
| ] | |
| outputs = gr.Image(label="Generated Image (Car with transferred style)") | |
| title = "AutoStyleGAN: Car Style Transfer" | |
| description = "Transfer the style of one car to another. Upload a source car image and a reference car image." | |
| iface = gr.Interface( | |
| fn=lambda source_image, reference_image: main(source_image, reference_image, checkpoint_path, args), | |
| inputs=inputs, | |
| outputs=outputs, | |
| title=title, | |
| description=description, | |
| ) | |
| return iface | |
| if __name__ == '__main__': | |
| iface = gradio_interface() | |
| iface.launch(share=True) | |