| """Современный decoder-only трансформер для обучения кодинг-модели с нуля. |
| |
| Компоненты (всё — проверенная практика для код-моделей): |
| - RoPE (rotary position embeddings): позволяет расширять контекст за пределы |
| обученной длины; нет обучаемых позиционных эмбеддингов. |
| - RMSNorm: дешевле и стабильнее LayerNorm. |
| - SwiGLU MLP: лучше GELU при том же бюджете параметров. |
| - Flash attention через F.scaled_dot_product_attention: память O(N) на практике, |
| causal-маска бесплатно. |
| - Gradient checkpointing (опц.): торгуем счёт за память -> длинный контекст |
| на одной карте. |
| - Tied embeddings (вход = выход): экономит параметры, обычно не вредит. |
| |
| Конфиг масштабируется от ~120M до ~1B; дефолт ~0.35B комфортно влезает в 96GB |
| с длинным контекстом и grad checkpointing. |
| """ |
|
|
| from dataclasses import dataclass |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 49152 |
| d_model: int = 1024 |
| n_layers: int = 24 |
| n_heads: int = 16 |
| n_kv_heads: int = 4 |
| block_size: int = 4096 |
| mlp_ratio: float = 8 / 3 |
| rope_theta: float = 100_000.0 |
| dropout: float = 0.0 |
| grad_checkpoint: bool = True |
| |
| |
| |
| |
| mixer: str = "attn" |
| attn_every: int = 4 |
| gla_chunk: int = 64 |
|
|
| @property |
| def head_dim(self): |
| return self.d_model // self.n_heads |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| dt = x.dtype |
| x = x.float() |
| x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return (x * self.weight.float()).to(dt) |
|
|
|
|
| def build_rope_cache(seq_len, head_dim, theta, device, dtype): |
| inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
| t = torch.arange(seq_len, device=device).float() |
| freqs = torch.outer(t, inv_freq) |
| cos = freqs.cos().to(dtype) |
| sin = freqs.sin().to(dtype) |
| return cos, sin |
|
|
|
|
| def apply_rope(x, cos, sin): |
| |
| T = x.shape[-2] |
| cos, sin = cos[:T], sin[:T] |
| x1, x2 = x[..., 0::2], x[..., 1::2] |
| cos = cos[None, None]; sin = sin[None, None] |
| rx1 = x1 * cos - x2 * sin |
| rx2 = x1 * sin + x2 * cos |
| out = torch.empty_like(x) |
| out[..., 0::2] = rx1 |
| out[..., 1::2] = rx2 |
| return out |
|
|
|
|
| class Attention(nn.Module): |
| """Causal multi-head attention с GQA и RoPE, flash через SDPA.""" |
|
|
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.n_heads = cfg.n_heads |
| self.n_kv = cfg.n_kv_heads |
| self.hd = cfg.head_dim |
| assert cfg.n_heads % cfg.n_kv_heads == 0, "n_heads должно делиться на n_kv_heads" |
| self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) |
| self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) |
| self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) |
| self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False) |
| self.dropout = cfg.dropout |
|
|
| def forward(self, x, cos, sin): |
| B, T, _ = x.shape |
| q = self.q_proj(x).view(B, T, self.n_heads, self.hd).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2) |
| q = apply_rope(q, cos, sin) |
| k = apply_rope(k, cos, sin) |
| if self.n_kv != self.n_heads: |
| rep = self.n_heads // self.n_kv |
| k = k.repeat_interleave(rep, dim=1) |
| v = v.repeat_interleave(rep, dim=1) |
| y = F.scaled_dot_product_attention( |
| q, k, v, is_causal=True, |
| dropout_p=self.dropout if self.training else 0.0) |
| y = y.transpose(1, 2).contiguous().view(B, T, -1) |
| return self.o_proj(y) |
|
|
|
|
| |
| |
| |
| |
| try: |
| from fla.ops.gla import chunk_gla as _fla_chunk_gla |
| _HAS_FLA = True |
| except Exception: |
| _fla_chunk_gla = None |
| _HAS_FLA = False |
|
|
|
|
| class GLAMixer(nn.Module): |
| """Gated Linear Attention через fla. O(N) по контексту, без RoPE |
| (затухание само кодирует позицию). Обучаемый ВЕКТОРНЫЙ гейт затухания |
| g = logsigmoid(W_g x) — каноническая форма GLA (мощнее скалярного gamma). |
| Раскладка для fla 0.5.0: (B, T, H, K), без kwargs (откалибровано отдельно). |
| GQA: KV-головы расширяются до n_heads (fla ждёт одинаковое число голов).""" |
|
|
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| assert _HAS_FLA, "GLAMixer требует flash-linear-attention (pip install)" |
| self.n_heads = cfg.n_heads |
| self.n_kv = cfg.n_kv_heads |
| self.hd = cfg.head_dim |
| self.chunk = cfg.gla_chunk |
| self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) |
| self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) |
| self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) |
| |
| self.g_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) |
| self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False) |
| |
| self.out_gate = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) |
|
|
| def forward(self, x, cos=None, sin=None): |
| B, T, _ = x.shape |
| H, KV, Dh = self.n_heads, self.n_kv, self.hd |
| |
| q = self.q_proj(x).view(B, T, H, Dh) |
| k = self.k_proj(x).view(B, T, KV, Dh) |
| v = self.v_proj(x).view(B, T, KV, Dh) |
| if KV != H: |
| rep = H // KV |
| k = k.repeat_interleave(rep, dim=2) |
| v = v.repeat_interleave(rep, dim=2) |
| q = F.normalize(q, dim=-1) |
| k = F.normalize(k, dim=-1) |
| |
| g = F.logsigmoid(self.g_proj(x).view(B, T, H, Dh).float()) |
| |
| |
| dt = x.dtype |
| q, k, v, g = q.to(dt), k.to(dt), v.to(dt), g.to(dt) |
| out = _fla_chunk_gla(q, k, v, g) |
| o = out[0] if isinstance(out, (tuple, list)) else out |
| o = o.reshape(B, T, H * Dh) * torch.sigmoid(self.out_gate(x)) |
| return self.o_proj(o) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| hidden = int(cfg.mlp_ratio * cfg.d_model) |
| hidden = 256 * ((hidden + 255) // 256) |
| self.gate = nn.Linear(cfg.d_model, hidden, bias=False) |
| self.up = nn.Linear(cfg.d_model, hidden, bias=False) |
| self.down = nn.Linear(hidden, cfg.d_model, bias=False) |
|
|
| def forward(self, x): |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| def _layer_is_attn(cfg: ModelConfig, layer_idx: int) -> bool: |
| """Какой смеситель в слое layer_idx. hybrid: attention каждый attn_every-й слой |
| (на индексах attn_every-1, 2*attn_every-1, ...), остальное — GLA.""" |
| if cfg.mixer == "attn": |
| return True |
| if cfg.mixer == "gla": |
| return False |
| |
| return (layer_idx + 1) % cfg.attn_every == 0 |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: ModelConfig, layer_idx: int = 0): |
| super().__init__() |
| self.is_attn = _layer_is_attn(cfg, layer_idx) |
| self.attn_norm = RMSNorm(cfg.d_model) |
| self.mixer = Attention(cfg) if self.is_attn else GLAMixer(cfg) |
| self.mlp_norm = RMSNorm(cfg.d_model) |
| self.mlp = SwiGLU(cfg) |
|
|
| def forward(self, x, cos, sin): |
| |
| x = x + self.mixer(self.attn_norm(x), cos, sin) |
| x = x + self.mlp(self.mlp_norm(x)) |
| return x |
|
|
|
|
| class CodeLM(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.drop = nn.Dropout(cfg.dropout) |
| self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)]) |
| self.norm_f = RMSNorm(cfg.d_model) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.lm_head.weight = self.tok_emb.weight |
| self._rope = None |
| self.apply(self._init) |
| |
| for n, p in self.named_parameters(): |
| if n.endswith("o_proj.weight") or n.endswith("down.weight"): |
| nn.init.normal_(p, std=0.02 / math.sqrt(2 * cfg.n_layers)) |
|
|
| def _init(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
|
|
| def _rope_cache(self, T, device, dtype): |
| if self._rope is None or self._rope[0].shape[0] < T or self._rope[0].device != device: |
| self._rope = build_rope_cache(max(T, self.cfg.block_size), |
| self.cfg.head_dim, self.cfg.rope_theta, |
| device, dtype) |
| return self._rope |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| x = self.drop(self.tok_emb(idx)) |
| cos, sin = self._rope_cache(T, idx.device, x.dtype) |
| for blk in self.blocks: |
| if self.cfg.grad_checkpoint and self.training: |
| x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False) |
| else: |
| x = blk(x, cos, sin) |
| x = self.norm_f(x) |
| if targets is None: |
| logits = self.lm_head(x[:, -1:]) |
| return logits, None |
| logits = self.lm_head(x) |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), ignore_index=-100) |
| return logits, loss |
|
|
| def hidden(self, idx): |
| """Состояние ПЕРЕД lm_head (B,T,d). Нужно для MTP-aux голов, которые |
| предсказывают токены на горизонте 2..K из того же h.""" |
| B, T = idx.shape |
| x = self.drop(self.tok_emb(idx)) |
| cos, sin = self._rope_cache(T, idx.device, x.dtype) |
| for blk in self.blocks: |
| if self.cfg.grad_checkpoint and self.training: |
| x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False) |
| else: |
| x = blk(x, cos, sin) |
| return self.norm_f(x) |
|
|
| def num_params(self, non_embed=True): |
| n = sum(p.numel() for p in self.parameters()) |
| if non_embed: |
| n -= self.tok_emb.weight.numel() |
| return n |
|
|