1.4: Prefill — Computational Deep Dive
The Question We Need to Answer
In Artifact 3, we saw that prefill processes hundreds or thousands of tokens yet completes quickly (e.g., ~50ms for 500 tokens). Meanwhile, decode processes just one token per step yet each step takes ~10ms.
This seems backwards. More tokens should mean more time, right?
To understand why prefill is efficient, we need to understand how GPUs work and what makes some operations fast versus slow. This requires introducing a crucial concept: arithmetic intensity.
How GPUs Actually Work (The Essential Mental Model)
A GPU is fundamentally two things:
1. A massive parallel computer — Thousands of cores that can do arithmetic operations simultaneously. An NVIDIA A100 can perform ~312 trillion floating-point operations per second (312 TFLOPS for FP16).
2. A memory system with limited bandwidth — GPU memory (HBM) can transfer data at ~2 TB/s on an A100. This sounds fast, but it's the bottleneck for many workloads.
Here's the key insight: the GPU can compute much faster than it can read data.
Let's put numbers to this:
A100 GPU Specs:
• Compute throughput: 312 TFLOPS (FP16) = 312 × 10¹² operations/second
• Memory bandwidth: 2 TB/s = 2 × 10¹² bytes/second
Ratio: 312 TFLOPS ÷ 2 TB/s = 156 FLOPs per byte
This means: for every byte you read from memory, you need to perform
156 floating-point operations to keep the compute units fully busy.
If your operation does fewer than 156 FLOPs per byte of data read, the compute units will be sitting idle waiting for data. If it does more than 156 FLOPs per byte, you're fully utilizing the compute capability.
Arithmetic Intensity: The Key Metric
Arithmetic intensity measures how much computation you do per byte of data moved:
Arithmetic Intensity = FLOPs performed / Bytes transferred
= (computation) / (memory traffic)
This single number tells you whether an operation is:
Compute-bound (high arithmetic intensity): The GPU spends most of its time computing, not waiting for data. Compute units are fully utilized. This is the good case.
Memory-bandwidth-bound (low arithmetic intensity): The GPU spends most of its time waiting for data from memory. Compute units are mostly idle. This is the bad case.
The threshold depends on the specific GPU. For an A100:
- Arithmetic intensity > 156 FLOPs/byte → compute-bound
- Arithmetic intensity < 156 FLOPs/byte → memory-bound
Matrix Multiplication: The Workhorse Operation
Almost everything in a transformer is matrix multiplication:
- Embedding lookup (can be viewed as matmul)
- Q, K, V projections:
hidden_states @ W_Q, etc. - Attention scores:
Q @ K.T - Attention output:
attention_weights @ V - FFN layers: two large matrix multiplications
- Output projection:
hidden_states @ W_output
So let's analyze the arithmetic intensity of matrix multiplication.
Case Study: Multiplying Two Matrices
Consider multiplying matrix A [M, K] by matrix B [K, N] to get C [M, N]:
A [M, K] × B [K, N] = C [M, N]
FLOPs:
Each element of C requires K multiplications and K-1 additions ≈ 2K FLOPs
C has M × N elements
Total FLOPs = 2 × M × N × K
Bytes transferred (naive analysis):
Read A: M × K elements × 2 bytes (FP16) = 2MK bytes
Read B: K × N elements × 2 bytes = 2KN bytes
Write C: M × N elements × 2 bytes = 2MN bytes
Total bytes = 2(MK + KN + MN)
Arithmetic Intensity = 2MNK / 2(MK + KN + MN) = MNK / (MK + KN + MN)
Let's see how this scales with matrix size:
Small Matrices (Bad)
A [1, 4096] × B [4096, 4096] = C [1, 4096]
(This is like: one token through a projection layer)
FLOPs = 2 × 1 × 4096 × 4096 = 33.5 million
Bytes = 2 × (1×4096 + 4096×4096 + 1×4096) = 33.6 MB
Arithmetic Intensity = 33.5M / 33.6M ≈ 1.0 FLOPs/byte
This is WAY below the 156 FLOPs/byte threshold!
The GPU will be ~99% idle waiting for memory.
Large Matrices (Good)
A [512, 4096] × B [4096, 4096] = C [512, 4096]
(This is like: 512 tokens through a projection layer)
FLOPs = 2 × 512 × 4096 × 4096 = 17.2 billion
Bytes = 2 × (512×4096 + 4096×4096 + 512×4096) = 41.9 MB
Arithmetic Intensity = 17.2B / 41.9M ≈ 410 FLOPs/byte
This is ABOVE the 156 FLOPs/byte threshold!
The GPU compute units will be fully utilized.
The Pattern
┌─────────────────────────────────────────────────────────────────────────┐
│ ARITHMETIC INTENSITY VS BATCH SIZE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Arithmetic │
│ Intensity × │
│ (FLOPs/byte) × │
│ ▲ × × × × × × × (saturates) │
│ │ │
│ 156 ├──────────────────────────────────────────────── threshold ───── │
│ │ × (A100) │
│ │ × │
│ │ × │
│ │ × │
│ │× │
│ └──────────────────────────────────────────────────────────► │
│ 1 8 32 64 128 256 512 1024 2048 │
│ Batch size (tokens) │
│ │
│ With 1 token: ~1 FLOP/byte → memory-bound, GPU mostly idle │
│ With 512 tokens: ~410 FLOPs/byte → compute-bound, GPU fully utilized │
│ │
└─────────────────────────────────────────────────────────────────────────┘
This is why batch size matters so much. The same weight matrix B [4096, 4096] gets read from memory either way. With 512 tokens, you amortize that memory read across 512× more computation.
Prefill: Why It's Compute-Bound
During prefill, we process all N prompt tokens together. Let's trace through the operations:
1. Q, K, V Projections
hidden_states: [N, hidden_dim] e.g., [512, 4096]
W_Q: [hidden_dim, head_dim × num_heads] e.g., [4096, 4096]
Q = hidden_states @ W_Q → [512, 4096]
FLOPs: 2 × 512 × 4096 × 4096 = 17.2 billion
Bytes: ~42 MB (dominated by reading W_Q)
Arithmetic Intensity: ~410 FLOPs/byte ✓ COMPUTE-BOUND
Same analysis applies to K and V projections. All compute-bound.
2. Attention Score Computation
Q: [N, num_heads, head_dim] e.g., [512, 32, 128]
K: [N, num_heads, head_dim] e.g., [512, 32, 128]
For each head, compute Q @ K.T:
Q_head: [512, 128]
K_head: [512, 128]
scores = Q_head @ K_head.T → [512, 512]
FLOPs per head: 2 × 512 × 512 × 128 = 67 million
Total FLOPs (32 heads): 2.1 billion
Bytes: Q and K are already in fast memory (registers/cache) from previous op
Arithmetic Intensity: Very high ✓ COMPUTE-BOUND
3. Attention Output
attention_weights: [N, N] e.g., [512, 512] per head
V: [N, head_dim] e.g., [512, 128] per head
output = attention_weights @ V → [512, 128] per head
Similar analysis: compute-bound due to large batch dimension.
4. FFN Layers
The feed-forward network typically has two linear layers with an expansion factor of 4:
FFN(x) = GELU(x @ W1) @ W2
x: [N, hidden_dim] e.g., [512, 4096]
W1: [hidden_dim, 4×hidden_dim] e.g., [4096, 16384]
W2: [4×hidden_dim, hidden_dim] e.g., [16384, 4096]
First matmul: [512, 4096] @ [4096, 16384]
FLOPs: 2 × 512 × 4096 × 16384 = 68.7 billion
Bytes: ~134 MB (dominated by W1)
Arithmetic Intensity: ~512 FLOPs/byte ✓ COMPUTE-BOUND
Second matmul: similar, also compute-bound.
Summary: All Major Prefill Operations Are Compute-Bound
┌─────────────────────────────────────────────────────────────────────────┐
│ PREFILL OPERATIONS (N = 512 tokens) │
├──────────────────────┬──────────────┬─────────────┬────────────────────┤
│ Operation │ FLOPs │ Bytes │ Arithmetic │
│ │ │ │ Intensity │
├──────────────────────┼──────────────┼─────────────┼────────────────────┤
│ Q projection │ 17.2 B │ 42 MB │ 410 FLOPs/byte ✓ │
│ K projection │ 17.2 B │ 42 MB │ 410 FLOPs/byte ✓ │
│ V projection │ 17.2 B │ 42 MB │ 410 FLOPs/byte ✓ │
│ Q @ Kᵀ (attention) │ 2.1 B │ small │ very high ✓ │
│ softmax @ V │ 2.1 B │ small │ very high ✓ │
│ Output projection │ 17.2 B │ 42 MB │ 410 FLOPs/byte ✓ │
│ FFN layer 1 │ 68.7 B │ 134 MB │ 512 FLOPs/byte ✓ │
│ FFN layer 2 │ 68.7 B │ 134 MB │ 512 FLOPs/byte ✓ │
├──────────────────────┼──────────────┼─────────────┼────────────────────┤
│ Total per layer │ ~210 B │ ~440 MB │ ~480 FLOPs/byte ✓ │
│ All 32 layers │ ~6.7 T │ ~14 GB │ │
└──────────────────────┴──────────────┴─────────────┴────────────────────┘
All operations exceed the 156 FLOPs/byte threshold → GPU is compute-bound
→ High utilization → Efficient!
GPU Utilization During Prefill
Because prefill operations are compute-bound, the GPU achieves high utilization:
Prefill GPU Utilization (typical):
┌────────────────────────────────────────────────────────────────┐
│████████████████████████████████████████████████████████████████│
│████████████████████████████████████████████████████████████████│
│██████████████████████████████████████████████████░░░░░░░░░░░░░░│
│████████████████████████████████████████████████████████████████│
└────────────────────────────────────────────────────────────────┘
▲ ▲
0% ~70-85% 100%
The compute units are mostly busy doing useful arithmetic.
Some inefficiency comes from:
• Memory transfers between operations
• Attention masking overhead
• Layer normalization (memory-bound, but small)
• Softmax (memory-bound, but small)
Overall: EFFICIENT
Time Complexity of Prefill
Let's derive how prefill time scales with sequence length.
FLOPs Analysis
For a transformer with:
- N = sequence length (prompt tokens)
- d = hidden dimension
- L = number of layers
Per layer:
Q, K, V projections: 3 × 2Nd² = 6Nd²
Attention (Q @ Kᵀ): 2N²d (for all heads combined, simplified)
Attention (weights @ V): 2N²d
Output projection: 2Nd²
FFN: 2 × 2Nd × 4d = 16Nd²
Total per layer ≈ 24Nd² + 4N²d
For L layers:
Total FLOPs ≈ L × (24Nd² + 4N²d)
= 24LNd² + 4LN²d
For typical models where d is large (4096+), the 24LNd² term dominates for moderate N:
Prefill FLOPs ≈ O(N × d² × L) [linear in N for moderate sequences]
But for very long sequences, the N² attention term matters:
Prefill FLOPs ≈ O(N² × d × L) [quadratic in N for long sequences]
Time Calculation
Prefill time ≈ Total FLOPs / GPU compute throughput
Example: 512 tokens on LLaMA-7B (32 layers, d=4096) on A100
FLOPs ≈ 32 × (24 × 512 × 4096² + 4 × 512² × 4096)
≈ 32 × (206B + 4.3B)
≈ 6.7 trillion FLOPs
A100 throughput: 312 TFLOPS
Time ≈ 6.7T / 312T = ~21 ms
(Real-world: ~30-50ms due to overhead, memory transfers, etc.)
Scaling Behavior
┌─────────────────────────────────────────────────────────────────────────┐
│ PREFILL TIME VS SEQUENCE LENGTH │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Time (ms) │
│ ▲ │
│ │ × │
│ 800 ┤ × │
│ │ × │
│ 600 ┤ × │
│ │ × │
│ 400 ┤ × │
│ │ × │
│ 200 ┤ × │
│ │ × │
│ 50 ┤ × │
│ └──────────────────────────────────────────────────────────► │
│ 512 1K 2K 4K 8K 16K 32K │
│ Sequence length (tokens) │
│ │
│ Roughly linear for short sequences, becomes quadratic for long │
│ sequences (due to N² attention computation) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Why Parallelism Enables Efficiency
Let's revisit why processing N tokens together is fundamentally different from processing them one at a time.
The Weight Reuse Insight
Consider the Q projection: Q = hidden_states @ W_Q
With N tokens (prefill):
hidden_states: [N, 4096]
W_Q: [4096, 4096]
We read W_Q once (67 MB) and use it for N tokens.
Cost per token: 67 MB / N
For N=512: 131 KB per token
With 1 token (decode):
hidden_states: [1, 4096]
W_Q: [4096, 4096]
We read W_Q once (67 MB) and use it for 1 token.
Cost per token: 67 MB / 1 = 67 MB per token
The same weight matrix gets read either way, but with batching, we amortize that memory cost across many tokens. This is why arithmetic intensity increases with batch size.
The Matrix Math View
WEIGHT REUSE VISUALIZATION
Prefill (N=512 tokens): Decode (1 token):
hidden_states W_Q hidden_states W_Q
[512, 4096] [4096, 4096] [1, 4096] [4096, 4096]
┌─────────┐ ┌─────────┐ ┌─┐ ┌─────────┐
│░░░░░░░░░│ │█████████│ │░│ │█████████│
│░░░░░░░░░│ │█████████│ └─┘ │█████████│
│░░░░░░░░░│ @ │█████████│ @ │█████████│
│░░░░░░░░░│ │█████████│ │█████████│
│░░░░░░░░░│ │█████████│ │█████████│
│░░░░░░░░░│ └─────────┘ └─────────┘
│░░░░░░░░░│ │ │
│░░░░░░░░░│ │ │
└─────────┘ ▼ ▼
│ Read W_Q once Read W_Q once
│ (67 MB) (67 MB)
▼ Amortized over Used for only
512 tokens 512 tokens 1 token
benefit = 131 KB/token = 67 MB/token
512× better memory efficiency!
Summary: Why Prefill Is Efficient
Large batch size (N tokens): Processing all prompt tokens together creates large matrices.
High arithmetic intensity: Large matrix multiplications do many FLOPs per byte of memory read, exceeding the GPU's compute-to-bandwidth ratio.
Compute-bound operations: The GPU spends its time computing, not waiting for memory. Compute units are fully utilized.
Weight reuse: Each weight matrix is read once and used for N tokens, amortizing memory bandwidth cost.
Parallelism: All N tokens flow through the network together, enabling efficient batched operations.
The result: Despite processing hundreds or thousands of tokens, prefill completes quickly because the GPU is doing what it's designed for—massively parallel matrix math.
Preview: Why Decode Is Different
In Artifact 5, we'll see that decode has the opposite characteristics:
- Batch size of 1 (one new token)
- Low arithmetic intensity
- Memory-bandwidth-bound operations
- Weight matrices read but barely used
- No parallelism across tokens
This explains why decode is slow despite doing "less work" per step—and why most inference optimization research focuses on the decode phase.
Check Your Understanding
Before moving on:
If arithmetic intensity is 10 FLOPs/byte and your GPU's threshold is 156 FLOPs/byte, what percentage of compute capacity are you utilizing?
Why does processing 512 tokens together have higher arithmetic intensity than processing 1 token, even though the weight matrices are the same size?
For a 4096-token prompt vs a 512-token prompt, roughly how much longer should prefill take? (Consider both the linear and quadratic terms.)