| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | """LSTM layers module.""" |
| |
|
| | from torch import nn |
| |
|
| |
|
| | class SLSTM(nn.Module): |
| | """ |
| | LSTM without worrying about the hidden state, nor the layout of the data. |
| | Expects input as convolutional layout. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dimension: int, |
| | num_layers: int = 2, |
| | skip: bool = True, |
| | bidirectional: bool = False, |
| | ): |
| | super().__init__() |
| | self.bidirectional = bidirectional |
| | self.skip = skip |
| | self.lstm = nn.LSTM( |
| | dimension, dimension, num_layers, bidirectional=bidirectional |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x.permute(2, 0, 1) |
| | y, _ = self.lstm(x) |
| | if self.bidirectional: |
| | x = x.repeat(1, 1, 2) |
| | if self.skip: |
| | y = y + x |
| | y = y.permute(1, 2, 0) |
| | return y |
| |
|