Bmo411 commited on
Commit
be0ab39
·
verified ·
1 Parent(s): 782e369

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -39
app.py CHANGED
@@ -6,69 +6,78 @@ from huggingface_hub import hf_hub_download
6
  import os
7
  import torch.nn as nn
8
 
 
 
 
9
  class VAE(nn.Module):
10
- def __init__(self, input_dim, h_dim=400, z_dim=40):
11
- super().__init__()
12
- #encoder
13
- self.img_2hid = nn.Linear(input_dim, h_dim)
14
- self.hid_2mu = nn.Linear(h_dim, z_dim)
15
- self.hid_2sigma = nn.Linear(h_dim, z_dim)
16
-
17
- #decoder
18
- self.z_2hid = nn.Linear(z_dim, h_dim)
19
- self.hid_2img = nn.Linear(h_dim, input_dim)
20
-
21
- self.relu = nn.ReLU()
22
- #self.sigmoid = nn.sigmoid()
23
-
24
- def encode(self, x):
25
- h = self.relu(self.img_2hid(x))
26
- mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
27
- return mu, sigma
28
-
29
- def decode(self, z):
30
- h = self.relu(self.z_2hid(z))
31
- return torch.sigmoid(self.hid_2img(h))
32
-
33
- def forward(self, x):
34
- mu, sigma = self.encode(x)
35
- epsilon = torch.randn_like(sigma)
36
- z_reparametrized = mu + sigma * epsilon
37
- x_reconstructed = self.decode(z_reparametrized)
38
- return x_reconstructed, mu, sigma
 
39
 
40
  # -----------------------------
41
- # 1. CARGAR MODELO DESDE HUGGING FACE
42
  # -----------------------------
43
- # Reemplaza estos datos con los tuyos
44
- REPO_ID = "Bmo411/VAE" # <-- cámbialo por el tuyo
45
  MODEL_FILENAME = "vae_complete_model.pth"
46
 
47
- # Descargar modelo automáticamente
48
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
49
 
50
- # Inicializar modelo y cargar pesos
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
- model = VAE(100*100)
 
53
  torch.serialization.add_safe_globals({"VAE": VAE})
 
 
54
  model = torch.load(model_path, map_location=device, weights_only=False)
55
  model.to(device)
56
  model.eval()
57
 
58
  # -----------------------------
59
- # 3. FUNCIÓN PARA GENERAR IMAGEN
60
  # -----------------------------
61
  def generate_image(z_dim=40):
62
  with torch.no_grad():
 
63
  z = torch.randn(1, z_dim).to(device)
64
- out = model.decode(z)
65
- out = torch.sigmoid(out)
 
66
  out = out.view(1, 1, 100, 100)
67
 
 
68
  output_path = "generated_sample.png"
69
  save_image(out, output_path)
70
 
71
- img = Image.open(output_path)
 
72
  return img
73
 
74
  # -----------------------------
 
6
  import os
7
  import torch.nn as nn
8
 
9
+ # -----------------------------
10
+ # 1. DEFINICIÓN DEL MODELO VAE
11
+ # -----------------------------
12
  class VAE(nn.Module):
13
+ def __init__(self, input_dim, h_dim=400, z_dim=40):
14
+ super().__init__()
15
+ self.z_dim = z_dim
16
+ # Encoder
17
+ self.img_2hid = nn.Linear(input_dim, h_dim)
18
+ self.hid_2mu = nn.Linear(h_dim, z_dim)
19
+ self.hid_2sigma = nn.Linear(h_dim, z_dim)
20
+
21
+ # Decoder
22
+ self.z_2hid = nn.Linear(z_dim, h_dim)
23
+ self.hid_2img = nn.Linear(h_dim, input_dim)
24
+
25
+ self.relu = nn.ReLU()
26
+
27
+ def encode(self, x):
28
+ h = self.relu(self.img_2hid(x))
29
+ mu = self.hid_2mu(h)
30
+ sigma = self.hid_2sigma(h)
31
+ return mu, sigma
32
+
33
+ def decode(self, z):
34
+ h = self.relu(self.z_2hid(z))
35
+ return torch.sigmoid(self.hid_2img(h))
36
+
37
+ def forward(self, x):
38
+ mu, sigma = self.encode(x)
39
+ epsilon = torch.randn_like(sigma)
40
+ z_reparam = mu + sigma * epsilon
41
+ x_recon = self.decode(z_reparam)
42
+ return x_recon, mu, sigma
43
 
44
  # -----------------------------
45
+ # 2. CARGAR MODELO DESDE HUGGING FACE
46
  # -----------------------------
47
+ REPO_ID = "Bmo411/VAE" # tu repo
 
48
  MODEL_FILENAME = "vae_complete_model.pth"
49
 
50
+ # Descargar el modelo
51
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
52
 
 
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ # Registrar la clase si se guardó como modelo completo
56
  torch.serialization.add_safe_globals({"VAE": VAE})
57
+
58
+ # Cargar el modelo completo
59
  model = torch.load(model_path, map_location=device, weights_only=False)
60
  model.to(device)
61
  model.eval()
62
 
63
  # -----------------------------
64
+ # 3. GENERAR IMAGEN ALEATORIA
65
  # -----------------------------
66
  def generate_image(z_dim=40):
67
  with torch.no_grad():
68
+ # Muestra del espacio latente
69
  z = torch.randn(1, z_dim).to(device)
70
+ out = model.decode(z) # tamaño: (1, 10000)
71
+
72
+ # Convertir a forma imagen (1, 1, 100, 100)
73
  out = out.view(1, 1, 100, 100)
74
 
75
+ # Guardar imagen temporal
76
  output_path = "generated_sample.png"
77
  save_image(out, output_path)
78
 
79
+ # Leer imagen para mostrar en Gradio
80
+ img = Image.open(output_path).convert("L")
81
  return img
82
 
83
  # -----------------------------