"""Современный decoder-only трансформер для обучения кодинг-модели с нуля. Компоненты (всё — проверенная практика для код-моделей): - RoPE (rotary position embeddings): позволяет расширять контекст за пределы обученной длины; нет обучаемых позиционных эмбеддингов. - RMSNorm: дешевле и стабильнее LayerNorm. - SwiGLU MLP: лучше GELU при том же бюджете параметров. - Flash attention через F.scaled_dot_product_attention: память O(N) на практике, causal-маска бесплатно. - Gradient checkpointing (опц.): торгуем счёт за память -> длинный контекст на одной карте. - Tied embeddings (вход = выход): экономит параметры, обычно не вредит. Конфиг масштабируется от ~120M до ~1B; дефолт ~0.35B комфортно влезает в 96GB с длинным контекстом и grad checkpointing. """ from dataclasses import dataclass import math import torch import torch.nn as nn import torch.nn.functional as F @dataclass class ModelConfig: vocab_size: int = 49152 # StarCoder2 BPE d_model: int = 1024 n_layers: int = 24 n_heads: int = 16 n_kv_heads: int = 4 # GQA: меньше KV-голов -> дешевле память/кэш block_size: int = 4096 # тренируемый контекст mlp_ratio: float = 8 / 3 # SwiGLU -> hidden ~ 8/3 * d_model, кратно 256 rope_theta: float = 100_000.0 # большая база -> легче расширять контекст dropout: float = 0.0 grad_checkpoint: bool = True # выбор смесителя последовательности: # "attn" — обычное внимание во всех слоях (O(N^2), точный recall); # "gla" — линейное внимание fla во всех слоях (O(N), но без точного recall); # "hybrid" — GLA везде + attention каждый attn_every-й слой (O(N) + recall). mixer: str = "attn" attn_every: int = 4 # для hybrid: каждый attn_every-й слой = attention gla_chunk: int = 64 # размер чанка для fla chunk_gla @property def head_dim(self): return self.d_model // self.n_heads class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): dt = x.dtype x = x.float() x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return (x * self.weight.float()).to(dt) def build_rope_cache(seq_len, head_dim, theta, device, dtype): inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, inv_freq) # (T, head_dim/2) cos = freqs.cos().to(dtype) sin = freqs.sin().to(dtype) return cos, sin def apply_rope(x, cos, sin): # x: (B, H, T, D). Поворачиваем пары (x1, x2). T = x.shape[-2] cos, sin = cos[:T], sin[:T] x1, x2 = x[..., 0::2], x[..., 1::2] cos = cos[None, None]; sin = sin[None, None] rx1 = x1 * cos - x2 * sin rx2 = x1 * sin + x2 * cos out = torch.empty_like(x) out[..., 0::2] = rx1 out[..., 1::2] = rx2 return out class Attention(nn.Module): """Causal multi-head attention с GQA и RoPE, flash через SDPA.""" def __init__(self, cfg: ModelConfig): super().__init__() self.n_heads = cfg.n_heads self.n_kv = cfg.n_kv_heads self.hd = cfg.head_dim assert cfg.n_heads % cfg.n_kv_heads == 0, "n_heads должно делиться на n_kv_heads" self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False) self.dropout = cfg.dropout def forward(self, x, cos, sin): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.hd).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2) q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) if self.n_kv != self.n_heads: # GQA: расширяем KV-головы rep = self.n_heads // self.n_kv k = k.repeat_interleave(rep, dim=1) v = v.repeat_interleave(rep, dim=1) y = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.dropout if self.training else 0.0) y = y.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(y) # fla (flash-linear-attention): рабочее fused Triton-ядро GLA (fwd+bwd). # Проверено на RTX PRO 6000: 4x быстрее flash-attn на 32k, обучается (recall грокнул). # Импорт защищён: если fla нет (нет triton/Blackwell), GLAMixer недоступен и train # должен откатиться на attention (см. _make_mixer). try: from fla.ops.gla import chunk_gla as _fla_chunk_gla _HAS_FLA = True except Exception: _fla_chunk_gla = None _HAS_FLA = False class GLAMixer(nn.Module): """Gated Linear Attention через fla. O(N) по контексту, без RoPE (затухание само кодирует позицию). Обучаемый ВЕКТОРНЫЙ гейт затухания g = logsigmoid(W_g x) — каноническая форма GLA (мощнее скалярного gamma). Раскладка для fla 0.5.0: (B, T, H, K), без kwargs (откалибровано отдельно). GQA: KV-головы расширяются до n_heads (fla ждёт одинаковое число голов).""" def __init__(self, cfg: ModelConfig): super().__init__() assert _HAS_FLA, "GLAMixer требует flash-linear-attention (pip install)" self.n_heads = cfg.n_heads self.n_kv = cfg.n_kv_heads self.hd = cfg.head_dim self.chunk = cfg.gla_chunk self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False) # гейт затухания на каждый канал q-голов (в лог-пространстве через logsigmoid) self.g_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False) # выходной гейт (как в GLA): сигмоида, стабилизирует амплитуду self.out_gate = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) def forward(self, x, cos=None, sin=None): # cos/sin игнорируем: GLA без RoPE B, T, _ = x.shape H, KV, Dh = self.n_heads, self.n_kv, self.hd # fla ждёт раскладку (B, T, H, Dh) q = self.q_proj(x).view(B, T, H, Dh) k = self.k_proj(x).view(B, T, KV, Dh) v = self.v_proj(x).view(B, T, KV, Dh) if KV != H: # GQA -> расширяем KV до H голов rep = H // KV k = k.repeat_interleave(rep, dim=2) v = v.repeat_interleave(rep, dim=2) q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) # лог-гейт затухания в (-inf, 0): logsigmoid -> устойчиво, gamma=exp(g) in (0,1) g = F.logsigmoid(self.g_proj(x).view(B, T, H, Dh).float()) # ЕДИНЫЙ dtype для fla: под autocast F.normalize даёт fp32, а v_proj — bf16; # fla-ядро падает на смешении типов в tl.dot. Приводим всё к dtype входа. dt = x.dtype q, k, v, g = q.to(dt), k.to(dt), v.to(dt), g.to(dt) out = _fla_chunk_gla(q, k, v, g) # (B, T, H, Dh), layout bthd o = out[0] if isinstance(out, (tuple, list)) else out o = o.reshape(B, T, H * Dh) * torch.sigmoid(self.out_gate(x)) return self.o_proj(o) class SwiGLU(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() hidden = int(cfg.mlp_ratio * cfg.d_model) hidden = 256 * ((hidden + 255) // 256) # кратно 256 для тензорных ядер self.gate = nn.Linear(cfg.d_model, hidden, bias=False) self.up = nn.Linear(cfg.d_model, hidden, bias=False) self.down = nn.Linear(hidden, cfg.d_model, bias=False) def forward(self, x): return self.down(F.silu(self.gate(x)) * self.up(x)) def _layer_is_attn(cfg: ModelConfig, layer_idx: int) -> bool: """Какой смеситель в слое layer_idx. hybrid: attention каждый attn_every-й слой (на индексах attn_every-1, 2*attn_every-1, ...), остальное — GLA.""" if cfg.mixer == "attn": return True if cfg.mixer == "gla": return False # hybrid return (layer_idx + 1) % cfg.attn_every == 0 class Block(nn.Module): def __init__(self, cfg: ModelConfig, layer_idx: int = 0): super().__init__() self.is_attn = _layer_is_attn(cfg, layer_idx) self.attn_norm = RMSNorm(cfg.d_model) self.mixer = Attention(cfg) if self.is_attn else GLAMixer(cfg) self.mlp_norm = RMSNorm(cfg.d_model) self.mlp = SwiGLU(cfg) def forward(self, x, cos, sin): # GLA-слой игнорирует cos/sin (нет RoPE); attention использует. x = x + self.mixer(self.attn_norm(x), cos, sin) x = x + self.mlp(self.mlp_norm(x)) return x class CodeLM(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)]) self.norm_f = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight # tied self._rope = None self.apply(self._init) # масштабирование инициализации остаточных проекций по глубине (GPT-2 трюк) for n, p in self.named_parameters(): if n.endswith("o_proj.weight") or n.endswith("down.weight"): nn.init.normal_(p, std=0.02 / math.sqrt(2 * cfg.n_layers)) def _init(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def _rope_cache(self, T, device, dtype): if self._rope is None or self._rope[0].shape[0] < T or self._rope[0].device != device: self._rope = build_rope_cache(max(T, self.cfg.block_size), self.cfg.head_dim, self.cfg.rope_theta, device, dtype) return self._rope def forward(self, idx, targets=None): B, T = idx.shape x = self.drop(self.tok_emb(idx)) cos, sin = self._rope_cache(T, idx.device, x.dtype) for blk in self.blocks: if self.cfg.grad_checkpoint and self.training: x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False) else: x = blk(x, cos, sin) x = self.norm_f(x) if targets is None: # инференс: только последний шаг logits = self.lm_head(x[:, -1:]) return logits, None logits = self.lm_head(x) loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100) return logits, loss def hidden(self, idx): """Состояние ПЕРЕД lm_head (B,T,d). Нужно для MTP-aux голов, которые предсказывают токены на горизонте 2..K из того же h.""" B, T = idx.shape x = self.drop(self.tok_emb(idx)) cos, sin = self._rope_cache(T, idx.device, x.dtype) for blk in self.blocks: if self.cfg.grad_checkpoint and self.training: x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False) else: x = blk(x, cos, sin) return self.norm_f(x) def num_params(self, non_embed=True): n = sum(p.numel() for p in self.parameters()) if non_embed: n -= self.tok_emb.weight.numel() # tied -> один раз return n