File size: 3,975 Bytes
1e6a1e7
 
10101c1
1e6a1e7
 
 
 
 
10101c1
 
 
 
1e6a1e7
 
 
 
 
 
 
 
 
505474b
 
 
1e6a1e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505474b
 
 
 
 
1e6a1e7
 
 
 
 
 
 
 
 
 
 
 
 
505474b
 
1e6a1e7
505474b
1e6a1e7
 
 
 
 
 
 
10101c1
505474b
 
1e6a1e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
083e56e
1e6a1e7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
from typing import TYPE_CHECKING
from torch.nn import functional as F

from .layers import layer_norm, mlp
from .config import TextConfig

# type checking imports if typechecking
if TYPE_CHECKING:
    from .rope import RotaryEmbedding


def text_encoder(input_ids: torch.Tensor, w: nn.Module):
    return F.embedding(input_ids, w.wte)

def attn(

    x: torch.Tensor,

    w: nn.Module,

    attn_mask: torch.Tensor,

    n_heads: int,

    rope: "RotaryEmbedding",

    kv_cache: nn.Module,

    pos_ids: torch.Tensor,

):
    bsz, q_len, d_model = x.shape
    head_dim = d_model // n_heads

    qkv_out = w.qkv(x)  # shape: (bsz, q_len, (n_heads * 3)*head_dim)

    qkv_reshaped = qkv_out.view(bsz, q_len, 3, n_heads, head_dim)

    # 2. Permute to bring heads before sequence length and QKV to the front
    # Current: (bsz, q_len, 3, n_heads, head_dim) -> (0, 1, 2, 3, 4)
    # Target:  (3, bsz, n_heads, q_len, head_dim) -> (2, 0, 3, 1, 4)
    qkv_permuted = qkv_reshaped.permute(2, 0, 3, 1, 4)

    # 3. Unpack/Split along the first dimension (which now separates Q, K, V)
    q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]

    q = rope.apply(q, pos_ids)
    k = rope.apply(k, pos_ids)

    k, v = kv_cache.update(pos_ids, k, v)


    out = F.scaled_dot_product_attention(
        q, k, v, attn_mask=attn_mask
    )
    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
    out = w.proj(out)
    return out

def text_decoder(

    x: torch.Tensor,

    w: nn.Module,

    attn_mask: torch.Tensor,

    config: TextConfig,

    rope: "RotaryEmbedding",

    pos_ids: torch.Tensor,

):

    for i, block in enumerate(w.blocks):
        l_in = layer_norm(x, block.ln)
        l_attn = attn(
            l_in,
            block.attn,
            attn_mask=attn_mask,
            n_heads=config.n_heads,
            rope=rope,
            kv_cache=block.kv_cache,
            pos_ids=pos_ids,
        )
        l_mlp = mlp(l_in, block.mlp)
        x = x + l_attn + l_mlp

    return x


def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
    hidden_BC = hidden_BTC[:, -1, :]
    hidden_BC = layer_norm(hidden_BC, w.post_ln)
    logits = w.lm_head(hidden_BC)
    return logits


def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
    qkv_dim = int(config.dim * 3)

    text = nn.ModuleDict(
        {
            "blocks": nn.ModuleList(
                [
                    nn.ModuleDict(
                        {
                            "ln": nn.LayerNorm(config.dim, dtype=dtype),
                            "attn": nn.ModuleDict(
                                {
                                    "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
                                    "proj": nn.Linear(
                                        config.dim, config.dim, dtype=dtype
                                    ),
                                }
                            ),
                            "mlp": nn.ModuleDict(
                                {
                                    "fc1": nn.Linear(
                                        config.dim, config.ff_dim, dtype=dtype
                                    ),
                                    "fc2": nn.Linear(
                                        config.ff_dim, config.dim, dtype=dtype
                                    ),
                                }
                            ),
                        }
                    )
                    for _ in range(config.n_layers)
                ]
            ),
            "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
            "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
        }
    )
    text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))


    return text