File size: 570 Bytes
d1deeeb
 
 
 
 
 
c6d22c4
d1deeeb
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import pytorch_lightning as pl

class VAEModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # VAE implementation
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            torch.nn.ReLU()
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1),
            torch.nn.Sigmoid()
        )
    
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)