2.2c: FlashAttention — IO Analysis and Evolution
What We've Established
In the previous two artifacts, we learned:
- Tiling 2.2a: FlashAttention processes attention in small blocks that fit in SRAM, avoiding the need to materialize the full N×N attention matrix in HBM.
- Online Softmax 2.2b: The mathematical trick that allows softmax to be computed incrementally, making tiling possible while producing exact results.
Now we'll quantify the benefits: How much does FlashAttention actually improve IO? Why is it faster despite doing more FLOPs? And how have subsequent versions (FlashAttention-2, FlashAttention-3) improved further?
IO Complexity Analysis
Let's formally analyze the memory traffic of standard attention versus FlashAttention.
Parameters:
- N: sequence length
- d: head dimension (typically 64-128)
- M: SRAM size (we'll use this to derive optimal block size)
- B_r, B_c: row and column block sizes
Standard Attention IO:
Step 1 (QK^T): Read Q (Nd), Read K (Nd), Write S (N²)
Step 2 (softmax): Read S (N²), Write P (N²)
Step 3 (PV): Read P (N²), Read V (Nd), Write O (Nd)
Total: 4N² + 4Nd bytes (in element count, multiply by 2 for FP16)
≈ 4N² bytes (since N >> d typically)
= O(N²)
FlashAttention IO:
Outer loop (T_r iterations): Each Q block read once = Nd total for all Q
Inner loop (T_r × T_c iterations):
- Each K block read T_r times = Nd × T_r = Nd × (N/B_r) = N²d/B_r
- Each V block read T_r times = N²d/B_r
Output writes: Nd total
Total: Nd + 2N²d/B_r + Nd
= 2Nd + 2N²d/B_r
= O(N²d/B_r)
The key question: What's the optimal B_r?
The constraint is that everything must fit in SRAM of size M:
- Q block: B_r × d
- K block: B_c × d
- V block: B_c × d
- S block: B_r × B_c
- O block: B_r × d
- Statistics: B_r (for m and l)
For simplicity, assume B_r = B_c = B and the dominant terms are Q, K, V, S, O blocks:
SRAM usage ≈ B×d + 2B×d + B² + B×d = 4Bd + B²
Setting this equal to M and solving for B (assuming B² term dominates for large B):
B² ≈ M → B ≈ √M
But if we want to maximize B_r specifically (to minimize K,V re-reads), and B_c can be smaller:
With careful analysis: B_r = O(M/d), B_c = O(d)
This gives FlashAttention IO complexity:
O(N²d/B_r) = O(N²d/(M/d)) = O(N²d²/M)
The FlashAttention paper states the complexity as:
FlashAttention IO = O(N²d²/M)
Comparing IO: Concrete Numbers
Let's calculate actual byte counts for realistic parameters:
Parameters: d = 128, M = 192 KB (A100 SRAM per SM), FP16. B_r ≈ M/(4d) = 192KB/(4×128×2B) ≈ 192 (practical: 128)
| Seq Len (N) | Standard IO (bytes) | FlashAttention IO (bytes) | Reduction Factor |
|---|---|---|---|
| 1,024 | 8.4 MB | 1.1 MB | 7.6× |
| 2,048 | 33.6 MB | 4.2 MB | 8.0× |
| 4,096 | 134.2 MB | 16.8 MB | 8.0× |
| 8,192 | 536.9 MB | 67.1 MB | 8.0× |
| 16,384 | 2,147.5 MB | 268.4 MB | 8.0× |
| 32,768 | 8,590.0 MB | 1,073.7 MB | 8.0× |
| 65,536 | 34,359.7 MB | 4,295.0 MB | 8.0× |
Calculation for N=4096:
- Standard: 4 × 4096² × 2 bytes ≈ 134 MB
- FlashAttention: 4 × 4096 × 128 × 2 × (4096/128) ≈ 17 MB
Note: Reduction factor ≈ N/(4d) for large N. With d=128: factor ≈ N/512. For N=4096: factor ≈ 8.
Key insight: The reduction factor grows linearly with N! For longer sequences, FlashAttention's advantage increases.
Why FlashAttention is Faster Despite More FLOPs
Here's a counterintuitive fact: FlashAttention actually does MORE floating-point operations than standard attention, yet it's faster. How?
FLOPs comparison:
Standard Attention FLOPs:
- QK^T: 2N²d
- softmax: 5N² (approx)
- PV: 2N²d
- Total: ~4N²d + 5N² ≈ 4N²d
FlashAttention FLOPs:
- Same matmuls: 4N²d
- Same softmax: 5N²
- ADDITIONAL: rescaling operations in online softmax
- Each block: O(B_r) rescaling ops
- Total blocks: O(N²/B_r B_c)
- Total extra: O(N²/B_c) ≈ O(N²/d)
- Total: ~4N²d + 5N² + O(N²) ≈ 4N²d (same order!)
FlashAttention does ~same FLOPs, with small constant factor increase.
So why is it faster?
The answer lies in the fact that: arithmetic intensity determines performance for memory-bound operations.
Standard Attention
| Metric | Value |
|---|---|
| FLOPs | ~4N²d ≈ 4 × 4096² × 128 = 8.6 GFLOP |
| Bytes | ~4N² × 2 = 134 MB |
| Arithmetic intensity | 8.6 GFLOP / 134 MB = 64 FLOP/byte |
| vs A100 ridge point | Below (156 FLOP/byte) |
| Bound type | MEMORY-BOUND — limited by bandwidth, not compute |
| Time | 134 MB / 2000 GB/s = 0.067 ms |
FlashAttention
| Metric | Value |
|---|---|
| FLOPs | ~4N²d = 8.6 GFLOP (same) |
| Bytes | ~17 MB (8× less) |
| Arithmetic intensity | 8.6 GFLOP / 17 MB = 506 FLOP/byte |
| vs A100 ridge point | Above (156 FLOP/byte) |
| Bound type | COMPUTE-BOUND — limited by compute, using GPU efficiently |
| Time | 8.6 GFLOP / 312 TFLOP/s = 0.028 ms |
Speedup: 0.067 / 0.028 ≈ 2.4× faster. By reducing bytes transferred, the operation moved from memory-bound to compute-bound, allowing the GPU's full compute capability to be utilized.
This is the fundamental lesson: For memory-bound operations, reducing data movement is more valuable than reducing computation. FlashAttention trades a small increase in FLOPs for a large decrease in memory traffic, resulting in a net speedup.
Roofline Visualization
Let's see how FlashAttention moves the operation on the roofline:
Standard Attention (memory-bound)
| Metric | Value |
|---|---|
| Arithmetic intensity | 64 FLOP/byte |
| Position | Below ridge point (156 FLOP/byte) |
| Achieved throughput | ~128 TFLOPS |
| Bottleneck | Memory bandwidth |
FlashAttention (compute-bound)
| Metric | Value |
|---|---|
| Arithmetic intensity | 506 FLOP/byte |
| Position | Above ridge point (156 FLOP/byte) |
| Achieved throughput | ~312 TFLOPS (theoretical max) |
| Bottleneck | Compute |
A100 Roofline: Ridge point = 156 FLOP/byte. Compute ceiling = 312 TFLOPS. Memory bandwidth = 2000 GB/s. FlashAttention moves attention from the sloped (memory-limited) region to the flat (compute-limited) region of the roofline.
Recomputation in the Backward Pass
FlashAttention makes an interesting tradeoff in the backward pass (for training): instead of storing the attention matrix P for use in the backward pass, it recomputes P from Q, K, and the stored statistics (m, l).
Why this makes sense:
Standard approach:
Forward: Compute P, store it (O(N²) memory)
Backward: Read stored P, compute gradients
FlashAttention approach:
Forward: Compute P in blocks, store only m, l (O(N) memory)
Backward: Recompute P in blocks from Q, K, m, l
The recomputation costs extra FLOPs, but:
- Saves O(N²) memory (critical for training with limited GPU memory)
- The recomputation is also done in SRAM, so it's fast
- The alternative (storing P) would be memory-bound anyway
The tradeoff:
| A: Store P (Standard) | B: Recompute P (FlashAttention) | |
|---|---|---|
| Memory | O(N²) for P | O(N) for statistics |
| Forward FLOPs | 4N²d | 4N²d |
| Backward FLOPs | ~4N²d | ~8N²d (recompute P + gradients) |
| Total FLOPs | ~8N²d | |
| Tradeoff | O(N²) memory limits batch size and sequence length | O(N) memory allows larger batches/sequences. Often FASTER because larger batch → better utilization, and recomputation is compute-bound |
This illustrates a general principle: with modern GPUs, computation is often "free" compared to memory access. Trading memory for compute is frequently a good deal.
FlashAttention-2: Optimizations
FlashAttention-2 (2023) introduced several improvements over the original FlashAttention, achieving 2× speedup:
1. Reduced Non-Matmul FLOPs
The original FlashAttention had significant overhead from non-matrix-multiplication operations (rescaling, statistics updates). FlashAttention-2 optimized these:
FlashAttention-1:
- Rescaling O_i after every KV block
- Multiple synchronization points
FlashAttention-2:
- Delayed rescaling: accumulate first, rescale once at the end
- Fewer synchronization points
- Better instruction scheduling
2. Better Parallelization
| FlashAttention-1 | FlashAttention-2 | |
|---|---|---|
| Parallelizes over | Batch and heads only | Batch, heads, and sequence length |
| Grid dimensions | (batch_size, num_heads, 1) | (batch_size, num_heads, num_q_blocks) |
| Example (batch=1, heads=8, seq=4096, block=128) | 8 thread blocks → massive underutilization (A100 has 108 SMs) | 8 × 32 = 256 thread blocks → good GPU utilization |
| Key insight | — | Different query blocks are independent and can be processed in parallel |
3. Improved Work Partitioning
FlashAttention-2 swaps the loop order for the backward pass:
FlashAttention-1 backward: For each KV block, iterate over Q blocks
FlashAttention-2 backward: For each Q block, iterate over KV blocks
The second order has better parallelism and memory access patterns.
4. Better Warp-Level Optimizations
More efficient use of warp shuffles and shared memory access patterns, reducing bank conflicts.
Result: FlashAttention-2 achieves ~2× speedup over FlashAttention-1, reaching 50-73% of theoretical max throughput (compared to 25-40% for FlashAttention-1).
FlashAttention-3: Hopper Architecture
FlashAttention-3 (2024) is specifically optimized for NVIDIA H100 (Hopper architecture), exploiting new hardware features:
1. Warp Specialization
Hopper supports "warp specialization" where different warps within a thread block can perform different tasks asynchronously:
Previous approach (all warps do same work):
- All warps execute Load → Compute → Store sequentially
- All warps wait for memory at the same time → stalls
FlashAttention-3 (specialized warps, pipelined):
| Time Step | Producer Warps | Consumer Warps |
|---|---|---|
| T1 | Load K₀V₀ | — |
| T2 | Load K₁V₁ | Compute₀ |
| T3 | Load K₂V₂ | Compute₁ |
| T4 | Load K₃V₃ | Compute₂ |
While consumers compute on K₀V₀, producers load K₁V₁ — memory latency is hidden behind computation.
2. Asynchronous Memory Operations (TMA)
H100's Tensor Memory Accelerator (TMA) enables asynchronous bulk data transfers:
Previous: Load data → wait → compute → wait → store
H100 TMA: Issue load → compute on previous data → load completes in background
This hides memory latency much more effectively.
3. FP8 Support
H100 has native FP8 Tensor Cores, which FlashAttention-3 exploits:
| Precision | Tensor Core TFLOPS | Memory per Element |
|---|---|---|
| FP16 | 990 | 2 bytes |
| FP8 | 1,979 | 1 byte |
| FP8 advantage | 2× compute | 2× less memory |
Combined effect: 2× more ops/sec, 2× less data to transfer → 4× improvement in arithmetic intensity. FlashAttention-3 with FP8 achieves close to 1.2 PFLOPS (1,200 TFLOPS) on H100 — unprecedented for attention.
4. Block Quantization for Accuracy
Pure FP8 can have accuracy issues. FlashAttention-3 uses "block quantization":
- Compute attention scores in higher precision
- Quantize to FP8 for the matmul with V
- Use per-block scaling factors to maintain accuracy
Result: FlashAttention-3 achieves 1.5-2× speedup over FlashAttention-2 on H100, reaching 75%+ of theoretical throughput.
Performance Summary Across Versions
Version Comparison
| Version | GPU | Peak Efficiency | Key Innovation |
|---|---|---|---|
| Standard Attention | A100 | ~25% | (baseline) |
| FlashAttention-1 | A100 | 25-40% | Tiling + online softmax |
| FlashAttention-2 | A100 | 50-73% | Better parallelism, fewer non-matmul ops |
| FlashAttention-3 | H100 | 75%+ | Warp specialization, TMA, FP8 |
Typical Speedups (forward pass, seq_len=4096)
| Transition | Speedup |
|---|---|
| Standard → FlashAttention-1 | 2-4× faster |
| FlashAttention-1 → 2 | ~2× faster |
| FlashAttention-2 → 3 (H100) | 1.5-2× faster |
| Standard → FlashAttention-3 | up to 16× faster |
Memory Reduction (all versions)
Standard: O(N²) → FlashAttention: O(N)
This enables:
- 128K+ context windows (impossible with standard attention)
- Larger batch sizes for training
- Lower memory requirements for inference
When Does FlashAttention Help Most?
FlashAttention's benefits vary by use case:
Maximum Benefit
- Long sequences (N > 2048) — O(N²) → O(N) memory matters more; IO reduction factor grows with N
- Memory-constrained scenarios — Training with limited GPU memory; inference with large KV caches
- Prefill phase (many queries) — Original FlashAttention parallelizes over queries; high utilization during prompt processing
Moderate Benefit
- Short sequences (N < 512) — O(N²) memory is manageable; overhead of tiling may not be fully amortized
- Already compute-bound scenarios — Large batch sizes; already high arithmetic intensity
Special Case — Decode Phase
- Original FlashAttention has limited benefit for decode
- Only 1 query token → limited parallelism over queries
- Need Flash Decoding for efficient decode
Summary: What FlashAttention Achieves
1. IO Complexity Reduction:
- Standard: O(N²) HBM accesses
- FlashAttention: O(N²d²/M) ≈ O(N²/M) HBM accesses (M = SRAM size)
- Practical reduction: 4-16× less memory traffic for typical sequence lengths
2. Memory Complexity Reduction:
- Standard: O(N²) storage for attention matrix
- FlashAttention: O(N) storage (just statistics)
- This enables 100K+ context windows
3. Speed Improvement:
- Moves attention from memory-bound to compute-bound regime
- 2-4× speedup over standard attention (FlashAttention-1)
- Up to 16× with latest optimizations (FlashAttention-3 on H100)
4. Exactness:
- FlashAttention computes the exact same result as standard attention
- Online softmax is mathematically equivalent, not an approximation
5. Key Insight:
- By reducing IO, FlashAttention increases arithmetic intensity
- This shifts the operation from memory-bound to compute-bound
- More FLOPs can mean faster execution when IO is the bottleneck
Check Your Understanding
- If SRAM size doubles (from 192KB to 384KB), what happens to FlashAttention's IO complexity? Does this make it faster?
- FlashAttention-2 achieves 50-73% of theoretical GPU throughput, while standard attention achieves ~25%. Where does the remaining 27-50% go? (What prevents 100% utilization?)
- Why does FlashAttention-3 require Hopper (H100) architecture? Could the same optimizations be implemented on A100?
- For training, FlashAttention recomputes the attention matrix in the backward pass instead of storing it. Under what conditions would storing the matrix actually be faster? (Hint: think about memory bandwidth vs compute tradeoffs)
- A model uses GQA with 8 KV heads and 64 query heads. How does this affect FlashAttention's performance compared to standard MHA? (Hint: think about what gets re-read in the inner loop)

