"""LSTM baseline for the selective copy task.""" import torch import torch.nn as nn class LSTMModel(nn.Module): def __init__(self, d_input, d_model, d_output, n_layers=1, dropout=0.0, **kwargs): super().__init__() self.lstm = nn.LSTM( input_size=d_input, hidden_size=d_model, num_layers=n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0.0, ) self.head = nn.Linear(d_model, d_output) def forward(self, x: torch.Tensor) -> torch.Tensor: """(B, T, d_input) → (B, T, d_output)""" out, _ = self.lstm(x) return self.head(out) @staticmethod def extra_kwargs(model_cfg) -> dict: return {}