2.1: Standard Attention — The IO Problem

Community Article Published February 3, 2026

Setting the Stage

You already understand the attention mechanism mathematically. You know that attention computes a weighted sum of values, where the weights come from the compatibility between queries and keys. What we need to examine now is not what attention computes, but how a naive implementation computes it—and why that "how" is catastrophically inefficient from a memory perspective.


The Standard Attention Formula

Let's start with the mathematical formulation. For a single attention head:

Attention(Q, K, V) = softmax(QK^T / √d) × V

Where:

  • Q (queries): [N, d] — N tokens, each with a d-dimensional query vector
  • K (keys): [N, d] — N tokens, each with a d-dimensional key vector
  • V (values): [N, d] — N tokens, each with a d-dimensional value vector
  • d is the head dimension (typically 64 or 128)
  • N is the sequence length

The output O is [N, d] — for each of the N tokens, we get a d-dimensional output vector.

For this artifact, we'll use concrete numbers:

  • N = 4096 (sequence length)
  • d = 128 (head dimension)
  • FP16 precision (2 bytes per number)

Breaking Down the Computation: Three Steps

A standard implementation breaks attention into three distinct operations:

Step 1: S = QK^T / √d        Compute attention scores
Step 2: P = softmax(S)       Normalize scores to probabilities
Step 3: O = PV               Compute weighted sum of values

Let's trace through each step, tracking exactly what data moves between HBM and the compute units.


Step 1: Computing the Attention Scores (S = QK^T / √d)

The operation: Multiply Q [N, d] by K^T [d, N] to get S [N, N], then scale by 1/√d.

image

Memory traffic analysis:

Read Q:    N × d × 2 bytes = 4096 × 128 × 2 = 1,048,576 bytes = 1 MB
Read K:    N × d × 2 bytes = 4096 × 128 × 2 = 1,048,576 bytes = 1 MB
Write S:   N × N × 2 bytes = 4096 × 4096 × 2 = 33,554,432 bytes = 32 MB

Total HBM traffic for Step 1: 34 MB

Computation:

Each element S[i,j] is a dot product of row i of Q with column j of K^T (which is row j of K):

S[i,j] = (1/√d) × Σ(k=0 to d-1) Q[i,k] × K[j,k]

That's d multiply-adds per element, and there are N² elements:

FLOPs = N² × 2d = 4096² × 2 × 128 = 4,294,967,296 FLOPs ≈ 4.3 GFLOP

Arithmetic intensity for Step 1:

Arithmetic intensity = 4.3 GFLOP / 34 MB ≈ 126 FLOPs/byte

This is actually close to the A100 ridge point (156)! The matrix multiplication itself isn't terribly inefficient. The problem is what comes next.


Step 2: Applying Softmax (P = softmax(S))

The operation: For each row of S, compute the softmax to get probability weights.

Softmax for row i:

P[i,j] = exp(S[i,j] - max(S[i,:])) / Σ(k=0 to N-1) exp(S[i,k] - max(S[i,:]))

The subtraction of max(S[i,:]) is for numerical stability—without it, exp() would overflow for large values.

image

For each of the N rows:

  1. Find max value (N comparisons)
  2. Subtract max and compute exp (N subtractions + N exp)
  3. Sum the exponentials (N additions)
  4. Divide each element by sum (N divisions)

Memory traffic analysis:

Read S:    N × N × 2 bytes = 32 MB
Write P:   N × N × 2 bytes = 32 MB

Total HBM traffic for Step 2: 64 MB

Computation:

For each of N rows, we do approximately 5N operations (find max, subtract, exp, sum, divide):

FLOPs ≈ N × 5N = 5N² = 5 × 4096² ≈ 84 MFLOP

Arithmetic intensity for Step 2:

Arithmetic intensity = 84 MFLOP / 64 MB ≈ 1.3 FLOPs/byte

This is extremely low! Softmax is a severely memory-bound operation. We're reading and writing 64 MB of data but only doing 84 million operations—the GPU spends almost all its time waiting for memory.


Step 3: Computing the Output (O = PV)

The operation: Multiply P [N, N] by V [N, d] to get O [N, d].

image

Memory traffic analysis:

Read P:    N × N × 2 bytes = 32 MB
Read V:    N × d × 2 bytes = 1 MB
Write O:   N × d × 2 bytes = 1 MB

Total HBM traffic for Step 3: 34 MB

Computation:

Matrix multiplication of [N, N] × [N, d]:

FLOPs = N × d × 2N = 2N²d = 2 × 4096² × 128 ≈ 4.3 GFLOP

Arithmetic intensity for Step 3:

Arithmetic intensity = 4.3 GFLOP / 34 MB ≈ 126 FLOPs/byte

Again, the matrix multiplication itself is reasonably efficient.


The Complete Picture: Total HBM Traffic

Let's sum up all the memory traffic (N = 4096, d = 128, FP16):

Step Reads from HBM Writes to HBM Subtotal
Step 1: S = QK^T / √d Q (1 MB) + K (1 MB) S (32 MB) ← N×N intermediate 34 MB
Step 2: P = softmax(S) S (32 MB) ← N×N intermediate P (32 MB) ← N×N intermediate 64 MB
Step 3: O = PV P (32 MB) + V (1 MB) O (1 MB) 34 MB
Total 132 MB

Visualizing the Data Flow

Here's how data flows through HBM in standard attention:

image

Each arrow crossing the HBM ↔ Compute boundary represents slow memory access. 3 complete round-trips through HBM for intermediate data. The problem isn't the computation — it's that we keep bouncing data back and forth through the slow HBM bottleneck.


How This Scales: The O(N²) Problem

Let's calculate total HBM traffic for different sequence lengths:

Sequence Length (N) N×N Matrix Size (FP16) Total HBM Traffic Time at 2 TB/s Scales as
512 0.5 MB 2 MB 0.001 ms
1,024 2 MB 8 MB 0.004 ms
2,048 8 MB 33 MB 0.016 ms
4,096 32 MB 132 MB 0.066 ms
8,192 128 MB 528 MB 0.264 ms
16,384 512 MB 2,112 MB 1.056 ms
32,768 2,048 MB 8,448 MB 4.224 ms
65,536 8,192 MB 33,792 MB 16.896 ms
131,072 32,768 MB 135,168 MB 67.584 ms

Key observations:

  1. Quadratic scaling: Doubling sequence length quadruples memory traffic. This is because the attention matrix is N×N.
  2. Memory capacity limits: At N=32K, the attention matrix alone is 2GB per head. With 32 heads, that's 64GB just for attention matrices—nearly the entire A100 memory!
  3. Bandwidth limits: Even if memory capacity isn't an issue, the time spent on memory transfers grows quadratically. At N=128K, we'd spend 67ms just moving attention matrices around, per layer, per head.

This is why early transformers were limited to 512 or 1024 tokens. The O(N²) memory scaling made longer sequences impractical.


The Fundamental Inefficiency: Why Do We Materialize N×N?

Here's the crucial question: Why do we write the N×N matrices to HBM at all?

The answer has two parts:

Reason 1: Programming Convenience

The natural way to implement attention is as three separate operations:

# Naive PyTorch implementation
S = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d)  # [N, N]
P = torch.softmax(S, dim=-1)                              # [N, N]
O = torch.matmul(P, V)                                    # [N, d]

Each line is a separate kernel launch. Each kernel reads its inputs from HBM and writes its outputs to HBM. PyTorch (and other frameworks) don't automatically fuse these operations.

This is the simplest, most readable implementation. But it's also the slowest.

Reason 2: Softmax Requires Global Information

The softmax operation is the real culprit. Look at the formula again:

P[i,j] = exp(S[i,j] - max(S[i,:])) / Σ(k=0 to N-1) exp(S[i,k] - max(S[i,:]))

To compute P[i,j], we need:

  1. The maximum value in row i (for numerical stability)
  2. The sum of all exponentials in row i (for normalization)

Both of these require seeing all N elements of row i before we can output any element of row i.

This seems to force us to materialize the full S matrix: we need all scores before we can compute any softmax output.

Or does it?

This is exactly the insight that FlashAttention exploits. As we'll see in the next artifacts, there's a clever way to compute softmax incrementally using running statistics, avoiding the need to ever materialize the full N×N matrix.


IO Complexity: The Right Way to Analyze Attention

In traditional algorithm analysis, we count FLOPs. For memory-bound operations, FLOPs don't tell the whole story. We need to count memory operations.

IO complexity measures the number of bytes transferred between memory levels (specifically, between HBM and SRAM/registers).

Standard attention IO complexity:

Reading Q, K, V:           O(Nd)     — read each once
Writing S:                 O(N²)     — N×N matrix
Reading S:                 O(N²)     — for softmax
Writing P:                 O(N²)     — N×N matrix
Reading P:                 O(N²)     — for matmul with V
Writing O:                 O(Nd)     — read once

Total IO complexity:       O(N² + Nd) = O(N²)   (since N >> d typically)

The O(N²) IO complexity is the problem. We're moving N² data when we only need N×d output.

What would ideal IO complexity look like?

If we could somehow compute attention without materializing the N×N intermediates:

Reading Q, K, V:           O(Nd)
Writing O:                 O(Nd)

Ideal IO complexity:       O(Nd)

This would be a factor of N/d improvement! For N=4096, d=128, that's 32× less memory traffic.

FlashAttention achieves close to this ideal.


Overall Arithmetic Intensity of Standard Attention

Let's compute the overall arithmetic intensity to confirm this operation is memory-bound:

Total FLOPs:
  - Step 1 (QK^T):     2N²d  = 2 × 4096² × 128 = 4.29 GFLOP
  - Step 2 (softmax):  5N²   = 5 × 4096²       = 0.08 GFLOP
  - Step 3 (PV):       2N²d  = 2 × 4096² × 128 = 4.29 GFLOP
  - Total:             ~4N²d = 4 × 4096² × 128 = 8.59 GFLOP

Total bytes transferred:   132 MB (calculated earlier)

Arithmetic intensity = 8.59 GFLOP / 132 MB ≈ 65 FLOPs/byte

This is well below the A100 ridge point of 156 FLOPs/byte. Standard attention is memory-bound, even though it involves matrix multiplications.

The culprit is the O(N²) intermediate data. The matrix multiplications themselves have reasonable arithmetic intensity (~126), but the softmax operation (1.3 FLOPs/byte) and the repeated reads/writes of N×N matrices drag down the overall efficiency.


Comparison: What We Have vs What We Need

Overhead: 132 MB / 4 MB = 33× more memory traffic than necessary. This grows with sequence length:

  • N = 4K: 33× overhead
  • N = 8K: 66× overhead
  • N = 16K: 132× overhead
  • N = 32K: 264× overhead

The Memory Capacity Problem

Beyond bandwidth, there's also a capacity problem. The N×N attention matrix must fit in GPU memory.

Sequence Length Attention Matrix (1 head) 32 Heads, 1 Layer 32 Heads, 32 Layers
2,048 8 MB 256 MB 8 GB
4,096 32 MB 1,024 MB 32 GB
8,192 128 MB 4,096 MB 128 GB ← OOM!
16,384 512 MB 16,384 MB 512 GB ← OOM!
32,768 2,048 MB 65,536 MB 2,048 GB ← OOM!

Note: In practice, most implementations compute one layer at a time (not all 32 simultaneously), so the "32 layers" column is theoretical. But even single-layer requirements become problematic at long sequence lengths.


What FlashAttention Will Do Differently

Now that we understand the problem, we can appreciate what FlashAttention accomplishes. Here's a preview (detailed in the next artifacts):

Metric Standard FlashAttention
Memory for N×N intermediates O(N²) — must fit in HBM O(1)* — only tile in SRAM
HBM traffic (IO complexity) O(N²) O(N²d/M), where M = SRAM size
HBM traffic (N=4096, d=128) 132 MB ~4 MB (close to minimum!)
Softmax computation Requires full N×N matrix Computed incrementally with running statistics
Result Exact Exact (not approximate!)

*Per tile — the full N×N is never materialized anywhere.

The key insight: by computing attention in small tiles that fit in SRAM, and using a clever "online softmax" algorithm to accumulate results correctly, FlashAttention avoids ever materializing the N×N matrix in HBM.


Summary: The Problem We Need to Solve

Standard attention has three critical problems:

1. O(N²) Memory Traffic: The intermediate attention matrices require reading and writing O(N²) data to HBM. This dominates the total memory bandwidth usage.

2. O(N²) Memory Capacity: The attention matrices must fit in GPU memory, limiting sequence length to a few thousand tokens on practical hardware.

3. Low Arithmetic Intensity: Despite involving matrix multiplications, the overall operation is memory-bound (~65 FLOPs/byte, below the 156 ridge point) because of the softmax bottleneck and intermediate data movement.

The root cause is that standard implementations materialize the N×N attention matrix, when in principle we only need the N×d output. This happens because:

  • It's the natural way to write the code (three separate operations)
  • Softmax appears to require global information before producing any output

FlashAttention solves both problems by:

  • Tiling the computation to fit in SRAM
  • Using online softmax to compute incrementally without seeing all scores at once

This is what we'll study next.


Check Your Understanding

Before moving to FlashAttention:

  1. For N = 8192, d = 128, calculate the size of the attention matrix in MB (FP16). How does this compare to the size of Q, K, V combined?
    • Attention matrix size: N × N × 2 bytes = 8192 × 8192 × 2 = 128 MB
    • Q, K, V combined: 3 × N × d × 2 bytes = 3 × 8192 × 128 × 2 = 6 MB
    • Ratio: 128 MB / 6 MB ≈ 21× larger. The intermediate attention matrix dwarfs the actual input data.
    • This ratio grows with N: at N=32K, it would be 85×.
  2. If we could magically fuse Steps 1, 2, and 3 into a single kernel that never writes intermediates to HBM, what would the HBM traffic be? What would the arithmetic intensity become?
    • HBM traffic: Only read Q, K, V (3 MB) and write O (1 MB) = 4 MB total
    • FLOPs: Still ~8.6 GFLOP (computation doesn't change)
    • Arithmetic intensity: 8.6 GFLOP / 4 MB = 2,150 FLOPs/byte
    • This is 14× above the A100 ridge point (156) — the operation would become compute-bound instead of memory-bound!
    • Speedup: From 65 FLOPs/byte to 2,150 FLOPs/byte = 33× improvement in efficiency. This is exactly what FlashAttention achieves (approximately).
  3. Why can't we simply fuse softmax into the matmul kernels using standard kernel fusion techniques? What makes softmax special?
    • Standard fusion works for elementwise ops (ReLU, add, etc.) because each output element depends only on corresponding input elements — you can compute output[i] as soon as you have input[i].
    • Softmax requires global information: To compute P\[i,j\], you need:
      • max(S\[i,:\]) — the maximum across the entire row
      • Σ exp(S\[i,k\]) — the sum of exponentials across the entire row
    • You can't output any element of row i until you've seen all elements of row i.
    • This data dependency prevents naive fusion: you'd need to fully compute row i of S before starting softmax for that row.
    • FlashAttention's solution: "Online softmax" — maintain running max and running sum, update them as new tiles arrive, and rescale previously computed partial results. This breaks the global dependency.
  4. A model has 32 heads and 32 layers. For a 4096-token sequence, how much HBM traffic does standard attention generate in total? (Use our 132 MB per head figure.)
    • Per head: 132 MB
    • Per layer (32 heads): 132 MB × 32 = 4,224 MB ≈ 4.1 GB
    • Total (32 layers): 4.1 GB × 32 = 131 GB
    • On A100 (2 TB/s bandwidth): 131 GB / 2,000 GB/s = 65.5 ms just for attention memory transfers!
    • This is for a single forward pass. And it doesn't include the linear projections, FFN, or any other operations.
  5. If H100 has 3.35 TB/s bandwidth instead of 2 TB/s, does that solve the standard attention efficiency problem? Why or why not?
    • No. H100's 1.7× bandwidth improvement only gives 1.7× speedup for memory-bound operations.
    • The fundamental problem is the 33× overhead in memory traffic (132 MB vs 4 MB needed). You're still doing 33× more memory transfers than necessary — just slightly faster.
    • The right solution is algorithmic: Reduce the bytes transferred (FlashAttention), not just transfer them faster.
    • Analogy: If you're driving 33× further than necessary, a faster car helps a little, but the real fix is taking the shorter route.
    • Additionally: H100's higher ridge point (296 vs 156 FLOPs/byte) makes standard attention more memory-bound relative to peak, not less.

Community

Sign up or log in to comment