1.3: The Two Phases Defined — Prefill and Decode
Why Two Phases Exist
Now that you understand the KV cache, we can properly define the two phases of autoregressive inference. These phases aren't arbitrary divisions—they emerge naturally from how the KV cache is built and used.
Prefill: Process the entire input prompt, building the initial KV cache.
Decode: Generate output tokens one at a time, using and extending the KV cache.
These phases have fundamentally different computational characteristics. Understanding why requires looking at exactly what operations happen in each phase.
Phase 1: Prefill
What Happens
The user submits a prompt (e.g., "Explain quantum computing in simple terms"). This prompt gets tokenized into, say, 500 tokens. The prefill phase processes all 500 tokens to:
- Build the KV cache for all 500 positions
- Compute the logits for the first generated token
The Computation
def prefill(prompt_tokens):
"""
Process entire prompt in one forward pass.
prompt_tokens: list of N token IDs
Returns: logits for next token, initialized KV cache
"""
N = len(prompt_tokens)
# Embed all tokens at once
hidden_states = embed(prompt_tokens) # Shape: [N, hidden_dim]
kv_cache = {}
for layer_idx, layer in enumerate(transformer_layers):
# Compute Q, K, V for ALL N tokens simultaneously
Q = hidden_states @ W_Q # Shape: [N, num_heads, head_dim]
K = hidden_states @ W_K # Shape: [N, num_heads, head_dim]
V = hidden_states @ W_V # Shape: [N, num_heads, head_dim]
# Store K and V in cache
kv_cache[layer_idx] = {'K': K, 'V': V}
# Full attention: each position attends to itself and all previous
# This is a [N, N] attention matrix (with causal mask)
attention_scores = Q @ K.transpose(-1, -2) / sqrt(d) # [N, N]
attention_scores = apply_causal_mask(attention_scores)
attention_weights = softmax(attention_scores)
attention_output = attention_weights @ V # [N, head_dim]
hidden_states = layer.ffn(layer.norm(attention_output + hidden_states))
# Get logits for next token (only need last position)
next_token_logits = hidden_states[-1] @ W_output
return next_token_logits, kv_cache
Key Characteristics of Prefill
All tokens processed in parallel: Unlike decode where we process one token at a time, prefill processes all N prompt tokens simultaneously. This enables efficient batched matrix multiplications.
Full attention matrix computed: The attention computation involves an [N, N] matrix, where each position computes attention scores against all positions it can see (respecting the causal mask).
KV cache is populated, not read: This is the phase where we build the cache. We compute K and V for all positions and store them. We don't read from a pre-existing cache.
Output is the first generated token: The logits at position N-1 give us the probability distribution for position N (the first token to generate).
Visualizing Prefill
PREFILL PHASE
Input: "The cat sat on the mat" (6 tokens)
┌─────────────────────────────────────────────────────────────────┐
│ │
│ Token embeddings (processed in parallel) │
│ ┌─────┬─────┬─────┬─────┬─────┬─────┐ │
│ │ The │ cat │ sat │ on │ the │ mat │ │
│ └──┬──┴──┬──┴──┬──┴──┬──┴──┬──┴──┬──┘ │
│ │ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ Transformer Layers (×32) │ │
│ │ │ │
│ │ For each layer: │ │
│ │ • Compute Q, K, V for all 6 tokens │ │
│ │ • Store K, V in cache │ │
│ │ • Compute [6×6] attention matrix │ │
│ │ • Apply FFN │ │
│ └─────────────────────────────────────┘ │
│ │ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ ▼ │
│ ┌─────┬─────┬─────┬─────┬─────┬─────┐ │
│ │ h₀ │ h₁ │ h₂ │ h₃ │ h₄ │ h₅ │ Final hidden states │
│ └─────┴─────┴─────┴─────┴──┬──┴─────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Logits (h₅→vocab)│ │
│ │ Sample: "." │ │
│ └─────────────────┘ │
│ │
│ KV Cache now contains: │
│ ┌────────────────────────────────────────┐ │
│ │ Layer 0: K₀,K₁,K₂,K₃,K₄,K₅ │ V₀...V₅ │ │
│ │ Layer 1: K₀,K₁,K₂,K₃,K₄,K₅ │ V₀...V₅ │ │
│ │ ... │ │
│ │ Layer 31: K₀...K₅ │ V₀...V₅ │ │
│ └────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
Phase 2: Decode
What Happens
After prefill, we have the first generated token and an initialized KV cache. Now we enter a loop: generate one token, append it to the sequence, repeat until we hit a stop condition (max length, EOS token, etc.).
The Computation
def decode_one_token(new_token, kv_cache):
"""
Process a single new token, using and extending the KV cache.
new_token: single token ID
kv_cache: existing cache from prefill or previous decode steps
Returns: logits for next token, updated KV cache
"""
# Embed just the one new token
hidden_states = embed([new_token]) # Shape: [1, hidden_dim]
for layer_idx, layer in enumerate(transformer_layers):
# Compute Q, K, V for ONLY the new token
Q_new = hidden_states @ W_Q # Shape: [1, num_heads, head_dim]
K_new = hidden_states @ W_K # Shape: [1, num_heads, head_dim]
V_new = hidden_states @ W_V # Shape: [1, num_heads, head_dim]
# Read cached K and V
K_cached = kv_cache[layer_idx]['K'] # Shape: [seq_len, num_heads, head_dim]
V_cached = kv_cache[layer_idx]['V'] # Shape: [seq_len, num_heads, head_dim]
# Append new K, V to cache
K_full = concat([K_cached, K_new], dim=0) # Shape: [seq_len+1, ...]
V_full = concat([V_cached, V_new], dim=0) # Shape: [seq_len+1, ...]
kv_cache[layer_idx] = {'K': K_full, 'V': V_full}
# Attention: Q_new attends to ALL keys (full sequence)
# This is a [1, seq_len+1] attention computation
attention_scores = Q_new @ K_full.transpose(-1, -2) / sqrt(d) # [1, seq_len+1]
attention_weights = softmax(attention_scores) # [1, seq_len+1]
attention_output = attention_weights @ V_full # [1, head_dim]
hidden_states = layer.ffn(layer.norm(attention_output + hidden_states))
# Get logits for next token
next_token_logits = hidden_states[0] @ W_output
return next_token_logits, kv_cache
def generate(prompt_tokens, max_new_tokens):
"""Full generation loop."""
# Phase 1: Prefill
logits, kv_cache = prefill(prompt_tokens)
generated_token = sample(logits)
output_tokens = [generated_token]
# Phase 2: Decode loop
for _ in range(max_new_tokens - 1):
logits, kv_cache = decode_one_token(generated_token, kv_cache)
generated_token = sample(logits)
output_tokens.append(generated_token)
if generated_token == EOS_TOKEN:
break
return output_tokens
Key Characteristics of Decode
One token at a time: Each decode step processes exactly one token. This is inherently sequential—you can't generate token 5 until you've generated token 4.
Narrow attention computation: Instead of an [N, N] attention matrix, we compute a [1, seq_len] attention vector. Only one query (for the new token) against all keys.
KV cache is read and extended: We read K and V for all previous positions from the cache, compute K and V for only the new position, and append to the cache.
Repeated many times: For generating G tokens, we run decode G times (after the single prefill). Each iteration depends on the previous one.
Visualizing Decode
DECODE PHASE (one step)
KV Cache state: contains K,V for positions 0-5 (from prefill)
New token to process: "." (position 6)
┌─────────────────────────────────────────────────────────────────┐
│ │
│ Input: single token "." │
│ ┌─────┐ │
│ │ . │ │
│ └──┬──┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Transformer Layers (×32) │ │
│ │ │ │
│ │ For each layer: │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ 1. Compute Q₆, K₆, V₆ (for new token only) │ │ │
│ │ │ │ │ │
│ │ │ 2. READ from cache: K₀...K₅, V₀...V₅ │ │ │
│ │ │ ┌─────────────────────────────┐ │ │ │
│ │ │ │ K_cache: [K₀,K₁,K₂,K₃,K₄,K₅]│ │ │ │
│ │ │ │ V_cache: [V₀,V₁,V₂,V₃,V₄,V₅]│ │ │ │
│ │ │ └─────────────────────────────┘ │ │ │
│ │ │ │ │ │
│ │ │ 3. Attention: Q₆ @ [K₀...K₆]ᵀ → [1×7] scores │ │ │
│ │ │ │ │ │
│ │ │ 4. APPEND to cache: K₆, V₆ │ │ │
│ │ │ ┌────────────────────────────────┐ │ │ │
│ │ │ │ K_cache: [K₀,K₁,K₂,K₃,K₄,K₅,K₆]│ │ │ │
│ │ │ │ V_cache: [V₀,V₁,V₂,V₃,V₄,V₅,V₆]│ │ │ │
│ │ │ └────────────────────────────────┘ │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────┐ │
│ │ h₆ │ Hidden state for position 6 │
│ └──┬──┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Logits (h₆→vocab)│ │
│ │ Sample: "The" │ ← next token to generate │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
This process repeats for each token we generate.
Side-by-Side Comparison
Let's directly compare what happens in each phase:
┌────────────────────────┬─────────────────────────┬─────────────────────────┐
│ Aspect │ PREFILL │ DECODE │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Tokens processed │ All prompt tokens (N) │ One token per step │
│ per forward pass │ │ │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Q vectors computed │ N vectors │ 1 vector │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ K, V vectors computed │ N vectors each │ 1 vector each │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Attention matrix shape │ [N, N] │ [1, seq_len] │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ KV cache operation │ WRITE (initialize) │ READ + APPEND │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Parallelism │ High (all tokens │ Low (sequential │
│ │ processed together) │ dependency) │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Number of times run │ Once per request │ Once per output token │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Can be parallelized │ Yes (within the pass) │ No (token i needs │
│ across tokens? │ │ token i-1 first) │
└────────────────────────┴─────────────────────────┴─────────────────────────┘
The Fundamental Difference: Parallelism
The most important difference between prefill and decode is parallelism.
Prefill: Embarrassingly Parallel
During prefill, all N tokens are independent in the sense that we can compute their embeddings, Q/K/V projections, and feed-forward outputs all at once. The attention computation has dependencies (position i can only attend to positions 0 to i), but modern GPUs handle this efficiently with a single batched operation plus masking.
The GPU sees large matrices and can utilize thousands of cores simultaneously:
- Embedding lookup: [N tokens] → parallel
- Q/K/V projection: [N, hidden_dim] @ [hidden_dim, head_dim] → large matrix multiply
- Attention: [N, N] matrix computation → large batched operation
- FFN: [N, hidden_dim] through feed-forward → large matrix multiply
Decode: Fundamentally Sequential
During decode, we cannot parallelize across tokens because each token depends on the previous one. We must:
- Generate token 1
- Wait for it to complete
- Use it to generate token 2
- Wait for it to complete
- ...and so on
Within a single decode step, we can parallelize across layers and heads, but the fundamental unit of work is tiny: processing just 1 token.
The GPU sees small matrices:
- Embedding lookup: [1 token] → trivial
- Q/K/V projection: [1, hidden_dim] @ [hidden_dim, head_dim] → tiny matrix multiply
- Attention: [1, seq_len] → small operation
- FFN: [1, hidden_dim] → small matrix multiply
This difference in parallelism is the root cause of why prefill and decode have such different performance characteristics—which we'll analyze in detail in Artifacts 4 and 5.
A Timeline View
Let's see how a typical request unfolds over time:
Request: 500 token prompt, generate 200 tokens
Time ──────────────────────────────────────────────────────────────────────►
│◄─── Prefill ───►│◄──────────────── Decode ─────────────────────────────►│
│ │ │
│ Process 500 │ Gen Gen Gen Gen Gen Gen │
│ tokens in │ tok tok tok tok ... tok tok │
│ ONE forward │ 1 2 3 4 199 200 │
│ pass │ │
│ │ ◄──► ◄──► ◄──► ◄──► ◄──► ◄──► │
│ │ Each decode step is a separate forward pass │
│ │ │
│ ~50ms │ ~2000ms │
│ (example) │ (example: 10ms per token × 200 tokens) │
│ │ │
Total time breakdown:
├─ Prefill: ~50ms (2.4% of total time)
├─ Decode: ~2000ms (97.6% of total time)
└─ Total: ~2050ms
Even though prefill processes 500 tokens and decode processes 200 tokens,
decode takes ~40× longer because it runs 200 sequential forward passes
vs. prefill's single forward pass.
This is the crucial insight: decode dominates wall-clock time even though it processes fewer tokens, because it cannot be parallelized across tokens.
The Questions This Raises
At this point, you might be wondering:
Why exactly is prefill fast despite processing many tokens? What makes those large matrix operations efficient?
Why exactly is decode slow despite processing one token? Shouldn't less work mean less time?
What determines GPU utilization in each phase? Why is prefill efficient but decode inefficient?
The answers involve understanding the difference between being compute-bound (prefill) versus memory-bandwidth-bound (decode). This is what we'll explore in Artifacts 4 and 5.
Summary
Prefill:
- Processes all prompt tokens in one forward pass
- Builds the initial KV cache
- High parallelism → efficient GPU utilization
- Runs once per request
Decode:
- Processes one token per forward pass
- Reads from and appends to KV cache
- Low parallelism → inefficient GPU utilization
- Runs once per generated token (potentially hundreds of times)
The key insight: Despite decode doing "less work" per step, it dominates total inference time because it runs many sequential steps with poor parallelism.
Check Your Understanding
Before moving on:
If you have a 1000-token prompt and generate 100 tokens, how many forward passes happen during prefill? How many during decode?
During decode step 50, what shape is the attention score matrix? What about during prefill with a 1000-token prompt?
Why can't we parallelize across tokens during decode the way we do during prefill?