File size: 11,517 Bytes
965057d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from typing import Optional
from config import ModelConfig


# RMSNorm
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * rms).to(x.dtype) * self.weight


# RoPE
def precompute_rope_freqs(head_dim: int, max_seq_len: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    t = torch.arange(max_seq_len)
    angles = torch.outer(t, freqs)
    return angles.cos(), angles.sin()


def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    # x: (B, T, n_heads, head_dim)
    x_even = x[..., ::2].float()
    x_odd = x[..., 1::2].float()

    # (T, head_dim/2) -> (1, T, 1, head_dim/2) for broadcasting
    cos = cos[: x.shape[1]].unsqueeze(0).unsqueeze(2)
    sin = sin[: x.shape[1]].unsqueeze(0).unsqueeze(2)

    out_even = x_even * cos - x_odd * sin
    out_odd = x_even * sin + x_odd * cos
    x_rot = torch.stack((out_even, out_odd), dim=-1).flatten(-2)
    return x_rot.to(x.dtype)


# GQA
class GQAttention(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        assert cfg.n_heads % cfg.n_kv_heads == 0
        self.n_heads    = cfg.n_heads
        self.n_kv_heads = cfg.n_kv_heads
        self.n_rep      = cfg.n_heads // cfg.n_kv_heads
        self.head_dim   = cfg.dim // cfg.n_heads

        self.wq = nn.Linear(cfg.dim, cfg.n_heads    * self.head_dim, bias=False)
        self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim,    bias=False)
        self.dropout_p = cfg.dropout

    def forward(
        self,
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        cache_k: Optional[torch.Tensor] = None,
        cache_v: Optional[torch.Tensor] = None,
        return_cache: bool = False,
    ):
        B, T, _ = x.shape
        q = self.wq(x).view(B, T, self.n_heads,    self.head_dim)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)

        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        if cache_k is not None:
            k = torch.cat([cache_k, k], dim=1)
            v = torch.cat([cache_v, v], dim=1)
        new_cache_k, new_cache_v = (k, v) if return_cache else (None, None)

        # Expand KV heads → Q heads
        k = k.repeat_interleave(self.n_rep, dim=2)
        v = v.repeat_interleave(self.n_rep, dim=2)

        # (B, n_heads, T, head_dim) for SDPA
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Flash / memory-efficient attention — never materialises (B,H,T,T) score matrix
        out = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=(cache_k is None),   # causal during training; non-causal with cache
        )

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.wo(out), new_cache_k, new_cache_v


# SwiGLU FFN
class SwiGLU(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        hidden = int(cfg.dim * cfg.ffn_dim_multiplier)
        hidden = (hidden + 255) & ~255
        self.w1 = nn.Linear(cfg.dim, hidden, bias=False)
        self.w2 = nn.Linear(hidden, cfg.dim, bias=False)
        self.w3 = nn.Linear(cfg.dim, hidden, bias=False)
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))


# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.attn_norm = RMSNorm(cfg.dim, cfg.norm_eps)
        self.attn      = GQAttention(cfg)
        self.ffn_norm  = RMSNorm(cfg.dim, cfg.norm_eps)
        self.ffn       = SwiGLU(cfg)

    def _forward(self, x, cos, sin, cache_k, cache_v, return_cache):
        attn_out, nck, ncv = self.attn(
            self.attn_norm(x), cos, sin, cache_k, cache_v, return_cache=return_cache
        )
        x = x + attn_out
        x = x + self.ffn(self.ffn_norm(x))
        return x, nck, ncv

    def forward(self, x, cos, sin, cache_k=None, cache_v=None, use_grad_ckpt=False, return_cache=False):
        if use_grad_ckpt and self.training:
            # gradient checkpointing: recompute activations on backward instead of storing them
            # cache is None during training so we pass dummy tensors to satisfy checkpoint API
            def ckpt_fn(x, cos, sin):
                out, _, _ = self._forward(x, cos, sin, None, None, False)
                return out
            x = checkpoint(ckpt_fn, x, cos, sin, use_reentrant=False)
            return x, None, None
        return self._forward(x, cos, sin, cache_k, cache_v, return_cache)


# LLM Definition
class LLM(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.embed   = nn.Embedding(cfg.vocab_size, cfg.dim)
        self.layers  = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.norm    = RMSNorm(cfg.dim, cfg.norm_eps)
        self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
        self.embed.weight = self.lm_head.weight  # weight tying

        head_dim = cfg.dim // cfg.n_heads
        cos, sin = precompute_rope_freqs(head_dim, cfg.max_seq_len * 2, cfg.rope_theta)
        self.register_buffer("rope_cos", cos)
        self.register_buffer("rope_sin", sin)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.cfg.n_layers))
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(
        self,
        idx: torch.Tensor,
        targets: Optional[torch.Tensor] = None,
        cache: Optional[list] = None,
        use_grad_ckpt: bool = False,
        return_cache: bool = False,
    ):
        B, T = idx.shape
        x = self.embed(idx)
        pos_start = 0 if (cache is None or cache[0][0] is None) else cache[0][0].shape[1]
        cos = self.rope_cos[pos_start: pos_start + T]
        sin = self.rope_sin[pos_start: pos_start + T]

        need_cache = return_cache or (cache is not None)
        new_cache = [] if need_cache else None
        for i, layer in enumerate(self.layers):
            ck, cv = cache[i] if cache else (None, None)
            x, nck, ncv = layer(
                x,
                cos,
                sin,
                ck,
                cv,
                use_grad_ckpt=use_grad_ckpt,
                return_cache=need_cache,
            )
            if need_cache:
                new_cache.append((nck, ncv))

        x = self.norm(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss, new_cache

    def param_count(self) -> int:
        return sum(p.numel() for p in self.parameters())

    @torch.no_grad()
    def probe_attention_entropy(self, idx: torch.Tensor, max_probe_len: int = 256) -> float:
        """
        Estimate mean causal attention entropy from layer 0 on a short token window.
        Lower entropy means sharper/more concentrated attention.
        """
        if idx.ndim != 2:
            raise ValueError(f"idx must be shape (B, T), got {tuple(idx.shape)}")
        if idx.shape[1] == 0:
            return float("nan")

        probe_len = min(int(max_probe_len), int(idx.shape[1]))
        idx = idx[:, -probe_len:]
        B, T = idx.shape

        x = self.embed(idx)
        cos = self.rope_cos[:T]
        sin = self.rope_sin[:T]
        layer0 = self.layers[0]
        attn = layer0.attn

        h = layer0.attn_norm(x)
        q = attn.wq(h).view(B, T, attn.n_heads, attn.head_dim)
        k = attn.wk(h).view(B, T, attn.n_kv_heads, attn.head_dim)
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        k = k.repeat_interleave(attn.n_rep, dim=2)

        q = q.transpose(1, 2).float()  # (B, H, T, D)
        k = k.transpose(1, 2).float()  # (B, H, T, D)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(attn.head_dim)
        causal_mask = torch.triu(
            torch.ones((T, T), device=scores.device, dtype=torch.bool), diagonal=1
        )
        scores = scores.masked_fill(causal_mask, float("-inf"))

        probs = torch.softmax(scores, dim=-1)
        entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)
        return float(entropy.mean().item())

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 0.8,
        top_p: float = 0.9,
        repetition_penalty: float = 1.1,
        no_repeat_ngram_size: int = 3,
    ):
        cache = None
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.cfg.max_seq_len:] if cache is None else idx[:, -1:]
            logits, _, cache = self(idx_cond, cache=cache, return_cache=True)
            logits = logits[:, -1, :]

            # Discourage copying previously generated tokens.
            if repetition_penalty > 1.0:
                for b in range(idx.size(0)):
                    used = idx[b].unique()
                    used_logits = logits[b, used]
                    logits[b, used] = torch.where(
                        used_logits > 0, used_logits / repetition_penalty, used_logits * repetition_penalty
                    )

            # Block tokens that would create repeated n-grams.
            if no_repeat_ngram_size and no_repeat_ngram_size > 1 and idx.size(1) >= no_repeat_ngram_size - 1:
                n = int(no_repeat_ngram_size)
                for b in range(idx.size(0)):
                    seq = idx[b].tolist()
                    prefix = tuple(seq[-(n - 1) :])
                    banned = set()
                    for i in range(len(seq) - n + 1):
                        if tuple(seq[i : i + n - 1]) == prefix:
                            banned.add(seq[i + n - 1])
                    if banned:
                        logits[b, list(banned)] = float("-inf")

            if temperature == 0.0:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)
            else:
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                sorted_probs, sorted_idx = torch.sort(probs, descending=True)
                cumsum = sorted_probs.cumsum(-1)
                sorted_probs[cumsum - sorted_probs > top_p] = 0.0
                sorted_probs /= sorted_probs.sum(-1, keepdim=True)
                next_tok = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))

            idx = torch.cat([idx, next_tok], dim=1)
        return idx