2.2a: FlashAttention — The Tiling Strategy
The Core Insight
In Artifact 2.1, we identified the problem: standard attention writes O(N²) intermediate data to HBM because it materializes the full N×N attention matrix. We also noted that we only need the final N×d output—the N×N matrix is just an intermediate step.
FlashAttention's core insight is simple to state but profound in its implications:
What if we never materialize the N×N matrix at all?
Instead of computing the entire attention matrix and then applying softmax and then multiplying by V, we compute attention in small tiles that fit entirely in SRAM. Each tile produces a partial result, and we accumulate these partial results to get the final output.
The challenge is softmax: it requires global information (the sum over all elements). How can we compute softmax correctly if we only see a few elements at a time? The answer is the "online softmax" algorithm, which we'll cover in Artifact 2.2b. For now, let's focus on the tiling strategy itself.
Tiling: The High-Level Idea
Tiling means dividing the computation into small blocks that fit in fast memory (SRAM). Instead of:
1. Compute ALL of QK^T → write N×N to HBM
2. Compute softmax on ALL of S → write N×N to HBM
3. Compute ALL of PV → write N×d to HBM
We do:
For each block of queries (Q_block):
For each block of keys/values (K_block, V_block):
Load Q_block, K_block, V_block into SRAM
Compute partial attention scores (small block, stays in SRAM)
Update running output using partial scores (stays in SRAM/registers)
Write final output for this Q_block to HBM
The key difference: intermediate attention scores never leave SRAM. We only write the final output to HBM.
Visualizing Standard vs Tiled Computation
Let's visualize what happens with N=16 tokens, split into blocks of 4:
Standard Attention (N=16)
Step 1: Compute FULL attention matrix S = QK^T
| K₀-K₃ | K₄-K₇ | K₈-K₁₁ | K₁₂-K₁₅ | |
|---|---|---|---|---|
| Q₀-Q₃ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ||||
| Q₄-Q₇ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ||||
| Q₈-Q₁₁ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ||||
| Q₁₂-Q₁₅ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ | ■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■ |
ALL 256 elements computed and stored in HBM — 16×16 matrix materialized in memory
FlashAttention Tiled (N=16, block=4)
Iteration 1: Q_block_0 × K_block_0
| K₀-K₃ | K₄-K₇ | K₈-K₁₁ | K₁₂-K₁₅ | |
|---|---|---|---|---|
| Q₀-Q₃ | **■■■■ | |||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■** | (later) | (later) | (later) | |
| Q₄-Q₇ | (later) | (later) | (later) | (later) |
| Q₈-Q₁₁ | (later) | (later) | (later) | (later) |
| Q₁₂-Q₁₅ | (later) | (later) | (later) | (later) |
Only this 4×4 block computed, stays in SRAM!
Iteration 2: Q_block_0 × K_block_1
| K₀-K₃ | K₄-K₇ | K₈-K₁₁ | K₁₂-K₁₅ | |
|---|---|---|---|---|
| Q₀-Q₃ | (done, combined) | **■■■■ | ||
| ■■■■ | ||||
| ■■■■ | ||||
| ■■■■** | (later) | (later) | ||
| Q₄-Q₇ | (later) | (later) | (later) | (later) |
| Q₈-Q₁₁ | (later) | (later) | (later) | (later) |
| Q₁₂-Q₁₅ | (later) | (later) | (later) | (later) |
This 4×4 block computed in SRAM, combined with previous result
... continue for all K blocks, then move to next Q block ...
The crucial difference: In standard attention, the full 16×16 matrix exists in HBM. In FlashAttention, only a 4×4 block ever exists, and it stays in SRAM.
The Block Size: Fitting in SRAM
How big should the blocks be? The constraint is SRAM capacity.
What needs to fit in SRAM simultaneously:
For a single tile computation, we need:
- Q_block: [B_q, d] — block of queries
- K_block: [B_k, d] — block of keys
- V_block: [B_k, d] — block of values
- S_block: [B_q, B_k] — partial attention scores
- O_block: [B_q, d] — partial output accumulator
- Running statistics: [B_q] for max and sum (for online softmax)
Memory calculation:
Q_block: B_q × d × 2 bytes
K_block: B_k × d × 2 bytes
V_block: B_k × d × 2 bytes
S_block: B_q × B_k × 2 bytes
O_block: B_q × d × 2 bytes
Stats: B_q × 2 × 4 bytes (max and sum in FP32)
Total ≈ 2 × (B_q × d + 2 × B_k × d + B_q × B_k + B_q × d) bytes
≈ 2 × (2 × B_q × d + 2 × B_k × d + B_q × B_k) bytes
For typical values (d=128, SRAM=192KB per SM):
If B_q = B_k = 128:
Q_block: 128 × 128 × 2 = 32 KB
K_block: 128 × 128 × 2 = 32 KB
V_block: 128 × 128 × 2 = 32 KB
S_block: 128 × 128 × 2 = 32 KB
O_block: 128 × 128 × 2 = 32 KB
Stats: 128 × 8 ≈ 1 KB
Total ≈ 161 KB
This fits comfortably in 192KB SRAM! Block sizes of 64-256 are typical.
The FlashAttention Algorithm: Step by Step
Let's trace through the complete algorithm. We have N queries, N keys, N values, and we want to compute the output O.
# FLASHATTENTION ALGORITHM
# INPUT: Q, K, V in HBM, each [N, d]
# OUTPUT: O in HBM, [N, d]
# PARAMETERS:
# - B_q: query block size (e.g., 128)
# - B_k: key block size (e.g., 128)
# Divide Q, K, V into blocks
T_q = ceil(N / B_q) # number of query blocks
T_k = ceil(N / B_k) # number of key blocks
FOR i = 1 TO T_q: # Loop over query blocks
# Load query block from HBM to SRAM
Q_i = Q[i*B_q : (i+1)*B_q, :] # [B_q, d]
# Initialize output accumulator and statistics
O_i = zeros(B_q, d) # in SRAM
l_i = zeros(B_q) # running sum of exp
m_i = -infinity(B_q) # running max
FOR j = 1 TO T_k: # Loop over key/value blocks
# Load key and value blocks from HBM to SRAM
K_j = K[j*B_k : (j+1)*B_k, :] # [B_k, d]
V_j = V[j*B_k : (j+1)*B_k, :] # [B_k, d]
# Compute attention scores for this block (IN SRAM)
S_ij = Q_i @ K_j.T / sqrt(d) # [B_q, B_k]
# === ONLINE SOFTMAX UPDATE (covered in 2.2b) ===
# Update running max
m_new = max(m_i, rowmax(S_ij))
# Rescale previous output and sum
scale = exp(m_i - m_new)
O_i = O_i * scale
l_i = l_i * scale
# Compute local softmax contribution
P_ij = exp(S_ij - m_new) # [B_q, B_k]
# Update running sum
l_i = l_i + rowsum(P_ij)
# Accumulate weighted values
O_i = O_i + P_ij @ V_j # [B_q, d]
# Update max for next iteration
m_i = m_new
END FOR (j)
# Final normalization
O_i = O_i / l_i
# Write output block to HBM
O[i*B_q : (i+1)*B_q, :] = O_i
END FOR (i)
The magic: The attention score matrix S_ij is only [B_q, B_k], not [N, N]. It's computed, used, and discarded—all within SRAM. The full N×N matrix never exists anywhere.
Memory Access Pattern: Standard vs FlashAttention
Let's trace exactly what gets read from and written to HBM:
Memory Access Comparison (N=4096, B=128, d=128)
Standard Attention:
| Operation | Reads | Writes |
|---|---|---|
| Q matrix | 1 MB (once) | — |
| K matrix | 1 MB (once) | — |
| S matrix (scores) | 32 MB (for softmax) | 32 MB |
| P matrix (attention) | 32 MB (for matmul) | 32 MB |
| V matrix | 1 MB (once) | — |
| O matrix (output) | — | 1 MB |
| Total | 67 MB | 65 MB |
| TOTAL HBM TRAFFIC | 132 MB |
FlashAttention:
Number of blocks: T_q = T_k = 4096/128 = 32
| Operation | Reads | Writes |
|---|---|---|
| Q blocks | 1 MB (each read once) | — |
| K blocks | 32 MB (read T_q times) | — |
| V blocks | 32 MB (read T_q times) | — |
| O blocks (output) | — | 1 MB |
| Total | 65 MB | 1 MB |
| TOTAL HBM TRAFFIC | 66 MB |
Key observations:
- NO N×N matrices ever written/read
- K, V re-reads are sequential and cache-friendly
Scaling Comparison:
| N | Standard Attention | FlashAttention | Reduction |
|---|---|---|---|
| 4,096 | 132 MB | 66 MB | 2× |
| 8,192 | 528 MB | 130 MB | 4× |
| 16,384 | 2,112 MB | 258 MB | 8× |
Why the improvement grows with N: Standard attention's traffic is O(N²) (dominated by the N×N matrices). FlashAttention's traffic is O(N²/M) where M is effectively the block size. As N grows, the O(N²) term dominates more, and FlashAttention's advantage increases.
The Iteration Pattern Visualized
Let's see how the algorithm progresses through the computation (N=8, B=2):
Full Attention Matrix (Never Materialized)
| K₀K₁ | K₂K₃ | K₄K₅ | K₆K₇ | Outputs | |
|---|---|---|---|---|---|
| Q₀Q₁ | 1 | 2 | 3 | 4 | → Rows 0-1 |
| Q₂Q₃ | 5 | 6 | 7 | 8 | → Rows 2-3 |
| Q₄Q₅ | 9 | 10 | 11 | 12 | → Rows 4-5 |
| Q₆Q₇ | 13 | 14 | 15 | 16 | → Rows 6-7 |
Numbers show execution order
Execution Order
OUTER LOOP i=0 (Q₀Q₁):
- j=0: Load Q₀₁, K₀₁, V₀₁ → compute block 1 → update O₀₁
- j=1: Load K₂₃, V₂₃ → compute block 2 → update O₀₁
- j=2: Load K₄₅, V₄₅ → compute block 3 → update O₀₁
- j=3: Load K₆₇, V₆₇ → compute block 4 → update O₀₁
- → Write final O₀₁ to HBM
OUTER LOOP i=1 (Q₂Q₃):
- j=0: Load Q₂₃, K₀₁, V₀₁ → compute block 5 → update O₂₃
- j=1: Load K₂₃, V₂₃ → compute block 6 → update O₂₃
- j=2: Load K₄₅, V₄₅ → compute block 7 → update O₂₃
- j=3: Load K₆₇, V₆₇ → compute block 8 → update O₂₃
- → Write final O₂₃ to HBM
... and so on for i=2 (Q₄Q₅) and i=3 (Q₆Q₇) ...
Key Observations
- K and V blocks are loaded multiple times (once per query block)
- But each load is only B×d elements, not N×d
- No block larger than B×B (here 2×2) is ever in memory
Why Doesn't Re-reading K and V Hurt Performance?
You might worry: "We're reading K and V multiple times! Doesn't that waste bandwidth?"
The answer involves understanding what we're trading off:
Standard attention:
- Reads K once: O(Nd)
- Reads V once: O(Nd)
- But writes/reads S: O(N²)
- And writes/reads P: O(N²)
- Total: O(N²)
FlashAttention:
- Reads K multiple times: O(Nd × T_q) = O(Nd × N/B_q) = O(N²d/B_q)
- Reads V multiple times: O(N²d/B_q)
- No S or P traffic!
- Total: O(N²d/B_q) = O(N²d/M) where M ≈ B_q
For typical values (d=128, B_q=128, so M ≈ d):
- FlashAttention traffic: O(N²d/d) = O(N²)
Wait, that's still O(N²)! But the constant factor is very different:
- Standard: ~4N² bytes (two N×N matrices)
- FlashAttention: ~4Nd bytes × (N/B_q) = 4N²d/B_q bytes
For N=4096, d=128, B_q=128:
- Standard: 4 × 4096² × 2 = 134 MB for intermediates alone
- FlashAttention: 4 × 4096 × 128 × 2 × (4096/128) = 134 MB for K,V re-reads
They look similar! But there's a crucial difference: the re-reads of K and V benefit from L2 cache and sequential access patterns, while the N×N intermediate reads are random and cache-unfriendly.
More importantly, as N grows, standard attention needs to store O(N²), which eventually exceeds memory capacity. FlashAttention never needs more than O(Bd) storage.
The Storage Advantage: O(N) vs O(N²)
Even if bandwidth were infinite, standard attention has a fundamental memory capacity problem:
Memory Capacity Comparison
Standard Attention - Storage Required:
| Component | Size | Complexity | Notes |
|---|---|---|---|
| Q, K, V | 3 × N × d × 2 bytes | O(Nd) | |
| S (scores matrix) | N × N × 2 bytes | O(N²) | ← Problem! |
| P (attention matrix) | N × N × 2 bytes | O(N²) | ← Problem! |
| O (output) | N × d × 2 bytes | O(Nd) | |
| TOTAL | O(N²) | Dominated by intermediates |
Examples:
- N = 32K: S alone = 32K × 32K × 2 = 2 GB per head!
- N = 128K: S alone = 128K × 128K × 2 = 32 GB per head!
FlashAttention - Storage Required:
| Component | Size | Complexity | Notes |
|---|---|---|---|
| Q, K, V | 3 × N × d × 2 bytes | O(Nd) | |
| O (output) | N × d × 2 bytes | O(Nd) | |
| Tile buffers | O(Bd) | O(1) | Relative to N |
| Statistics | O(N) | O(N) | Max and sum per row |
| TOTAL | O(N) | Since d is constant |
Examples:
- N = 32K: Q,K,V,O = 4 × 32K × 128 × 2 = 32 MB
- N = 128K: Q,K,V,O = 4 × 128K × 128 × 2 = 128 MB
This is why 128K+ context windows are possible!
This O(N) vs O(N²) memory difference is arguably more important than the bandwidth savings. It's what enables modern LLMs to have 100K+ context windows.
Handling Causal Masking
In autoregressive models, we use causal (triangular) masking: token i can only attend to tokens 0, 1, ..., i. This means the upper triangle of the attention matrix should be zero (or negative infinity before softmax).
FlashAttention handles this elegantly:
Full Causal Attention Matrix
| K₀ | K₁ | K₂ | K₃ | K₄ | K₅ | K₆ | K₇ | |
|---|---|---|---|---|---|---|---|---|
| Q₀ | ■ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Q₁ | ■ | ■ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Q₂ | ■ | ■ | ■ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Q₃ | ■ | ■ | ■ | ■ | ✗ | ✗ | ✗ | ✗ |
| Q₄ | ■ | ■ | ■ | ■ | ■ | ✗ | ✗ | ✗ |
| Q₅ | ■ | ■ | ■ | ■ | ■ | ■ | ✗ | ✗ |
| Q₆ | ■ | ■ | ■ | ■ | ■ | ■ | ■ | ✗ |
| Q₇ | ■ | ■ | ■ | ■ | ■ | ■ | ■ | ■ |
■ = computed | ✗ = masked (skipped)
With Tiling (B=2)
| K₀K₁ | K₂K₃ | K₄K₅ | K₆K₇ | |
|---|---|---|---|---|
| Q₀Q₁ | ▣ | skip | skip | skip |
| Q₂Q₃ | ■ | ▣ | skip | skip |
| Q₄Q₅ | ■ | ■ | ▣ | skip |
| Q₆Q₇ | ■ | ■ | ■ | ▣ |
Legend:
- ▣ = partial block (apply mask within)
- ■ = full block (no mask needed)
- skip = entirely masked, skip computation
Optimization: Skip Entirely Masked Blocks
- Blocks where min(query_indices) > max(key_indices) → skip
- Reduces computation by ~50% for causal attention
For causal masking, FlashAttention can skip entire blocks that would be completely masked. This gives nearly 2× speedup for causal attention compared to bidirectional attention.
FlashAttention and the KV Cache: What It Does and Doesn't Solve
A common misconception is that FlashAttention solves all memory problems. It doesn't. Understanding what FlashAttention addresses—and what it leaves unsolved—is crucial.
The Distinction: Input Data vs. Computed Intermediates
| Data | Type | Can avoid storing in HBM? | Why? |
|---|---|---|---|
| KV Cache | Input/Source data | ❌ No | Must persist—it's the actual K, V vectors from previous tokens |
| Attention Matrix (S, P) | Computed intermediate | ✅ Yes | Derived from Q, K—can recompute any tile on demand |
Why the attention matrix doesn't need to fully exist:
The attention matrix S = QK^T is derived from Q and K. At any moment, you can recompute any tile: S_ij = Q_i @ K_j^T. FlashAttention exploits this—instead of computing all of S, storing it, then using it, you compute one tile, use it immediately, discard it, and move on.
Why the KV cache MUST exist:
The KV cache contains the actual K and V vectors computed during prefill or accumulated during decode. You cannot "recompute K_j on demand" without re-running the entire model on the original input tokens—defeating the purpose of caching.
Analogy:
- KV cache = A book you're referencing (must exist on your shelf)
- Attention matrix = Notes you derive from the book (can re-derive anytime by looking at the book again)
FlashAttention says: "Don't write out all your notes at once—just look at the book page by page and write your final summary directly." But the book itself must still exist.
What FlashAttention Helps With
| Problem | FlashAttention helps? | Explanation |
|---|---|---|
| O(N²) attention matrix storage | ✅ Yes | Never materializes full matrix |
| O(N²) attention matrix bandwidth | ✅ Yes | Tiles stay in SRAM |
| KV cache storage in HBM | ❌ No | Full cache must still exist in HBM |
| KV cache bandwidth during attention | ✅ Partially | Streams blocks efficiently, L2 cache helps |
When KV Cache Becomes the Bottleneck
For very long sequences, even with FlashAttention, KV cache storage can exhaust GPU memory:
KV Cache Size = 2 × layers × heads × seq_len × head_dim × bytes_per_element
Example: Llama 2 70B with 100K context
- 2 × 80 layers × 8 KV heads × 100,000 × 128 × 2 bytes
- = ~32 GB just for KV cache (per request!)
Solutions for KV Cache Growth
Since FlashAttention doesn't reduce KV cache storage, other techniques are needed:
1. Reduce KV cache size per token:
| Technique | How it helps | Reduction |
|---|---|---|
| GQA/MQA | Share K/V heads across query heads | 4-8× |
| KV cache quantization | INT8/INT4 instead of FP16 | 2-4× |
| Low-rank compression | Compress K/V representations | Variable |
2. Reduce tokens stored:
| Technique | How it helps |
|---|---|
| Sliding window attention | Only cache last W tokens (Mistral) |
| Eviction policies | Drop low-attention tokens (H2O, Scissorhands) |
| Landmark attention | Keep important "landmark" tokens |
3. Memory management (same size, better utilization):
| Technique | How it helps |
|---|---|
| PagedAttention (vLLM) | Virtual memory for KV cache—no fragmentation, efficient batching |
| Chunked prefill | Interleave prefill/decode to balance memory |
4. Distribute across devices:
| Technique | How it helps |
|---|---|
| Ring Attention | Shard KV cache across GPUs, stream blocks |
| Tensor parallelism | Split heads across GPUs |
| Offloading | Spill KV cache to CPU/NVMe (slow but works) |
The practical stack for production 100K+ context:
GQA (smaller cache per token)
+ PagedAttention (efficient memory management)
+ FlashAttention (efficient attention compute)
+ KV cache quantization (optional, 2× more capacity)
KV Cache Construction During Decode
A key distinction: FlashAttention's "blocks" are a compute-time concept, not a storage concept. The KV cache is appended one token at a time, not in blocks.
How KV Cache Grows During Decode
Decode step 101:
1. New token embedding → passes through layers
2. At each attention layer:
- Compute k_new, v_new for this ONE token [1, d]
- Append to cache: K_cache[101] = k_new, V_cache[101] = v_new
- Now cache has 102 entries [0:101]
3. Compute attention over entire cache (FlashAttention streams in blocks)
4. Continue to next layer...
The Separation of Concerns
| Concern | Granularity | What handles it |
|---|---|---|
| KV cache storage/append | Token-by-token | Memory management (continuous or paged) |
| KV cache reading for attention | Block-by-block | FlashAttention |
Visualization
Decode step: generating token 257
KV Cache state in HBM:
┌─────────────────────────────────────────────────────────────┐
│ K: [k₀, k₁, k₂, ..., k₂₅₅, k₂₅₆] ← 257 vectors │
│ V: [v₀, v₁, v₂, ..., v₂₅₅, v₂₅₆] │
└─────────────────────────────────────────────────────────────┘
↑
Just appended k₂₅₆, v₂₅₆
FlashAttention reads this as blocks (B=128):
┌──────────────┬──────────────┬──────────┐
│ Block 0 │ Block 1 │ Block 2 │
│ k₀...k₁₂₇ │ k₁₂₈...k₂₅₅ │ k₂₅₆ │ (partial block)
└──────────────┴──────────────┴──────────┘
Why Not Append in Blocks?
You generate one token at a time. You can't wait for 128 tokens before appending—each output depends on all previous tokens. The block structure is purely for efficient reading, not writing.
PagedAttention Nuance
With PagedAttention (vLLM), the KV cache IS stored in fixed-size "pages" (blocks), but:
- You still append one token at a time
- Pages are just the memory allocation unit
- A page fills up gradually (token by token) until full, then a new page is allocated
This solves memory fragmentation, not the append granularity.
Summary: What Tiling Achieves
The tiling strategy gives us three major benefits:
1. O(N) Memory Instead of O(N²)
By never materializing the full attention matrix, we only need storage proportional to sequence length, not its square. This is what enables 100K+ context windows.
2. Reduced HBM Traffic
By keeping intermediate computations in SRAM, we avoid the O(N²) reads and writes of the attention matrix. The traffic reduction factor grows with sequence length.
3. Better Cache Utilization
The K and V re-reads are sequential and can benefit from L2 cache. The computation pattern has good locality.
The Remaining Challenge:
The tiling strategy requires computing softmax incrementally. Each tile only sees a portion of the attention scores, but softmax needs global information (the sum over all scores). How do we handle this?
The answer is the online softmax algorithm, which we'll cover in the next artifact. It's the mathematical trick that makes the whole approach work.
Check Your Understanding
Before moving to online softmax:
1. If SRAM is 192KB and we need to store Q_block, K_block, V_block, S_block, and O_block (all in FP16) with d=128, what's the maximum block size B we can use? (Assume B_q = B_k = B)
Answer: We need: Q_block (B×128×2) + K_block (B×128×2) + V_block (B×128×2) + S_block (B×B×2) + O_block (B×128×2) + stats (~1KB)
= 4×B×128×2 + 2B² = 1024B + 2B² bytes
For 192KB = 196,608 bytes: 2B² + 1024B ≤ 196,608
Solving: B ≈ 185. In practice, B=128 or B=256 (powers of 2) are used. B=128 uses ~161KB, B=256 would exceed capacity.
2. For N=8192 and B=128, how many times is each K block loaded from HBM? How does this compare to standard attention's single load of K?
Answer: T_q = 8192/128 = 64 query blocks. Each K block is loaded 64 times (once per query block iteration).
Standard attention loads K once. But standard attention also writes/reads the 8192×8192 attention matrix (~134MB per head), which dwarfs the K reload cost. The trade-off favors FlashAttention.
3. Why can't we use very small blocks (e.g., B=16) to minimize re-reads of K and V? What's the tradeoff?
Answer: Smaller blocks mean more iterations (T_q = N/B increases), so K/V get reloaded more times. Also, very small blocks underutilize GPU parallelism—matrix multiplications are most efficient with larger tiles. There's also fixed overhead per block (kernel launch, memory management). The sweet spot balances SRAM capacity, compute efficiency, and reload cost.
4. In causal attention, approximately what fraction of the blocks can be skipped entirely? (Hint: think about the triangular structure)
Answer: Roughly half. The causal mask is lower-triangular, so approximately half the blocks (upper triangle) are entirely masked and can be skipped. This gives ~2× speedup for causal vs bidirectional attention.
5. If FlashAttention has O(N) memory complexity, why do we still see "out of memory" errors with very long sequences? What else consumes memory? (Hint: think about what else a transformer has besides attention)
Answer: The KV cache grows as O(N × layers × heads × d)—linear in N but multiplied by model depth. Model weights are constant but large. Activations for other layers (FFN, LayerNorm) also consume memory. For very long sequences, KV cache alone can exhaust GPU memory even without the N² attention bottleneck.
6. Why are K and V blocks reloaded multiple times while Q blocks are loaded only once?
Answer: Each output row i depends on Q_i attending to ALL keys/values: O_i = f(Q_i, K_all, V_all). Query blocks are independent—Q_0 doesn't need Q_1 to compute its output. But every query block needs every K/V block. The outer loop iterates over Q blocks, inner loop streams through K/V blocks. This asymmetry means K/V are reloaded T_q times while each Q block is loaded once.
7. When is the output written to HBM? Is each tile's output written, or only the final result?
Answer: Only the final result per query block. The accumulator O_i stays in SRAM throughout the entire inner loop, being updated as each K/V tile is processed. Only after ALL T_k key/value blocks are processed does O_i get written to HBM at rows [i×B_q : (i+1)×B_q]. Tile outputs are never written—they're accumulated in-place.
8. Is FlashAttention used in decode phase? When is it most beneficial?
Answer: FlashAttention is most beneficial during prefill where you have N queries × N keys creating the N² bottleneck. In single-token decode, attention is only [1, N]—no N² problem exists, and decode is memory-bound anyway (dominated by KV cache reads). FlashAttention is still used in decode (unified codepath) but provides minimal speedup. Batched decode benefits more—B concurrent requests means B queries, increasing arithmetic intensity.
9. What's the difference between reloading "the K matrix" vs "a K block"?
Answer: The entire K matrix [N, d] is never in SRAM at once. Each inner loop iteration loads one block K_j of size [B_k, d], uses it, then discards it before loading the next. The "reload cost" is the sum of many small [B_k, d] transfers, not repeated [N, d] transfers. This is why the computation fits in limited SRAM.
10. Why must the KV cache be fully stored in HBM, while the attention matrix doesn't need to be?
Answer: The attention matrix S = QK^T is a computed intermediate—you can recompute any tile on demand from Q and K. FlashAttention exploits this by computing tiles, using them, and discarding them. The KV cache, however, is source data—the actual K and V vectors from previous tokens. You cannot "recompute" the KV cache without re-running the entire model on the original inputs, which defeats the purpose of caching.
11. Does FlashAttention help with KV cache memory problems? What solutions exist for KV cache growth?
Answer: FlashAttention does NOT reduce KV cache storage—the full cache must still exist in HBM. It only helps with efficient reading (streaming in blocks). For KV cache growth, use: (1) Size reduction: GQA/MQA, KV quantization; (2) Token reduction: sliding window, eviction policies; (3) Memory management: PagedAttention; (4) Distribution: Ring Attention, tensor parallelism. Production systems typically combine GQA + PagedAttention + FlashAttention.
12. During decode, is the KV cache appended in blocks or token-by-token? How does FlashAttention's blocking relate to this?
Answer: The KV cache is appended token-by-token. Each decode step computes one k, v pair and appends it to the cache. FlashAttention's "blocks" are a compute-time concept for efficiently reading the cache, not a storage concept. You can't wait for 128 tokens before appending because each output depends on all previous tokens. With PagedAttention, pages are the memory allocation unit, but you still append one token at a time within each page.
13. If FlashAttention streams K/V in blocks rather than loading the entire cache, isn't it helping with KV cache memory?
Answer: It helps with bandwidth (efficient reading), not storage. The full KV cache must still exist somewhere in HBM—FlashAttention just doesn't need it all in SRAM simultaneously. Think of it like reading a book: you can read one page at a time (block streaming), but the whole book must still exist on your shelf (HBM). FlashAttention prevents needing a giant notepad (N×N attention matrix), but doesn't shrink the book itself.