Upgrade to modded-nanoGPT + Muon checkpoint (val 2.65 -> 2.45)
Browse files- README.md +33 -37
- config.json +5 -2
- model.py +64 -18
- tinystories-25m.pt +2 -2
README.md
CHANGED
|
@@ -12,7 +12,8 @@ tags:
|
|
| 12 |
- pytorch
|
| 13 |
- rope
|
| 14 |
- gqa
|
| 15 |
-
-
|
|
|
|
| 16 |
- multi-token-prediction
|
| 17 |
pipeline_tag: text-generation
|
| 18 |
---
|
|
@@ -21,34 +22,40 @@ pipeline_tag: text-generation
|
|
| 21 |
|
| 22 |
A small (~19.2M parameter) decoder-only GPT trained **from scratch** on
|
| 23 |
[TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). It writes
|
| 24 |
-
simple, coherent children's stories and is
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
## Sample output
|
| 29 |
|
| 30 |
-
> **Once upon a time,** there was a little girl named Lily. She loved to play
|
| 31 |
-
>
|
| 32 |
-
>
|
|
|
|
| 33 |
|
| 34 |
-
> **Lily and Tom went to the park and**
|
| 35 |
-
>
|
| 36 |
-
> not judge others. They were happy.
|
| 37 |
|
| 38 |
## Architecture
|
| 39 |
|
| 40 |
-
A LLaMA-style decoder-only transformer
|
| 41 |
|
| 42 |
| Component | Choice |
|
| 43 |
|---|---|
|
| 44 |
| Layers / heads / dim | 8 layers, 6 heads, `n_embd` 384 |
|
| 45 |
| Context length | 256 tokens |
|
| 46 |
| Vocabulary | 16,384 (ByteLevel BPE) |
|
| 47 |
-
| Position encoding | **RoPE**
|
| 48 |
-
| Attention | **Grouped-Query Attention** (2 KV heads) |
|
| 49 |
-
| MLP | **
|
| 50 |
| Normalization | **RMSNorm** |
|
| 51 |
-
|
|
|
|
|
| 52 |
| Weight tying | token embedding ↔ output head (and MTP heads) |
|
| 53 |
|
| 54 |
## Training
|
|
@@ -57,20 +64,17 @@ A LLaMA-style decoder-only transformer with several modern techniques wired in:
|
|
| 57 |
|---|---|
|
| 58 |
| Dataset | TinyStories (~2.1M stories) |
|
| 59 |
| Steps | 3,000 |
|
| 60 |
-
| Batch |
|
| 61 |
-
| Optimizer |
|
| 62 |
-
| Precision | fp16 mixed precision |
|
| 63 |
-
| Hardware | 1× RTX 2060 Super (8 GB), ~
|
| 64 |
-
|
|
| 65 |
-
|
|
| 66 |
-
| Validation loss | 2.65 |
|
| 67 |
-
|
| 68 |
-
This is a lightly trained demo checkpoint; longer training lowers loss further.
|
| 69 |
|
| 70 |
## Usage
|
| 71 |
|
| 72 |
-
This is a **custom architecture**, so you need `model.py` from this repo (
|
| 73 |
-
|
| 74 |
|
| 75 |
```python
|
| 76 |
import torch
|
|
@@ -105,22 +109,14 @@ print(tok.decode(out[0].tolist()))
|
|
| 105 |
|
| 106 |
## Limitations
|
| 107 |
|
| 108 |
-
- Trained only on TinyStories —
|
| 109 |
-
|
| 110 |
-
- Small and lightly trained: it repeats phrases and occasionally drifts or
|
| 111 |
-
contradicts itself (e.g. swapping character names).
|
| 112 |
- 256-token context.
|
| 113 |
|
| 114 |
-
## Source
|
| 115 |
-
|
| 116 |
-
Trained with the "train a language model from scratch" project — a from-scratch GPT
|
| 117 |
-
with independently configurable modern techniques (RoPE, GQA, SwiGLU, RMSNorm, MTP,
|
| 118 |
-
mHC, BitNet, TurboQuant) plus Muon/AdamW optimizers and speculative decoding.
|
| 119 |
-
|
| 120 |
## References
|
| 121 |
|
| 122 |
- [TinyStories](https://arxiv.org/abs/2305.07759)
|
| 123 |
- [RoFormer / RoPE](https://arxiv.org/abs/2104.09864)
|
| 124 |
- [GQA](https://arxiv.org/abs/2305.13245)
|
| 125 |
-
- [GLU Variants / SwiGLU](https://arxiv.org/abs/2002.05202)
|
| 126 |
- [DeepSeek-V3 (MTP)](https://arxiv.org/abs/2412.19437)
|
|
|
|
|
|
| 12 |
- pytorch
|
| 13 |
- rope
|
| 14 |
- gqa
|
| 15 |
+
- qk-norm
|
| 16 |
+
- muon
|
| 17 |
- multi-token-prediction
|
| 18 |
pipeline_tag: text-generation
|
| 19 |
---
|
|
|
|
| 22 |
|
| 23 |
A small (~19.2M parameter) decoder-only GPT trained **from scratch** on
|
| 24 |
[TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). It writes
|
| 25 |
+
simple, coherent children's stories and is a compact, hackable reference for modern
|
| 26 |
+
LLM architecture + optimization techniques — trained end-to-end in a few minutes on a
|
| 27 |
+
single consumer GPU (RTX 2060 Super, 8 GB).
|
| 28 |
+
|
| 29 |
+
This checkpoint uses the **modded-nanoGPT-style recipe**: trained with the **Muon**
|
| 30 |
+
optimizer and **QK-Norm + squared-ReLU MLP + logit soft-capping**, which improved
|
| 31 |
+
validation loss from 2.65 to **2.45** versus a plain AdamW/SwiGLU baseline at the same
|
| 32 |
+
3,000 steps.
|
| 33 |
|
| 34 |
## Sample output
|
| 35 |
|
| 36 |
+
> **Once upon a time,** there was a little girl named Lily. She loved to play outside
|
| 37 |
+
> and explore the world around her. One day, she found a long piece of cardboard on the
|
| 38 |
+
> floor. It was a big, white box with a bow on it. She picked it up and opened it. Inside
|
| 39 |
+
> the box, she found a toy car...
|
| 40 |
|
| 41 |
+
> **Lily and Tom went to the park and** saw a man with a big hat and a big smile. He was
|
| 42 |
+
> very nice... "Sure, you can play with us," Lily said. They played tag and hide and seek.
|
|
|
|
| 43 |
|
| 44 |
## Architecture
|
| 45 |
|
| 46 |
+
A LLaMA-/modded-nanoGPT-style decoder-only transformer:
|
| 47 |
|
| 48 |
| Component | Choice |
|
| 49 |
|---|---|
|
| 50 |
| Layers / heads / dim | 8 layers, 6 heads, `n_embd` 384 |
|
| 51 |
| Context length | 256 tokens |
|
| 52 |
| Vocabulary | 16,384 (ByteLevel BPE) |
|
| 53 |
+
| Position encoding | **RoPE** |
|
| 54 |
+
| Attention | **Grouped-Query Attention** (2 KV heads) + **QK-Norm** |
|
| 55 |
+
| MLP | **squared-ReLU** (ungated) |
|
| 56 |
| Normalization | **RMSNorm** |
|
| 57 |
+
| Logits | **soft-capped** at 15 (`cap·tanh(logits/cap)`) |
|
| 58 |
+
| Extra heads | **Multi-Token Prediction** (2 auxiliary heads) |
|
| 59 |
| Weight tying | token embedding ↔ output head (and MTP heads) |
|
| 60 |
|
| 61 |
## Training
|
|
|
|
| 64 |
|---|---|
|
| 65 |
| Dataset | TinyStories (~2.1M stories) |
|
| 66 |
| Steps | 3,000 |
|
| 67 |
+
| Batch | 40 × 256 tokens |
|
| 68 |
+
| Optimizer | **Muon** (2D weights) + AdamW (embeddings/norms), peak LR 3e-3, cosine schedule |
|
| 69 |
+
| Precision | fp16 mixed precision, `torch.compile` |
|
| 70 |
+
| Hardware | 1× RTX 2060 Super (8 GB), ~11 minutes (~47K tokens/sec) |
|
| 71 |
+
| Train loss | 2.47 (combined next-token + MTP auxiliary) |
|
| 72 |
+
| **Validation loss** | **2.45** (perplexity 11.5) |
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
## Usage
|
| 75 |
|
| 76 |
+
This is a **custom architecture**, so you need `model.py` from this repo (small,
|
| 77 |
+
dependency-light). Download it next to your script, then:
|
| 78 |
|
| 79 |
```python
|
| 80 |
import torch
|
|
|
|
| 109 |
|
| 110 |
## Limitations
|
| 111 |
|
| 112 |
+
- Trained only on TinyStories — simple children's-story English, not a general assistant.
|
| 113 |
+
- Small and lightly trained: occasional repetition, name swaps, or drift.
|
|
|
|
|
|
|
| 114 |
- 256-token context.
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
## References
|
| 117 |
|
| 118 |
- [TinyStories](https://arxiv.org/abs/2305.07759)
|
| 119 |
- [RoFormer / RoPE](https://arxiv.org/abs/2104.09864)
|
| 120 |
- [GQA](https://arxiv.org/abs/2305.13245)
|
|
|
|
| 121 |
- [DeepSeek-V3 (MTP)](https://arxiv.org/abs/2412.19437)
|
| 122 |
+
- [Muon optimizer](https://kellerjordan.github.io/posts/muon/) · [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt)
|
config.json
CHANGED
|
@@ -6,10 +6,13 @@
|
|
| 6 |
"n_layer": 8,
|
| 7 |
"use_rope": true,
|
| 8 |
"n_kv_head": 2,
|
| 9 |
-
"use_swiglu":
|
| 10 |
"use_rmsnorm": true,
|
| 11 |
"use_mtp": true,
|
| 12 |
"mtp_heads": 2,
|
| 13 |
"mtp_weight": 0.1,
|
| 14 |
-
"tie_mtp_lm_head": true
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
|
|
|
| 6 |
"n_layer": 8,
|
| 7 |
"use_rope": true,
|
| 8 |
"n_kv_head": 2,
|
| 9 |
+
"use_swiglu": false,
|
| 10 |
"use_rmsnorm": true,
|
| 11 |
"use_mtp": true,
|
| 12 |
"mtp_heads": 2,
|
| 13 |
"mtp_weight": 0.1,
|
| 14 |
+
"tie_mtp_lm_head": true,
|
| 15 |
+
"use_relu2": true,
|
| 16 |
+
"use_qk_norm": true,
|
| 17 |
+
"logit_cap": 15.0
|
| 18 |
}
|
model.py
CHANGED
|
@@ -5,6 +5,13 @@ import math
|
|
| 5 |
from torch.utils.checkpoint import checkpoint
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# --- mHC: Manifold-Constrained Hyper-Connections ---
|
| 9 |
|
| 10 |
def sinkhorn(log_alpha, n_iters=5):
|
|
@@ -264,28 +271,26 @@ class MTPHead(nn.Module):
|
|
| 264 |
self.future_idx = future_idx
|
| 265 |
n_embd = config["n_embd"]
|
| 266 |
vocab_size = config["vocab_size"]
|
|
|
|
| 267 |
self.proj = nn.Linear(n_embd, n_embd)
|
| 268 |
self.ln = nn.LayerNorm(n_embd)
|
| 269 |
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| 270 |
|
| 271 |
def forward(self, hidden, targets=None):
|
|
|
|
|
|
|
|
|
|
| 272 |
if targets is not None:
|
| 273 |
shift = self.future_idx
|
| 274 |
-
if targets.size(1)
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
targets_shifted.reshape(-1),
|
| 284 |
-
ignore_index=-1,
|
| 285 |
-
)
|
| 286 |
-
return logits, loss
|
| 287 |
-
h = self.ln(self.proj(hidden))
|
| 288 |
-
return self.lm_head(h), None
|
| 289 |
|
| 290 |
|
| 291 |
# --- RoPE: Rotary Position Embeddings ---
|
|
@@ -339,6 +344,23 @@ class SwiGLU(nn.Module):
|
|
| 339 |
return self.down(F.silu(self.gate(x)) * self.up(x))
|
| 340 |
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
# --- Core model ---
|
| 343 |
|
| 344 |
def make_norm(n_embd, use_rmsnorm=False):
|
|
@@ -359,6 +381,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 359 |
raise ValueError(f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})")
|
| 360 |
self.head_dim = self.n_embd // self.n_head
|
| 361 |
self.use_rope = config.get("use_rope", False)
|
|
|
|
| 362 |
use_bitnet = config.get("use_bitnet", False)
|
| 363 |
use_fast_bitnet = config.get("use_fast_bitnet", False)
|
| 364 |
|
|
@@ -367,6 +390,11 @@ class CausalSelfAttention(nn.Module):
|
|
| 367 |
self.v_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
|
| 368 |
self.proj = make_linear(self.n_embd, self.n_embd, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
if self.use_rope:
|
| 371 |
self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.get("block_size", 512))
|
| 372 |
|
|
@@ -376,6 +404,10 @@ class CausalSelfAttention(nn.Module):
|
|
| 376 |
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
|
| 377 |
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
if self.use_rope:
|
| 380 |
cos, sin = self.rope(pos_offset + T)
|
| 381 |
cos, sin = cos[pos_offset:pos_offset + T], sin[pos_offset:pos_offset + T]
|
|
@@ -416,7 +448,9 @@ class Block(nn.Module):
|
|
| 416 |
self.ln1 = make_norm(config["n_embd"], use_rmsnorm)
|
| 417 |
self.attn = CausalSelfAttention(config)
|
| 418 |
self.ln2 = make_norm(config["n_embd"], use_rmsnorm)
|
| 419 |
-
if config.get("
|
|
|
|
|
|
|
| 420 |
self.mlp = SwiGLU(config)
|
| 421 |
else:
|
| 422 |
self.mlp = MLP(config)
|
|
@@ -452,6 +486,7 @@ class GPT(nn.Module):
|
|
| 452 |
self.use_turboquant = config.get("use_turboquant", False)
|
| 453 |
self.turboquant_bits = config.get("turboquant_bits", 4)
|
| 454 |
self.use_activation_checkpointing = config.get("use_activation_checkpointing", False)
|
|
|
|
| 455 |
use_rmsnorm = config.get("use_rmsnorm", False)
|
| 456 |
|
| 457 |
self.tok_emb = nn.Embedding(config["vocab_size"], config["n_embd"])
|
|
@@ -513,7 +548,7 @@ class GPT(nn.Module):
|
|
| 513 |
|
| 514 |
def forward(self, idx, targets=None, return_hidden=False):
|
| 515 |
hidden = self._compute_hidden(idx)
|
| 516 |
-
logits = self.lm_head(hidden)
|
| 517 |
loss = None
|
| 518 |
if targets is not None:
|
| 519 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
|
@@ -536,7 +571,7 @@ class GPT(nn.Module):
|
|
| 536 |
for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
|
| 537 |
x = block(x, kv_cache=cache, pos_offset=pos_offset)
|
| 538 |
hidden = self.ln_f(x)
|
| 539 |
-
logits = self.lm_head(hidden)
|
| 540 |
if return_hidden:
|
| 541 |
return logits, hidden
|
| 542 |
return logits
|
|
@@ -916,6 +951,16 @@ FAST_2060_MTP_FBITNET_CONFIG = {
|
|
| 916 |
"use_fast_bitnet": True,
|
| 917 |
}
|
| 918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 919 |
FAST_2060_MTP_TURBO_CONFIG = {
|
| 920 |
**FAST_2060_MTP_CONFIG,
|
| 921 |
"use_turboquant": True,
|
|
@@ -957,6 +1002,7 @@ CONFIGS = {
|
|
| 957 |
"fast_2060": FAST_2060_CONFIG,
|
| 958 |
"fast_2060_mtp": FAST_2060_MTP_CONFIG,
|
| 959 |
"fast_2060_mtp_fbitnet": FAST_2060_MTP_FBITNET_CONFIG,
|
|
|
|
| 960 |
"fast_2060_mtp_turbo": FAST_2060_MTP_TURBO_CONFIG,
|
| 961 |
"tiny_fast": TINY_FAST_CONFIG,
|
| 962 |
"low_memory_2060": LOW_MEMORY_2060_CONFIG,
|
|
|
|
| 5 |
from torch.utils.checkpoint import checkpoint
|
| 6 |
|
| 7 |
|
| 8 |
+
def soft_cap(logits, cap):
|
| 9 |
+
"""Gemma2/modded-nanoGPT logit soft-capping: cap * tanh(logits / cap). No-op if cap falsy."""
|
| 10 |
+
if cap:
|
| 11 |
+
return cap * torch.tanh(logits / cap)
|
| 12 |
+
return logits
|
| 13 |
+
|
| 14 |
+
|
| 15 |
# --- mHC: Manifold-Constrained Hyper-Connections ---
|
| 16 |
|
| 17 |
def sinkhorn(log_alpha, n_iters=5):
|
|
|
|
| 271 |
self.future_idx = future_idx
|
| 272 |
n_embd = config["n_embd"]
|
| 273 |
vocab_size = config["vocab_size"]
|
| 274 |
+
self.logit_cap = config.get("logit_cap", 0)
|
| 275 |
self.proj = nn.Linear(n_embd, n_embd)
|
| 276 |
self.ln = nn.LayerNorm(n_embd)
|
| 277 |
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| 278 |
|
| 279 |
def forward(self, hidden, targets=None):
|
| 280 |
+
h = self.ln(self.proj(hidden))
|
| 281 |
+
logits = soft_cap(self.lm_head(h), self.logit_cap)
|
| 282 |
+
loss = None
|
| 283 |
if targets is not None:
|
| 284 |
shift = self.future_idx
|
| 285 |
+
if targets.size(1) > shift:
|
| 286 |
+
logits_shifted = logits[:, :-shift].contiguous()
|
| 287 |
+
targets_shifted = targets[:, shift:].contiguous()
|
| 288 |
+
loss = F.cross_entropy(
|
| 289 |
+
logits_shifted.view(-1, logits_shifted.size(-1)),
|
| 290 |
+
targets_shifted.view(-1),
|
| 291 |
+
ignore_index=-1,
|
| 292 |
+
)
|
| 293 |
+
return logits, loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
# --- RoPE: Rotary Position Embeddings ---
|
|
|
|
| 344 |
return self.down(F.silu(self.gate(x)) * self.up(x))
|
| 345 |
|
| 346 |
|
| 347 |
+
class ReLU2MLP(nn.Module):
|
| 348 |
+
"""Ungated MLP with squared-ReLU activation (modded-nanoGPT). Simpler and a bit
|
| 349 |
+
faster than SwiGLU; competitive quality at small scale."""
|
| 350 |
+
|
| 351 |
+
def __init__(self, config):
|
| 352 |
+
super().__init__()
|
| 353 |
+
n_embd = config["n_embd"]
|
| 354 |
+
hidden = 4 * n_embd
|
| 355 |
+
use_bitnet = config.get("use_bitnet", False)
|
| 356 |
+
use_fast_bitnet = config.get("use_fast_bitnet", False)
|
| 357 |
+
self.fc = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
|
| 358 |
+
self.proj = make_linear(hidden, n_embd, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
|
| 359 |
+
|
| 360 |
+
def forward(self, x):
|
| 361 |
+
return self.proj(F.relu(self.fc(x)).square())
|
| 362 |
+
|
| 363 |
+
|
| 364 |
# --- Core model ---
|
| 365 |
|
| 366 |
def make_norm(n_embd, use_rmsnorm=False):
|
|
|
|
| 381 |
raise ValueError(f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})")
|
| 382 |
self.head_dim = self.n_embd // self.n_head
|
| 383 |
self.use_rope = config.get("use_rope", False)
|
| 384 |
+
self.use_qk_norm = config.get("use_qk_norm", False)
|
| 385 |
use_bitnet = config.get("use_bitnet", False)
|
| 386 |
use_fast_bitnet = config.get("use_fast_bitnet", False)
|
| 387 |
|
|
|
|
| 390 |
self.v_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
|
| 391 |
self.proj = make_linear(self.n_embd, self.n_embd, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
|
| 392 |
|
| 393 |
+
# QK-Norm (modded-nanoGPT): RMSNorm Q and K over the head dim before attention.
|
| 394 |
+
if self.use_qk_norm:
|
| 395 |
+
self.q_norm = nn.RMSNorm(self.head_dim)
|
| 396 |
+
self.k_norm = nn.RMSNorm(self.head_dim)
|
| 397 |
+
|
| 398 |
if self.use_rope:
|
| 399 |
self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.get("block_size", 512))
|
| 400 |
|
|
|
|
| 404 |
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
|
| 405 |
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
|
| 406 |
|
| 407 |
+
if self.use_qk_norm:
|
| 408 |
+
q = self.q_norm(q)
|
| 409 |
+
k = self.k_norm(k)
|
| 410 |
+
|
| 411 |
if self.use_rope:
|
| 412 |
cos, sin = self.rope(pos_offset + T)
|
| 413 |
cos, sin = cos[pos_offset:pos_offset + T], sin[pos_offset:pos_offset + T]
|
|
|
|
| 448 |
self.ln1 = make_norm(config["n_embd"], use_rmsnorm)
|
| 449 |
self.attn = CausalSelfAttention(config)
|
| 450 |
self.ln2 = make_norm(config["n_embd"], use_rmsnorm)
|
| 451 |
+
if config.get("use_relu2", False):
|
| 452 |
+
self.mlp = ReLU2MLP(config)
|
| 453 |
+
elif config.get("use_swiglu", False):
|
| 454 |
self.mlp = SwiGLU(config)
|
| 455 |
else:
|
| 456 |
self.mlp = MLP(config)
|
|
|
|
| 486 |
self.use_turboquant = config.get("use_turboquant", False)
|
| 487 |
self.turboquant_bits = config.get("turboquant_bits", 4)
|
| 488 |
self.use_activation_checkpointing = config.get("use_activation_checkpointing", False)
|
| 489 |
+
self.logit_cap = config.get("logit_cap", 0)
|
| 490 |
use_rmsnorm = config.get("use_rmsnorm", False)
|
| 491 |
|
| 492 |
self.tok_emb = nn.Embedding(config["vocab_size"], config["n_embd"])
|
|
|
|
| 548 |
|
| 549 |
def forward(self, idx, targets=None, return_hidden=False):
|
| 550 |
hidden = self._compute_hidden(idx)
|
| 551 |
+
logits = soft_cap(self.lm_head(hidden), self.logit_cap)
|
| 552 |
loss = None
|
| 553 |
if targets is not None:
|
| 554 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
|
|
|
| 571 |
for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
|
| 572 |
x = block(x, kv_cache=cache, pos_offset=pos_offset)
|
| 573 |
hidden = self.ln_f(x)
|
| 574 |
+
logits = soft_cap(self.lm_head(hidden), self.logit_cap)
|
| 575 |
if return_hidden:
|
| 576 |
return logits, hidden
|
| 577 |
return logits
|
|
|
|
| 951 |
"use_fast_bitnet": True,
|
| 952 |
}
|
| 953 |
|
| 954 |
+
# modded-nanoGPT-style recipe. QK-Norm helps under any optimizer; ReLU2 and
|
| 955 |
+
# logit_cap only pay off paired with Muon's higher LR. Train with --optimizer muon.
|
| 956 |
+
FAST_2060_MODDED_CONFIG = {
|
| 957 |
+
**FAST_2060_MTP_CONFIG,
|
| 958 |
+
"use_swiglu": False, # superseded by ReLU2 below
|
| 959 |
+
"use_relu2": True,
|
| 960 |
+
"use_qk_norm": True,
|
| 961 |
+
"logit_cap": 15.0,
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
FAST_2060_MTP_TURBO_CONFIG = {
|
| 965 |
**FAST_2060_MTP_CONFIG,
|
| 966 |
"use_turboquant": True,
|
|
|
|
| 1002 |
"fast_2060": FAST_2060_CONFIG,
|
| 1003 |
"fast_2060_mtp": FAST_2060_MTP_CONFIG,
|
| 1004 |
"fast_2060_mtp_fbitnet": FAST_2060_MTP_FBITNET_CONFIG,
|
| 1005 |
+
"fast_2060_modded": FAST_2060_MODDED_CONFIG,
|
| 1006 |
"fast_2060_mtp_turbo": FAST_2060_MTP_TURBO_CONFIG,
|
| 1007 |
"tiny_fast": TINY_FAST_CONFIG,
|
| 1008 |
"low_memory_2060": LOW_MEMORY_2060_CONFIG,
|
tinystories-25m.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69375d07a06ef3b325f3189b23b0caf21a7983fc1e87316b0f5651c579331af3
|
| 3 |
+
size 76800459
|