1.3: The Two Phases Defined — Prefill and Decode

Community Article Published January 26, 2026

Why Two Phases Exist

Now that you understand the KV cache, we can properly define the two phases of autoregressive inference. These phases aren't arbitrary divisions—they emerge naturally from how the KV cache is built and used.

Prefill: Process the entire input prompt, building the initial KV cache.

Decode: Generate output tokens one at a time, using and extending the KV cache.

These phases have fundamentally different computational characteristics. Understanding why requires looking at exactly what operations happen in each phase.


Phase 1: Prefill

What Happens

The user submits a prompt (e.g., "Explain quantum computing in simple terms"). This prompt gets tokenized into, say, 500 tokens. The prefill phase processes all 500 tokens to:

  1. Build the KV cache for all 500 positions
  2. Compute the logits for the first generated token

The Computation

def prefill(prompt_tokens):
    """
    Process entire prompt in one forward pass.
    prompt_tokens: list of N token IDs
    Returns: logits for next token, initialized KV cache
    """
    N = len(prompt_tokens)
    
    # Embed all tokens at once
    hidden_states = embed(prompt_tokens)  # Shape: [N, hidden_dim]
    
    kv_cache = {}
    
    for layer_idx, layer in enumerate(transformer_layers):
        # Compute Q, K, V for ALL N tokens simultaneously
        Q = hidden_states @ W_Q  # Shape: [N, num_heads, head_dim]
        K = hidden_states @ W_K  # Shape: [N, num_heads, head_dim]
        V = hidden_states @ W_V  # Shape: [N, num_heads, head_dim]
        
        # Store K and V in cache
        kv_cache[layer_idx] = {'K': K, 'V': V}
        
        # Full attention: each position attends to itself and all previous
        # This is a [N, N] attention matrix (with causal mask)
        attention_scores = Q @ K.transpose(-1, -2) / sqrt(d)  # [N, N]
        attention_scores = apply_causal_mask(attention_scores)
        attention_weights = softmax(attention_scores)
        attention_output = attention_weights @ V  # [N, head_dim]
        
        hidden_states = layer.ffn(layer.norm(attention_output + hidden_states))
    
    # Get logits for next token (only need last position)
    next_token_logits = hidden_states[-1] @ W_output
    
    return next_token_logits, kv_cache

Key Characteristics of Prefill

All tokens processed in parallel: Unlike decode where we process one token at a time, prefill processes all N prompt tokens simultaneously. This enables efficient batched matrix multiplications.

Full attention matrix computed: The attention computation involves an [N, N] matrix, where each position computes attention scores against all positions it can see (respecting the causal mask).

KV cache is populated, not read: This is the phase where we build the cache. We compute K and V for all positions and store them. We don't read from a pre-existing cache.

Output is the first generated token: The logits at position N-1 give us the probability distribution for position N (the first token to generate).

Visualizing Prefill

                              PREFILL PHASE
                              
    Input: "The cat sat on the mat" (6 tokens)
    
    ┌─────────────────────────────────────────────────────────────────┐
    │                                                                 │
    │   Token embeddings (processed in parallel)                      │
    │   ┌─────┬─────┬─────┬─────┬─────┬─────┐                        │
    │   │ The │ cat │ sat │ on  │ the │ mat │                        │
    │   └──┬──┴──┬──┴──┬──┴──┬──┴──┬──┴──┬──┘                        │
    │      │     │     │     │     │     │                            │
    │      ▼     ▼     ▼     ▼     ▼     ▼                            │
    │   ┌─────────────────────────────────────┐                       │
    │   │     Transformer Layers (×32)        │                       │
    │   │                                     │                       │
    │   │  For each layer:                    │                       │
    │   │  • Compute Q, K, V for all 6 tokens │                       │
    │   │  • Store K, V in cache              │                       │
    │   │  • Compute [6×6] attention matrix   │                       │
    │   │  • Apply FFN                        │                       │
    │   └─────────────────────────────────────┘                       │
    │      │     │     │     │     │     │                            │
    │      ▼     ▼     ▼     ▼     ▼     ▼                            │
    │   ┌─────┬─────┬─────┬─────┬─────┬─────┐                        │
    │   │ h₀  │ h₁  │ h₂  │ h₃  │ h₄  │ h₅  │  Final hidden states   │
    │   └─────┴─────┴─────┴─────┴──┬──┴─────┘                        │
    │                              │                                  │
    │                              ▼                                  │
    │                    ┌─────────────────┐                          │
    │                    │ Logits (h₅→vocab)│                         │
    │                    │ Sample: "."      │                         │
    │                    └─────────────────┘                          │
    │                                                                 │
    │   KV Cache now contains:                                        │
    │   ┌────────────────────────────────────────┐                    │
    │   │ Layer 0: K₀,K₁,K₂,K₃,K₄,K₅ │ V₀...V₅  │                    │
    │   │ Layer 1: K₀,K₁,K₂,K₃,K₄,K₅ │ V₀...V₅  │                    │
    │   │ ...                                    │                    │
    │   │ Layer 31: K₀...K₅ │ V₀...V₅            │                    │
    │   └────────────────────────────────────────┘                    │
    │                                                                 │
    └─────────────────────────────────────────────────────────────────┘

Phase 2: Decode

What Happens

After prefill, we have the first generated token and an initialized KV cache. Now we enter a loop: generate one token, append it to the sequence, repeat until we hit a stop condition (max length, EOS token, etc.).

The Computation

def decode_one_token(new_token, kv_cache):
    """
    Process a single new token, using and extending the KV cache.
    new_token: single token ID
    kv_cache: existing cache from prefill or previous decode steps
    Returns: logits for next token, updated KV cache
    """
    # Embed just the one new token
    hidden_states = embed([new_token])  # Shape: [1, hidden_dim]
    
    for layer_idx, layer in enumerate(transformer_layers):
        # Compute Q, K, V for ONLY the new token
        Q_new = hidden_states @ W_Q  # Shape: [1, num_heads, head_dim]
        K_new = hidden_states @ W_K  # Shape: [1, num_heads, head_dim]
        V_new = hidden_states @ W_V  # Shape: [1, num_heads, head_dim]
        
        # Read cached K and V
        K_cached = kv_cache[layer_idx]['K']  # Shape: [seq_len, num_heads, head_dim]
        V_cached = kv_cache[layer_idx]['V']  # Shape: [seq_len, num_heads, head_dim]
        
        # Append new K, V to cache
        K_full = concat([K_cached, K_new], dim=0)  # Shape: [seq_len+1, ...]
        V_full = concat([V_cached, V_new], dim=0)  # Shape: [seq_len+1, ...]
        kv_cache[layer_idx] = {'K': K_full, 'V': V_full}
        
        # Attention: Q_new attends to ALL keys (full sequence)
        # This is a [1, seq_len+1] attention computation
        attention_scores = Q_new @ K_full.transpose(-1, -2) / sqrt(d)  # [1, seq_len+1]
        attention_weights = softmax(attention_scores)  # [1, seq_len+1]
        attention_output = attention_weights @ V_full  # [1, head_dim]
        
        hidden_states = layer.ffn(layer.norm(attention_output + hidden_states))
    
    # Get logits for next token
    next_token_logits = hidden_states[0] @ W_output
    
    return next_token_logits, kv_cache


def generate(prompt_tokens, max_new_tokens):
    """Full generation loop."""
    # Phase 1: Prefill
    logits, kv_cache = prefill(prompt_tokens)
    generated_token = sample(logits)
    output_tokens = [generated_token]
    
    # Phase 2: Decode loop
    for _ in range(max_new_tokens - 1):
        logits, kv_cache = decode_one_token(generated_token, kv_cache)
        generated_token = sample(logits)
        output_tokens.append(generated_token)
        
        if generated_token == EOS_TOKEN:
            break
    
    return output_tokens

Key Characteristics of Decode

One token at a time: Each decode step processes exactly one token. This is inherently sequential—you can't generate token 5 until you've generated token 4.

Narrow attention computation: Instead of an [N, N] attention matrix, we compute a [1, seq_len] attention vector. Only one query (for the new token) against all keys.

KV cache is read and extended: We read K and V for all previous positions from the cache, compute K and V for only the new position, and append to the cache.

Repeated many times: For generating G tokens, we run decode G times (after the single prefill). Each iteration depends on the previous one.

Visualizing Decode

                              DECODE PHASE (one step)
                              
    KV Cache state: contains K,V for positions 0-5 (from prefill)
    New token to process: "." (position 6)
    
    ┌─────────────────────────────────────────────────────────────────┐
    │                                                                 │
    │   Input: single token "."                                       │
    │   ┌─────┐                                                       │
    │   │  .  │                                                       │
    │   └──┬──┘                                                       │
    │      │                                                          │
    │      ▼                                                          │
    │   ┌─────────────────────────────────────────────────────────┐   │
    │   │              Transformer Layers (×32)                   │   │
    │   │                                                         │   │
    │   │  For each layer:                                        │   │
    │   │  ┌─────────────────────────────────────────────────┐   │   │
    │   │  │ 1. Compute Q₆, K₆, V₆ (for new token only)      │   │   │
    │   │  │                                                  │   │   │
    │   │  │ 2. READ from cache: K₀...K₅, V₀...V₅            │   │   │
    │   │  │    ┌─────────────────────────────┐              │   │   │
    │   │  │    │ K_cache: [K₀,K₁,K₂,K₃,K₄,K₅]│              │   │   │
    │   │  │    │ V_cache: [V₀,V₁,V₂,V₃,V₄,V₅]│              │   │   │
    │   │  │    └─────────────────────────────┘              │   │   │
    │   │  │                                                  │   │   │
    │   │  │ 3. Attention: Q₆ @ [K₀...K₆]ᵀ → [1×7] scores   │   │   │
    │   │  │                                                  │   │   │
    │   │  │ 4. APPEND to cache: K₆, V₆                      │   │   │
    │   │  │    ┌────────────────────────────────┐           │   │   │
    │   │  │    │ K_cache: [K₀,K₁,K₂,K₃,K₄,K₅,K₆]│           │   │   │
    │   │  │    │ V_cache: [V₀,V₁,V₂,V₃,V₄,V₅,V₆]│           │   │   │
    │   │  │    └────────────────────────────────┘           │   │   │
    │   │  └─────────────────────────────────────────────────┘   │   │
    │   └─────────────────────────────────────────────────────────┘   │
    │      │                                                          │
    │      ▼                                                          │
    │   ┌─────┐                                                       │
    │   │ h₆  │  Hidden state for position 6                          │
    │   └──┬──┘                                                       │
    │      │                                                          │
    │      ▼                                                          │
    │   ┌─────────────────┐                                           │
    │   │ Logits (h₆→vocab)│                                          │
    │   │ Sample: "The"    │  ← next token to generate                │
    │   └─────────────────┘                                           │
    │                                                                 │
    └─────────────────────────────────────────────────────────────────┘
    
    This process repeats for each token we generate.

Side-by-Side Comparison

Let's directly compare what happens in each phase:

┌────────────────────────┬─────────────────────────┬─────────────────────────┐
│        Aspect          │        PREFILL          │         DECODE          │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Tokens processed       │ All prompt tokens (N)   │ One token per step      │
│ per forward pass       │                         │                         │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Q vectors computed     │ N vectors               │ 1 vector                │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ K, V vectors computed  │ N vectors each          │ 1 vector each           │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Attention matrix shape │ [N, N]                  │ [1, seq_len]            │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ KV cache operation     │ WRITE (initialize)      │ READ + APPEND           │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Parallelism            │ High (all tokens        │ Low (sequential         │
│                        │ processed together)     │ dependency)             │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Number of times run    │ Once per request        │ Once per output token   │
├────────────────────────┼─────────────────────────┼─────────────────────────┤
│ Can be parallelized    │ Yes (within the pass)   │ No (token i needs       │
│ across tokens?         │                         │ token i-1 first)        │
└────────────────────────┴─────────────────────────┴─────────────────────────┘

The Fundamental Difference: Parallelism

The most important difference between prefill and decode is parallelism.

Prefill: Embarrassingly Parallel

During prefill, all N tokens are independent in the sense that we can compute their embeddings, Q/K/V projections, and feed-forward outputs all at once. The attention computation has dependencies (position i can only attend to positions 0 to i), but modern GPUs handle this efficiently with a single batched operation plus masking.

The GPU sees large matrices and can utilize thousands of cores simultaneously:

  • Embedding lookup: [N tokens] → parallel
  • Q/K/V projection: [N, hidden_dim] @ [hidden_dim, head_dim] → large matrix multiply
  • Attention: [N, N] matrix computation → large batched operation
  • FFN: [N, hidden_dim] through feed-forward → large matrix multiply

Decode: Fundamentally Sequential

During decode, we cannot parallelize across tokens because each token depends on the previous one. We must:

  1. Generate token 1
  2. Wait for it to complete
  3. Use it to generate token 2
  4. Wait for it to complete
  5. ...and so on

Within a single decode step, we can parallelize across layers and heads, but the fundamental unit of work is tiny: processing just 1 token.

The GPU sees small matrices:

  • Embedding lookup: [1 token] → trivial
  • Q/K/V projection: [1, hidden_dim] @ [hidden_dim, head_dim] → tiny matrix multiply
  • Attention: [1, seq_len] → small operation
  • FFN: [1, hidden_dim] → small matrix multiply

This difference in parallelism is the root cause of why prefill and decode have such different performance characteristics—which we'll analyze in detail in Artifacts 4 and 5.


A Timeline View

Let's see how a typical request unfolds over time:

Request: 500 token prompt, generate 200 tokens

Time ──────────────────────────────────────────────────────────────────────►

│◄─── Prefill ───►│◄──────────────── Decode ─────────────────────────────►│
│                 │                                                        │
│  Process 500    │  Gen    Gen    Gen    Gen           Gen    Gen        │
│  tokens in      │  tok    tok    tok    tok    ...    tok    tok        │
│  ONE forward    │   1      2      3      4            199    200        │
│  pass           │                                                        │
│                 │  ◄──►  ◄──►  ◄──►  ◄──►            ◄──►  ◄──►         │
│                 │  Each decode step is a separate forward pass          │
│                 │                                                        │
│    ~50ms        │                    ~2000ms                             │
│   (example)     │     (example: 10ms per token × 200 tokens)            │
│                 │                                                        │

Total time breakdown:
├─ Prefill:  ~50ms   (2.4% of total time)
├─ Decode:   ~2000ms (97.6% of total time)
└─ Total:    ~2050ms

Even though prefill processes 500 tokens and decode processes 200 tokens,
decode takes ~40× longer because it runs 200 sequential forward passes
vs. prefill's single forward pass.

This is the crucial insight: decode dominates wall-clock time even though it processes fewer tokens, because it cannot be parallelized across tokens.


The Questions This Raises

At this point, you might be wondering:

  1. Why exactly is prefill fast despite processing many tokens? What makes those large matrix operations efficient?

  2. Why exactly is decode slow despite processing one token? Shouldn't less work mean less time?

  3. What determines GPU utilization in each phase? Why is prefill efficient but decode inefficient?

The answers involve understanding the difference between being compute-bound (prefill) versus memory-bandwidth-bound (decode). This is what we'll explore in Artifacts 4 and 5.


Summary

Prefill:

  • Processes all prompt tokens in one forward pass
  • Builds the initial KV cache
  • High parallelism → efficient GPU utilization
  • Runs once per request

Decode:

  • Processes one token per forward pass
  • Reads from and appends to KV cache
  • Low parallelism → inefficient GPU utilization
  • Runs once per generated token (potentially hundreds of times)

The key insight: Despite decode doing "less work" per step, it dominates total inference time because it runs many sequential steps with poor parallelism.


Check Your Understanding

Before moving on:

  1. If you have a 1000-token prompt and generate 100 tokens, how many forward passes happen during prefill? How many during decode?

  2. During decode step 50, what shape is the attention score matrix? What about during prefill with a 1000-token prompt?

  3. Why can't we parallelize across tokens during decode the way we do during prefill?

Community

Sign up or log in to comment