modell-name / src /models /lstm.py
RabidUmarell's picture
Add model checkpoint and source
8006486 verified
"""LSTM baseline for the selective copy task."""
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, d_input, d_model, d_output, n_layers=1, dropout=0.0, **kwargs):
super().__init__()
self.lstm = nn.LSTM(
input_size=d_input,
hidden_size=d_model,
num_layers=n_layers,
batch_first=True,
dropout=dropout if n_layers > 1 else 0.0,
)
self.head = nn.Linear(d_model, d_output)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""(B, T, d_input) → (B, T, d_output)"""
out, _ = self.lstm(x)
return self.head(out)
@staticmethod
def extra_kwargs(model_cfg) -> dict:
return {}