| """Causal Transformer for the selective copy task.""" |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, d_model, n_heads, dropout=0.0): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.proj = nn.Linear(d_model, d_model, bias=False) |
| self.attn_drop_p = dropout |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| q, k, v = self.qkv(x).split(C, dim=-1) |
| def split_heads(t): |
| return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
| out = torch.nn.functional.scaled_dot_product_attention( |
| split_heads(q), split_heads(k), split_heads(v), |
| is_causal=True, |
| dropout_p=self.attn_drop_p if self.training else 0.0, |
| ) |
| return self.proj(out.transpose(1, 2).contiguous().view(B, T, C)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, d_model, n_heads, dropout=0.0): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(d_model) |
| self.attn = CausalSelfAttention(d_model, n_heads, dropout) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, 4 * d_model), nn.GELU(), |
| nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class TransformerModel(nn.Module): |
| def __init__(self, d_input, d_model, d_output, n_layers=2, n_heads=4, max_len=4096, dropout=0.0, **kwargs): |
| super().__init__() |
| self.input_proj = nn.Linear(d_input, d_model) |
| self.pos_emb = nn.Embedding(max_len, d_model) |
| self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)]) |
| self.ln_f = nn.LayerNorm(d_model) |
| self.head = nn.Linear(d_model, d_output) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """(B, T, d_input) → (B, T, d_output)""" |
| B, T, _ = x.shape |
| if T > self.pos_emb.num_embeddings: |
| raise ValueError(f"Sequence length {T} exceeds max_len {self.pos_emb.num_embeddings}") |
| h = self.input_proj(x) + self.pos_emb(torch.arange(T, device=x.device)) |
| for block in self.blocks: |
| h = block(h) |
| return self.head(self.ln_f(h)) |
|
|
| @staticmethod |
| def extra_kwargs(model_cfg) -> dict: |
| return {"n_heads": model_cfg.n_heads} |
|
|