Bmo411 commited on
Commit
a74e15b
·
verified ·
1 Parent(s): be0ab39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -10,9 +10,8 @@ import torch.nn as nn
10
  # 1. DEFINICIÓN DEL MODELO VAE
11
  # -----------------------------
12
  class VAE(nn.Module):
13
- def __init__(self, input_dim, h_dim=400, z_dim=40):
14
  super().__init__()
15
- self.z_dim = z_dim
16
  # Encoder
17
  self.img_2hid = nn.Linear(input_dim, h_dim)
18
  self.hid_2mu = nn.Linear(h_dim, z_dim)
@@ -26,8 +25,7 @@ class VAE(nn.Module):
26
 
27
  def encode(self, x):
28
  h = self.relu(self.img_2hid(x))
29
- mu = self.hid_2mu(h)
30
- sigma = self.hid_2sigma(h)
31
  return mu, sigma
32
 
33
  def decode(self, z):
@@ -37,47 +35,48 @@ class VAE(nn.Module):
37
  def forward(self, x):
38
  mu, sigma = self.encode(x)
39
  epsilon = torch.randn_like(sigma)
40
- z_reparam = mu + sigma * epsilon
41
- x_recon = self.decode(z_reparam)
42
- return x_recon, mu, sigma
43
 
44
  # -----------------------------
45
- # 2. CARGAR MODELO DESDE HUGGING FACE
46
  # -----------------------------
47
- REPO_ID = "Bmo411/VAE" # tu repo
48
  MODEL_FILENAME = "vae_complete_model.pth"
49
 
50
- # Descargar el modelo
51
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
52
 
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
54
 
55
- # Registrar la clase si se guardó como modelo completo
56
  torch.serialization.add_safe_globals({"VAE": VAE})
57
 
58
- # Cargar el modelo completo
 
59
  model = torch.load(model_path, map_location=device, weights_only=False)
60
  model.to(device)
61
  model.eval()
62
 
 
 
 
63
  # -----------------------------
64
- # 3. GENERAR IMAGEN ALEATORIA
65
  # -----------------------------
66
- def generate_image(z_dim=40):
67
  with torch.no_grad():
68
- # Muestra del espacio latente
69
  z = torch.randn(1, z_dim).to(device)
70
- out = model.decode(z) # tamaño: (1, 10000)
71
-
72
- # Convertir a forma imagen (1, 1, 100, 100)
73
  out = out.view(1, 1, 100, 100)
74
 
75
- # Guardar imagen temporal
76
  output_path = "generated_sample.png"
77
  save_image(out, output_path)
78
 
79
- # Leer imagen para mostrar en Gradio
80
- img = Image.open(output_path).convert("L")
81
  return img
82
 
83
  # -----------------------------
@@ -85,10 +84,10 @@ def generate_image(z_dim=40):
85
  # -----------------------------
86
  iface = gr.Interface(
87
  fn=generate_image,
88
- inputs=gr.Slider(10, 100, value=40, step=1, label="Dimensión latente (z_dim)"),
89
  outputs="image",
90
  title="Generador de Imagen con VAE",
91
- description="Genera una imagen aleatoria a partir del espacio latente del VAE entrenado."
92
  )
93
 
94
  iface.launch()
 
10
  # 1. DEFINICIÓN DEL MODELO VAE
11
  # -----------------------------
12
  class VAE(nn.Module):
13
+ def __init__(self, input_dim, h_dim=400, z_dim=20): # NOTA: z_dim por defecto en 20
14
  super().__init__()
 
15
  # Encoder
16
  self.img_2hid = nn.Linear(input_dim, h_dim)
17
  self.hid_2mu = nn.Linear(h_dim, z_dim)
 
25
 
26
  def encode(self, x):
27
  h = self.relu(self.img_2hid(x))
28
+ mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
 
29
  return mu, sigma
30
 
31
  def decode(self, z):
 
35
  def forward(self, x):
36
  mu, sigma = self.encode(x)
37
  epsilon = torch.randn_like(sigma)
38
+ z_reparametrized = mu + sigma * epsilon
39
+ x_reconstructed = self.decode(z_reparametrized)
40
+ return x_reconstructed, mu, sigma
41
 
42
  # -----------------------------
43
+ # 2. CARGAR EL MODELO DESDE HUGGING FACE
44
  # -----------------------------
45
+ REPO_ID = "Bmo411/VAE" # <-- reemplaza con tu repo si cambia
46
  MODEL_FILENAME = "vae_complete_model.pth"
47
 
48
+ # Descargar modelo automáticamente
49
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
50
 
51
+ # Inicializar arquitectura del modelo
52
+ input_dim = 100 * 100
53
+ dummy_model = VAE(input_dim=input_dim, z_dim=20) # la arquitectura base es necesaria para cargar pesos
54
 
55
+ # Permitir deserialización segura
56
  torch.serialization.add_safe_globals({"VAE": VAE})
57
 
58
+ # Cargar modelo completo (no solo pesos)
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  model = torch.load(model_path, map_location=device, weights_only=False)
61
  model.to(device)
62
  model.eval()
63
 
64
+ # Detectar z_dim automáticamente desde el decoder
65
+ z_dim = model.z_2hid.in_features
66
+
67
  # -----------------------------
68
+ # 3. FUNCIÓN PARA GENERAR IMAGEN
69
  # -----------------------------
70
+ def generate_image():
71
  with torch.no_grad():
 
72
  z = torch.randn(1, z_dim).to(device)
73
+ out = model.decode(z)
 
 
74
  out = out.view(1, 1, 100, 100)
75
 
 
76
  output_path = "generated_sample.png"
77
  save_image(out, output_path)
78
 
79
+ img = Image.open(output_path).convert("L") # Convertir a escala de grises
 
80
  return img
81
 
82
  # -----------------------------
 
84
  # -----------------------------
85
  iface = gr.Interface(
86
  fn=generate_image,
87
+ inputs=[],
88
  outputs="image",
89
  title="Generador de Imagen con VAE",
90
+ description=f"Genera una imagen aleatoria desde el VAE entrenado. Dimensión latente del modelo detectada: {z_dim}"
91
  )
92
 
93
  iface.launch()