sllm / model_explained.md
geeteshcodes's picture
Initial commit
7f974df verified

Model Folder β€” Plain Language Explanation

The model/ folder builds a GPT-style decoder-only transformer from scratch, piece by piece. Each file is one component. Here's how they stack:

tokens (integers)
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Embedding  β”‚  config.py defines the shape of everything
β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
       β”‚
       β–Ό  Γ—N layers
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚         TransformerBlock             β”‚  block.py
β”‚                                      β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚   β”‚ RMSNorm  β”‚    β”‚  RMSNorm     β”‚   β”‚  norm.py
β”‚   β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
β”‚        β”‚                 β”‚           β”‚
β”‚   β”Œβ”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚   β”‚Attention β”‚    β”‚  SwiGLU MLP  β”‚   β”‚  attention.py / mlp.py
β”‚   β”‚  + RoPE  β”‚    β”‚              β”‚   β”‚  rope.py
β”‚   β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
β”‚        β”‚  (+residual)    β”‚ (+residual)β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚                 β”‚
         β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                  β”‚
                  β–Ό
           β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
           β”‚ RMSNorm  β”‚  final norm
           β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
                β”‚
           β”Œβ”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”
           β”‚  LM Head β”‚  Linear β†’ vocab_size logits
           β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1. config.py β€” The Blueprint

What it does: Stores all the numbers that define the model size. Nothing computes anything here β€” it's just a settings object.

@dataclass
class ModelConfig:
    vocab_size     = 32_000   # how many tokens exist
    context_length = 1024     # max sequence length
    d_model        = 1024     # width of every vector throughout the model
    n_heads        = 16       # how many attention heads
    n_layers       = 9        # how many transformer blocks stacked
    d_ff           = 2816     # width of the MLP hidden layer (auto-computed)

Why these numbers?

  • d_model is the "resolution" of the model β€” bigger = more expressive but more memory
  • n_heads splits each attention layer into parallel sub-attentions
  • head_dim = d_model / n_heads = 64 β€” each head sees 64-dim slices
  • d_ff for SwiGLU = round_256( 2/3 Γ— 4 Γ— d_model ) β€” compensates for having 3 matrices instead of 2

Presets defined here:

SLLM_100M: d=768,  h=12, l=12  β†’  109.5M params
SLLM_150M: d=1024, h=16, l=9   β†’  148.4M params

2. norm.py β€” RMSNorm

What it does: Normalizes vectors so they don't explode or vanish during training. Used before every attention and MLP layer.

Standard LayerNorm (GPT-2):

1. Compute mean of x
2. Subtract mean  (centering)
3. Divide by std
4. Scale by learned weight
5. Add learned bias

RMSNorm (LLaMA / our model):

1. Compute RMS = sqrt( mean(xΒ²) )   ← no mean subtraction!
2. Divide by RMS
3. Scale by learned weight           ← no bias!

Why simpler is better:

  • No mean subtraction β†’ ~40% faster
  • No bias β†’ fewer parameters
  • Works just as well in practice
  • LLaMA, Mistral, Gemma all use it
# What it computes:
output = (x / sqrt(mean(xΒ²) + 1e-6)) * weight
#          ↑ normalize           ↑ rescale with learned gain

The weight starts at all-ones (no change at init) and is learned during training.


3. rope.py β€” Rotary Position Embedding (RoPE)

The problem it solves: Transformers have no built-in sense of position. Without position encoding, "cat sat on mat" and "mat on sat cat" look identical.

How older models solved it (GPT-2): Added a fixed learned vector to each token: token[i] += position_embedding[i] Problem: can't generalize beyond the training length.

What RoPE does instead: Instead of adding position info to token vectors, it rotates the Query and Key vectors in attention by an angle that depends on their position.

Token at position 3 β†’ rotate Q and K by angle θ₃
Token at position 7 β†’ rotate Q and K by angle θ₇

When you compute attention score QΒ·K, the rotation cancels out in a way that encodes relative distance between tokens, not absolute positions.

Why this is better:

  • No extra parameters (pure math, no learned table)
  • Works beyond training length (extrapolates)
  • Used in LLaMA, Mistral, GPT-4 (likely), Gemma

How the code works:

# Step 1: precompute a table of cos/sin values for every position
cos, sin = precompute_rope_freqs(head_dim=64, max_seq_len=1024)
# cos/sin shape: (1024, 64)

# Step 2: at forward time, rotate Q and K
q_rotated = q * cos + rotate_half(q) * sin
k_rotated = k * cos + rotate_half(k) * sin

# rotate_half(x): splits x in half, negates second half, swaps
# [a, b, c, d] β†’ [-c, -d, a, b]

V (values) are not rotated β€” only Q and K get position encoding.


4. attention.py β€” Causal Self-Attention

What it does: Lets every token look at all previous tokens and decide which ones are relevant to predict the next token.

The full flow:

Input x: (Batch, Tokens, d_model)
         e.g. (2, 1024, 1024)
    β”‚
    β–Ό
QKV projection: one big Linear(d_model β†’ 3Γ—d_model)
    β”‚
    β”œβ”€β”€β”€ Q: (2, 1024, 1024)  β€” "what am I looking for?"
    β”œβ”€β”€β”€ K: (2, 1024, 1024)  β€” "what do I contain?"
    └─── V: (2, 1024, 1024)  β€” "what do I send if attended to?"
    β”‚
    β–Ό
Reshape to heads: (2, 16_heads, 1024, 64_head_dim)
    β”‚
    β–Ό
Apply RoPE to Q and K  ← position encoding happens here
    β”‚
    β–Ό
Scaled Dot-Product Attention:
    scores = Q @ K^T / sqrt(64)    # how much does each token attend to each other
    mask   = causal mask            # can only look LEFT (past), not right (future)
    weights = softmax(scores + mask)
    out    = weights @ V            # weighted sum of values
    β”‚
    β–Ό
Reshape back: (2, 1024, 1024)
    β”‚
    β–Ό
Output projection: Linear(d_model β†’ d_model)

Causal mask β€” this is what makes it a language model (predicts next token):

Position:  0  1  2  3
Token 0:  [βœ“  βœ—  βœ—  βœ—]   can only see itself
Token 1:  [βœ“  βœ“  βœ—  βœ—]   can see 0,1
Token 2:  [βœ“  βœ“  βœ“  βœ—]   can see 0,1,2
Token 3:  [βœ“  βœ“  βœ“  βœ“]   can see all

Flash Attention: We use F.scaled_dot_product_attention(..., is_causal=True) which is PyTorch 2.0's built-in Flash Attention β€” it never materializes the full O(TΒ²) attention matrix in memory. Much faster and uses far less VRAM.


5. mlp.py β€” SwiGLU Feed-Forward Network

What it does: After attention (which mixes between tokens), the MLP transforms each token independently β€” it's where most of the model's "knowledge" is stored.

Standard MLP (GPT-2):

out = W2 @ GELU(W1 @ x)   # 2 matrices

SwiGLU (LLaMA / our model):

gate   = W_gate @ x         # linear
up     = W_up   @ x         # linear
hidden = SiLU(gate) * up    # element-wise gate  ← the key difference
out    = W_down @ hidden     # 3 matrices total

What is SiLU?

SiLU(x) = x Γ— sigmoid(x)

It's a smooth version of ReLU β€” never exactly zero, has a small negative region.

Why gating matters:

  • SiLU(gate) acts as a learned on/off switch for each hidden dimension
  • The model learns to activate only the neurons relevant to each input
  • Empirically outperforms GELU at the same parameter count
  • Used in LLaMA, PaLM, Mistral

The d_ff formula:

d_ff = round_up_256( int(2/3 Γ— 4 Γ— d_model) )

For 150M: round_up_256( int(2/3 Γ— 4 Γ— 1024) ) = round_up_256(2730) = 2816

The 2/3 factor compensates for having 3 matrices instead of 2 β€” keeps total parameter count equal to a standard 4Γ— FFN.


6. block.py β€” TransformerBlock

What it does: Wraps attention + MLP into one reusable block. The model is just N copies of this block stacked.

def forward(x):
    # Attention sub-layer
    x = x + attention( rmsnorm(x) )   # pre-norm + residual
    
    # MLP sub-layer  
    x = x + mlp( rmsnorm(x) )         # pre-norm + residual
    
    return x

Two key ideas:

1. Pre-norm (normalize BEFORE the sublayer):

Pre-norm (LLaMA):   x β†’ norm β†’ attention β†’ + original x
Post-norm (GPT-2):  x β†’ attention β†’ + original x β†’ norm

Pre-norm is more stable at large scale β€” gradients flow more cleanly.

2. Residual connections (x + sublayer(x)): The output of each sublayer is added back to the input, not replacing it. This means:

  • Gradients can skip directly to earlier layers during backprop
  • The model learns corrections to the input, not transformations from scratch
  • Allows stacking many layers without vanishing gradients

7. model.py β€” SLLM (The Full Model)

What it does: Assembles everything into the complete language model.

tokens: (B, T)  ← integer IDs like [423, 1829, 55, ...]
    β”‚
    β–Ό
token_emb: Embedding(32000 β†’ 1024)
    β”‚   converts each integer to a 1024-dim vector
    β–Ό
blocks[0]: TransformerBlock   ─┐
blocks[1]: TransformerBlock    β”‚  9 blocks for 150M
...                            β”‚
blocks[8]: TransformerBlock   β”€β”˜
    β”‚
    β–Ό
norm: RMSNorm(1024)   ← final stabilization
    β”‚
    β–Ό
lm_head: Linear(1024 β†’ 32000)
    β”‚   produces a score for each possible next token
    β–Ό
logits: (B, T, 32000)   ← unnormalized scores

Weight tying: The token_emb matrix and lm_head matrix share the same weights.

self.lm_head.weight = self.token_emb.weight
  • Same matrix used for: embedding lookup (input) AND output projection
  • Saves 32M parameters (32000 Γ— 1024)
  • Works because: if token X has a similar embedding to the current hidden state, it should also score highly as the next token prediction

Loss computation:

# Cross-entropy: at each position, predict the NEXT token
# Input:  [The, cat, sat, on]   β†’ predicts [cat, sat, on, mat]
# targets = input shifted by 1
loss = cross_entropy(logits.view(-1, 32000), targets.view(-1))

Gradient checkpointing (enable_gradient_checkpointing()): Normally PyTorch saves all intermediate activations during forward pass to use in backprop. For 9 layers with batch_size=2 and seq_len=1024, that's ~1.5GB.

With gradient checkpointing:

  • Activations are NOT saved during forward pass
  • During backward pass, they are recomputed on-the-fly
  • Result: ~40% less VRAM, ~30% slower training
  • Essential for fitting 150M on a 4GB GPU

Weight initialization:

# All Linear and Embedding weights: Normal(mean=0, std=0.02)
# Residual projections (o_proj, mlp.down): scaled down by 1/sqrt(2 Γ— n_layers)

The residual scaling prevents the residual stream from growing too large at initialization when many layers add to it.


How it all fits together β€” One forward pass

"The cat sat" β†’ tokenizer β†’ [423, 1829, 55]

token_emb:  [423]β†’[0.1,-0.3,...] (1024 floats)
            [1829]β†’[0.8, 0.2,...] (1024 floats)
            [55]  β†’[-0.1,0.4,...] (1024 floats)

Block 0:
  norm β†’ Q,K,V projections β†’ RoPE rotation β†’ Flash Attention β†’ output proj β†’ + residual
  norm β†’ gate,up projections β†’ SiLU(gate)*up β†’ down proj β†’ + residual

Block 1..8: same

Final norm β†’ LM head β†’ 32000 scores per position

softmax β†’ probabilities β†’ sample next token

Total parameters (150M):

Embedding:   32000 Γ— 1024          =  32.8M
Per block:   attn(4.2M) + mlp(8.6M) + norms(~0M)  =  12.85M
9 blocks:    9 Γ— 12.85M            = 115.6M
Final norm:  1024                  = ~0M
LM head:     TIED to embedding     =   0M  (reuses same weights)
─────────────────────────────────────────
TOTAL:       148.4M params