rahimizadeh commited on
Commit
d1deeeb
·
verified ·
1 Parent(s): bd7790d

Update models/vae.py

Browse files
Files changed (1) hide show
  1. models/vae.py +19 -0
models/vae.py CHANGED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+
4
+ class VAEModel(pl.LightningModule):
5
+ def __init__(self):
6
+ super().__init__()
7
+ # Your VAE implementation here
8
+ self.encoder = torch.nn.Sequential(
9
+ torch.nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
10
+ torch.nn.ReLU()
11
+ )
12
+ self.decoder = torch.nn.Sequential(
13
+ torch.nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1),
14
+ torch.nn.Sigmoid()
15
+ )
16
+
17
+ def forward(self, x):
18
+ z = self.encoder(x)
19
+ return self.decoder(z)