sllm / model_explained.md
geeteshcodes's picture
Initial commit
7f974df verified
# Model Folder β€” Plain Language Explanation
The `model/` folder builds a **GPT-style decoder-only transformer** from scratch,
piece by piece. Each file is one component. Here's how they stack:
```
tokens (integers)
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Embedding β”‚ config.py defines the shape of everything
β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό Γ—N layers
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TransformerBlock β”‚ block.py
β”‚ β”‚
β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚ β”‚ RMSNorm β”‚ β”‚ RMSNorm β”‚ β”‚ norm.py
β”‚ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚ β”‚ β”‚ β”‚
β”‚ β”Œβ”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚ β”‚Attention β”‚ β”‚ SwiGLU MLP β”‚ β”‚ attention.py / mlp.py
β”‚ β”‚ + RoPE β”‚ β”‚ β”‚ β”‚ rope.py
β”‚ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚ β”‚ (+residual) β”‚ (+residual)β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ RMSNorm β”‚ final norm
β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
β”‚
β”Œβ”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”
β”‚ LM Head β”‚ Linear β†’ vocab_size logits
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
---
## 1. `config.py` β€” The Blueprint
**What it does:** Stores all the numbers that define the model size.
Nothing computes anything here β€” it's just a settings object.
```python
@dataclass
class ModelConfig:
vocab_size = 32_000 # how many tokens exist
context_length = 1024 # max sequence length
d_model = 1024 # width of every vector throughout the model
n_heads = 16 # how many attention heads
n_layers = 9 # how many transformer blocks stacked
d_ff = 2816 # width of the MLP hidden layer (auto-computed)
```
**Why these numbers?**
- `d_model` is the "resolution" of the model β€” bigger = more expressive but more memory
- `n_heads` splits each attention layer into parallel sub-attentions
- `head_dim = d_model / n_heads = 64` β€” each head sees 64-dim slices
- `d_ff` for SwiGLU = `round_256( 2/3 Γ— 4 Γ— d_model )` β€” compensates for having 3 matrices instead of 2
**Presets defined here:**
```
SLLM_100M: d=768, h=12, l=12 β†’ 109.5M params
SLLM_150M: d=1024, h=16, l=9 β†’ 148.4M params
```
---
## 2. `norm.py` β€” RMSNorm
**What it does:** Normalizes vectors so they don't explode or vanish during training.
Used before every attention and MLP layer.
**Standard LayerNorm (GPT-2):**
```
1. Compute mean of x
2. Subtract mean (centering)
3. Divide by std
4. Scale by learned weight
5. Add learned bias
```
**RMSNorm (LLaMA / our model):**
```
1. Compute RMS = sqrt( mean(xΒ²) ) ← no mean subtraction!
2. Divide by RMS
3. Scale by learned weight ← no bias!
```
**Why simpler is better:**
- No mean subtraction β†’ ~40% faster
- No bias β†’ fewer parameters
- Works just as well in practice
- LLaMA, Mistral, Gemma all use it
```python
# What it computes:
output = (x / sqrt(mean(xΒ²) + 1e-6)) * weight
# ↑ normalize ↑ rescale with learned gain
```
The `weight` starts at all-ones (no change at init) and is learned during training.
---
## 3. `rope.py` β€” Rotary Position Embedding (RoPE)
**The problem it solves:** Transformers have no built-in sense of position.
Without position encoding, `"cat sat on mat"` and `"mat on sat cat"` look identical.
**How older models solved it (GPT-2):**
Added a fixed learned vector to each token: `token[i] += position_embedding[i]`
Problem: can't generalize beyond the training length.
**What RoPE does instead:**
Instead of adding position info to token vectors, it **rotates** the Query and Key
vectors in attention by an angle that depends on their position.
```
Token at position 3 β†’ rotate Q and K by angle θ₃
Token at position 7 β†’ rotate Q and K by angle θ₇
```
When you compute attention score `QΒ·K`, the rotation cancels out in a way that
encodes *relative distance* between tokens, not absolute positions.
**Why this is better:**
- No extra parameters (pure math, no learned table)
- Works beyond training length (extrapolates)
- Used in LLaMA, Mistral, GPT-4 (likely), Gemma
**How the code works:**
```python
# Step 1: precompute a table of cos/sin values for every position
cos, sin = precompute_rope_freqs(head_dim=64, max_seq_len=1024)
# cos/sin shape: (1024, 64)
# Step 2: at forward time, rotate Q and K
q_rotated = q * cos + rotate_half(q) * sin
k_rotated = k * cos + rotate_half(k) * sin
# rotate_half(x): splits x in half, negates second half, swaps
# [a, b, c, d] β†’ [-c, -d, a, b]
```
V (values) are **not** rotated β€” only Q and K get position encoding.
---
## 4. `attention.py` β€” Causal Self-Attention
**What it does:** Lets every token look at all *previous* tokens and decide
which ones are relevant to predict the next token.
**The full flow:**
```
Input x: (Batch, Tokens, d_model)
e.g. (2, 1024, 1024)
β”‚
β–Ό
QKV projection: one big Linear(d_model β†’ 3Γ—d_model)
β”‚
β”œβ”€β”€β”€ Q: (2, 1024, 1024) β€” "what am I looking for?"
β”œβ”€β”€β”€ K: (2, 1024, 1024) β€” "what do I contain?"
└─── V: (2, 1024, 1024) β€” "what do I send if attended to?"
β”‚
β–Ό
Reshape to heads: (2, 16_heads, 1024, 64_head_dim)
β”‚
β–Ό
Apply RoPE to Q and K ← position encoding happens here
β”‚
β–Ό
Scaled Dot-Product Attention:
scores = Q @ K^T / sqrt(64) # how much does each token attend to each other
mask = causal mask # can only look LEFT (past), not right (future)
weights = softmax(scores + mask)
out = weights @ V # weighted sum of values
β”‚
β–Ό
Reshape back: (2, 1024, 1024)
β”‚
β–Ό
Output projection: Linear(d_model β†’ d_model)
```
**Causal mask** β€” this is what makes it a *language model* (predicts next token):
```
Position: 0 1 2 3
Token 0: [βœ“ βœ— βœ— βœ—] can only see itself
Token 1: [βœ“ βœ“ βœ— βœ—] can see 0,1
Token 2: [βœ“ βœ“ βœ“ βœ—] can see 0,1,2
Token 3: [βœ“ βœ“ βœ“ βœ“] can see all
```
**Flash Attention:** We use `F.scaled_dot_product_attention(..., is_causal=True)`
which is PyTorch 2.0's built-in Flash Attention β€” it never materializes the full
O(TΒ²) attention matrix in memory. Much faster and uses far less VRAM.
---
## 5. `mlp.py` β€” SwiGLU Feed-Forward Network
**What it does:** After attention (which mixes *between* tokens), the MLP
transforms each token *independently* β€” it's where most of the model's
"knowledge" is stored.
**Standard MLP (GPT-2):**
```python
out = W2 @ GELU(W1 @ x) # 2 matrices
```
**SwiGLU (LLaMA / our model):**
```python
gate = W_gate @ x # linear
up = W_up @ x # linear
hidden = SiLU(gate) * up # element-wise gate ← the key difference
out = W_down @ hidden # 3 matrices total
```
**What is SiLU?**
```
SiLU(x) = x Γ— sigmoid(x)
```
It's a smooth version of ReLU β€” never exactly zero, has a small negative region.
**Why gating matters:**
- `SiLU(gate)` acts as a learned on/off switch for each hidden dimension
- The model learns to activate only the neurons relevant to each input
- Empirically outperforms GELU at the same parameter count
- Used in LLaMA, PaLM, Mistral
**The d_ff formula:**
```
d_ff = round_up_256( int(2/3 Γ— 4 Γ— d_model) )
For 150M: round_up_256( int(2/3 Γ— 4 Γ— 1024) ) = round_up_256(2730) = 2816
```
The `2/3` factor compensates for having 3 matrices instead of 2 β€” keeps
total parameter count equal to a standard 4Γ— FFN.
---
## 6. `block.py` β€” TransformerBlock
**What it does:** Wraps attention + MLP into one reusable block.
The model is just N copies of this block stacked.
```python
def forward(x):
# Attention sub-layer
x = x + attention( rmsnorm(x) ) # pre-norm + residual
# MLP sub-layer
x = x + mlp( rmsnorm(x) ) # pre-norm + residual
return x
```
**Two key ideas:**
**1. Pre-norm (normalize BEFORE the sublayer):**
```
Pre-norm (LLaMA): x β†’ norm β†’ attention β†’ + original x
Post-norm (GPT-2): x β†’ attention β†’ + original x β†’ norm
```
Pre-norm is more stable at large scale β€” gradients flow more cleanly.
**2. Residual connections (`x + sublayer(x)`):**
The output of each sublayer is *added* back to the input, not replacing it.
This means:
- Gradients can skip directly to earlier layers during backprop
- The model learns *corrections* to the input, not transformations from scratch
- Allows stacking many layers without vanishing gradients
---
## 7. `model.py` β€” SLLM (The Full Model)
**What it does:** Assembles everything into the complete language model.
```
tokens: (B, T) ← integer IDs like [423, 1829, 55, ...]
β”‚
β–Ό
token_emb: Embedding(32000 β†’ 1024)
β”‚ converts each integer to a 1024-dim vector
β–Ό
blocks[0]: TransformerBlock ─┐
blocks[1]: TransformerBlock β”‚ 9 blocks for 150M
... β”‚
blocks[8]: TransformerBlock β”€β”˜
β”‚
β–Ό
norm: RMSNorm(1024) ← final stabilization
β”‚
β–Ό
lm_head: Linear(1024 β†’ 32000)
β”‚ produces a score for each possible next token
β–Ό
logits: (B, T, 32000) ← unnormalized scores
```
**Weight tying:**
The `token_emb` matrix and `lm_head` matrix **share the same weights**.
```python
self.lm_head.weight = self.token_emb.weight
```
- Same matrix used for: embedding lookup (input) AND output projection
- Saves 32M parameters (32000 Γ— 1024)
- Works because: if token X has a similar embedding to the current hidden state,
it should also score highly as the next token prediction
**Loss computation:**
```python
# Cross-entropy: at each position, predict the NEXT token
# Input: [The, cat, sat, on] β†’ predicts [cat, sat, on, mat]
# targets = input shifted by 1
loss = cross_entropy(logits.view(-1, 32000), targets.view(-1))
```
**Gradient checkpointing** (`enable_gradient_checkpointing()`):
Normally PyTorch saves all intermediate activations during forward pass to use
in backprop. For 9 layers with batch_size=2 and seq_len=1024, that's ~1.5GB.
With gradient checkpointing:
- Activations are **NOT saved** during forward pass
- During backward pass, they are **recomputed on-the-fly**
- Result: ~40% less VRAM, ~30% slower training
- Essential for fitting 150M on a 4GB GPU
**Weight initialization:**
```python
# All Linear and Embedding weights: Normal(mean=0, std=0.02)
# Residual projections (o_proj, mlp.down): scaled down by 1/sqrt(2 Γ— n_layers)
```
The residual scaling prevents the residual stream from growing too large
at initialization when many layers add to it.
---
## How it all fits together β€” One forward pass
```
"The cat sat" β†’ tokenizer β†’ [423, 1829, 55]
token_emb: [423]β†’[0.1,-0.3,...] (1024 floats)
[1829]β†’[0.8, 0.2,...] (1024 floats)
[55] β†’[-0.1,0.4,...] (1024 floats)
Block 0:
norm β†’ Q,K,V projections β†’ RoPE rotation β†’ Flash Attention β†’ output proj β†’ + residual
norm β†’ gate,up projections β†’ SiLU(gate)*up β†’ down proj β†’ + residual
Block 1..8: same
Final norm β†’ LM head β†’ 32000 scores per position
softmax β†’ probabilities β†’ sample next token
```
**Total parameters (150M):**
```
Embedding: 32000 Γ— 1024 = 32.8M
Per block: attn(4.2M) + mlp(8.6M) + norms(~0M) = 12.85M
9 blocks: 9 Γ— 12.85M = 115.6M
Final norm: 1024 = ~0M
LM head: TIED to embedding = 0M (reuses same weights)
─────────────────────────────────────────
TOTAL: 148.4M params
```