File size: 13,599 Bytes
03b7838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""Современный decoder-only трансформер для обучения кодинг-модели с нуля.

Компоненты (всё — проверенная практика для код-моделей):
  - RoPE (rotary position embeddings): позволяет расширять контекст за пределы
    обученной длины; нет обучаемых позиционных эмбеддингов.
  - RMSNorm: дешевле и стабильнее LayerNorm.
  - SwiGLU MLP: лучше GELU при том же бюджете параметров.
  - Flash attention через F.scaled_dot_product_attention: память O(N) на практике,
    causal-маска бесплатно.
  - Gradient checkpointing (опц.): торгуем счёт за память -> длинный контекст
    на одной карте.
  - Tied embeddings (вход = выход): экономит параметры, обычно не вредит.

Конфиг масштабируется от ~120M до ~1B; дефолт ~0.35B комфортно влезает в 96GB
с длинным контекстом и grad checkpointing.
"""

from dataclasses import dataclass
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class ModelConfig:
    vocab_size: int = 49152          # StarCoder2 BPE
    d_model: int = 1024
    n_layers: int = 24
    n_heads: int = 16
    n_kv_heads: int = 4              # GQA: меньше KV-голов -> дешевле память/кэш
    block_size: int = 4096          # тренируемый контекст
    mlp_ratio: float = 8 / 3        # SwiGLU -> hidden ~ 8/3 * d_model, кратно 256
    rope_theta: float = 100_000.0   # большая база -> легче расширять контекст
    dropout: float = 0.0
    grad_checkpoint: bool = True
    # выбор смесителя последовательности:
    #   "attn"   — обычное внимание во всех слоях (O(N^2), точный recall);
    #   "gla"    — линейное внимание fla во всех слоях (O(N), но без точного recall);
    #   "hybrid" — GLA везде + attention каждый attn_every-й слой (O(N) + recall).
    mixer: str = "attn"
    attn_every: int = 4             # для hybrid: каждый attn_every-й слой = attention
    gla_chunk: int = 64             # размер чанка для fla chunk_gla

    @property
    def head_dim(self):
        return self.d_model // self.n_heads


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        dt = x.dtype
        x = x.float()
        x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return (x * self.weight.float()).to(dt)


def build_rope_cache(seq_len, head_dim, theta, device, dtype):
    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    t = torch.arange(seq_len, device=device).float()
    freqs = torch.outer(t, inv_freq)                     # (T, head_dim/2)
    cos = freqs.cos().to(dtype)
    sin = freqs.sin().to(dtype)
    return cos, sin


def apply_rope(x, cos, sin):
    # x: (B, H, T, D). Поворачиваем пары (x1, x2).
    T = x.shape[-2]
    cos, sin = cos[:T], sin[:T]
    x1, x2 = x[..., 0::2], x[..., 1::2]
    cos = cos[None, None]; sin = sin[None, None]
    rx1 = x1 * cos - x2 * sin
    rx2 = x1 * sin + x2 * cos
    out = torch.empty_like(x)
    out[..., 0::2] = rx1
    out[..., 1::2] = rx2
    return out


class Attention(nn.Module):
    """Causal multi-head attention с GQA и RoPE, flash через SDPA."""

    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.n_heads = cfg.n_heads
        self.n_kv = cfg.n_kv_heads
        self.hd = cfg.head_dim
        assert cfg.n_heads % cfg.n_kv_heads == 0, "n_heads должно делиться на n_kv_heads"
        self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
        self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
        self.dropout = cfg.dropout

    def forward(self, x, cos, sin):
        B, T, _ = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.hd).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2)
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)
        if self.n_kv != self.n_heads:                    # GQA: расширяем KV-головы
            rep = self.n_heads // self.n_kv
            k = k.repeat_interleave(rep, dim=1)
            v = v.repeat_interleave(rep, dim=1)
        y = F.scaled_dot_product_attention(
            q, k, v, is_causal=True,
            dropout_p=self.dropout if self.training else 0.0)
        y = y.transpose(1, 2).contiguous().view(B, T, -1)
        return self.o_proj(y)


# fla (flash-linear-attention): рабочее fused Triton-ядро GLA (fwd+bwd).
# Проверено на RTX PRO 6000: 4x быстрее flash-attn на 32k, обучается (recall грокнул).
# Импорт защищён: если fla нет (нет triton/Blackwell), GLAMixer недоступен и train
# должен откатиться на attention (см. _make_mixer).
try:
    from fla.ops.gla import chunk_gla as _fla_chunk_gla
    _HAS_FLA = True
except Exception:
    _fla_chunk_gla = None
    _HAS_FLA = False


class GLAMixer(nn.Module):
    """Gated Linear Attention через fla. O(N) по контексту, без RoPE
    (затухание само кодирует позицию). Обучаемый ВЕКТОРНЫЙ гейт затухания
    g = logsigmoid(W_g x) — каноническая форма GLA (мощнее скалярного gamma).
    Раскладка для fla 0.5.0: (B, T, H, K), без kwargs (откалибровано отдельно).
    GQA: KV-головы расширяются до n_heads (fla ждёт одинаковое число голов)."""

    def __init__(self, cfg: ModelConfig):
        super().__init__()
        assert _HAS_FLA, "GLAMixer требует flash-linear-attention (pip install)"
        self.n_heads = cfg.n_heads
        self.n_kv = cfg.n_kv_heads
        self.hd = cfg.head_dim
        self.chunk = cfg.gla_chunk
        self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
        # гейт затухания на каждый канал q-голов (в лог-пространстве через logsigmoid)
        self.g_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
        self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
        # выходной гейт (как в GLA): сигмоида, стабилизирует амплитуду
        self.out_gate = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)

    def forward(self, x, cos=None, sin=None):     # cos/sin игнорируем: GLA без RoPE
        B, T, _ = x.shape
        H, KV, Dh = self.n_heads, self.n_kv, self.hd
        # fla ждёт раскладку (B, T, H, Dh)
        q = self.q_proj(x).view(B, T, H, Dh)
        k = self.k_proj(x).view(B, T, KV, Dh)
        v = self.v_proj(x).view(B, T, KV, Dh)
        if KV != H:                               # GQA -> расширяем KV до H голов
            rep = H // KV
            k = k.repeat_interleave(rep, dim=2)
            v = v.repeat_interleave(rep, dim=2)
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)
        # лог-гейт затухания в (-inf, 0): logsigmoid -> устойчиво, gamma=exp(g) in (0,1)
        g = F.logsigmoid(self.g_proj(x).view(B, T, H, Dh).float())
        # ЕДИНЫЙ dtype для fla: под autocast F.normalize даёт fp32, а v_proj — bf16;
        # fla-ядро падает на смешении типов в tl.dot. Приводим всё к dtype входа.
        dt = x.dtype
        q, k, v, g = q.to(dt), k.to(dt), v.to(dt), g.to(dt)
        out = _fla_chunk_gla(q, k, v, g)          # (B, T, H, Dh), layout bthd
        o = out[0] if isinstance(out, (tuple, list)) else out
        o = o.reshape(B, T, H * Dh) * torch.sigmoid(self.out_gate(x))
        return self.o_proj(o)


class SwiGLU(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        hidden = int(cfg.mlp_ratio * cfg.d_model)
        hidden = 256 * ((hidden + 255) // 256)           # кратно 256 для тензорных ядер
        self.gate = nn.Linear(cfg.d_model, hidden, bias=False)
        self.up = nn.Linear(cfg.d_model, hidden, bias=False)
        self.down = nn.Linear(hidden, cfg.d_model, bias=False)

    def forward(self, x):
        return self.down(F.silu(self.gate(x)) * self.up(x))


def _layer_is_attn(cfg: ModelConfig, layer_idx: int) -> bool:
    """Какой смеситель в слое layer_idx. hybrid: attention каждый attn_every-й слой
    (на индексах attn_every-1, 2*attn_every-1, ...), остальное — GLA."""
    if cfg.mixer == "attn":
        return True
    if cfg.mixer == "gla":
        return False
    # hybrid
    return (layer_idx + 1) % cfg.attn_every == 0


class Block(nn.Module):
    def __init__(self, cfg: ModelConfig, layer_idx: int = 0):
        super().__init__()
        self.is_attn = _layer_is_attn(cfg, layer_idx)
        self.attn_norm = RMSNorm(cfg.d_model)
        self.mixer = Attention(cfg) if self.is_attn else GLAMixer(cfg)
        self.mlp_norm = RMSNorm(cfg.d_model)
        self.mlp = SwiGLU(cfg)

    def forward(self, x, cos, sin):
        # GLA-слой игнорирует cos/sin (нет RoPE); attention использует.
        x = x + self.mixer(self.attn_norm(x), cos, sin)
        x = x + self.mlp(self.mlp_norm(x))
        return x


class CodeLM(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.tok_emb.weight        # tied
        self._rope = None
        self.apply(self._init)
        # масштабирование инициализации остаточных проекций по глубине (GPT-2 трюк)
        for n, p in self.named_parameters():
            if n.endswith("o_proj.weight") or n.endswith("down.weight"):
                nn.init.normal_(p, std=0.02 / math.sqrt(2 * cfg.n_layers))

    def _init(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, std=0.02)

    def _rope_cache(self, T, device, dtype):
        if self._rope is None or self._rope[0].shape[0] < T or self._rope[0].device != device:
            self._rope = build_rope_cache(max(T, self.cfg.block_size),
                                          self.cfg.head_dim, self.cfg.rope_theta,
                                          device, dtype)
        return self._rope

    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.drop(self.tok_emb(idx))
        cos, sin = self._rope_cache(T, idx.device, x.dtype)
        for blk in self.blocks:
            if self.cfg.grad_checkpoint and self.training:
                x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False)
            else:
                x = blk(x, cos, sin)
        x = self.norm_f(x)
        if targets is None:                              # инференс: только последний шаг
            logits = self.lm_head(x[:, -1:])
            return logits, None
        logits = self.lm_head(x)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
                               targets.reshape(-1), ignore_index=-100)
        return logits, loss

    def hidden(self, idx):
        """Состояние ПЕРЕД lm_head (B,T,d). Нужно для MTP-aux голов, которые
        предсказывают токены на горизонте 2..K из того же h."""
        B, T = idx.shape
        x = self.drop(self.tok_emb(idx))
        cos, sin = self._rope_cache(T, idx.device, x.dtype)
        for blk in self.blocks:
            if self.cfg.grad_checkpoint and self.training:
                x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False)
            else:
                x = blk(x, cos, sin)
        return self.norm_f(x)

    def num_params(self, non_embed=True):
        n = sum(p.numel() for p in self.parameters())
        if non_embed:
            n -= self.tok_emb.weight.numel()             # tied -> один раз
        return n