Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| class Encoder(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=16, latent_dim=64, act_fn=nn.ReLU()): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), # (480, 360) | |
| act_fn, | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1), | |
| act_fn, | |
| nn.Conv2d(out_channels, 2 * out_channels, 3, padding=1, stride=2), # (240, 180) | |
| act_fn, | |
| nn.Conv2d(2 * out_channels, 2 * out_channels, 3, padding=1), | |
| act_fn, | |
| nn.Conv2d(2 * out_channels, 4 * out_channels, 3, padding=1, stride=2), # (120, 90) | |
| act_fn, | |
| nn.Conv2d(4 * out_channels, 4 * out_channels, 3, padding=1), | |
| act_fn, | |
| nn.Flatten(), | |
| nn.Linear(4 * out_channels * 120 * 90, latent_dim), | |
| act_fn | |
| ) | |
| def forward(self, x): | |
| x = x.view(-1, 1, 480, 360) | |
| output = self.net(x) | |
| return output |