# Model Folder — Plain Language Explanation The `model/` folder builds a **GPT-style decoder-only transformer** from scratch, piece by piece. Each file is one component. Here's how they stack: ``` tokens (integers) │ ▼ ┌─────────────┐ │ Embedding │ config.py defines the shape of everything └──────┬──────┘ │ ▼ ×N layers ┌──────────────────────────────────────┐ │ TransformerBlock │ block.py │ │ │ ┌──────────┐ ┌──────────────┐ │ │ │ RMSNorm │ │ RMSNorm │ │ norm.py │ └────┬─────┘ └──────┬───────┘ │ │ │ │ │ │ ┌────▼─────┐ ┌──────▼───────┐ │ │ │Attention │ │ SwiGLU MLP │ │ attention.py / mlp.py │ │ + RoPE │ │ │ │ rope.py │ └────┬─────┘ └──────┬───────┘ │ │ │ (+residual) │ (+residual)│ └────────┼─────────────────┼───────────┘ │ │ └────────┬────────┘ │ ▼ ┌──────────┐ │ RMSNorm │ final norm └────┬─────┘ │ ┌────▼─────┐ │ LM Head │ Linear → vocab_size logits └──────────┘ ``` --- ## 1. `config.py` — The Blueprint **What it does:** Stores all the numbers that define the model size. Nothing computes anything here — it's just a settings object. ```python @dataclass class ModelConfig: vocab_size = 32_000 # how many tokens exist context_length = 1024 # max sequence length d_model = 1024 # width of every vector throughout the model n_heads = 16 # how many attention heads n_layers = 9 # how many transformer blocks stacked d_ff = 2816 # width of the MLP hidden layer (auto-computed) ``` **Why these numbers?** - `d_model` is the "resolution" of the model — bigger = more expressive but more memory - `n_heads` splits each attention layer into parallel sub-attentions - `head_dim = d_model / n_heads = 64` — each head sees 64-dim slices - `d_ff` for SwiGLU = `round_256( 2/3 × 4 × d_model )` — compensates for having 3 matrices instead of 2 **Presets defined here:** ``` SLLM_100M: d=768, h=12, l=12 → 109.5M params SLLM_150M: d=1024, h=16, l=9 → 148.4M params ``` --- ## 2. `norm.py` — RMSNorm **What it does:** Normalizes vectors so they don't explode or vanish during training. Used before every attention and MLP layer. **Standard LayerNorm (GPT-2):** ``` 1. Compute mean of x 2. Subtract mean (centering) 3. Divide by std 4. Scale by learned weight 5. Add learned bias ``` **RMSNorm (LLaMA / our model):** ``` 1. Compute RMS = sqrt( mean(x²) ) ← no mean subtraction! 2. Divide by RMS 3. Scale by learned weight ← no bias! ``` **Why simpler is better:** - No mean subtraction → ~40% faster - No bias → fewer parameters - Works just as well in practice - LLaMA, Mistral, Gemma all use it ```python # What it computes: output = (x / sqrt(mean(x²) + 1e-6)) * weight # ↑ normalize ↑ rescale with learned gain ``` The `weight` starts at all-ones (no change at init) and is learned during training. --- ## 3. `rope.py` — Rotary Position Embedding (RoPE) **The problem it solves:** Transformers have no built-in sense of position. Without position encoding, `"cat sat on mat"` and `"mat on sat cat"` look identical. **How older models solved it (GPT-2):** Added a fixed learned vector to each token: `token[i] += position_embedding[i]` Problem: can't generalize beyond the training length. **What RoPE does instead:** Instead of adding position info to token vectors, it **rotates** the Query and Key vectors in attention by an angle that depends on their position. ``` Token at position 3 → rotate Q and K by angle θ₃ Token at position 7 → rotate Q and K by angle θ₇ ``` When you compute attention score `Q·K`, the rotation cancels out in a way that encodes *relative distance* between tokens, not absolute positions. **Why this is better:** - No extra parameters (pure math, no learned table) - Works beyond training length (extrapolates) - Used in LLaMA, Mistral, GPT-4 (likely), Gemma **How the code works:** ```python # Step 1: precompute a table of cos/sin values for every position cos, sin = precompute_rope_freqs(head_dim=64, max_seq_len=1024) # cos/sin shape: (1024, 64) # Step 2: at forward time, rotate Q and K q_rotated = q * cos + rotate_half(q) * sin k_rotated = k * cos + rotate_half(k) * sin # rotate_half(x): splits x in half, negates second half, swaps # [a, b, c, d] → [-c, -d, a, b] ``` V (values) are **not** rotated — only Q and K get position encoding. --- ## 4. `attention.py` — Causal Self-Attention **What it does:** Lets every token look at all *previous* tokens and decide which ones are relevant to predict the next token. **The full flow:** ``` Input x: (Batch, Tokens, d_model) e.g. (2, 1024, 1024) │ ▼ QKV projection: one big Linear(d_model → 3×d_model) │ ├─── Q: (2, 1024, 1024) — "what am I looking for?" ├─── K: (2, 1024, 1024) — "what do I contain?" └─── V: (2, 1024, 1024) — "what do I send if attended to?" │ ▼ Reshape to heads: (2, 16_heads, 1024, 64_head_dim) │ ▼ Apply RoPE to Q and K ← position encoding happens here │ ▼ Scaled Dot-Product Attention: scores = Q @ K^T / sqrt(64) # how much does each token attend to each other mask = causal mask # can only look LEFT (past), not right (future) weights = softmax(scores + mask) out = weights @ V # weighted sum of values │ ▼ Reshape back: (2, 1024, 1024) │ ▼ Output projection: Linear(d_model → d_model) ``` **Causal mask** — this is what makes it a *language model* (predicts next token): ``` Position: 0 1 2 3 Token 0: [✓ ✗ ✗ ✗] can only see itself Token 1: [✓ ✓ ✗ ✗] can see 0,1 Token 2: [✓ ✓ ✓ ✗] can see 0,1,2 Token 3: [✓ ✓ ✓ ✓] can see all ``` **Flash Attention:** We use `F.scaled_dot_product_attention(..., is_causal=True)` which is PyTorch 2.0's built-in Flash Attention — it never materializes the full O(T²) attention matrix in memory. Much faster and uses far less VRAM. --- ## 5. `mlp.py` — SwiGLU Feed-Forward Network **What it does:** After attention (which mixes *between* tokens), the MLP transforms each token *independently* — it's where most of the model's "knowledge" is stored. **Standard MLP (GPT-2):** ```python out = W2 @ GELU(W1 @ x) # 2 matrices ``` **SwiGLU (LLaMA / our model):** ```python gate = W_gate @ x # linear up = W_up @ x # linear hidden = SiLU(gate) * up # element-wise gate ← the key difference out = W_down @ hidden # 3 matrices total ``` **What is SiLU?** ``` SiLU(x) = x × sigmoid(x) ``` It's a smooth version of ReLU — never exactly zero, has a small negative region. **Why gating matters:** - `SiLU(gate)` acts as a learned on/off switch for each hidden dimension - The model learns to activate only the neurons relevant to each input - Empirically outperforms GELU at the same parameter count - Used in LLaMA, PaLM, Mistral **The d_ff formula:** ``` d_ff = round_up_256( int(2/3 × 4 × d_model) ) For 150M: round_up_256( int(2/3 × 4 × 1024) ) = round_up_256(2730) = 2816 ``` The `2/3` factor compensates for having 3 matrices instead of 2 — keeps total parameter count equal to a standard 4× FFN. --- ## 6. `block.py` — TransformerBlock **What it does:** Wraps attention + MLP into one reusable block. The model is just N copies of this block stacked. ```python def forward(x): # Attention sub-layer x = x + attention( rmsnorm(x) ) # pre-norm + residual # MLP sub-layer x = x + mlp( rmsnorm(x) ) # pre-norm + residual return x ``` **Two key ideas:** **1. Pre-norm (normalize BEFORE the sublayer):** ``` Pre-norm (LLaMA): x → norm → attention → + original x Post-norm (GPT-2): x → attention → + original x → norm ``` Pre-norm is more stable at large scale — gradients flow more cleanly. **2. Residual connections (`x + sublayer(x)`):** The output of each sublayer is *added* back to the input, not replacing it. This means: - Gradients can skip directly to earlier layers during backprop - The model learns *corrections* to the input, not transformations from scratch - Allows stacking many layers without vanishing gradients --- ## 7. `model.py` — SLLM (The Full Model) **What it does:** Assembles everything into the complete language model. ``` tokens: (B, T) ← integer IDs like [423, 1829, 55, ...] │ ▼ token_emb: Embedding(32000 → 1024) │ converts each integer to a 1024-dim vector ▼ blocks[0]: TransformerBlock ─┐ blocks[1]: TransformerBlock │ 9 blocks for 150M ... │ blocks[8]: TransformerBlock ─┘ │ ▼ norm: RMSNorm(1024) ← final stabilization │ ▼ lm_head: Linear(1024 → 32000) │ produces a score for each possible next token ▼ logits: (B, T, 32000) ← unnormalized scores ``` **Weight tying:** The `token_emb` matrix and `lm_head` matrix **share the same weights**. ```python self.lm_head.weight = self.token_emb.weight ``` - Same matrix used for: embedding lookup (input) AND output projection - Saves 32M parameters (32000 × 1024) - Works because: if token X has a similar embedding to the current hidden state, it should also score highly as the next token prediction **Loss computation:** ```python # Cross-entropy: at each position, predict the NEXT token # Input: [The, cat, sat, on] → predicts [cat, sat, on, mat] # targets = input shifted by 1 loss = cross_entropy(logits.view(-1, 32000), targets.view(-1)) ``` **Gradient checkpointing** (`enable_gradient_checkpointing()`): Normally PyTorch saves all intermediate activations during forward pass to use in backprop. For 9 layers with batch_size=2 and seq_len=1024, that's ~1.5GB. With gradient checkpointing: - Activations are **NOT saved** during forward pass - During backward pass, they are **recomputed on-the-fly** - Result: ~40% less VRAM, ~30% slower training - Essential for fitting 150M on a 4GB GPU **Weight initialization:** ```python # All Linear and Embedding weights: Normal(mean=0, std=0.02) # Residual projections (o_proj, mlp.down): scaled down by 1/sqrt(2 × n_layers) ``` The residual scaling prevents the residual stream from growing too large at initialization when many layers add to it. --- ## How it all fits together — One forward pass ``` "The cat sat" → tokenizer → [423, 1829, 55] token_emb: [423]→[0.1,-0.3,...] (1024 floats) [1829]→[0.8, 0.2,...] (1024 floats) [55] →[-0.1,0.4,...] (1024 floats) Block 0: norm → Q,K,V projections → RoPE rotation → Flash Attention → output proj → + residual norm → gate,up projections → SiLU(gate)*up → down proj → + residual Block 1..8: same Final norm → LM head → 32000 scores per position softmax → probabilities → sample next token ``` **Total parameters (150M):** ``` Embedding: 32000 × 1024 = 32.8M Per block: attn(4.2M) + mlp(8.6M) + norms(~0M) = 12.85M 9 blocks: 9 × 12.85M = 115.6M Final norm: 1024 = ~0M LM head: TIED to embedding = 0M (reuses same weights) ───────────────────────────────────────── TOTAL: 148.4M params ```