File size: 2,858 Bytes
5edf65d aa836e9 5edf65d be0ab39 1d48959 15444ce be0ab39 15444ce be0ab39 a74e15b be0ab39 a74e15b 5edf65d a74e15b 5edf65d a74e15b 0cbbd57 5edf65d a74e15b 5edf65d a74e15b 249c35f be0ab39 a74e15b d49caa3 be0ab39 a74e15b d49caa3 5edf65d 7603cbb 5edf65d a74e15b 5edf65d a74e15b 5edf65d a74e15b 0cbbd57 5edf65d a74e15b 5edf65d a74e15b 5edf65d a74e15b 5edf65d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | import gradio as gr
import torch
from torchvision.utils import save_image
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import torch.nn as nn
# -----------------------------
# 1. DEFINICIÓN DEL MODELO VAE
# -----------------------------
class VAE(nn.Module):
def __init__(self, input_dim, h_dim=400, z_dim=40):
super().__init__()
# Encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
# Decoder
self.z_2hid = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, input_dim) # Aquí asegúrate de que input_dim sea 10000
self.relu = nn.ReLU()
def encode(self, x):
h = self.relu(self.img_2hid(x))
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
h = self.relu(self.z_2hid(z))
return torch.sigmoid(self.hid_2img(h))
def forward(self, x):
mu, sigma = self.encode(x)
epsilon = torch.randn_like(sigma)
z_reparametrized = mu + sigma * epsilon
x_reconstructed = self.decode(z_reparametrized)
return x_reconstructed, mu, sigma
# -----------------------------
# 2. CARGAR EL MODELO DESDE HUGGING FACE
# -----------------------------
REPO_ID = "Bmo411/VAE" # <-- reemplaza con tu repo si cambia
MODEL_FILENAME = "vae_complete_model (1).pth"
# Descargar modelo automáticamente
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
# Inicializar arquitectura del modelo
input_dim = 100 * 100
dummy_model = VAE(input_dim=input_dim, z_dim=40) # la arquitectura base es necesaria para cargar pesos
# Permitir deserialización segura
torch.serialization.add_safe_globals({"VAE": VAE})
# Cargar modelo completo (no solo pesos)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device, weights_only=False)
model.to(device)
model.eval()
z_dim = model.z_2hid.in_features
# -----------------------------
# 3. FUNCIÓN PARA GENERAR IMAGEN
# -----------------------------
def generate_image():
with torch.no_grad():
z = torch.randn(1, z_dim).to(device)
out = model.decode(z)
out = out.view(1, 1, 100, 100)
output_path = "generated_sample.png"
save_image(out, output_path)
img = Image.open(output_path).convert("L") # Convertir a escala de grises
return img
# -----------------------------
# 4. INTERFAZ GRADIO
# -----------------------------
iface = gr.Interface(
fn=generate_image,
inputs=[],
outputs="image",
title="Generador de Imagen con VAE",
description=f"Genera una imagen aleatoria desde el VAE entrenado. Dimensión latente del modelo detectada: {z_dim}"
)
iface.launch()
|