File size: 2,629 Bytes
8006486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Causal Transformer for the selective copy task."""

import torch
import torch.nn as nn


class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads     = n_heads
        self.d_head      = d_model // n_heads
        self.qkv         = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj        = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop_p = dropout

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).split(C, dim=-1)
        def split_heads(t):
            return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        out = torch.nn.functional.scaled_dot_product_attention(
            split_heads(q), split_heads(k), split_heads(v),
            is_causal=True,
            dropout_p=self.attn_drop_p if self.training else 0.0,
        )
        return self.proj(out.transpose(1, 2).contiguous().view(B, T, C))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        self.ln1  = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout)
        self.ln2  = nn.LayerNorm(d_model)
        self.mlp  = nn.Sequential(
            nn.Linear(d_model, 4 * d_model), nn.GELU(),
            nn.Linear(4 * d_model, d_model), nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class TransformerModel(nn.Module):
    def __init__(self, d_input, d_model, d_output, n_layers=2, n_heads=4, max_len=4096, dropout=0.0, **kwargs):
        super().__init__()
        self.input_proj = nn.Linear(d_input, d_model)
        self.pos_emb    = nn.Embedding(max_len, d_model)
        self.blocks     = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)])
        self.ln_f       = nn.LayerNorm(d_model)
        self.head       = nn.Linear(d_model, d_output)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """(B, T, d_input) → (B, T, d_output)"""
        B, T, _ = x.shape
        if T > self.pos_emb.num_embeddings:
            raise ValueError(f"Sequence length {T} exceeds max_len {self.pos_emb.num_embeddings}")
        h = self.input_proj(x) + self.pos_emb(torch.arange(T, device=x.device))
        for block in self.blocks:
            h = block(h)
        return self.head(self.ln_f(h))

    @staticmethod
    def extra_kwargs(model_cfg) -> dict:
        return {"n_heads": model_cfg.n_heads}