File size: 404 Bytes
7690851 | 1 2 3 4 5 6 7 8 9 10 11 12 | import torch.nn as nn
class LSTMAutoencoder(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.decoder = nn.LSTM(hidden_dim, input_dim, batch_first=True)
def forward(self, x):
encoded, _ = self.encoder(x)
decoded, _ = self.decoder(encoded)
return decoded |