VAE / app.py
Bmo411's picture
Update app.py
0cbbd57 verified
import gradio as gr
import torch
from torchvision.utils import save_image
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import torch.nn as nn
# -----------------------------
# 1. DEFINICIÓN DEL MODELO VAE
# -----------------------------
class VAE(nn.Module):
def __init__(self, input_dim, h_dim=400, z_dim=40):
super().__init__()
# Encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
# Decoder
self.z_2hid = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, input_dim) # Aquí asegúrate de que input_dim sea 10000
self.relu = nn.ReLU()
def encode(self, x):
h = self.relu(self.img_2hid(x))
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
h = self.relu(self.z_2hid(z))
return torch.sigmoid(self.hid_2img(h))
def forward(self, x):
mu, sigma = self.encode(x)
epsilon = torch.randn_like(sigma)
z_reparametrized = mu + sigma * epsilon
x_reconstructed = self.decode(z_reparametrized)
return x_reconstructed, mu, sigma
# -----------------------------
# 2. CARGAR EL MODELO DESDE HUGGING FACE
# -----------------------------
REPO_ID = "Bmo411/VAE" # <-- reemplaza con tu repo si cambia
MODEL_FILENAME = "vae_complete_model (1).pth"
# Descargar modelo automáticamente
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
# Inicializar arquitectura del modelo
input_dim = 100 * 100
dummy_model = VAE(input_dim=input_dim, z_dim=40) # la arquitectura base es necesaria para cargar pesos
# Permitir deserialización segura
torch.serialization.add_safe_globals({"VAE": VAE})
# Cargar modelo completo (no solo pesos)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device, weights_only=False)
model.to(device)
model.eval()
z_dim = model.z_2hid.in_features
# -----------------------------
# 3. FUNCIÓN PARA GENERAR IMAGEN
# -----------------------------
def generate_image():
with torch.no_grad():
z = torch.randn(1, z_dim).to(device)
out = model.decode(z)
out = out.view(1, 1, 100, 100)
output_path = "generated_sample.png"
save_image(out, output_path)
img = Image.open(output_path).convert("L") # Convertir a escala de grises
return img
# -----------------------------
# 4. INTERFAZ GRADIO
# -----------------------------
iface = gr.Interface(
fn=generate_image,
inputs=[],
outputs="image",
title="Generador de Imagen con VAE",
description=f"Genera una imagen aleatoria desde el VAE entrenado. Dimensión latente del modelo detectada: {z_dim}"
)
iface.launch()