2.2c: FlashAttention — IO Analysis and Evolution

Community Article Published February 9, 2026

What We've Established

In the previous two artifacts, we learned:

  • Tiling 2.2a: FlashAttention processes attention in small blocks that fit in SRAM, avoiding the need to materialize the full N×N attention matrix in HBM.
  • Online Softmax 2.2b: The mathematical trick that allows softmax to be computed incrementally, making tiling possible while producing exact results.

Now we'll quantify the benefits: How much does FlashAttention actually improve IO? Why is it faster despite doing more FLOPs? And how have subsequent versions (FlashAttention-2, FlashAttention-3) improved further?


IO Complexity Analysis

Let's formally analyze the memory traffic of standard attention versus FlashAttention.

Parameters:

  • N: sequence length
  • d: head dimension (typically 64-128)
  • M: SRAM size (we'll use this to derive optimal block size)
  • B_r, B_c: row and column block sizes

Standard Attention IO:

Step 1 (QK^T):     Read Q (Nd), Read K (Nd), Write S (N²)
Step 2 (softmax):  Read S (N²), Write P (N²)
Step 3 (PV):       Read P (N²), Read V (Nd), Write O (Nd)

Total: 4N² + 4Nd bytes (in element count, multiply by 2 for FP16)
     ≈ 4N² bytes (since N >> d typically)
     = O(N²)

FlashAttention IO:

Outer loop (T_r iterations): Each Q block read once = Nd total for all Q
Inner loop (T_r × T_c iterations):
  - Each K block read T_r times = Nd × T_r = Nd × (N/B_r) = N²d/B_r
  - Each V block read T_r times = N²d/B_r
Output writes: Nd total

Total: Nd + 2N²d/B_r + Nd
     = 2Nd + 2N²d/B_r
     = O(N²d/B_r)

The key question: What's the optimal B_r?

The constraint is that everything must fit in SRAM of size M:

  • Q block: B_r × d
  • K block: B_c × d
  • V block: B_c × d
  • S block: B_r × B_c
  • O block: B_r × d
  • Statistics: B_r (for m and l)

For simplicity, assume B_r = B_c = B and the dominant terms are Q, K, V, S, O blocks:

SRAM usage ≈ B×d + 2B×d + B² + B×d = 4Bd + B²

Setting this equal to M and solving for B (assuming B² term dominates for large B):

B² ≈ M  →  B ≈ √M

But if we want to maximize B_r specifically (to minimize K,V re-reads), and B_c can be smaller:

With careful analysis: B_r = O(M/d), B_c = O(d)

This gives FlashAttention IO complexity:

O(N²d/B_r) = O(N²d/(M/d)) = O(N²d²/M)

The FlashAttention paper states the complexity as:

FlashAttention IO = O(N²d²/M)

Comparing IO: Concrete Numbers

Let's calculate actual byte counts for realistic parameters:

Parameters: d = 128, M = 192 KB (A100 SRAM per SM), FP16. B_r ≈ M/(4d) = 192KB/(4×128×2B) ≈ 192 (practical: 128)

Seq Len (N) Standard IO (bytes) FlashAttention IO (bytes) Reduction Factor
1,024 8.4 MB 1.1 MB 7.6×
2,048 33.6 MB 4.2 MB 8.0×
4,096 134.2 MB 16.8 MB 8.0×
8,192 536.9 MB 67.1 MB 8.0×
16,384 2,147.5 MB 268.4 MB 8.0×
32,768 8,590.0 MB 1,073.7 MB 8.0×
65,536 34,359.7 MB 4,295.0 MB 8.0×

Calculation for N=4096:

  • Standard: 4 × 4096² × 2 bytes ≈ 134 MB
  • FlashAttention: 4 × 4096 × 128 × 2 × (4096/128) ≈ 17 MB

Note: Reduction factor ≈ N/(4d) for large N. With d=128: factor ≈ N/512. For N=4096: factor ≈ 8.

Key insight: The reduction factor grows linearly with N! For longer sequences, FlashAttention's advantage increases.


Why FlashAttention is Faster Despite More FLOPs

Here's a counterintuitive fact: FlashAttention actually does MORE floating-point operations than standard attention, yet it's faster. How?

FLOPs comparison:

Standard Attention FLOPs:
  - QK^T:     2N²d
  - softmax:  5N² (approx)
  - PV:       2N²d
  - Total:    ~4N²d + 5N² ≈ 4N²d

FlashAttention FLOPs:
  - Same matmuls: 4N²d
  - Same softmax: 5N²
  - ADDITIONAL: rescaling operations in online softmax
    - Each block: O(B_r) rescaling ops
    - Total blocks: O(N²/B_r B_c)
    - Total extra: O(N²/B_c) ≈ O(N²/d)
  - Total: ~4N²d + 5N² + O(N²) ≈ 4N²d (same order!)

FlashAttention does ~same FLOPs, with small constant factor increase.

So why is it faster?

The answer lies in the fact that: arithmetic intensity determines performance for memory-bound operations.

Standard Attention

Metric Value
FLOPs ~4N²d ≈ 4 × 4096² × 128 = 8.6 GFLOP
Bytes ~4N² × 2 = 134 MB
Arithmetic intensity 8.6 GFLOP / 134 MB = 64 FLOP/byte
vs A100 ridge point Below (156 FLOP/byte)
Bound type MEMORY-BOUND — limited by bandwidth, not compute
Time 134 MB / 2000 GB/s = 0.067 ms

FlashAttention

Metric Value
FLOPs ~4N²d = 8.6 GFLOP (same)
Bytes ~17 MB (8× less)
Arithmetic intensity 8.6 GFLOP / 17 MB = 506 FLOP/byte
vs A100 ridge point Above (156 FLOP/byte)
Bound type COMPUTE-BOUND — limited by compute, using GPU efficiently
Time 8.6 GFLOP / 312 TFLOP/s = 0.028 ms

Speedup: 0.067 / 0.028 ≈ 2.4× faster. By reducing bytes transferred, the operation moved from memory-bound to compute-bound, allowing the GPU's full compute capability to be utilized.

This is the fundamental lesson: For memory-bound operations, reducing data movement is more valuable than reducing computation. FlashAttention trades a small increase in FLOPs for a large decrease in memory traffic, resulting in a net speedup.


Roofline Visualization

Let's see how FlashAttention moves the operation on the roofline:

Standard Attention (memory-bound)

Metric Value
Arithmetic intensity 64 FLOP/byte
Position Below ridge point (156 FLOP/byte)
Achieved throughput ~128 TFLOPS
Bottleneck Memory bandwidth

FlashAttention (compute-bound)

Metric Value
Arithmetic intensity 506 FLOP/byte
Position Above ridge point (156 FLOP/byte)
Achieved throughput ~312 TFLOPS (theoretical max)
Bottleneck Compute

A100 Roofline: Ridge point = 156 FLOP/byte. Compute ceiling = 312 TFLOPS. Memory bandwidth = 2000 GB/s. FlashAttention moves attention from the sloped (memory-limited) region to the flat (compute-limited) region of the roofline.


Recomputation in the Backward Pass

FlashAttention makes an interesting tradeoff in the backward pass (for training): instead of storing the attention matrix P for use in the backward pass, it recomputes P from Q, K, and the stored statistics (m, l).

Why this makes sense:

Standard approach:
  Forward: Compute P, store it (O(N²) memory)
  Backward: Read stored P, compute gradients

FlashAttention approach:
  Forward: Compute P in blocks, store only m, l (O(N) memory)
  Backward: Recompute P in blocks from Q, K, m, l

The recomputation costs extra FLOPs, but:

  1. Saves O(N²) memory (critical for training with limited GPU memory)
  2. The recomputation is also done in SRAM, so it's fast
  3. The alternative (storing P) would be memory-bound anyway

The tradeoff:

A: Store P (Standard) B: Recompute P (FlashAttention)
Memory O(N²) for P O(N) for statistics
Forward FLOPs 4N²d 4N²d
Backward FLOPs ~4N²d ~8N²d (recompute P + gradients)
Total FLOPs ~8N²d 12N²d (1.5× more)
Tradeoff O(N²) memory limits batch size and sequence length O(N) memory allows larger batches/sequences. Often FASTER because larger batch → better utilization, and recomputation is compute-bound

This illustrates a general principle: with modern GPUs, computation is often "free" compared to memory access. Trading memory for compute is frequently a good deal.


FlashAttention-2: Optimizations

FlashAttention-2 (2023) introduced several improvements over the original FlashAttention, achieving 2× speedup:

1. Reduced Non-Matmul FLOPs

The original FlashAttention had significant overhead from non-matrix-multiplication operations (rescaling, statistics updates). FlashAttention-2 optimized these:

FlashAttention-1:
  - Rescaling O_i after every KV block
  - Multiple synchronization points

FlashAttention-2:
  - Delayed rescaling: accumulate first, rescale once at the end
  - Fewer synchronization points
  - Better instruction scheduling

2. Better Parallelization

FlashAttention-1 FlashAttention-2
Parallelizes over Batch and heads only Batch, heads, and sequence length
Grid dimensions (batch_size, num_heads, 1) (batch_size, num_heads, num_q_blocks)
Example (batch=1, heads=8, seq=4096, block=128) 8 thread blocks → massive underutilization (A100 has 108 SMs) 8 × 32 = 256 thread blocks → good GPU utilization
Key insight Different query blocks are independent and can be processed in parallel

3. Improved Work Partitioning

FlashAttention-2 swaps the loop order for the backward pass:

FlashAttention-1 backward: For each KV block, iterate over Q blocks
FlashAttention-2 backward: For each Q block, iterate over KV blocks

The second order has better parallelism and memory access patterns.

4. Better Warp-Level Optimizations

More efficient use of warp shuffles and shared memory access patterns, reducing bank conflicts.

Result: FlashAttention-2 achieves ~2× speedup over FlashAttention-1, reaching 50-73% of theoretical max throughput (compared to 25-40% for FlashAttention-1).


FlashAttention-3: Hopper Architecture

FlashAttention-3 (2024) is specifically optimized for NVIDIA H100 (Hopper architecture), exploiting new hardware features:

1. Warp Specialization

Hopper supports "warp specialization" where different warps within a thread block can perform different tasks asynchronously:

Previous approach (all warps do same work):

  • All warps execute Load → Compute → Store sequentially
  • All warps wait for memory at the same time → stalls

FlashAttention-3 (specialized warps, pipelined):

Time Step Producer Warps Consumer Warps
T1 Load K₀V₀
T2 Load K₁V₁ Compute₀
T3 Load K₂V₂ Compute₁
T4 Load K₃V₃ Compute₂

While consumers compute on K₀V₀, producers load K₁V₁ — memory latency is hidden behind computation.

2. Asynchronous Memory Operations (TMA)

H100's Tensor Memory Accelerator (TMA) enables asynchronous bulk data transfers:

Previous: Load data → wait → compute → wait → store
H100 TMA: Issue load → compute on previous data → load completes in background

This hides memory latency much more effectively.

3. FP8 Support

H100 has native FP8 Tensor Cores, which FlashAttention-3 exploits:

Precision Tensor Core TFLOPS Memory per Element
FP16 990 2 bytes
FP8 1,979 1 byte
FP8 advantage 2× compute 2× less memory

Combined effect: 2× more ops/sec, 2× less data to transfer → 4× improvement in arithmetic intensity. FlashAttention-3 with FP8 achieves close to 1.2 PFLOPS (1,200 TFLOPS) on H100 — unprecedented for attention.

4. Block Quantization for Accuracy

Pure FP8 can have accuracy issues. FlashAttention-3 uses "block quantization":

  • Compute attention scores in higher precision
  • Quantize to FP8 for the matmul with V
  • Use per-block scaling factors to maintain accuracy

Result: FlashAttention-3 achieves 1.5-2× speedup over FlashAttention-2 on H100, reaching 75%+ of theoretical throughput.


Performance Summary Across Versions

Version Comparison

Version GPU Peak Efficiency Key Innovation
Standard Attention A100 ~25% (baseline)
FlashAttention-1 A100 25-40% Tiling + online softmax
FlashAttention-2 A100 50-73% Better parallelism, fewer non-matmul ops
FlashAttention-3 H100 75%+ Warp specialization, TMA, FP8

Typical Speedups (forward pass, seq_len=4096)

Transition Speedup
Standard → FlashAttention-1 2-4× faster
FlashAttention-1 → 2 ~2× faster
FlashAttention-2 → 3 (H100) 1.5-2× faster
Standard → FlashAttention-3 up to 16× faster

Memory Reduction (all versions)

Standard: O(N²) → FlashAttention: O(N)

This enables:

  • 128K+ context windows (impossible with standard attention)
  • Larger batch sizes for training
  • Lower memory requirements for inference

When Does FlashAttention Help Most?

FlashAttention's benefits vary by use case:

Maximum Benefit

  • Long sequences (N > 2048) — O(N²) → O(N) memory matters more; IO reduction factor grows with N
  • Memory-constrained scenarios — Training with limited GPU memory; inference with large KV caches
  • Prefill phase (many queries) — Original FlashAttention parallelizes over queries; high utilization during prompt processing

Moderate Benefit

  • Short sequences (N < 512) — O(N²) memory is manageable; overhead of tiling may not be fully amortized
  • Already compute-bound scenarios — Large batch sizes; already high arithmetic intensity

Special Case — Decode Phase

  • Original FlashAttention has limited benefit for decode
  • Only 1 query token → limited parallelism over queries
  • Need Flash Decoding for efficient decode

Summary: What FlashAttention Achieves

1. IO Complexity Reduction:

  • Standard: O(N²) HBM accesses
  • FlashAttention: O(N²d²/M) ≈ O(N²/M) HBM accesses (M = SRAM size)
  • Practical reduction: 4-16× less memory traffic for typical sequence lengths

2. Memory Complexity Reduction:

  • Standard: O(N²) storage for attention matrix
  • FlashAttention: O(N) storage (just statistics)
  • This enables 100K+ context windows

3. Speed Improvement:

  • Moves attention from memory-bound to compute-bound regime
  • 2-4× speedup over standard attention (FlashAttention-1)
  • Up to 16× with latest optimizations (FlashAttention-3 on H100)

4. Exactness:

  • FlashAttention computes the exact same result as standard attention
  • Online softmax is mathematically equivalent, not an approximation

5. Key Insight:

  • By reducing IO, FlashAttention increases arithmetic intensity
  • This shifts the operation from memory-bound to compute-bound
  • More FLOPs can mean faster execution when IO is the bottleneck

Check Your Understanding

  1. If SRAM size doubles (from 192KB to 384KB), what happens to FlashAttention's IO complexity? Does this make it faster?
  2. FlashAttention-2 achieves 50-73% of theoretical GPU throughput, while standard attention achieves ~25%. Where does the remaining 27-50% go? (What prevents 100% utilization?)
  3. Why does FlashAttention-3 require Hopper (H100) architecture? Could the same optimizations be implemented on A100?
  4. For training, FlashAttention recomputes the attention matrix in the backward pass instead of storing it. Under what conditions would storing the matrix actually be faster? (Hint: think about memory bandwidth vs compute tradeoffs)
  5. A model uses GQA with 8 KV heads and 64 query heads. How does this affect FlashAttention's performance compared to standard MHA? (Hint: think about what gets re-read in the inner loop)

Community

Sign up or log in to comment