| """LSTM baseline for the selective copy task.""" | |
| import torch | |
| import torch.nn as nn | |
| class LSTMModel(nn.Module): | |
| def __init__(self, d_input, d_model, d_output, n_layers=1, dropout=0.0, **kwargs): | |
| super().__init__() | |
| self.lstm = nn.LSTM( | |
| input_size=d_input, | |
| hidden_size=d_model, | |
| num_layers=n_layers, | |
| batch_first=True, | |
| dropout=dropout if n_layers > 1 else 0.0, | |
| ) | |
| self.head = nn.Linear(d_model, d_output) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """(B, T, d_input) → (B, T, d_output)""" | |
| out, _ = self.lstm(x) | |
| return self.head(out) | |
| def extra_kwargs(model_cfg) -> dict: | |
| return {} | |