import gradio as gr import torch from torchvision.utils import save_image from PIL import Image from huggingface_hub import hf_hub_download import os import torch.nn as nn # ----------------------------- # 1. DEFINICIÓN DEL MODELO VAE # ----------------------------- class VAE(nn.Module): def __init__(self, input_dim, h_dim=400, z_dim=40): super().__init__() # Encoder self.img_2hid = nn.Linear(input_dim, h_dim) self.hid_2mu = nn.Linear(h_dim, z_dim) self.hid_2sigma = nn.Linear(h_dim, z_dim) # Decoder self.z_2hid = nn.Linear(z_dim, h_dim) self.hid_2img = nn.Linear(h_dim, input_dim) # Aquí asegúrate de que input_dim sea 10000 self.relu = nn.ReLU() def encode(self, x): h = self.relu(self.img_2hid(x)) mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) return mu, sigma def decode(self, z): h = self.relu(self.z_2hid(z)) return torch.sigmoid(self.hid_2img(h)) def forward(self, x): mu, sigma = self.encode(x) epsilon = torch.randn_like(sigma) z_reparametrized = mu + sigma * epsilon x_reconstructed = self.decode(z_reparametrized) return x_reconstructed, mu, sigma # ----------------------------- # 2. CARGAR EL MODELO DESDE HUGGING FACE # ----------------------------- REPO_ID = "Bmo411/VAE" # <-- reemplaza con tu repo si cambia MODEL_FILENAME = "vae_complete_model (1).pth" # Descargar modelo automáticamente model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) # Inicializar arquitectura del modelo input_dim = 100 * 100 dummy_model = VAE(input_dim=input_dim, z_dim=40) # la arquitectura base es necesaria para cargar pesos # Permitir deserialización segura torch.serialization.add_safe_globals({"VAE": VAE}) # Cargar modelo completo (no solo pesos) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(model_path, map_location=device, weights_only=False) model.to(device) model.eval() z_dim = model.z_2hid.in_features # ----------------------------- # 3. FUNCIÓN PARA GENERAR IMAGEN # ----------------------------- def generate_image(): with torch.no_grad(): z = torch.randn(1, z_dim).to(device) out = model.decode(z) out = out.view(1, 1, 100, 100) output_path = "generated_sample.png" save_image(out, output_path) img = Image.open(output_path).convert("L") # Convertir a escala de grises return img # ----------------------------- # 4. INTERFAZ GRADIO # ----------------------------- iface = gr.Interface( fn=generate_image, inputs=[], outputs="image", title="Generador de Imagen con VAE", description=f"Genera una imagen aleatoria desde el VAE entrenado. Dimensión latente del modelo detectada: {z_dim}" ) iface.launch()