File size: 745 Bytes
8006486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""LSTM baseline for the selective copy task."""

import torch
import torch.nn as nn


class LSTMModel(nn.Module):
    def __init__(self, d_input, d_model, d_output, n_layers=1, dropout=0.0, **kwargs):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=d_input,
            hidden_size=d_model,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0,
        )
        self.head = nn.Linear(d_model, d_output)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """(B, T, d_input) → (B, T, d_output)"""
        out, _ = self.lstm(x)
        return self.head(out)

    @staticmethod
    def extra_kwargs(model_cfg) -> dict:
        return {}