|
|
import gradio as gr |
|
|
import torch |
|
|
from torchvision.utils import save_image |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAE(nn.Module): |
|
|
def __init__(self, input_dim, h_dim=400, z_dim=40): |
|
|
super().__init__() |
|
|
|
|
|
self.img_2hid = nn.Linear(input_dim, h_dim) |
|
|
self.hid_2mu = nn.Linear(h_dim, z_dim) |
|
|
self.hid_2sigma = nn.Linear(h_dim, z_dim) |
|
|
|
|
|
|
|
|
self.z_2hid = nn.Linear(z_dim, h_dim) |
|
|
self.hid_2img = nn.Linear(h_dim, input_dim) |
|
|
|
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def encode(self, x): |
|
|
h = self.relu(self.img_2hid(x)) |
|
|
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) |
|
|
return mu, sigma |
|
|
|
|
|
def decode(self, z): |
|
|
h = self.relu(self.z_2hid(z)) |
|
|
return torch.sigmoid(self.hid_2img(h)) |
|
|
|
|
|
def forward(self, x): |
|
|
mu, sigma = self.encode(x) |
|
|
epsilon = torch.randn_like(sigma) |
|
|
z_reparametrized = mu + sigma * epsilon |
|
|
x_reconstructed = self.decode(z_reparametrized) |
|
|
return x_reconstructed, mu, sigma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "Bmo411/VAE" |
|
|
MODEL_FILENAME = "vae_complete_model (1).pth" |
|
|
|
|
|
|
|
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) |
|
|
|
|
|
|
|
|
input_dim = 100 * 100 |
|
|
dummy_model = VAE(input_dim=input_dim, z_dim=40) |
|
|
|
|
|
|
|
|
torch.serialization.add_safe_globals({"VAE": VAE}) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = torch.load(model_path, map_location=device, weights_only=False) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
z_dim = model.z_2hid.in_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image(): |
|
|
with torch.no_grad(): |
|
|
z = torch.randn(1, z_dim).to(device) |
|
|
out = model.decode(z) |
|
|
out = out.view(1, 1, 100, 100) |
|
|
|
|
|
output_path = "generated_sample.png" |
|
|
save_image(out, output_path) |
|
|
|
|
|
img = Image.open(output_path).convert("L") |
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=generate_image, |
|
|
inputs=[], |
|
|
outputs="image", |
|
|
title="Generador de Imagen con VAE", |
|
|
description=f"Genera una imagen aleatoria desde el VAE entrenado. Dimensión latente del modelo detectada: {z_dim}" |
|
|
) |
|
|
|
|
|
iface.launch() |
|
|
|