| # Model Folder β Plain Language Explanation |
|
|
| The `model/` folder builds a **GPT-style decoder-only transformer** from scratch, |
| piece by piece. Each file is one component. Here's how they stack: |
|
|
| ``` |
| tokens (integers) |
| β |
| βΌ |
| βββββββββββββββ |
| β Embedding β config.py defines the shape of everything |
| ββββββββ¬βββββββ |
| β |
| βΌ ΓN layers |
| ββββββββββββββββββββββββββββββββββββββββ |
| β TransformerBlock β block.py |
| β β |
| β ββββββββββββ ββββββββββββββββ β |
| β β RMSNorm β β RMSNorm β β norm.py |
| β ββββββ¬ββββββ ββββββββ¬ββββββββ β |
| β β β β |
| β ββββββΌββββββ ββββββββΌββββββββ β |
| β βAttention β β SwiGLU MLP β β attention.py / mlp.py |
| β β + RoPE β β β β rope.py |
| β ββββββ¬ββββββ ββββββββ¬ββββββββ β |
| β β (+residual) β (+residual)β |
| ββββββββββΌββββββββββββββββββΌββββββββββββ |
| β β |
| ββββββββββ¬βββββββββ |
| β |
| βΌ |
| ββββββββββββ |
| β RMSNorm β final norm |
| ββββββ¬ββββββ |
| β |
| ββββββΌββββββ |
| β LM Head β Linear β vocab_size logits |
| ββββββββββββ |
| ``` |
|
|
| --- |
|
|
| ## 1. `config.py` β The Blueprint |
|
|
| **What it does:** Stores all the numbers that define the model size. |
| Nothing computes anything here β it's just a settings object. |
|
|
| ```python |
| @dataclass |
| class ModelConfig: |
| vocab_size = 32_000 # how many tokens exist |
| context_length = 1024 # max sequence length |
| d_model = 1024 # width of every vector throughout the model |
| n_heads = 16 # how many attention heads |
| n_layers = 9 # how many transformer blocks stacked |
| d_ff = 2816 # width of the MLP hidden layer (auto-computed) |
| ``` |
|
|
| **Why these numbers?** |
| - `d_model` is the "resolution" of the model β bigger = more expressive but more memory |
| - `n_heads` splits each attention layer into parallel sub-attentions |
| - `head_dim = d_model / n_heads = 64` β each head sees 64-dim slices |
| - `d_ff` for SwiGLU = `round_256( 2/3 Γ 4 Γ d_model )` β compensates for having 3 matrices instead of 2 |
|
|
| **Presets defined here:** |
| ``` |
| SLLM_100M: d=768, h=12, l=12 β 109.5M params |
| SLLM_150M: d=1024, h=16, l=9 β 148.4M params |
| ``` |
|
|
| --- |
|
|
| ## 2. `norm.py` β RMSNorm |
|
|
| **What it does:** Normalizes vectors so they don't explode or vanish during training. |
| Used before every attention and MLP layer. |
|
|
| **Standard LayerNorm (GPT-2):** |
| ``` |
| 1. Compute mean of x |
| 2. Subtract mean (centering) |
| 3. Divide by std |
| 4. Scale by learned weight |
| 5. Add learned bias |
| ``` |
|
|
| **RMSNorm (LLaMA / our model):** |
| ``` |
| 1. Compute RMS = sqrt( mean(xΒ²) ) β no mean subtraction! |
| 2. Divide by RMS |
| 3. Scale by learned weight β no bias! |
| ``` |
|
|
| **Why simpler is better:** |
| - No mean subtraction β ~40% faster |
| - No bias β fewer parameters |
| - Works just as well in practice |
| - LLaMA, Mistral, Gemma all use it |
|
|
| ```python |
| # What it computes: |
| output = (x / sqrt(mean(xΒ²) + 1e-6)) * weight |
| # β normalize β rescale with learned gain |
| ``` |
|
|
| The `weight` starts at all-ones (no change at init) and is learned during training. |
|
|
| --- |
|
|
| ## 3. `rope.py` β Rotary Position Embedding (RoPE) |
|
|
| **The problem it solves:** Transformers have no built-in sense of position. |
| Without position encoding, `"cat sat on mat"` and `"mat on sat cat"` look identical. |
|
|
| **How older models solved it (GPT-2):** |
| Added a fixed learned vector to each token: `token[i] += position_embedding[i]` |
| Problem: can't generalize beyond the training length. |
|
|
| **What RoPE does instead:** |
| Instead of adding position info to token vectors, it **rotates** the Query and Key |
| vectors in attention by an angle that depends on their position. |
|
|
| ``` |
| Token at position 3 β rotate Q and K by angle ΞΈβ |
| Token at position 7 β rotate Q and K by angle ΞΈβ |
| ``` |
|
|
| When you compute attention score `QΒ·K`, the rotation cancels out in a way that |
| encodes *relative distance* between tokens, not absolute positions. |
|
|
| **Why this is better:** |
| - No extra parameters (pure math, no learned table) |
| - Works beyond training length (extrapolates) |
| - Used in LLaMA, Mistral, GPT-4 (likely), Gemma |
|
|
| **How the code works:** |
| ```python |
| # Step 1: precompute a table of cos/sin values for every position |
| cos, sin = precompute_rope_freqs(head_dim=64, max_seq_len=1024) |
| # cos/sin shape: (1024, 64) |
| |
| # Step 2: at forward time, rotate Q and K |
| q_rotated = q * cos + rotate_half(q) * sin |
| k_rotated = k * cos + rotate_half(k) * sin |
| |
| # rotate_half(x): splits x in half, negates second half, swaps |
| # [a, b, c, d] β [-c, -d, a, b] |
| ``` |
|
|
| V (values) are **not** rotated β only Q and K get position encoding. |
|
|
| --- |
|
|
| ## 4. `attention.py` β Causal Self-Attention |
|
|
| **What it does:** Lets every token look at all *previous* tokens and decide |
| which ones are relevant to predict the next token. |
|
|
| **The full flow:** |
|
|
| ``` |
| Input x: (Batch, Tokens, d_model) |
| e.g. (2, 1024, 1024) |
| β |
| βΌ |
| QKV projection: one big Linear(d_model β 3Γd_model) |
| β |
| ββββ Q: (2, 1024, 1024) β "what am I looking for?" |
| ββββ K: (2, 1024, 1024) β "what do I contain?" |
| ββββ V: (2, 1024, 1024) β "what do I send if attended to?" |
| β |
| βΌ |
| Reshape to heads: (2, 16_heads, 1024, 64_head_dim) |
| β |
| βΌ |
| Apply RoPE to Q and K β position encoding happens here |
| β |
| βΌ |
| Scaled Dot-Product Attention: |
| scores = Q @ K^T / sqrt(64) # how much does each token attend to each other |
| mask = causal mask # can only look LEFT (past), not right (future) |
| weights = softmax(scores + mask) |
| out = weights @ V # weighted sum of values |
| β |
| βΌ |
| Reshape back: (2, 1024, 1024) |
| β |
| βΌ |
| Output projection: Linear(d_model β d_model) |
| ``` |
|
|
| **Causal mask** β this is what makes it a *language model* (predicts next token): |
| ``` |
| Position: 0 1 2 3 |
| Token 0: [β β β β] can only see itself |
| Token 1: [β β β β] can see 0,1 |
| Token 2: [β β β β] can see 0,1,2 |
| Token 3: [β β β β] can see all |
| ``` |
|
|
| **Flash Attention:** We use `F.scaled_dot_product_attention(..., is_causal=True)` |
| which is PyTorch 2.0's built-in Flash Attention β it never materializes the full |
| O(TΒ²) attention matrix in memory. Much faster and uses far less VRAM. |
|
|
| --- |
|
|
| ## 5. `mlp.py` β SwiGLU Feed-Forward Network |
|
|
| **What it does:** After attention (which mixes *between* tokens), the MLP |
| transforms each token *independently* β it's where most of the model's |
| "knowledge" is stored. |
|
|
| **Standard MLP (GPT-2):** |
| ```python |
| out = W2 @ GELU(W1 @ x) # 2 matrices |
| ``` |
|
|
| **SwiGLU (LLaMA / our model):** |
| ```python |
| gate = W_gate @ x # linear |
| up = W_up @ x # linear |
| hidden = SiLU(gate) * up # element-wise gate β the key difference |
| out = W_down @ hidden # 3 matrices total |
| ``` |
|
|
| **What is SiLU?** |
| ``` |
| SiLU(x) = x Γ sigmoid(x) |
| ``` |
| It's a smooth version of ReLU β never exactly zero, has a small negative region. |
|
|
| **Why gating matters:** |
| - `SiLU(gate)` acts as a learned on/off switch for each hidden dimension |
| - The model learns to activate only the neurons relevant to each input |
| - Empirically outperforms GELU at the same parameter count |
| - Used in LLaMA, PaLM, Mistral |
|
|
| **The d_ff formula:** |
| ``` |
| d_ff = round_up_256( int(2/3 Γ 4 Γ d_model) ) |
| |
| For 150M: round_up_256( int(2/3 Γ 4 Γ 1024) ) = round_up_256(2730) = 2816 |
| ``` |
| The `2/3` factor compensates for having 3 matrices instead of 2 β keeps |
| total parameter count equal to a standard 4Γ FFN. |
| |
| --- |
| |
| ## 6. `block.py` β TransformerBlock |
| |
| **What it does:** Wraps attention + MLP into one reusable block. |
| The model is just N copies of this block stacked. |
| |
| ```python |
| def forward(x): |
| # Attention sub-layer |
| x = x + attention( rmsnorm(x) ) # pre-norm + residual |
| |
| # MLP sub-layer |
| x = x + mlp( rmsnorm(x) ) # pre-norm + residual |
| |
| return x |
| ``` |
| |
| **Two key ideas:** |
| |
| **1. Pre-norm (normalize BEFORE the sublayer):** |
| ``` |
| Pre-norm (LLaMA): x β norm β attention β + original x |
| Post-norm (GPT-2): x β attention β + original x β norm |
| ``` |
| Pre-norm is more stable at large scale β gradients flow more cleanly. |
| |
| **2. Residual connections (`x + sublayer(x)`):** |
| The output of each sublayer is *added* back to the input, not replacing it. |
| This means: |
| - Gradients can skip directly to earlier layers during backprop |
| - The model learns *corrections* to the input, not transformations from scratch |
| - Allows stacking many layers without vanishing gradients |
| |
| --- |
| |
| ## 7. `model.py` β SLLM (The Full Model) |
| |
| **What it does:** Assembles everything into the complete language model. |
| |
| ``` |
| tokens: (B, T) β integer IDs like [423, 1829, 55, ...] |
| β |
| βΌ |
| token_emb: Embedding(32000 β 1024) |
| β converts each integer to a 1024-dim vector |
| βΌ |
| blocks[0]: TransformerBlock ββ |
| blocks[1]: TransformerBlock β 9 blocks for 150M |
| ... β |
| blocks[8]: TransformerBlock ββ |
| β |
| βΌ |
| norm: RMSNorm(1024) β final stabilization |
| β |
| βΌ |
| lm_head: Linear(1024 β 32000) |
| β produces a score for each possible next token |
| βΌ |
| logits: (B, T, 32000) β unnormalized scores |
| ``` |
| |
| **Weight tying:** |
| The `token_emb` matrix and `lm_head` matrix **share the same weights**. |
| ```python |
| self.lm_head.weight = self.token_emb.weight |
| ``` |
| - Same matrix used for: embedding lookup (input) AND output projection |
| - Saves 32M parameters (32000 Γ 1024) |
| - Works because: if token X has a similar embedding to the current hidden state, |
| it should also score highly as the next token prediction |
| |
| **Loss computation:** |
| ```python |
| # Cross-entropy: at each position, predict the NEXT token |
| # Input: [The, cat, sat, on] β predicts [cat, sat, on, mat] |
| # targets = input shifted by 1 |
| loss = cross_entropy(logits.view(-1, 32000), targets.view(-1)) |
| ``` |
| |
| **Gradient checkpointing** (`enable_gradient_checkpointing()`): |
| Normally PyTorch saves all intermediate activations during forward pass to use |
| in backprop. For 9 layers with batch_size=2 and seq_len=1024, that's ~1.5GB. |
|
|
| With gradient checkpointing: |
| - Activations are **NOT saved** during forward pass |
| - During backward pass, they are **recomputed on-the-fly** |
| - Result: ~40% less VRAM, ~30% slower training |
| - Essential for fitting 150M on a 4GB GPU |
|
|
| **Weight initialization:** |
| ```python |
| # All Linear and Embedding weights: Normal(mean=0, std=0.02) |
| # Residual projections (o_proj, mlp.down): scaled down by 1/sqrt(2 Γ n_layers) |
| ``` |
| The residual scaling prevents the residual stream from growing too large |
| at initialization when many layers add to it. |
|
|
| --- |
|
|
| ## How it all fits together β One forward pass |
|
|
| ``` |
| "The cat sat" β tokenizer β [423, 1829, 55] |
| |
| token_emb: [423]β[0.1,-0.3,...] (1024 floats) |
| [1829]β[0.8, 0.2,...] (1024 floats) |
| [55] β[-0.1,0.4,...] (1024 floats) |
| |
| Block 0: |
| norm β Q,K,V projections β RoPE rotation β Flash Attention β output proj β + residual |
| norm β gate,up projections β SiLU(gate)*up β down proj β + residual |
| |
| Block 1..8: same |
| |
| Final norm β LM head β 32000 scores per position |
| |
| softmax β probabilities β sample next token |
| ``` |
|
|
| **Total parameters (150M):** |
| ``` |
| Embedding: 32000 Γ 1024 = 32.8M |
| Per block: attn(4.2M) + mlp(8.6M) + norms(~0M) = 12.85M |
| 9 blocks: 9 Γ 12.85M = 115.6M |
| Final norm: 1024 = ~0M |
| LM head: TIED to embedding = 0M (reuses same weights) |
| βββββββββββββββββββββββββββββββββββββββββ |
| TOTAL: 148.4M params |
| ``` |
|
|