2.1: Standard Attention — The IO Problem
Setting the Stage
You already understand the attention mechanism mathematically. You know that attention computes a weighted sum of values, where the weights come from the compatibility between queries and keys. What we need to examine now is not what attention computes, but how a naive implementation computes it—and why that "how" is catastrophically inefficient from a memory perspective.
The Standard Attention Formula
Let's start with the mathematical formulation. For a single attention head:
Attention(Q, K, V) = softmax(QK^T / √d) × V
Where:
- Q (queries): [N, d] — N tokens, each with a d-dimensional query vector
- K (keys): [N, d] — N tokens, each with a d-dimensional key vector
- V (values): [N, d] — N tokens, each with a d-dimensional value vector
- d is the head dimension (typically 64 or 128)
- N is the sequence length
The output O is [N, d] — for each of the N tokens, we get a d-dimensional output vector.
For this artifact, we'll use concrete numbers:
- N = 4096 (sequence length)
- d = 128 (head dimension)
- FP16 precision (2 bytes per number)
Breaking Down the Computation: Three Steps
A standard implementation breaks attention into three distinct operations:
Step 1: S = QK^T / √d Compute attention scores
Step 2: P = softmax(S) Normalize scores to probabilities
Step 3: O = PV Compute weighted sum of values
Let's trace through each step, tracking exactly what data moves between HBM and the compute units.
Step 1: Computing the Attention Scores (S = QK^T / √d)
The operation: Multiply Q [N, d] by K^T [d, N] to get S [N, N], then scale by 1/√d.
Memory traffic analysis:
Read Q: N × d × 2 bytes = 4096 × 128 × 2 = 1,048,576 bytes = 1 MB
Read K: N × d × 2 bytes = 4096 × 128 × 2 = 1,048,576 bytes = 1 MB
Write S: N × N × 2 bytes = 4096 × 4096 × 2 = 33,554,432 bytes = 32 MB
Total HBM traffic for Step 1: 34 MB
Computation:
Each element S[i,j] is a dot product of row i of Q with column j of K^T (which is row j of K):
S[i,j] = (1/√d) × Σ(k=0 to d-1) Q[i,k] × K[j,k]
That's d multiply-adds per element, and there are N² elements:
FLOPs = N² × 2d = 4096² × 2 × 128 = 4,294,967,296 FLOPs ≈ 4.3 GFLOP
Arithmetic intensity for Step 1:
Arithmetic intensity = 4.3 GFLOP / 34 MB ≈ 126 FLOPs/byte
This is actually close to the A100 ridge point (156)! The matrix multiplication itself isn't terribly inefficient. The problem is what comes next.
Step 2: Applying Softmax (P = softmax(S))
The operation: For each row of S, compute the softmax to get probability weights.
Softmax for row i:
P[i,j] = exp(S[i,j] - max(S[i,:])) / Σ(k=0 to N-1) exp(S[i,k] - max(S[i,:]))
The subtraction of max(S[i,:]) is for numerical stability—without it, exp() would overflow for large values.
For each of the N rows:
- Find max value (N comparisons)
- Subtract max and compute exp (N subtractions + N exp)
- Sum the exponentials (N additions)
- Divide each element by sum (N divisions)
Memory traffic analysis:
Read S: N × N × 2 bytes = 32 MB
Write P: N × N × 2 bytes = 32 MB
Total HBM traffic for Step 2: 64 MB
Computation:
For each of N rows, we do approximately 5N operations (find max, subtract, exp, sum, divide):
FLOPs ≈ N × 5N = 5N² = 5 × 4096² ≈ 84 MFLOP
Arithmetic intensity for Step 2:
Arithmetic intensity = 84 MFLOP / 64 MB ≈ 1.3 FLOPs/byte
This is extremely low! Softmax is a severely memory-bound operation. We're reading and writing 64 MB of data but only doing 84 million operations—the GPU spends almost all its time waiting for memory.
Step 3: Computing the Output (O = PV)
The operation: Multiply P [N, N] by V [N, d] to get O [N, d].
Memory traffic analysis:
Read P: N × N × 2 bytes = 32 MB
Read V: N × d × 2 bytes = 1 MB
Write O: N × d × 2 bytes = 1 MB
Total HBM traffic for Step 3: 34 MB
Computation:
Matrix multiplication of [N, N] × [N, d]:
FLOPs = N × d × 2N = 2N²d = 2 × 4096² × 128 ≈ 4.3 GFLOP
Arithmetic intensity for Step 3:
Arithmetic intensity = 4.3 GFLOP / 34 MB ≈ 126 FLOPs/byte
Again, the matrix multiplication itself is reasonably efficient.
The Complete Picture: Total HBM Traffic
Let's sum up all the memory traffic (N = 4096, d = 128, FP16):
| Step | Reads from HBM | Writes to HBM | Subtotal |
|---|---|---|---|
| Step 1: S = QK^T / √d | Q (1 MB) + K (1 MB) | S (32 MB) ← N×N intermediate | 34 MB |
| Step 2: P = softmax(S) | S (32 MB) ← N×N intermediate | P (32 MB) ← N×N intermediate | 64 MB |
| Step 3: O = PV | P (32 MB) + V (1 MB) | O (1 MB) | 34 MB |
| Total | 132 MB |
Visualizing the Data Flow
Here's how data flows through HBM in standard attention:
Each arrow crossing the HBM ↔ Compute boundary represents slow memory access. 3 complete round-trips through HBM for intermediate data. The problem isn't the computation — it's that we keep bouncing data back and forth through the slow HBM bottleneck.
How This Scales: The O(N²) Problem
Let's calculate total HBM traffic for different sequence lengths:
| Sequence Length (N) | N×N Matrix Size (FP16) | Total HBM Traffic | Time at 2 TB/s | Scales as |
|---|---|---|---|---|
| 512 | 0.5 MB | 2 MB | 0.001 ms | — |
| 1,024 | 2 MB | 8 MB | 0.004 ms | 4× |
| 2,048 | 8 MB | 33 MB | 0.016 ms | 4× |
| 4,096 | 32 MB | 132 MB | 0.066 ms | 4× |
| 8,192 | 128 MB | 528 MB | 0.264 ms | 4× |
| 16,384 | 512 MB | 2,112 MB | 1.056 ms | 4× |
| 32,768 | 2,048 MB | 8,448 MB | 4.224 ms | 4× |
| 65,536 | 8,192 MB | 33,792 MB | 16.896 ms | 4× |
| 131,072 | 32,768 MB | 135,168 MB | 67.584 ms | 4× |
Key observations:
- Quadratic scaling: Doubling sequence length quadruples memory traffic. This is because the attention matrix is N×N.
- Memory capacity limits: At N=32K, the attention matrix alone is 2GB per head. With 32 heads, that's 64GB just for attention matrices—nearly the entire A100 memory!
- Bandwidth limits: Even if memory capacity isn't an issue, the time spent on memory transfers grows quadratically. At N=128K, we'd spend 67ms just moving attention matrices around, per layer, per head.
This is why early transformers were limited to 512 or 1024 tokens. The O(N²) memory scaling made longer sequences impractical.
The Fundamental Inefficiency: Why Do We Materialize N×N?
Here's the crucial question: Why do we write the N×N matrices to HBM at all?
The answer has two parts:
Reason 1: Programming Convenience
The natural way to implement attention is as three separate operations:
# Naive PyTorch implementation
S = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d) # [N, N]
P = torch.softmax(S, dim=-1) # [N, N]
O = torch.matmul(P, V) # [N, d]
Each line is a separate kernel launch. Each kernel reads its inputs from HBM and writes its outputs to HBM. PyTorch (and other frameworks) don't automatically fuse these operations.
This is the simplest, most readable implementation. But it's also the slowest.
Reason 2: Softmax Requires Global Information
The softmax operation is the real culprit. Look at the formula again:
P[i,j] = exp(S[i,j] - max(S[i,:])) / Σ(k=0 to N-1) exp(S[i,k] - max(S[i,:]))
To compute P[i,j], we need:
- The maximum value in row i (for numerical stability)
- The sum of all exponentials in row i (for normalization)
Both of these require seeing all N elements of row i before we can output any element of row i.
This seems to force us to materialize the full S matrix: we need all scores before we can compute any softmax output.
Or does it?
This is exactly the insight that FlashAttention exploits. As we'll see in the next artifacts, there's a clever way to compute softmax incrementally using running statistics, avoiding the need to ever materialize the full N×N matrix.
IO Complexity: The Right Way to Analyze Attention
In traditional algorithm analysis, we count FLOPs. For memory-bound operations, FLOPs don't tell the whole story. We need to count memory operations.
IO complexity measures the number of bytes transferred between memory levels (specifically, between HBM and SRAM/registers).
Standard attention IO complexity:
Reading Q, K, V: O(Nd) — read each once
Writing S: O(N²) — N×N matrix
Reading S: O(N²) — for softmax
Writing P: O(N²) — N×N matrix
Reading P: O(N²) — for matmul with V
Writing O: O(Nd) — read once
Total IO complexity: O(N² + Nd) = O(N²) (since N >> d typically)
The O(N²) IO complexity is the problem. We're moving N² data when we only need N×d output.
What would ideal IO complexity look like?
If we could somehow compute attention without materializing the N×N intermediates:
Reading Q, K, V: O(Nd)
Writing O: O(Nd)
Ideal IO complexity: O(Nd)
This would be a factor of N/d improvement! For N=4096, d=128, that's 32× less memory traffic.
FlashAttention achieves close to this ideal.
Overall Arithmetic Intensity of Standard Attention
Let's compute the overall arithmetic intensity to confirm this operation is memory-bound:
Total FLOPs:
- Step 1 (QK^T): 2N²d = 2 × 4096² × 128 = 4.29 GFLOP
- Step 2 (softmax): 5N² = 5 × 4096² = 0.08 GFLOP
- Step 3 (PV): 2N²d = 2 × 4096² × 128 = 4.29 GFLOP
- Total: ~4N²d = 4 × 4096² × 128 = 8.59 GFLOP
Total bytes transferred: 132 MB (calculated earlier)
Arithmetic intensity = 8.59 GFLOP / 132 MB ≈ 65 FLOPs/byte
This is well below the A100 ridge point of 156 FLOPs/byte. Standard attention is memory-bound, even though it involves matrix multiplications.
The culprit is the O(N²) intermediate data. The matrix multiplications themselves have reasonable arithmetic intensity (~126), but the softmax operation (1.3 FLOPs/byte) and the repeated reads/writes of N×N matrices drag down the overall efficiency.
Comparison: What We Have vs What We Need
Overhead: 132 MB / 4 MB = 33× more memory traffic than necessary. This grows with sequence length:
- N = 4K: 33× overhead
- N = 8K: 66× overhead
- N = 16K: 132× overhead
- N = 32K: 264× overhead
The Memory Capacity Problem
Beyond bandwidth, there's also a capacity problem. The N×N attention matrix must fit in GPU memory.
| Sequence Length | Attention Matrix (1 head) | 32 Heads, 1 Layer | 32 Heads, 32 Layers |
|---|---|---|---|
| 2,048 | 8 MB | 256 MB | 8 GB |
| 4,096 | 32 MB | 1,024 MB | 32 GB |
| 8,192 | 128 MB | 4,096 MB | 128 GB ← OOM! |
| 16,384 | 512 MB | 16,384 MB | 512 GB ← OOM! |
| 32,768 | 2,048 MB | 65,536 MB | 2,048 GB ← OOM! |
Note: In practice, most implementations compute one layer at a time (not all 32 simultaneously), so the "32 layers" column is theoretical. But even single-layer requirements become problematic at long sequence lengths.
What FlashAttention Will Do Differently
Now that we understand the problem, we can appreciate what FlashAttention accomplishes. Here's a preview (detailed in the next artifacts):
| Metric | Standard | FlashAttention |
|---|---|---|
| Memory for N×N intermediates | O(N²) — must fit in HBM | O(1)* — only tile in SRAM |
| HBM traffic (IO complexity) | O(N²) | O(N²d/M), where M = SRAM size |
| HBM traffic (N=4096, d=128) | 132 MB | ~4 MB (close to minimum!) |
| Softmax computation | Requires full N×N matrix | Computed incrementally with running statistics |
| Result | Exact | Exact (not approximate!) |
*Per tile — the full N×N is never materialized anywhere.
The key insight: by computing attention in small tiles that fit in SRAM, and using a clever "online softmax" algorithm to accumulate results correctly, FlashAttention avoids ever materializing the N×N matrix in HBM.
Summary: The Problem We Need to Solve
Standard attention has three critical problems:
1. O(N²) Memory Traffic: The intermediate attention matrices require reading and writing O(N²) data to HBM. This dominates the total memory bandwidth usage.
2. O(N²) Memory Capacity: The attention matrices must fit in GPU memory, limiting sequence length to a few thousand tokens on practical hardware.
3. Low Arithmetic Intensity: Despite involving matrix multiplications, the overall operation is memory-bound (~65 FLOPs/byte, below the 156 ridge point) because of the softmax bottleneck and intermediate data movement.
The root cause is that standard implementations materialize the N×N attention matrix, when in principle we only need the N×d output. This happens because:
- It's the natural way to write the code (three separate operations)
- Softmax appears to require global information before producing any output
FlashAttention solves both problems by:
- Tiling the computation to fit in SRAM
- Using online softmax to compute incrementally without seeing all scores at once
This is what we'll study next.
Check Your Understanding
Before moving to FlashAttention:
- For N = 8192, d = 128, calculate the size of the attention matrix in MB (FP16). How does this compare to the size of Q, K, V combined?
- Attention matrix size: N × N × 2 bytes = 8192 × 8192 × 2 = 128 MB
- Q, K, V combined: 3 × N × d × 2 bytes = 3 × 8192 × 128 × 2 = 6 MB
- Ratio: 128 MB / 6 MB ≈ 21× larger. The intermediate attention matrix dwarfs the actual input data.
- This ratio grows with N: at N=32K, it would be 85×.
- If we could magically fuse Steps 1, 2, and 3 into a single kernel that never writes intermediates to HBM, what would the HBM traffic be? What would the arithmetic intensity become?
- HBM traffic: Only read Q, K, V (3 MB) and write O (1 MB) = 4 MB total
- FLOPs: Still ~8.6 GFLOP (computation doesn't change)
- Arithmetic intensity: 8.6 GFLOP / 4 MB = 2,150 FLOPs/byte
- This is 14× above the A100 ridge point (156) — the operation would become compute-bound instead of memory-bound!
- Speedup: From 65 FLOPs/byte to 2,150 FLOPs/byte = 33× improvement in efficiency. This is exactly what FlashAttention achieves (approximately).
- Why can't we simply fuse softmax into the matmul kernels using standard kernel fusion techniques? What makes softmax special?
- Standard fusion works for elementwise ops (ReLU, add, etc.) because each output element depends only on corresponding input elements — you can compute output[i] as soon as you have input[i].
- Softmax requires global information: To compute
P\[i,j\], you need:max(S\[i,:\])— the maximum across the entire rowΣ exp(S\[i,k\])— the sum of exponentials across the entire row
- You can't output any element of row i until you've seen all elements of row i.
- This data dependency prevents naive fusion: you'd need to fully compute row i of S before starting softmax for that row.
- FlashAttention's solution: "Online softmax" — maintain running max and running sum, update them as new tiles arrive, and rescale previously computed partial results. This breaks the global dependency.
- A model has 32 heads and 32 layers. For a 4096-token sequence, how much HBM traffic does standard attention generate in total? (Use our 132 MB per head figure.)
- Per head: 132 MB
- Per layer (32 heads): 132 MB × 32 = 4,224 MB ≈ 4.1 GB
- Total (32 layers): 4.1 GB × 32 = 131 GB
- On A100 (2 TB/s bandwidth): 131 GB / 2,000 GB/s = 65.5 ms just for attention memory transfers!
- This is for a single forward pass. And it doesn't include the linear projections, FFN, or any other operations.
- If H100 has 3.35 TB/s bandwidth instead of 2 TB/s, does that solve the standard attention efficiency problem? Why or why not?
- No. H100's 1.7× bandwidth improvement only gives 1.7× speedup for memory-bound operations.
- The fundamental problem is the 33× overhead in memory traffic (132 MB vs 4 MB needed). You're still doing 33× more memory transfers than necessary — just slightly faster.
- The right solution is algorithmic: Reduce the bytes transferred (FlashAttention), not just transfer them faster.
- Analogy: If you're driving 33× further than necessary, a faster car helps a little, but the real fix is taking the shorter route.
- Additionally: H100's higher ridge point (296 vs 156 FLOPs/byte) makes standard attention more memory-bound relative to peak, not less.



