Asilarknes's picture
Upload model.py with huggingface_hub
03b7838 verified
Raw
History Blame Contribute Delete
13.6 kB
"""Современный decoder-only трансформер для обучения кодинг-модели с нуля.
Компоненты (всё — проверенная практика для код-моделей):
- RoPE (rotary position embeddings): позволяет расширять контекст за пределы
обученной длины; нет обучаемых позиционных эмбеддингов.
- RMSNorm: дешевле и стабильнее LayerNorm.
- SwiGLU MLP: лучше GELU при том же бюджете параметров.
- Flash attention через F.scaled_dot_product_attention: память O(N) на практике,
causal-маска бесплатно.
- Gradient checkpointing (опц.): торгуем счёт за память -> длинный контекст
на одной карте.
- Tied embeddings (вход = выход): экономит параметры, обычно не вредит.
Конфиг масштабируется от ~120M до ~1B; дефолт ~0.35B комфортно влезает в 96GB
с длинным контекстом и grad checkpointing.
"""
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class ModelConfig:
vocab_size: int = 49152 # StarCoder2 BPE
d_model: int = 1024
n_layers: int = 24
n_heads: int = 16
n_kv_heads: int = 4 # GQA: меньше KV-голов -> дешевле память/кэш
block_size: int = 4096 # тренируемый контекст
mlp_ratio: float = 8 / 3 # SwiGLU -> hidden ~ 8/3 * d_model, кратно 256
rope_theta: float = 100_000.0 # большая база -> легче расширять контекст
dropout: float = 0.0
grad_checkpoint: bool = True
# выбор смесителя последовательности:
# "attn" — обычное внимание во всех слоях (O(N^2), точный recall);
# "gla" — линейное внимание fla во всех слоях (O(N), но без точного recall);
# "hybrid" — GLA везде + attention каждый attn_every-й слой (O(N) + recall).
mixer: str = "attn"
attn_every: int = 4 # для hybrid: каждый attn_every-й слой = attention
gla_chunk: int = 64 # размер чанка для fla chunk_gla
@property
def head_dim(self):
return self.d_model // self.n_heads
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
dt = x.dtype
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return (x * self.weight.float()).to(dt)
def build_rope_cache(seq_len, head_dim, theta, device, dtype):
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq) # (T, head_dim/2)
cos = freqs.cos().to(dtype)
sin = freqs.sin().to(dtype)
return cos, sin
def apply_rope(x, cos, sin):
# x: (B, H, T, D). Поворачиваем пары (x1, x2).
T = x.shape[-2]
cos, sin = cos[:T], sin[:T]
x1, x2 = x[..., 0::2], x[..., 1::2]
cos = cos[None, None]; sin = sin[None, None]
rx1 = x1 * cos - x2 * sin
rx2 = x1 * sin + x2 * cos
out = torch.empty_like(x)
out[..., 0::2] = rx1
out[..., 1::2] = rx2
return out
class Attention(nn.Module):
"""Causal multi-head attention с GQA и RoPE, flash через SDPA."""
def __init__(self, cfg: ModelConfig):
super().__init__()
self.n_heads = cfg.n_heads
self.n_kv = cfg.n_kv_heads
self.hd = cfg.head_dim
assert cfg.n_heads % cfg.n_kv_heads == 0, "n_heads должно делиться на n_kv_heads"
self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
self.dropout = cfg.dropout
def forward(self, x, cos, sin):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.hd).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
if self.n_kv != self.n_heads: # GQA: расширяем KV-головы
rep = self.n_heads // self.n_kv
k = k.repeat_interleave(rep, dim=1)
v = v.repeat_interleave(rep, dim=1)
y = F.scaled_dot_product_attention(
q, k, v, is_causal=True,
dropout_p=self.dropout if self.training else 0.0)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(y)
# fla (flash-linear-attention): рабочее fused Triton-ядро GLA (fwd+bwd).
# Проверено на RTX PRO 6000: 4x быстрее flash-attn на 32k, обучается (recall грокнул).
# Импорт защищён: если fla нет (нет triton/Blackwell), GLAMixer недоступен и train
# должен откатиться на attention (см. _make_mixer).
try:
from fla.ops.gla import chunk_gla as _fla_chunk_gla
_HAS_FLA = True
except Exception:
_fla_chunk_gla = None
_HAS_FLA = False
class GLAMixer(nn.Module):
"""Gated Linear Attention через fla. O(N) по контексту, без RoPE
(затухание само кодирует позицию). Обучаемый ВЕКТОРНЫЙ гейт затухания
g = logsigmoid(W_g x) — каноническая форма GLA (мощнее скалярного gamma).
Раскладка для fla 0.5.0: (B, T, H, K), без kwargs (откалибровано отдельно).
GQA: KV-головы расширяются до n_heads (fla ждёт одинаковое число голов)."""
def __init__(self, cfg: ModelConfig):
super().__init__()
assert _HAS_FLA, "GLAMixer требует flash-linear-attention (pip install)"
self.n_heads = cfg.n_heads
self.n_kv = cfg.n_kv_heads
self.hd = cfg.head_dim
self.chunk = cfg.gla_chunk
self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
# гейт затухания на каждый канал q-голов (в лог-пространстве через logsigmoid)
self.g_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
# выходной гейт (как в GLA): сигмоида, стабилизирует амплитуду
self.out_gate = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
def forward(self, x, cos=None, sin=None): # cos/sin игнорируем: GLA без RoPE
B, T, _ = x.shape
H, KV, Dh = self.n_heads, self.n_kv, self.hd
# fla ждёт раскладку (B, T, H, Dh)
q = self.q_proj(x).view(B, T, H, Dh)
k = self.k_proj(x).view(B, T, KV, Dh)
v = self.v_proj(x).view(B, T, KV, Dh)
if KV != H: # GQA -> расширяем KV до H голов
rep = H // KV
k = k.repeat_interleave(rep, dim=2)
v = v.repeat_interleave(rep, dim=2)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
# лог-гейт затухания в (-inf, 0): logsigmoid -> устойчиво, gamma=exp(g) in (0,1)
g = F.logsigmoid(self.g_proj(x).view(B, T, H, Dh).float())
# ЕДИНЫЙ dtype для fla: под autocast F.normalize даёт fp32, а v_proj — bf16;
# fla-ядро падает на смешении типов в tl.dot. Приводим всё к dtype входа.
dt = x.dtype
q, k, v, g = q.to(dt), k.to(dt), v.to(dt), g.to(dt)
out = _fla_chunk_gla(q, k, v, g) # (B, T, H, Dh), layout bthd
o = out[0] if isinstance(out, (tuple, list)) else out
o = o.reshape(B, T, H * Dh) * torch.sigmoid(self.out_gate(x))
return self.o_proj(o)
class SwiGLU(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
hidden = int(cfg.mlp_ratio * cfg.d_model)
hidden = 256 * ((hidden + 255) // 256) # кратно 256 для тензорных ядер
self.gate = nn.Linear(cfg.d_model, hidden, bias=False)
self.up = nn.Linear(cfg.d_model, hidden, bias=False)
self.down = nn.Linear(hidden, cfg.d_model, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
def _layer_is_attn(cfg: ModelConfig, layer_idx: int) -> bool:
"""Какой смеситель в слое layer_idx. hybrid: attention каждый attn_every-й слой
(на индексах attn_every-1, 2*attn_every-1, ...), остальное — GLA."""
if cfg.mixer == "attn":
return True
if cfg.mixer == "gla":
return False
# hybrid
return (layer_idx + 1) % cfg.attn_every == 0
class Block(nn.Module):
def __init__(self, cfg: ModelConfig, layer_idx: int = 0):
super().__init__()
self.is_attn = _layer_is_attn(cfg, layer_idx)
self.attn_norm = RMSNorm(cfg.d_model)
self.mixer = Attention(cfg) if self.is_attn else GLAMixer(cfg)
self.mlp_norm = RMSNorm(cfg.d_model)
self.mlp = SwiGLU(cfg)
def forward(self, x, cos, sin):
# GLA-слой игнорирует cos/sin (нет RoPE); attention использует.
x = x + self.mixer(self.attn_norm(x), cos, sin)
x = x + self.mlp(self.mlp_norm(x))
return x
class CodeLM(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
self.norm_f = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight # tied
self._rope = None
self.apply(self._init)
# масштабирование инициализации остаточных проекций по глубине (GPT-2 трюк)
for n, p in self.named_parameters():
if n.endswith("o_proj.weight") or n.endswith("down.weight"):
nn.init.normal_(p, std=0.02 / math.sqrt(2 * cfg.n_layers))
def _init(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def _rope_cache(self, T, device, dtype):
if self._rope is None or self._rope[0].shape[0] < T or self._rope[0].device != device:
self._rope = build_rope_cache(max(T, self.cfg.block_size),
self.cfg.head_dim, self.cfg.rope_theta,
device, dtype)
return self._rope
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.drop(self.tok_emb(idx))
cos, sin = self._rope_cache(T, idx.device, x.dtype)
for blk in self.blocks:
if self.cfg.grad_checkpoint and self.training:
x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False)
else:
x = blk(x, cos, sin)
x = self.norm_f(x)
if targets is None: # инференс: только последний шаг
logits = self.lm_head(x[:, -1:])
return logits, None
logits = self.lm_head(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
targets.reshape(-1), ignore_index=-100)
return logits, loss
def hidden(self, idx):
"""Состояние ПЕРЕД lm_head (B,T,d). Нужно для MTP-aux голов, которые
предсказывают токены на горизонте 2..K из того же h."""
B, T = idx.shape
x = self.drop(self.tok_emb(idx))
cos, sin = self._rope_cache(T, idx.device, x.dtype)
for blk in self.blocks:
if self.cfg.grad_checkpoint and self.training:
x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False)
else:
x = blk(x, cos, sin)
return self.norm_f(x)
def num_params(self, non_embed=True):
n = sum(p.numel() for p in self.parameters())
if non_embed:
n -= self.tok_emb.weight.numel() # tied -> один раз
return n