epoyraz commited on
Commit
ab0b8a1
·
verified ·
1 Parent(s): ba0921a

Add TinyStories GPT (19M) checkpoint, model code, tokenizer, and card

Browse files
Files changed (5) hide show
  1. README.md +126 -0
  2. config.json +15 -0
  3. model.py +982 -0
  4. tinystories-25m.pt +3 -0
  5. tokenizer.json +0 -0
README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - roneneldan/TinyStories
5
+ language:
6
+ - en
7
+ tags:
8
+ - text-generation
9
+ - gpt
10
+ - tinystories
11
+ - from-scratch
12
+ - pytorch
13
+ - rope
14
+ - gqa
15
+ - swiglu
16
+ - multi-token-prediction
17
+ pipeline_tag: text-generation
18
+ ---
19
+
20
+ # TinyStories GPT (19M)
21
+
22
+ A small (~19.2M parameter) decoder-only GPT trained **from scratch** on
23
+ [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). It writes
24
+ simple, coherent children's stories and is meant as a compact, hackable reference
25
+ for modern LLM architecture techniques — small enough to train end-to-end in a few
26
+ minutes on a consumer GPU (RTX 2060 Super, 8 GB).
27
+
28
+ ## Sample output
29
+
30
+ > **Once upon a time,** there was a little girl named Lily. She loved to play with
31
+ > her dolls and sing songs. One day, she went to the park to play with her friends.
32
+ > She saw a boy playing with a toy car and asked why he played too much...
33
+
34
+ > **Lily and Tom went to the park and** played on the swings. They had a lot of fun.
35
+ > They played with their toys and had a lot of fun. They also learned to be good and
36
+ > not judge others. They were happy.
37
+
38
+ ## Architecture
39
+
40
+ A LLaMA-style decoder-only transformer with several modern techniques wired in:
41
+
42
+ | Component | Choice |
43
+ |---|---|
44
+ | Layers / heads / dim | 8 layers, 6 heads, `n_embd` 384 |
45
+ | Context length | 256 tokens |
46
+ | Vocabulary | 16,384 (ByteLevel BPE) |
47
+ | Position encoding | **RoPE** (rotary embeddings) |
48
+ | Attention | **Grouped-Query Attention** (2 KV heads) |
49
+ | MLP | **SwiGLU** |
50
+ | Normalization | **RMSNorm** |
51
+ | Extra heads | **Multi-Token Prediction** (2 auxiliary heads) for sample efficiency |
52
+ | Weight tying | token embedding ↔ output head (and MTP heads) |
53
+
54
+ ## Training
55
+
56
+ | | |
57
+ |---|---|
58
+ | Dataset | TinyStories (~2.1M stories) |
59
+ | Steps | 3,000 |
60
+ | Batch | 32 × 256 tokens |
61
+ | Optimizer | AdamW, cosine schedule, 200-step warmup, peak LR 6e-4 |
62
+ | Precision | fp16 mixed precision |
63
+ | Hardware | 1× RTX 2060 Super (8 GB), ~7 minutes |
64
+ | Throughput | ~57K tokens/sec |
65
+ | Final loss | 2.62 (combined next-token + MTP auxiliary) |
66
+ | Validation loss | 2.65 |
67
+
68
+ This is a lightly trained demo checkpoint; longer training lowers loss further.
69
+
70
+ ## Usage
71
+
72
+ This is a **custom architecture**, so you need `model.py` from this repo (it's small
73
+ and dependency-light). Download it next to your script, then:
74
+
75
+ ```python
76
+ import torch
77
+ from huggingface_hub import hf_hub_download
78
+ from tokenizers import Tokenizer
79
+ from model import GPT # model.py downloaded from this repo
80
+
81
+ repo = "epoyraz/tinystories-25m"
82
+ ckpt = torch.load(
83
+ hf_hub_download(repo, "tinystories-25m.pt"),
84
+ map_location="cpu", weights_only=True,
85
+ )
86
+ model = GPT(ckpt["config"]).eval()
87
+ model.load_state_dict(ckpt["model"])
88
+
89
+ tok = Tokenizer.from_file(hf_hub_download(repo, "tokenizer.json"))
90
+ ids = tok.encode("Once upon a time,").ids
91
+ out = model.generate(
92
+ torch.tensor([ids]), max_new_tokens=120, temperature=0.7, top_k=40,
93
+ )
94
+ print(tok.decode(out[0].tolist()))
95
+ ```
96
+
97
+ `pip install torch tokenizers huggingface_hub`
98
+
99
+ ## Files
100
+
101
+ - `tinystories-25m.pt` — checkpoint (`config` + `model` state dict)
102
+ - `model.py` — model definition (`GPT`, all techniques)
103
+ - `config.json` — the model config, for reference
104
+ - `tokenizer.json` — ByteLevel BPE tokenizer (16K vocab)
105
+
106
+ ## Limitations
107
+
108
+ - Trained only on TinyStories — vocabulary and style are limited to simple
109
+ children's-story English. It is not a general-purpose assistant.
110
+ - Small and lightly trained: it repeats phrases and occasionally drifts or
111
+ contradicts itself (e.g. swapping character names).
112
+ - 256-token context.
113
+
114
+ ## Source
115
+
116
+ Trained with the "train a language model from scratch" project — a from-scratch GPT
117
+ with independently configurable modern techniques (RoPE, GQA, SwiGLU, RMSNorm, MTP,
118
+ mHC, BitNet, TurboQuant) plus Muon/AdamW optimizers and speculative decoding.
119
+
120
+ ## References
121
+
122
+ - [TinyStories](https://arxiv.org/abs/2305.07759)
123
+ - [RoFormer / RoPE](https://arxiv.org/abs/2104.09864)
124
+ - [GQA](https://arxiv.org/abs/2305.13245)
125
+ - [GLU Variants / SwiGLU](https://arxiv.org/abs/2002.05202)
126
+ - [DeepSeek-V3 (MTP)](https://arxiv.org/abs/2412.19437)
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 16384,
3
+ "block_size": 256,
4
+ "n_embd": 384,
5
+ "n_head": 6,
6
+ "n_layer": 8,
7
+ "use_rope": true,
8
+ "n_kv_head": 2,
9
+ "use_swiglu": true,
10
+ "use_rmsnorm": true,
11
+ "use_mtp": true,
12
+ "mtp_heads": 2,
13
+ "mtp_weight": 0.1,
14
+ "tie_mtp_lm_head": true
15
+ }
model.py ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+
8
+ # --- mHC: Manifold-Constrained Hyper-Connections ---
9
+
10
+ def sinkhorn(log_alpha, n_iters=5):
11
+ for _ in range(n_iters):
12
+ log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-1, keepdim=True)
13
+ log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-2, keepdim=True)
14
+ return log_alpha.exp()
15
+
16
+
17
+ class MHCResidual(nn.Module):
18
+ def __init__(self, n_streams):
19
+ super().__init__()
20
+ self.n_streams = n_streams
21
+ self.log_alpha = nn.Parameter(torch.zeros(n_streams, n_streams))
22
+
23
+ def forward(self, streams, update):
24
+ W = sinkhorn(self.log_alpha)
25
+ mixed = torch.einsum("ij,bjte->bite", W, streams)
26
+ mixed[:, 0] = mixed[:, 0] + update
27
+ return mixed
28
+
29
+
30
+ class MHCExpand(nn.Module):
31
+ def __init__(self, n_streams, n_embd):
32
+ super().__init__()
33
+ self.n_streams = n_streams
34
+ self.proj = nn.Linear(n_embd, n_streams * n_embd) if n_streams > 1 else None
35
+
36
+ def forward(self, x):
37
+ if self.n_streams == 1:
38
+ return x.unsqueeze(1)
39
+ B, T, C = x.shape
40
+ return self.proj(x).view(B, self.n_streams, T, C)
41
+
42
+
43
+ class MHCCollapse(nn.Module):
44
+ def __init__(self, n_streams, n_embd):
45
+ super().__init__()
46
+ self.n_streams = n_streams
47
+ self.proj = nn.Linear(n_streams * n_embd, n_embd) if n_streams > 1 else None
48
+
49
+ def forward(self, streams):
50
+ if self.n_streams == 1:
51
+ return streams.squeeze(1)
52
+ B, S, T, C = streams.shape
53
+ return self.proj(streams.permute(0, 2, 1, 3).reshape(B, T, S * C))
54
+
55
+
56
+ # --- BitNet: Ternary weight linear layer ---
57
+
58
+ class BitLinear(nn.Module):
59
+ def __init__(self, in_features, out_features, bias=True):
60
+ super().__init__()
61
+ self.in_features = in_features
62
+ self.out_features = out_features
63
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
64
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
65
+ self.rms_norm = nn.RMSNorm(in_features)
66
+ nn.init.normal_(self.weight, std=0.02)
67
+
68
+ def ternary_quantize(self, w):
69
+ alpha = w.abs().mean()
70
+ threshold = alpha * 0.5
71
+ w_ternary = torch.zeros_like(w)
72
+ w_ternary[w > threshold] = alpha
73
+ w_ternary[w < -threshold] = -alpha
74
+ return w_ternary.detach() + (w - w.detach())
75
+
76
+ def activation_quantize(self, x):
77
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
78
+ x_scaled = x * scale
79
+ x_q = x_scaled.round().clamp(-128, 127).detach() + (x_scaled - x_scaled.detach())
80
+ return x_q / scale
81
+
82
+ def forward(self, x):
83
+ x = self.rms_norm(x)
84
+ w_q = self.ternary_quantize(self.weight)
85
+ x_q = self.activation_quantize(x)
86
+ out = F.linear(x_q, w_q, self.bias)
87
+ return out
88
+
89
+
90
+ class FastBitLinear(nn.Module):
91
+ def __init__(self, in_features, out_features, bias=True):
92
+ super().__init__()
93
+ self.in_features = in_features
94
+ self.out_features = out_features
95
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
96
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
97
+ self.rms_norm = nn.RMSNorm(in_features)
98
+ nn.init.normal_(self.weight, std=0.02)
99
+
100
+ def _int8_forward(self, x):
101
+ w = self.weight.detach()
102
+ alpha = w.abs().mean()
103
+ threshold = alpha * 0.5
104
+ w_pos = (w > threshold).to(torch.int8)
105
+ w_neg = (w < -threshold).to(torch.int8)
106
+
107
+ x_max = x.detach().abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
108
+ x_scale = 127.0 / x_max
109
+ x_q = (x.detach() * x_scale).round().clamp(-128, 127).to(torch.int8)
110
+
111
+ shape = x_q.shape
112
+ x_2d = x_q.reshape(-1, shape[-1])
113
+
114
+ rows = x_2d.shape[0]
115
+ if rows <= 16:
116
+ pad = 17 - rows
117
+ x_2d = torch.nn.functional.pad(x_2d, (0, 0, 0, pad))
118
+ y_pos = torch._int_mm(x_2d, w_pos.T)[:rows]
119
+ y_neg = torch._int_mm(x_2d, w_neg.T)[:rows]
120
+ else:
121
+ y_pos = torch._int_mm(x_2d, w_pos.T)
122
+ y_neg = torch._int_mm(x_2d, w_neg.T)
123
+
124
+ y = (y_pos - y_neg).float().reshape(*shape[:-1], self.out_features)
125
+ return y * (alpha / x_scale)
126
+
127
+ def _ste_forward(self, x):
128
+ alpha = self.weight.abs().mean()
129
+ threshold = alpha * 0.5
130
+ w_ternary = torch.zeros_like(self.weight)
131
+ w_ternary[self.weight > threshold] = alpha
132
+ w_ternary[self.weight < -threshold] = -alpha
133
+ w_q = self.weight + (w_ternary - self.weight).detach()
134
+
135
+ x_scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
136
+ x_scaled = x * x_scale
137
+ x_q = x_scaled + (x_scaled.round().clamp(-128, 127) - x_scaled).detach()
138
+ x_q = x_q / x_scale
139
+
140
+ return F.linear(x_q, w_q, None)
141
+
142
+ def forward(self, x):
143
+ x = self.rms_norm(x)
144
+ if self.training:
145
+ out = self._ste_forward(x)
146
+ else:
147
+ out = self._int8_forward(x)
148
+ if self.bias is not None:
149
+ out = out + self.bias
150
+ return out
151
+
152
+
153
+ def make_linear(in_f, out_f, bias=True, use_bitnet=False, use_fast_bitnet=False):
154
+ if use_fast_bitnet:
155
+ return FastBitLinear(in_f, out_f, bias=bias)
156
+ if use_bitnet:
157
+ return BitLinear(in_f, out_f, bias=bias)
158
+ return nn.Linear(in_f, out_f, bias=bias)
159
+
160
+
161
+ # --- TurboQuant: KV-cache compression for inference ---
162
+
163
+ class PolarQuantizer:
164
+ def __init__(self, bits=4):
165
+ self.bits = bits
166
+ self.levels = 2 ** bits
167
+
168
+ def quantize(self, tensor):
169
+ norms = tensor.norm(dim=-1, keepdim=True).clamp(min=1e-8)
170
+ unit = tensor / norms
171
+ norm_min = norms.min()
172
+ norm_max = norms.max()
173
+ norm_scale = (norm_max - norm_min) / (self.levels - 1)
174
+ q_norms = ((norms - norm_min) / norm_scale.clamp(min=1e-8)).round().clamp(0, self.levels - 1)
175
+ val_min = unit.min()
176
+ val_max = unit.max()
177
+ val_scale = (val_max - val_min) / (self.levels - 1)
178
+ q_unit = ((unit - val_min) / val_scale.clamp(min=1e-8)).round().clamp(0, self.levels - 1)
179
+ return q_norms, q_unit, (norm_min, norm_scale, val_min, val_scale)
180
+
181
+ def dequantize(self, q_norms, q_unit, params):
182
+ norm_min, norm_scale, val_min, val_scale = params
183
+ norms = q_norms * norm_scale + norm_min
184
+ unit = q_unit * val_scale + val_min
185
+ unit = unit / unit.norm(dim=-1, keepdim=True).clamp(min=1e-8)
186
+ return unit * norms
187
+
188
+
189
+ class TurboQuantKVCache:
190
+ def __init__(self, bits=4):
191
+ self.quantizer = PolarQuantizer(bits=bits)
192
+ self.k_cache = []
193
+ self.v_cache = []
194
+
195
+ def update(self, k_new, v_new):
196
+ qk_norms, qk_unit, k_params = self.quantizer.quantize(k_new)
197
+ qv_norms, qv_unit, v_params = self.quantizer.quantize(v_new)
198
+ self.k_cache.append((qk_norms, qk_unit, k_params))
199
+ self.v_cache.append((qv_norms, qv_unit, v_params))
200
+
201
+ def get(self):
202
+ ks = [self.quantizer.dequantize(*entry) for entry in self.k_cache]
203
+ vs = [self.quantizer.dequantize(*entry) for entry in self.v_cache]
204
+ return torch.cat(ks, dim=2), torch.cat(vs, dim=2)
205
+
206
+ def clear(self):
207
+ self.k_cache.clear()
208
+ self.v_cache.clear()
209
+
210
+
211
+ class KVCache:
212
+ def __init__(self, max_seq_len):
213
+ self.max_seq_len = max_seq_len
214
+ self.k_cache = None
215
+ self.v_cache = None
216
+ self.pos = 0
217
+
218
+ def _ensure_allocated(self, k_new, v_new):
219
+ B, H, _, D = k_new.shape
220
+ needs_alloc = (
221
+ self.k_cache is None
222
+ or self.k_cache.shape[0] != B
223
+ or self.k_cache.shape[1] != H
224
+ or self.k_cache.shape[3] != D
225
+ or self.k_cache.device != k_new.device
226
+ or self.k_cache.dtype != k_new.dtype
227
+ )
228
+ if needs_alloc:
229
+ self.k_cache = torch.empty(
230
+ B, H, self.max_seq_len, D,
231
+ device=k_new.device,
232
+ dtype=k_new.dtype,
233
+ )
234
+ self.v_cache = torch.empty(
235
+ B, H, self.max_seq_len, D,
236
+ device=v_new.device,
237
+ dtype=v_new.dtype,
238
+ )
239
+ self.pos = 0
240
+
241
+ def update(self, k_new, v_new):
242
+ self._ensure_allocated(k_new, v_new)
243
+ T = k_new.size(2)
244
+ if self.pos + T > self.max_seq_len:
245
+ raise ValueError(f"KV cache length {self.pos + T} exceeds max_seq_len {self.max_seq_len}")
246
+ self.k_cache[:, :, self.pos:self.pos + T, :].copy_(k_new)
247
+ self.v_cache[:, :, self.pos:self.pos + T, :].copy_(v_new)
248
+ self.pos += T
249
+
250
+ def get(self):
251
+ if self.k_cache is None:
252
+ return None, None
253
+ return self.k_cache[:, :, :self.pos, :], self.v_cache[:, :, :self.pos, :]
254
+
255
+ def clear(self):
256
+ self.pos = 0
257
+
258
+
259
+ # --- MTP: Multi-Token Prediction ---
260
+
261
+ class MTPHead(nn.Module):
262
+ def __init__(self, config, future_idx):
263
+ super().__init__()
264
+ self.future_idx = future_idx
265
+ n_embd = config["n_embd"]
266
+ vocab_size = config["vocab_size"]
267
+ self.proj = nn.Linear(n_embd, n_embd)
268
+ self.ln = nn.LayerNorm(n_embd)
269
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
270
+
271
+ def forward(self, hidden, targets=None):
272
+ if targets is not None:
273
+ shift = self.future_idx
274
+ if targets.size(1) <= shift:
275
+ return None, None
276
+ # Only the first T-shift positions have a future target, so project
277
+ # just those instead of the full sequence (saves a vocab matmul slice).
278
+ h = self.ln(self.proj(hidden[:, :-shift]))
279
+ logits = self.lm_head(h)
280
+ targets_shifted = targets[:, shift:]
281
+ loss = F.cross_entropy(
282
+ logits.reshape(-1, logits.size(-1)),
283
+ targets_shifted.reshape(-1),
284
+ ignore_index=-1,
285
+ )
286
+ return logits, loss
287
+ h = self.ln(self.proj(hidden))
288
+ return self.lm_head(h), None
289
+
290
+
291
+ # --- RoPE: Rotary Position Embeddings ---
292
+
293
+ class RotaryEmbedding(nn.Module):
294
+ def __init__(self, dim, max_seq_len=4096, base=10000.0):
295
+ super().__init__()
296
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
297
+ self.register_buffer("inv_freq", inv_freq)
298
+ self._build_cache(max_seq_len)
299
+
300
+ def _build_cache(self, seq_len):
301
+ t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
302
+ freqs = torch.outer(t, self.inv_freq)
303
+ emb = torch.cat([freqs, freqs], dim=-1)
304
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
305
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
306
+
307
+ def forward(self, seq_len):
308
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
309
+
310
+
311
+ def rotate_half(x):
312
+ x1, x2 = x.chunk(2, dim=-1)
313
+ return torch.cat([-x2, x1], dim=-1)
314
+
315
+
316
+ def apply_rope(q, k, cos, sin):
317
+ cos = cos.unsqueeze(0).unsqueeze(0)
318
+ sin = sin.unsqueeze(0).unsqueeze(0)
319
+ q = q * cos + rotate_half(q) * sin
320
+ k = k * cos + rotate_half(k) * sin
321
+ return q, k
322
+
323
+
324
+ # --- SwiGLU MLP ---
325
+
326
+ class SwiGLU(nn.Module):
327
+ def __init__(self, config):
328
+ super().__init__()
329
+ n_embd = config["n_embd"]
330
+ hidden = int(4 * n_embd * 2 / 3)
331
+ hidden = ((hidden + 63) // 64) * 64
332
+ use_bitnet = config.get("use_bitnet", False)
333
+ use_fast_bitnet = config.get("use_fast_bitnet", False)
334
+ self.gate = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
335
+ self.up = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
336
+ self.down = make_linear(hidden, n_embd, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
337
+
338
+ def forward(self, x):
339
+ return self.down(F.silu(self.gate(x)) * self.up(x))
340
+
341
+
342
+ # --- Core model ---
343
+
344
+ def make_norm(n_embd, use_rmsnorm=False):
345
+ if use_rmsnorm:
346
+ return nn.RMSNorm(n_embd)
347
+ return nn.LayerNorm(n_embd)
348
+
349
+
350
+ class CausalSelfAttention(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.n_head = config["n_head"]
354
+ self.n_embd = config["n_embd"]
355
+ self.n_kv_head = config.get("n_kv_head", self.n_head)
356
+ if self.n_embd % self.n_head != 0:
357
+ raise ValueError(f"n_embd ({self.n_embd}) must be divisible by n_head ({self.n_head})")
358
+ if self.n_head % self.n_kv_head != 0:
359
+ raise ValueError(f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})")
360
+ self.head_dim = self.n_embd // self.n_head
361
+ self.use_rope = config.get("use_rope", False)
362
+ use_bitnet = config.get("use_bitnet", False)
363
+ use_fast_bitnet = config.get("use_fast_bitnet", False)
364
+
365
+ self.q_proj = make_linear(self.n_embd, self.n_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
366
+ self.k_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
367
+ self.v_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
368
+ self.proj = make_linear(self.n_embd, self.n_embd, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
369
+
370
+ if self.use_rope:
371
+ self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.get("block_size", 512))
372
+
373
+ def forward(self, x, kv_cache=None, pos_offset=0):
374
+ B, T, C = x.shape
375
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
376
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
377
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
378
+
379
+ if self.use_rope:
380
+ cos, sin = self.rope(pos_offset + T)
381
+ cos, sin = cos[pos_offset:pos_offset + T], sin[pos_offset:pos_offset + T]
382
+ q, k = apply_rope(q, k, cos, sin)
383
+
384
+ if self.n_kv_head < self.n_head:
385
+ repeats = self.n_head // self.n_kv_head
386
+ k = k.repeat_interleave(repeats, dim=1)
387
+ v = v.repeat_interleave(repeats, dim=1)
388
+
389
+ if kv_cache is not None:
390
+ kv_cache.update(k, v)
391
+ k, v = kv_cache.get()
392
+
393
+ use_causal = (T > 1)
394
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
395
+ out = out.transpose(1, 2).reshape(B, T, C)
396
+ return self.proj(out)
397
+
398
+
399
+ class MLP(nn.Module):
400
+ def __init__(self, config):
401
+ super().__init__()
402
+ use_bitnet = config.get("use_bitnet", False)
403
+ use_fast_bitnet = config.get("use_fast_bitnet", False)
404
+ self.fc = make_linear(config["n_embd"], 4 * config["n_embd"], use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
405
+ self.proj = make_linear(4 * config["n_embd"], config["n_embd"], use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
406
+
407
+ def forward(self, x):
408
+ return self.proj(F.gelu(self.fc(x)))
409
+
410
+
411
+ class Block(nn.Module):
412
+ def __init__(self, config, layer_idx=0):
413
+ super().__init__()
414
+ self.use_mhc = config.get("use_mhc", False)
415
+ use_rmsnorm = config.get("use_rmsnorm", False)
416
+ self.ln1 = make_norm(config["n_embd"], use_rmsnorm)
417
+ self.attn = CausalSelfAttention(config)
418
+ self.ln2 = make_norm(config["n_embd"], use_rmsnorm)
419
+ if config.get("use_swiglu", False):
420
+ self.mlp = SwiGLU(config)
421
+ else:
422
+ self.mlp = MLP(config)
423
+ if self.use_mhc:
424
+ n_streams = config.get("mhc_streams", 4)
425
+ self.mhc_attn = MHCResidual(n_streams)
426
+ self.mhc_mlp = MHCResidual(n_streams)
427
+
428
+ def forward(self, x, streams=None, kv_cache=None, pos_offset=0):
429
+ if self.use_mhc and streams is not None:
430
+ inp = streams[:, 0]
431
+ attn_out = self.attn(self.ln1(inp), kv_cache=kv_cache, pos_offset=pos_offset)
432
+ streams = self.mhc_attn(streams, attn_out)
433
+ mlp_inp = streams[:, 0]
434
+ mlp_out = self.mlp(self.ln2(mlp_inp))
435
+ streams = self.mhc_mlp(streams, mlp_out)
436
+ return streams
437
+ else:
438
+ x = x + self.attn(self.ln1(x), kv_cache=kv_cache, pos_offset=pos_offset)
439
+ x = x + self.mlp(self.ln2(x))
440
+ return x
441
+
442
+
443
+ class GPT(nn.Module):
444
+ def __init__(self, config):
445
+ super().__init__()
446
+ self.config = config
447
+ self.use_mhc = config.get("use_mhc", False)
448
+ self.use_mtp = config.get("use_mtp", False)
449
+ self.use_rope = config.get("use_rope", False)
450
+ self.mtp_heads_n = config.get("mtp_heads", 4)
451
+ self.mtp_weight = config.get("mtp_weight", 0.1)
452
+ self.use_turboquant = config.get("use_turboquant", False)
453
+ self.turboquant_bits = config.get("turboquant_bits", 4)
454
+ self.use_activation_checkpointing = config.get("use_activation_checkpointing", False)
455
+ use_rmsnorm = config.get("use_rmsnorm", False)
456
+
457
+ self.tok_emb = nn.Embedding(config["vocab_size"], config["n_embd"])
458
+ if not self.use_rope:
459
+ self.pos_emb = nn.Embedding(config["block_size"], config["n_embd"])
460
+ self.blocks = nn.ModuleList([Block(config, i) for i in range(config["n_layer"])])
461
+ self.ln_f = make_norm(config["n_embd"], use_rmsnorm)
462
+ self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"], bias=False)
463
+ self.tok_emb.weight = self.lm_head.weight
464
+
465
+ if self.use_mhc:
466
+ n_streams = config.get("mhc_streams", 4)
467
+ self.mhc_expand = MHCExpand(n_streams, config["n_embd"])
468
+ self.mhc_collapse = MHCCollapse(n_streams, config["n_embd"])
469
+
470
+ if self.use_mtp:
471
+ self.mtp_heads = nn.ModuleList([
472
+ MTPHead(config, future_idx=i + 1) for i in range(self.mtp_heads_n)
473
+ ])
474
+ if config.get("tie_mtp_lm_head", True):
475
+ for head in self.mtp_heads:
476
+ head.lm_head.weight = self.lm_head.weight
477
+
478
+ self.apply(self._init_weights)
479
+
480
+ def _init_weights(self, module):
481
+ if isinstance(module, (nn.Linear, BitLinear)):
482
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
483
+ if module.bias is not None:
484
+ torch.nn.init.zeros_(module.bias)
485
+ elif isinstance(module, nn.Embedding):
486
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
487
+
488
+ def _compute_hidden(self, idx):
489
+ B, T = idx.shape
490
+ if T > self.config["block_size"]:
491
+ raise ValueError(f"Input length {T} exceeds block_size {self.config['block_size']}")
492
+ x = self.tok_emb(idx)
493
+ if not self.use_rope:
494
+ pos = torch.arange(T, device=idx.device)
495
+ x = x + self.pos_emb(pos)
496
+
497
+ if self.use_mhc:
498
+ streams = self.mhc_expand(x)
499
+ for block in self.blocks:
500
+ if self.training and self.use_activation_checkpointing:
501
+ streams = checkpoint(lambda s, b=block: b(x, streams=s), streams, use_reentrant=False)
502
+ else:
503
+ streams = block(x, streams=streams)
504
+ x = self.mhc_collapse(streams)
505
+ else:
506
+ for block in self.blocks:
507
+ if self.training and self.use_activation_checkpointing:
508
+ x = checkpoint(block, x, use_reentrant=False)
509
+ else:
510
+ x = block(x)
511
+
512
+ return self.ln_f(x)
513
+
514
+ def forward(self, idx, targets=None, return_hidden=False):
515
+ hidden = self._compute_hidden(idx)
516
+ logits = self.lm_head(hidden)
517
+ loss = None
518
+ if targets is not None:
519
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
520
+ if self.use_mtp:
521
+ for head in self.mtp_heads:
522
+ _, mtp_loss = head(hidden, targets)
523
+ if mtp_loss is not None:
524
+ loss = loss + self.mtp_weight * mtp_loss
525
+ if return_hidden:
526
+ return logits, loss, hidden
527
+ return logits, loss
528
+
529
+ def _forward_inference(self, x, kv_caches, pos_offset=0, return_hidden=False):
530
+ if self.use_mhc:
531
+ streams = self.mhc_expand(x)
532
+ for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
533
+ streams = block(x, streams=streams, kv_cache=cache, pos_offset=pos_offset)
534
+ x = self.mhc_collapse(streams)
535
+ else:
536
+ for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
537
+ x = block(x, kv_cache=cache, pos_offset=pos_offset)
538
+ hidden = self.ln_f(x)
539
+ logits = self.lm_head(hidden)
540
+ if return_hidden:
541
+ return logits, hidden
542
+ return logits
543
+
544
+ def _embed(self, tokens, pos_offset=0):
545
+ x = self.tok_emb(tokens)
546
+ if not self.use_rope:
547
+ T = tokens.shape[1]
548
+ pos = torch.arange(pos_offset, pos_offset + T, device=tokens.device)
549
+ x = x + self.pos_emb(pos)
550
+ return x
551
+
552
+ def _filter_logits(self, logits, top_k=None, top_p=None, min_p=None):
553
+ if top_k is not None and top_k > 0:
554
+ k = min(top_k, logits.size(-1))
555
+ values, _ = torch.topk(logits, k)
556
+ logits = logits.masked_fill(logits < values[:, [-1]], -float("inf"))
557
+
558
+ if min_p is not None and min_p > 0:
559
+ probs = F.softmax(logits, dim=-1)
560
+ max_probs = probs.max(dim=-1, keepdim=True).values
561
+ remove = probs < (min_p * max_probs)
562
+ top_token = logits.argmax(dim=-1, keepdim=True)
563
+ remove.scatter_(dim=-1, index=top_token, value=False)
564
+ logits = logits.masked_fill(remove, -float("inf"))
565
+
566
+ if top_p is not None and 0 < top_p < 1.0:
567
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
568
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
569
+ cumulative_probs = sorted_probs.cumsum(dim=-1)
570
+ sorted_remove = cumulative_probs > top_p
571
+ sorted_remove[..., 1:] = sorted_remove[..., :-1].clone()
572
+ sorted_remove[..., 0] = False
573
+ remove = torch.zeros_like(logits, dtype=torch.bool)
574
+ remove.scatter_(dim=-1, index=sorted_idx, src=sorted_remove)
575
+ logits = logits.masked_fill(remove, -float("inf"))
576
+
577
+ return logits
578
+
579
+ def _distribution(self, logits, temperature=0.8, top_k=40, top_p=None, min_p=None):
580
+ if temperature <= 0:
581
+ token = logits.argmax(dim=-1, keepdim=True)
582
+ probs = torch.zeros_like(logits)
583
+ probs.scatter_(1, token, 1.0)
584
+ return token, probs
585
+ logits = self._filter_logits(logits / temperature, top_k=top_k, top_p=top_p, min_p=min_p)
586
+ probs = F.softmax(logits, dim=-1)
587
+ token = torch.multinomial(probs, num_samples=1)
588
+ return token, probs
589
+
590
+ def _make_kv_caches(self, use_turboquant, use_kv_cache=True):
591
+ if not use_kv_cache:
592
+ return None
593
+ if use_turboquant:
594
+ return [TurboQuantKVCache(bits=self.turboquant_bits) for _ in self.blocks]
595
+ return [KVCache(self.config["block_size"]) for _ in self.blocks]
596
+
597
+ def _trim_or_seed_prompt(self, idx):
598
+ block_size = self.config["block_size"]
599
+ if idx.shape[1] == 0:
600
+ eos_id = 1
601
+ idx = torch.tensor([[eos_id]], dtype=idx.dtype, device=idx.device)
602
+ return idx[:, -block_size:]
603
+
604
+ def _prefill_generation(self, idx, use_turboquant=False, use_kv_cache=True):
605
+ kv_caches = self._make_kv_caches(use_turboquant, use_kv_cache=use_kv_cache)
606
+ seq_len = idx.shape[1]
607
+ x = self._embed(idx)
608
+ logits, hidden = self._forward_inference(x, kv_caches, pos_offset=0, return_hidden=True)
609
+ return logits, hidden[:, -1:, :], kv_caches, seq_len
610
+
611
+ def _advance_generation_state(self, idx, idx_next, kv_caches, seq_len, use_turboquant):
612
+ block_size = self.config["block_size"]
613
+ if kv_caches is not None and seq_len < block_size:
614
+ x = self._embed(idx_next, pos_offset=seq_len)
615
+ logits, hidden = self._forward_inference(x, kv_caches, pos_offset=seq_len, return_hidden=True)
616
+ return logits, hidden[:, -1:, :], kv_caches, seq_len + 1
617
+
618
+ use_kv_cache = kv_caches is not None
619
+ if kv_caches:
620
+ for cache in kv_caches:
621
+ cache.clear()
622
+ idx_cond = idx[:, -block_size:]
623
+ logits, hidden, kv_caches, seq_len = self._prefill_generation(
624
+ idx_cond,
625
+ use_turboquant=use_turboquant,
626
+ use_kv_cache=use_kv_cache,
627
+ )
628
+ return logits, hidden, kv_caches, seq_len
629
+
630
+ def _generate_autoregressive(
631
+ self,
632
+ idx,
633
+ max_new_tokens,
634
+ temperature=0.8,
635
+ top_k=40,
636
+ top_p=None,
637
+ min_p=None,
638
+ use_turboquant=None,
639
+ use_kv_cache=True,
640
+ ):
641
+ idx = self._trim_or_seed_prompt(idx)
642
+ use_turboquant = self.use_turboquant if use_turboquant is None else use_turboquant
643
+ logits, last_hidden, kv_caches, seq_len = self._prefill_generation(
644
+ idx,
645
+ use_turboquant=use_turboquant,
646
+ use_kv_cache=use_kv_cache,
647
+ )
648
+
649
+ for i in range(max_new_tokens):
650
+ idx_next, _ = self._distribution(
651
+ logits[:, -1, :],
652
+ temperature=temperature,
653
+ top_k=top_k,
654
+ top_p=top_p,
655
+ min_p=min_p,
656
+ )
657
+ idx = torch.cat([idx, idx_next], dim=1)
658
+
659
+ if i < max_new_tokens - 1:
660
+ logits, last_hidden, kv_caches, seq_len = self._advance_generation_state(
661
+ idx, idx_next, kv_caches, seq_len, use_turboquant
662
+ )
663
+
664
+ return idx
665
+
666
+ def _mtp_draft(self, last_hidden, n_tokens, temperature=0.8, top_k=40, top_p=None, min_p=None):
667
+ draft_tokens = []
668
+ draft_probs = []
669
+ for head in self.mtp_heads[:n_tokens]:
670
+ draft_logits, _ = head(last_hidden)
671
+ token, probs = self._distribution(
672
+ draft_logits[:, -1, :],
673
+ temperature=temperature,
674
+ top_k=top_k,
675
+ top_p=top_p,
676
+ min_p=min_p,
677
+ )
678
+ draft_tokens.append(token)
679
+ draft_probs.append(probs)
680
+ return draft_tokens, draft_probs
681
+
682
+ def _resample_on_reject(self, target_token, p_probs, q_probs, temperature):
683
+ if temperature <= 0:
684
+ return target_token
685
+ residual = (p_probs - q_probs).clamp(min=0)
686
+ denom = residual.sum(dim=-1, keepdim=True)
687
+ if denom.item() <= 1e-12:
688
+ return target_token
689
+ return torch.multinomial(residual / denom, num_samples=1)
690
+
691
+ def _mtp_speculative_generate(
692
+ self,
693
+ idx,
694
+ max_new_tokens,
695
+ temperature=0.8,
696
+ top_k=40,
697
+ top_p=None,
698
+ min_p=None,
699
+ speculate_tokens=None,
700
+ use_turboquant=None,
701
+ use_kv_cache=True,
702
+ ):
703
+ use_turboquant = self.use_turboquant if use_turboquant is None else use_turboquant
704
+ # Batched verification needs a single sequence, MTP draft heads, and the
705
+ # plain (rollback-able) KV cache. TurboQuant's cache cannot be rolled back
706
+ # token-by-token, so fall back to autoregressive there.
707
+ if not self.use_mtp or idx.size(0) != 1 or not use_kv_cache or use_turboquant:
708
+ return self._generate_autoregressive(
709
+ idx,
710
+ max_new_tokens,
711
+ temperature=temperature,
712
+ top_k=top_k,
713
+ top_p=top_p,
714
+ min_p=min_p,
715
+ use_turboquant=use_turboquant,
716
+ use_kv_cache=use_kv_cache,
717
+ )
718
+
719
+ idx = self._trim_or_seed_prompt(idx)
720
+ block_size = self.config["block_size"]
721
+ draft_width = speculate_tokens or self.mtp_heads_n
722
+ draft_width = max(1, min(draft_width, self.mtp_heads_n))
723
+
724
+ logits, last_hidden, kv_caches, seq_len = self._prefill_generation(
725
+ idx, use_turboquant=False, use_kv_cache=True
726
+ )
727
+ # p0 = main-model logits for the next token (verifies the first draft).
728
+ p0_logits = logits[:, -1, :]
729
+ generated = 0
730
+
731
+ while generated < max_new_tokens:
732
+ remaining = max_new_tokens - generated
733
+ n_draft = min(draft_width, remaining)
734
+
735
+ # No room left in the cache window: take one plain step (this slides the
736
+ # window via re-prefill inside _advance_generation_state) and continue.
737
+ if seq_len + n_draft > block_size:
738
+ idx_next, _ = self._distribution(p0_logits, temperature, top_k, top_p, min_p)
739
+ idx = torch.cat([idx, idx_next], dim=1)
740
+ generated += 1
741
+ if generated < max_new_tokens:
742
+ logits, last_hidden, kv_caches, seq_len = self._advance_generation_state(
743
+ idx, idx_next, kv_caches, seq_len, False
744
+ )
745
+ p0_logits = logits[:, -1, :]
746
+ continue
747
+
748
+ # 1. Draft n tokens cheaply from the MTP heads (no main-model forward).
749
+ draft_tokens, draft_probs = self._mtp_draft(
750
+ last_hidden, n_draft, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p
751
+ )
752
+ draft_seq = torch.cat(draft_tokens, dim=1)
753
+
754
+ # 2. Verify ALL drafts in a SINGLE main-model forward pass.
755
+ x = self._embed(draft_seq, pos_offset=seq_len)
756
+ v_logits, v_hidden = self._forward_inference(
757
+ x, kv_caches, pos_offset=seq_len, return_hidden=True
758
+ )
759
+
760
+ # 3. Walk the drafts left-to-right; draft j is checked against the main
761
+ # distribution at the previous position (p0 for j=0, else v_logits[j-1]).
762
+ accepted = 0
763
+ reject_token = None
764
+ for j in range(n_draft):
765
+ target_logits = p0_logits if j == 0 else v_logits[:, j - 1, :]
766
+ target_token, p_probs = self._distribution(
767
+ target_logits, temperature, top_k, top_p, min_p
768
+ )
769
+ if temperature <= 0:
770
+ accept = torch.equal(draft_tokens[j], target_token)
771
+ else:
772
+ proposed = draft_tokens[j].item()
773
+ p = p_probs[0, proposed]
774
+ q = draft_probs[j][0, proposed].clamp(min=1e-12)
775
+ accept = torch.rand((), device=idx.device) <= torch.minimum(torch.ones_like(p), p / q)
776
+ if accept:
777
+ accepted += 1
778
+ else:
779
+ reject_token = self._resample_on_reject(
780
+ target_token, p_probs, draft_probs[j], temperature
781
+ )
782
+ break
783
+
784
+ if accepted == n_draft:
785
+ # Every draft matched the main model: commit them all. The cache
786
+ # already holds them and v_hidden/v_logits give the next draft state
787
+ # for free (no extra forward, no separate bonus token needed).
788
+ idx = torch.cat([idx, draft_seq], dim=1)
789
+ generated += n_draft
790
+ seq_len += n_draft
791
+ last_hidden = v_hidden[:, -1:, :]
792
+ p0_logits = v_logits[:, -1, :]
793
+ else:
794
+ # Commit the accepted prefix plus the corrected token, then roll the
795
+ # cache back to drop the rejected drafts' (now stale) KV entries.
796
+ commit = torch.cat(draft_tokens[:accepted] + [reject_token], dim=1)
797
+ idx = torch.cat([idx, commit], dim=1)
798
+ generated += accepted + 1
799
+ for cache in kv_caches:
800
+ cache.pos = seq_len + accepted
801
+ seq_len += accepted
802
+ if generated < max_new_tokens:
803
+ # reject_token's KV/hidden are not cached yet; one short forward rebases.
804
+ logits, last_hidden, kv_caches, seq_len = self._advance_generation_state(
805
+ idx, reject_token, kv_caches, seq_len, False
806
+ )
807
+ p0_logits = logits[:, -1, :]
808
+
809
+ return idx
810
+
811
+ def generate(
812
+ self,
813
+ idx,
814
+ max_new_tokens,
815
+ temperature=0.8,
816
+ top_k=40,
817
+ top_p=None,
818
+ min_p=None,
819
+ speculative=False,
820
+ speculate_tokens=None,
821
+ use_turboquant=None,
822
+ use_kv_cache=True,
823
+ ):
824
+ if speculative:
825
+ return self._mtp_speculative_generate(
826
+ idx,
827
+ max_new_tokens,
828
+ temperature=temperature,
829
+ top_k=top_k,
830
+ top_p=top_p,
831
+ min_p=min_p,
832
+ speculate_tokens=speculate_tokens,
833
+ use_turboquant=use_turboquant,
834
+ use_kv_cache=use_kv_cache,
835
+ )
836
+ return self._generate_autoregressive(
837
+ idx,
838
+ max_new_tokens,
839
+ temperature=temperature,
840
+ top_k=top_k,
841
+ top_p=top_p,
842
+ min_p=min_p,
843
+ use_turboquant=use_turboquant,
844
+ use_kv_cache=use_kv_cache,
845
+ )
846
+
847
+
848
+ # --- Configs ---
849
+
850
+ BASE_CONFIG = {
851
+ "vocab_size": 16384,
852
+ "block_size": 512,
853
+ "n_embd": 512,
854
+ "n_head": 8,
855
+ "n_layer": 12,
856
+ }
857
+
858
+ # Individual techniques
859
+ MHC_CONFIG = {**BASE_CONFIG, "use_mhc": True, "mhc_streams": 4}
860
+ BITNET_CONFIG = {**BASE_CONFIG, "use_bitnet": True}
861
+ FAST_BITNET_CONFIG = {**BASE_CONFIG, "use_fast_bitnet": True}
862
+ MTP_CONFIG = {**BASE_CONFIG, "use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1}
863
+ ROPE_CONFIG = {**BASE_CONFIG, "use_rope": True}
864
+ GQA_CONFIG = {**BASE_CONFIG, "n_kv_head": 2}
865
+ SWIGLU_CONFIG = {**BASE_CONFIG, "use_swiglu": True}
866
+ RMSNORM_CONFIG = {**BASE_CONFIG, "use_rmsnorm": True}
867
+ TURBOQUANT_CONFIG = {**BASE_CONFIG, "use_turboquant": True, "turboquant_bits": 4}
868
+
869
+ # Combinations
870
+ MHC_BITNET_CONFIG = {**BASE_CONFIG, "use_mhc": True, "mhc_streams": 4, "use_bitnet": True}
871
+ MHC_MTP_CONFIG = {**BASE_CONFIG, "use_mhc": True, "mhc_streams": 4, "use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1}
872
+
873
+ # Modern LLaMA-style (RoPE + GQA + SwiGLU + RMSNorm)
874
+ MODERN_CONFIG = {**BASE_CONFIG, "use_rope": True, "n_kv_head": 2, "use_swiglu": True, "use_rmsnorm": True}
875
+
876
+ # Everything
877
+ ALL_CONFIG = {
878
+ **BASE_CONFIG,
879
+ "use_mhc": True, "mhc_streams": 4,
880
+ "use_bitnet": True,
881
+ "use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1,
882
+ "use_rope": True, "n_kv_head": 2,
883
+ "use_swiglu": True, "use_rmsnorm": True,
884
+ "use_turboquant": True, "turboquant_bits": 4,
885
+ }
886
+
887
+ RECOMMENDED_CONFIG = {
888
+ **BASE_CONFIG,
889
+ "use_rope": True, "n_kv_head": 2,
890
+ "use_swiglu": True, "use_rmsnorm": True,
891
+ "use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1,
892
+ }
893
+
894
+ FAST_2060_CONFIG = {
895
+ **BASE_CONFIG,
896
+ "block_size": 256,
897
+ "n_embd": 384,
898
+ "n_head": 6,
899
+ "n_layer": 8,
900
+ "use_rope": True,
901
+ "n_kv_head": 2,
902
+ "use_swiglu": True,
903
+ "use_rmsnorm": True,
904
+ }
905
+
906
+ FAST_2060_MTP_CONFIG = {
907
+ **FAST_2060_CONFIG,
908
+ "use_mtp": True,
909
+ "mtp_heads": 2,
910
+ "mtp_weight": 0.1,
911
+ "tie_mtp_lm_head": True,
912
+ }
913
+
914
+ FAST_2060_MTP_FBITNET_CONFIG = {
915
+ **FAST_2060_MTP_CONFIG,
916
+ "use_fast_bitnet": True,
917
+ }
918
+
919
+ FAST_2060_MTP_TURBO_CONFIG = {
920
+ **FAST_2060_MTP_CONFIG,
921
+ "use_turboquant": True,
922
+ "turboquant_bits": 4,
923
+ }
924
+
925
+ TINY_FAST_CONFIG = {
926
+ **BASE_CONFIG,
927
+ "block_size": 256,
928
+ "n_embd": 256,
929
+ "n_head": 4,
930
+ "n_layer": 6,
931
+ "use_rope": True,
932
+ "n_kv_head": 2,
933
+ "use_swiglu": True,
934
+ "use_rmsnorm": True,
935
+ }
936
+
937
+ LOW_MEMORY_2060_CONFIG = {
938
+ **FAST_2060_CONFIG,
939
+ "use_activation_checkpointing": True,
940
+ }
941
+
942
+ CONFIGS = {
943
+ "base": BASE_CONFIG,
944
+ "mhc": MHC_CONFIG,
945
+ "bitnet": BITNET_CONFIG,
946
+ "mtp": MTP_CONFIG,
947
+ "rope": ROPE_CONFIG,
948
+ "gqa": GQA_CONFIG,
949
+ "swiglu": SWIGLU_CONFIG,
950
+ "rmsnorm": RMSNORM_CONFIG,
951
+ "turboquant": TURBOQUANT_CONFIG,
952
+ "mhc_bitnet": MHC_BITNET_CONFIG,
953
+ "mhc_mtp": MHC_MTP_CONFIG,
954
+ "modern": MODERN_CONFIG,
955
+ "all": ALL_CONFIG,
956
+ "recommended": RECOMMENDED_CONFIG,
957
+ "fast_2060": FAST_2060_CONFIG,
958
+ "fast_2060_mtp": FAST_2060_MTP_CONFIG,
959
+ "fast_2060_mtp_fbitnet": FAST_2060_MTP_FBITNET_CONFIG,
960
+ "fast_2060_mtp_turbo": FAST_2060_MTP_TURBO_CONFIG,
961
+ "tiny_fast": TINY_FAST_CONFIG,
962
+ "low_memory_2060": LOW_MEMORY_2060_CONFIG,
963
+ }
964
+
965
+
966
+ def get_model_config(name="fast_2060", **overrides):
967
+ if name not in CONFIGS:
968
+ available = ", ".join(sorted(CONFIGS))
969
+ raise ValueError(f"Unknown config '{name}'. Available configs: {available}")
970
+ return {**CONFIGS[name], **{k: v for k, v in overrides.items() if v is not None}}
971
+
972
+
973
+ MODEL_CONFIG = RECOMMENDED_CONFIG
974
+
975
+ if __name__ == "__main__":
976
+ configs = CONFIGS
977
+ for name, cfg in configs.items():
978
+ model = GPT(cfg)
979
+ n_params = sum(p.numel() for p in model.parameters())
980
+ x = torch.randint(0, cfg["vocab_size"], (2, 64))
981
+ logits, loss = model(x, x)
982
+ print(f"{name:<12} | {n_params:>12,} params ({n_params/1e6:.1f}M) | loss: {loss.item():.2f}")
tinystories-25m.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f08fa57d4360cd654e407322bce66695018c5b9b673df8be5f8c9f5631fe3103
3
+ size 76793291
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff