File size: 2,858 Bytes
5edf65d
 
 
 
 
 
aa836e9
5edf65d
be0ab39
 
 
1d48959
15444ce
be0ab39
 
 
 
 
 
 
 
15444ce
be0ab39
 
 
 
 
a74e15b
be0ab39
 
 
 
 
 
 
 
 
a74e15b
 
 
5edf65d
 
a74e15b
5edf65d
a74e15b
0cbbd57
5edf65d
a74e15b
5edf65d
 
a74e15b
 
249c35f
be0ab39
a74e15b
d49caa3
be0ab39
a74e15b
 
d49caa3
5edf65d
 
7603cbb
5edf65d
 
a74e15b
5edf65d
a74e15b
5edf65d
 
a74e15b
0cbbd57
5edf65d
 
 
 
a74e15b
5edf65d
 
 
 
 
 
 
a74e15b
5edf65d
 
a74e15b
5edf65d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import torch
from torchvision.utils import save_image
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import torch.nn as nn

# -----------------------------
# 1. DEFINICIÓN DEL MODELO VAE
# -----------------------------
class VAE(nn.Module):
    def __init__(self, input_dim, h_dim=400, z_dim=40):
        super().__init__()
        # Encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # Decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)  # Aquí asegúrate de que input_dim sea 10000

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma * epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma

# -----------------------------
# 2. CARGAR EL MODELO DESDE HUGGING FACE
# -----------------------------
REPO_ID = "Bmo411/VAE"  # <-- reemplaza con tu repo si cambia
MODEL_FILENAME = "vae_complete_model (1).pth"

# Descargar modelo automáticamente
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)

# Inicializar arquitectura del modelo
input_dim = 100 * 100
dummy_model = VAE(input_dim=input_dim, z_dim=40)  # la arquitectura base es necesaria para cargar pesos

# Permitir deserialización segura
torch.serialization.add_safe_globals({"VAE": VAE})

# Cargar modelo completo (no solo pesos)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device, weights_only=False)
model.to(device)
model.eval()
z_dim = model.z_2hid.in_features

# -----------------------------
# 3. FUNCIÓN PARA GENERAR IMAGEN
# -----------------------------
def generate_image():
    with torch.no_grad():
        z = torch.randn(1, z_dim).to(device)
        out = model.decode(z)
        out = out.view(1, 1, 100, 100)

        output_path = "generated_sample.png"
        save_image(out, output_path)

        img = Image.open(output_path).convert("L")  # Convertir a escala de grises
        return img

# -----------------------------
# 4. INTERFAZ GRADIO
# -----------------------------
iface = gr.Interface(
    fn=generate_image,
    inputs=[],
    outputs="image",
    title="Generador de Imagen con VAE",
    description=f"Genera una imagen aleatoria desde el VAE entrenado. Dimensión latente del modelo detectada: {z_dim}"
)

iface.launch()