| import torch | |
| import torch.nn as nn | |
| class AutoEncoder(nn.Module): | |
| def __init__(self): | |
| super(AutoEncoder, self).__init__() | |
| self.encoder = nn.Linear(343, 410) | |
| self.sparsify = nn.Sigmoid() | |
| self.decoder = nn.Linear(410, 343) | |
| def forward(self, out): | |
| out = out.view(-1, 343) | |
| out = self.encoder(out) | |
| out = self.sparsify(out) | |
| s_ = out | |
| out = self.decoder(out) | |
| return out, s_ | |