Asilarknes commited on
Commit
03b7838
·
verified ·
1 Parent(s): fba7873

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +284 -0
model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Современный decoder-only трансформер для обучения кодинг-модели с нуля.
2
+
3
+ Компоненты (всё — проверенная практика для код-моделей):
4
+ - RoPE (rotary position embeddings): позволяет расширять контекст за пределы
5
+ обученной длины; нет обучаемых позиционных эмбеддингов.
6
+ - RMSNorm: дешевле и стабильнее LayerNorm.
7
+ - SwiGLU MLP: лучше GELU при том же бюджете параметров.
8
+ - Flash attention через F.scaled_dot_product_attention: память O(N) на практике,
9
+ causal-маска бесплатно.
10
+ - Gradient checkpointing (опц.): торгуем счёт за память -> длинный контекст
11
+ на одной карте.
12
+ - Tied embeddings (вход = выход): экономит параметры, обычно не вредит.
13
+
14
+ Конфиг масштабируется от ~120M до ~1B; дефолт ~0.35B комфортно влезает в 96GB
15
+ с длинным контекстом и grad checkpointing.
16
+ """
17
+
18
+ from dataclasses import dataclass
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ @dataclass
26
+ class ModelConfig:
27
+ vocab_size: int = 49152 # StarCoder2 BPE
28
+ d_model: int = 1024
29
+ n_layers: int = 24
30
+ n_heads: int = 16
31
+ n_kv_heads: int = 4 # GQA: меньше KV-голов -> дешевле память/кэш
32
+ block_size: int = 4096 # тренируемый контекст
33
+ mlp_ratio: float = 8 / 3 # SwiGLU -> hidden ~ 8/3 * d_model, кратно 256
34
+ rope_theta: float = 100_000.0 # большая база -> легче расширять контекст
35
+ dropout: float = 0.0
36
+ grad_checkpoint: bool = True
37
+ # выбор смесителя последовательности:
38
+ # "attn" — обычное внимание во всех слоях (O(N^2), точный recall);
39
+ # "gla" — линейное внимание fla во всех слоях (O(N), но без точного recall);
40
+ # "hybrid" — GLA везде + attention каждый attn_every-й слой (O(N) + recall).
41
+ mixer: str = "attn"
42
+ attn_every: int = 4 # для hybrid: каждый attn_every-й слой = attention
43
+ gla_chunk: int = 64 # размер чанка для fla chunk_gla
44
+
45
+ @property
46
+ def head_dim(self):
47
+ return self.d_model // self.n_heads
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ def __init__(self, dim, eps=1e-5):
52
+ super().__init__()
53
+ self.eps = eps
54
+ self.weight = nn.Parameter(torch.ones(dim))
55
+
56
+ def forward(self, x):
57
+ dt = x.dtype
58
+ x = x.float()
59
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
60
+ return (x * self.weight.float()).to(dt)
61
+
62
+
63
+ def build_rope_cache(seq_len, head_dim, theta, device, dtype):
64
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
65
+ t = torch.arange(seq_len, device=device).float()
66
+ freqs = torch.outer(t, inv_freq) # (T, head_dim/2)
67
+ cos = freqs.cos().to(dtype)
68
+ sin = freqs.sin().to(dtype)
69
+ return cos, sin
70
+
71
+
72
+ def apply_rope(x, cos, sin):
73
+ # x: (B, H, T, D). Поворачиваем пары (x1, x2).
74
+ T = x.shape[-2]
75
+ cos, sin = cos[:T], sin[:T]
76
+ x1, x2 = x[..., 0::2], x[..., 1::2]
77
+ cos = cos[None, None]; sin = sin[None, None]
78
+ rx1 = x1 * cos - x2 * sin
79
+ rx2 = x1 * sin + x2 * cos
80
+ out = torch.empty_like(x)
81
+ out[..., 0::2] = rx1
82
+ out[..., 1::2] = rx2
83
+ return out
84
+
85
+
86
+ class Attention(nn.Module):
87
+ """Causal multi-head attention с GQA и RoPE, flash через SDPA."""
88
+
89
+ def __init__(self, cfg: ModelConfig):
90
+ super().__init__()
91
+ self.n_heads = cfg.n_heads
92
+ self.n_kv = cfg.n_kv_heads
93
+ self.hd = cfg.head_dim
94
+ assert cfg.n_heads % cfg.n_kv_heads == 0, "n_heads должно делиться на n_kv_heads"
95
+ self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
96
+ self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
97
+ self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
98
+ self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
99
+ self.dropout = cfg.dropout
100
+
101
+ def forward(self, x, cos, sin):
102
+ B, T, _ = x.shape
103
+ q = self.q_proj(x).view(B, T, self.n_heads, self.hd).transpose(1, 2)
104
+ k = self.k_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2)
105
+ v = self.v_proj(x).view(B, T, self.n_kv, self.hd).transpose(1, 2)
106
+ q = apply_rope(q, cos, sin)
107
+ k = apply_rope(k, cos, sin)
108
+ if self.n_kv != self.n_heads: # GQA: расширяем KV-головы
109
+ rep = self.n_heads // self.n_kv
110
+ k = k.repeat_interleave(rep, dim=1)
111
+ v = v.repeat_interleave(rep, dim=1)
112
+ y = F.scaled_dot_product_attention(
113
+ q, k, v, is_causal=True,
114
+ dropout_p=self.dropout if self.training else 0.0)
115
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
116
+ return self.o_proj(y)
117
+
118
+
119
+ # fla (flash-linear-attention): рабочее fused Triton-ядро GLA (fwd+bwd).
120
+ # Проверено на RTX PRO 6000: 4x быстрее flash-attn на 32k, обучается (recall грокнул).
121
+ # Импорт защищён: если fla нет (нет triton/Blackwell), GLAMixer недоступен и train
122
+ # должен откатиться на attention (см. _make_mixer).
123
+ try:
124
+ from fla.ops.gla import chunk_gla as _fla_chunk_gla
125
+ _HAS_FLA = True
126
+ except Exception:
127
+ _fla_chunk_gla = None
128
+ _HAS_FLA = False
129
+
130
+
131
+ class GLAMixer(nn.Module):
132
+ """Gated Linear Attention через fla. O(N) по контексту, без RoPE
133
+ (затухание само кодирует позицию). Обучаемый ВЕКТОРНЫЙ гейт затухания
134
+ g = logsigmoid(W_g x) — каноническая форма GLA (мощнее скалярного gamma).
135
+ Раскладка для fla 0.5.0: (B, T, H, K), без kwargs (откалибровано отдельно).
136
+ GQA: KV-головы расширяются до n_heads (fla ждёт одинаковое число голов)."""
137
+
138
+ def __init__(self, cfg: ModelConfig):
139
+ super().__init__()
140
+ assert _HAS_FLA, "GLAMixer требует flash-linear-attention (pip install)"
141
+ self.n_heads = cfg.n_heads
142
+ self.n_kv = cfg.n_kv_heads
143
+ self.hd = cfg.head_dim
144
+ self.chunk = cfg.gla_chunk
145
+ self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
146
+ self.k_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
147
+ self.v_proj = nn.Linear(cfg.d_model, self.n_kv * self.hd, bias=False)
148
+ # гейт затухания на каждый канал q-голов (в лог-пространстве через logsigmoid)
149
+ self.g_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
150
+ self.o_proj = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
151
+ # выходной гейт (как в GLA): сигмоида, стабилизирует амплитуду
152
+ self.out_gate = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
153
+
154
+ def forward(self, x, cos=None, sin=None): # cos/sin игнорируем: GLA без RoPE
155
+ B, T, _ = x.shape
156
+ H, KV, Dh = self.n_heads, self.n_kv, self.hd
157
+ # fla ждёт раскладку (B, T, H, Dh)
158
+ q = self.q_proj(x).view(B, T, H, Dh)
159
+ k = self.k_proj(x).view(B, T, KV, Dh)
160
+ v = self.v_proj(x).view(B, T, KV, Dh)
161
+ if KV != H: # GQA -> расширяем KV до H голов
162
+ rep = H // KV
163
+ k = k.repeat_interleave(rep, dim=2)
164
+ v = v.repeat_interleave(rep, dim=2)
165
+ q = F.normalize(q, dim=-1)
166
+ k = F.normalize(k, dim=-1)
167
+ # лог-гейт затухания в (-inf, 0): logsigmoid -> устойчиво, gamma=exp(g) in (0,1)
168
+ g = F.logsigmoid(self.g_proj(x).view(B, T, H, Dh).float())
169
+ # ЕДИНЫЙ dtype для fla: под autocast F.normalize даёт fp32, а v_proj — bf16;
170
+ # fla-ядро падает на смешении типов в tl.dot. Приводим всё к dtype входа.
171
+ dt = x.dtype
172
+ q, k, v, g = q.to(dt), k.to(dt), v.to(dt), g.to(dt)
173
+ out = _fla_chunk_gla(q, k, v, g) # (B, T, H, Dh), layout bthd
174
+ o = out[0] if isinstance(out, (tuple, list)) else out
175
+ o = o.reshape(B, T, H * Dh) * torch.sigmoid(self.out_gate(x))
176
+ return self.o_proj(o)
177
+
178
+
179
+ class SwiGLU(nn.Module):
180
+ def __init__(self, cfg: ModelConfig):
181
+ super().__init__()
182
+ hidden = int(cfg.mlp_ratio * cfg.d_model)
183
+ hidden = 256 * ((hidden + 255) // 256) # кратно 256 для тензорных ядер
184
+ self.gate = nn.Linear(cfg.d_model, hidden, bias=False)
185
+ self.up = nn.Linear(cfg.d_model, hidden, bias=False)
186
+ self.down = nn.Linear(hidden, cfg.d_model, bias=False)
187
+
188
+ def forward(self, x):
189
+ return self.down(F.silu(self.gate(x)) * self.up(x))
190
+
191
+
192
+ def _layer_is_attn(cfg: ModelConfig, layer_idx: int) -> bool:
193
+ """Какой смеситель в слое layer_idx. hybrid: attention каждый attn_every-й слой
194
+ (на индексах attn_every-1, 2*attn_every-1, ...), остальное — GLA."""
195
+ if cfg.mixer == "attn":
196
+ return True
197
+ if cfg.mixer == "gla":
198
+ return False
199
+ # hybrid
200
+ return (layer_idx + 1) % cfg.attn_every == 0
201
+
202
+
203
+ class Block(nn.Module):
204
+ def __init__(self, cfg: ModelConfig, layer_idx: int = 0):
205
+ super().__init__()
206
+ self.is_attn = _layer_is_attn(cfg, layer_idx)
207
+ self.attn_norm = RMSNorm(cfg.d_model)
208
+ self.mixer = Attention(cfg) if self.is_attn else GLAMixer(cfg)
209
+ self.mlp_norm = RMSNorm(cfg.d_model)
210
+ self.mlp = SwiGLU(cfg)
211
+
212
+ def forward(self, x, cos, sin):
213
+ # GLA-слой игнорирует cos/sin (нет RoPE); attention использует.
214
+ x = x + self.mixer(self.attn_norm(x), cos, sin)
215
+ x = x + self.mlp(self.mlp_norm(x))
216
+ return x
217
+
218
+
219
+ class CodeLM(nn.Module):
220
+ def __init__(self, cfg: ModelConfig):
221
+ super().__init__()
222
+ self.cfg = cfg
223
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
224
+ self.drop = nn.Dropout(cfg.dropout)
225
+ self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
226
+ self.norm_f = RMSNorm(cfg.d_model)
227
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
228
+ self.lm_head.weight = self.tok_emb.weight # tied
229
+ self._rope = None
230
+ self.apply(self._init)
231
+ # масштабирование инициализации остаточных проекций по глубине (GPT-2 трюк)
232
+ for n, p in self.named_parameters():
233
+ if n.endswith("o_proj.weight") or n.endswith("down.weight"):
234
+ nn.init.normal_(p, std=0.02 / math.sqrt(2 * cfg.n_layers))
235
+
236
+ def _init(self, m):
237
+ if isinstance(m, nn.Linear):
238
+ nn.init.normal_(m.weight, std=0.02)
239
+ elif isinstance(m, nn.Embedding):
240
+ nn.init.normal_(m.weight, std=0.02)
241
+
242
+ def _rope_cache(self, T, device, dtype):
243
+ if self._rope is None or self._rope[0].shape[0] < T or self._rope[0].device != device:
244
+ self._rope = build_rope_cache(max(T, self.cfg.block_size),
245
+ self.cfg.head_dim, self.cfg.rope_theta,
246
+ device, dtype)
247
+ return self._rope
248
+
249
+ def forward(self, idx, targets=None):
250
+ B, T = idx.shape
251
+ x = self.drop(self.tok_emb(idx))
252
+ cos, sin = self._rope_cache(T, idx.device, x.dtype)
253
+ for blk in self.blocks:
254
+ if self.cfg.grad_checkpoint and self.training:
255
+ x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False)
256
+ else:
257
+ x = blk(x, cos, sin)
258
+ x = self.norm_f(x)
259
+ if targets is None: # инференс: только последний шаг
260
+ logits = self.lm_head(x[:, -1:])
261
+ return logits, None
262
+ logits = self.lm_head(x)
263
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
264
+ targets.reshape(-1), ignore_index=-100)
265
+ return logits, loss
266
+
267
+ def hidden(self, idx):
268
+ """Состояние ПЕРЕД lm_head (B,T,d). Нужно для MTP-aux голов, которые
269
+ предсказывают токены на горизонте 2..K из того же h."""
270
+ B, T = idx.shape
271
+ x = self.drop(self.tok_emb(idx))
272
+ cos, sin = self._rope_cache(T, idx.device, x.dtype)
273
+ for blk in self.blocks:
274
+ if self.cfg.grad_checkpoint and self.training:
275
+ x = torch.utils.checkpoint.checkpoint(blk, x, cos, sin, use_reentrant=False)
276
+ else:
277
+ x = blk(x, cos, sin)
278
+ return self.norm_f(x)
279
+
280
+ def num_params(self, non_embed=True):
281
+ n = sum(p.numel() for p in self.parameters())
282
+ if non_embed:
283
+ n -= self.tok_emb.weight.numel() # tied -> один раз
284
+ return n