Bmo411 commited on
Commit
1d48959
·
verified ·
1 Parent(s): c495696

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -0
app.py CHANGED
@@ -5,6 +5,36 @@ from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
  import os
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # -----------------------------
10
  # 1. CARGAR MODELO DESDE HUGGING FACE
 
5
  from huggingface_hub import hf_hub_download
6
  import os
7
 
8
+ class VAE(nn.Module):
9
+ def __init__(self, input_dim, h_dim=400, z_dim=20):
10
+ super().__init__()
11
+ #encoder
12
+ self.img_2hid = nn.Linear(input_dim, h_dim)
13
+ self.hid_2mu = nn.Linear(h_dim, z_dim)
14
+ self.hid_2sigma = nn.Linear(h_dim, z_dim)
15
+
16
+ #decoder
17
+ self.z_2hid = nn.Linear(z_dim, h_dim)
18
+ self.hid_2img = nn.Linear(h_dim, input_dim)
19
+
20
+ self.relu = nn.ReLU()
21
+ #self.sigmoid = nn.sigmoid()
22
+
23
+ def encode(self, x):
24
+ h = self.relu(self.img_2hid(x))
25
+ mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
26
+ return mu, sigma
27
+
28
+ def decode(self, z):
29
+ h = self.relu(self.z_2hid(z))
30
+ return torch.sigmoid(self.hid_2img(h))
31
+
32
+ def forward(self, x):
33
+ mu, sigma = self.encode(x)
34
+ epsilon = torch.randn_like(sigma)
35
+ z_reparametrized = mu + sigma * epsilon
36
+ x_reconstructed = self.decode(z_reparametrized)
37
+ return x_reconstructed, mu, sigma
38
 
39
  # -----------------------------
40
  # 1. CARGAR MODELO DESDE HUGGING FACE