2.2a: FlashAttention — The Tiling Strategy

Community Article Published February 3, 2026

The Core Insight

In Artifact 2.1, we identified the problem: standard attention writes O(N²) intermediate data to HBM because it materializes the full N×N attention matrix. We also noted that we only need the final N×d output—the N×N matrix is just an intermediate step.

FlashAttention's core insight is simple to state but profound in its implications:

What if we never materialize the N×N matrix at all?

Instead of computing the entire attention matrix and then applying softmax and then multiplying by V, we compute attention in small tiles that fit entirely in SRAM. Each tile produces a partial result, and we accumulate these partial results to get the final output.

The challenge is softmax: it requires global information (the sum over all elements). How can we compute softmax correctly if we only see a few elements at a time? The answer is the "online softmax" algorithm, which we'll cover in Artifact 2.2b. For now, let's focus on the tiling strategy itself.


Tiling: The High-Level Idea

Tiling means dividing the computation into small blocks that fit in fast memory (SRAM). Instead of:

1. Compute ALL of QK^T → write N×N to HBM
2. Compute softmax on ALL of S → write N×N to HBM
3. Compute ALL of PV → write N×d to HBM

We do:

For each block of queries (Q_block):
    For each block of keys/values (K_block, V_block):
        Load Q_block, K_block, V_block into SRAM
        Compute partial attention scores (small block, stays in SRAM)
        Update running output using partial scores (stays in SRAM/registers)
    Write final output for this Q_block to HBM

The key difference: intermediate attention scores never leave SRAM. We only write the final output to HBM.


Visualizing Standard vs Tiled Computation

Let's visualize what happens with N=16 tokens, split into blocks of 4:

Standard Attention (N=16)

Step 1: Compute FULL attention matrix S = QK^T

K₀-K₃ K₄-K₇ K₈-K₁₁ K₁₂-K₁₅
Q₀-Q₃ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■
Q₄-Q₇ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■
Q₈-Q₁₁ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■
Q₁₂-Q₁₅ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■ ■■■■
■■■■
■■■■
■■■■

ALL 256 elements computed and stored in HBM — 16×16 matrix materialized in memory


FlashAttention Tiled (N=16, block=4)

Iteration 1: Q_block_0 × K_block_0

K₀-K₃ K₄-K₇ K₈-K₁₁ K₁₂-K₁₅
Q₀-Q₃ **■■■■
■■■■
■■■■
■■■■** (later) (later) (later)
Q₄-Q₇ (later) (later) (later) (later)
Q₈-Q₁₁ (later) (later) (later) (later)
Q₁₂-Q₁₅ (later) (later) (later) (later)

Only this 4×4 block computed, stays in SRAM!

Iteration 2: Q_block_0 × K_block_1

K₀-K₃ K₄-K₇ K₈-K₁₁ K₁₂-K₁₅
Q₀-Q₃ (done, combined) **■■■■
■■■■
■■■■
■■■■** (later) (later)
Q₄-Q₇ (later) (later) (later) (later)
Q₈-Q₁₁ (later) (later) (later) (later)
Q₁₂-Q₁₅ (later) (later) (later) (later)

This 4×4 block computed in SRAM, combined with previous result

... continue for all K blocks, then move to next Q block ...


The crucial difference: In standard attention, the full 16×16 matrix exists in HBM. In FlashAttention, only a 4×4 block ever exists, and it stays in SRAM.


The Block Size: Fitting in SRAM

How big should the blocks be? The constraint is SRAM capacity.

What needs to fit in SRAM simultaneously:

For a single tile computation, we need:

  • Q_block: [B_q, d] — block of queries
  • K_block: [B_k, d] — block of keys
  • V_block: [B_k, d] — block of values
  • S_block: [B_q, B_k] — partial attention scores
  • O_block: [B_q, d] — partial output accumulator
  • Running statistics: [B_q] for max and sum (for online softmax)

Memory calculation:

Q_block:  B_q × d × 2 bytes
K_block:  B_k × d × 2 bytes
V_block:  B_k × d × 2 bytes
S_block:  B_q × B_k × 2 bytes
O_block:  B_q × d × 2 bytes
Stats:    B_q × 2 × 4 bytes (max and sum in FP32)

Total ≈ 2 × (B_q × d + 2 × B_k × d + B_q × B_k + B_q × d) bytes
      ≈ 2 × (2 × B_q × d + 2 × B_k × d + B_q × B_k) bytes

For typical values (d=128, SRAM=192KB per SM):

If B_q = B_k = 128:

Q_block:  128 × 128 × 2 = 32 KB
K_block:  128 × 128 × 2 = 32 KB
V_block:  128 × 128 × 2 = 32 KB
S_block:  128 × 128 × 2 = 32 KB
O_block:  128 × 128 × 2 = 32 KB
Stats:    128 × 8       ≈ 1 KB

Total ≈ 161 KB

This fits comfortably in 192KB SRAM! Block sizes of 64-256 are typical.


The FlashAttention Algorithm: Step by Step

Let's trace through the complete algorithm. We have N queries, N keys, N values, and we want to compute the output O.

# FLASHATTENTION ALGORITHM

# INPUT:  Q, K, V in HBM, each [N, d]
# OUTPUT: O in HBM, [N, d]

# PARAMETERS:
# - B_q: query block size (e.g., 128)
# - B_k: key block size (e.g., 128)

# Divide Q, K, V into blocks
T_q = ceil(N / B_q)    # number of query blocks
T_k = ceil(N / B_k)    # number of key blocks

FOR i = 1 TO T_q:    # Loop over query blocks
    
    # Load query block from HBM to SRAM
    Q_i = Q[i*B_q : (i+1)*B_q, :]        # [B_q, d]
    
    # Initialize output accumulator and statistics
    O_i = zeros(B_q, d)                  # in SRAM
    l_i = zeros(B_q)                     # running sum of exp
    m_i = -infinity(B_q)                 # running max
    
    FOR j = 1 TO T_k:    # Loop over key/value blocks
        
        # Load key and value blocks from HBM to SRAM
        K_j = K[j*B_k : (j+1)*B_k, :]    # [B_k, d]
        V_j = V[j*B_k : (j+1)*B_k, :]    # [B_k, d]
        
        # Compute attention scores for this block (IN SRAM)
        S_ij = Q_i @ K_j.T / sqrt(d)     # [B_q, B_k]
        
        # === ONLINE SOFTMAX UPDATE (covered in 2.2b) ===
        # Update running max
        m_new = max(m_i, rowmax(S_ij))
        
        # Rescale previous output and sum
        scale = exp(m_i - m_new)
        O_i = O_i * scale
        l_i = l_i * scale
        
        # Compute local softmax contribution
        P_ij = exp(S_ij - m_new)         # [B_q, B_k]
        
        # Update running sum
        l_i = l_i + rowsum(P_ij)
        
        # Accumulate weighted values
        O_i = O_i + P_ij @ V_j           # [B_q, d]
        
        # Update max for next iteration
        m_i = m_new
        
    END FOR (j)
    
    # Final normalization
    O_i = O_i / l_i
    
    # Write output block to HBM
    O[i*B_q : (i+1)*B_q, :] = O_i
    
END FOR (i)

The magic: The attention score matrix S_ij is only [B_q, B_k], not [N, N]. It's computed, used, and discarded—all within SRAM. The full N×N matrix never exists anywhere.


Memory Access Pattern: Standard vs FlashAttention

Let's trace exactly what gets read from and written to HBM:

Memory Access Comparison (N=4096, B=128, d=128)

Standard Attention:

Operation Reads Writes
Q matrix 1 MB (once)
K matrix 1 MB (once)
S matrix (scores) 32 MB (for softmax) 32 MB
P matrix (attention) 32 MB (for matmul) 32 MB
V matrix 1 MB (once)
O matrix (output) 1 MB
Total 67 MB 65 MB
TOTAL HBM TRAFFIC 132 MB

FlashAttention:

Number of blocks: T_q = T_k = 4096/128 = 32

Operation Reads Writes
Q blocks 1 MB (each read once)
K blocks 32 MB (read T_q times)
V blocks 32 MB (read T_q times)
O blocks (output) 1 MB
Total 65 MB 1 MB
TOTAL HBM TRAFFIC 66 MB

Key observations:

  • NO N×N matrices ever written/read
  • K, V re-reads are sequential and cache-friendly

Scaling Comparison:

N Standard Attention FlashAttention Reduction
4,096 132 MB 66 MB
8,192 528 MB 130 MB
16,384 2,112 MB 258 MB

Why the improvement grows with N: Standard attention's traffic is O(N²) (dominated by the N×N matrices). FlashAttention's traffic is O(N²/M) where M is effectively the block size. As N grows, the O(N²) term dominates more, and FlashAttention's advantage increases.


The Iteration Pattern Visualized

Let's see how the algorithm progresses through the computation (N=8, B=2):

Full Attention Matrix (Never Materialized)

K₀K₁ K₂K₃ K₄K₅ K₆K₇ Outputs
Q₀Q₁ 1 2 3 4 → Rows 0-1
Q₂Q₃ 5 6 7 8 → Rows 2-3
Q₄Q₅ 9 10 11 12 → Rows 4-5
Q₆Q₇ 13 14 15 16 → Rows 6-7

Numbers show execution order

Execution Order

OUTER LOOP i=0 (Q₀Q₁):

  • j=0: Load Q₀₁, K₀₁, V₀₁ → compute block 1 → update O₀₁
  • j=1: Load K₂₃, V₂₃ → compute block 2 → update O₀₁
  • j=2: Load K₄₅, V₄₅ → compute block 3 → update O₀₁
  • j=3: Load K₆₇, V₆₇ → compute block 4 → update O₀₁
  • → Write final O₀₁ to HBM

OUTER LOOP i=1 (Q₂Q₃):

  • j=0: Load Q₂₃, K₀₁, V₀₁ → compute block 5 → update O₂₃
  • j=1: Load K₂₃, V₂₃ → compute block 6 → update O₂₃
  • j=2: Load K₄₅, V₄₅ → compute block 7 → update O₂₃
  • j=3: Load K₆₇, V₆₇ → compute block 8 → update O₂₃
  • → Write final O₂₃ to HBM

... and so on for i=2 (Q₄Q₅) and i=3 (Q₆Q₇) ...

Key Observations

  • K and V blocks are loaded multiple times (once per query block)
  • But each load is only B×d elements, not N×d
  • No block larger than B×B (here 2×2) is ever in memory

Why Doesn't Re-reading K and V Hurt Performance?

You might worry: "We're reading K and V multiple times! Doesn't that waste bandwidth?"

The answer involves understanding what we're trading off:

Standard attention:

  • Reads K once: O(Nd)
  • Reads V once: O(Nd)
  • But writes/reads S: O(N²)
  • And writes/reads P: O(N²)
  • Total: O(N²)

FlashAttention:

  • Reads K multiple times: O(Nd × T_q) = O(Nd × N/B_q) = O(N²d/B_q)
  • Reads V multiple times: O(N²d/B_q)
  • No S or P traffic!
  • Total: O(N²d/B_q) = O(N²d/M) where M ≈ B_q

For typical values (d=128, B_q=128, so M ≈ d):

  • FlashAttention traffic: O(N²d/d) = O(N²)

Wait, that's still O(N²)! But the constant factor is very different:

  • Standard: ~4N² bytes (two N×N matrices)
  • FlashAttention: ~4Nd bytes × (N/B_q) = 4N²d/B_q bytes

For N=4096, d=128, B_q=128:

  • Standard: 4 × 4096² × 2 = 134 MB for intermediates alone
  • FlashAttention: 4 × 4096 × 128 × 2 × (4096/128) = 134 MB for K,V re-reads

They look similar! But there's a crucial difference: the re-reads of K and V benefit from L2 cache and sequential access patterns, while the N×N intermediate reads are random and cache-unfriendly.

More importantly, as N grows, standard attention needs to store O(N²), which eventually exceeds memory capacity. FlashAttention never needs more than O(Bd) storage.


The Storage Advantage: O(N) vs O(N²)

Even if bandwidth were infinite, standard attention has a fundamental memory capacity problem:

Memory Capacity Comparison

Standard Attention - Storage Required:

Component Size Complexity Notes
Q, K, V 3 × N × d × 2 bytes O(Nd)
S (scores matrix) N × N × 2 bytes O(N²) ← Problem!
P (attention matrix) N × N × 2 bytes O(N²) ← Problem!
O (output) N × d × 2 bytes O(Nd)
TOTAL O(N²) Dominated by intermediates

Examples:

  • N = 32K: S alone = 32K × 32K × 2 = 2 GB per head!
  • N = 128K: S alone = 128K × 128K × 2 = 32 GB per head!

FlashAttention - Storage Required:

Component Size Complexity Notes
Q, K, V 3 × N × d × 2 bytes O(Nd)
O (output) N × d × 2 bytes O(Nd)
Tile buffers O(Bd) O(1) Relative to N
Statistics O(N) O(N) Max and sum per row
TOTAL O(N) Since d is constant

Examples:

  • N = 32K: Q,K,V,O = 4 × 32K × 128 × 2 = 32 MB
  • N = 128K: Q,K,V,O = 4 × 128K × 128 × 2 = 128 MB

This is why 128K+ context windows are possible!


This O(N) vs O(N²) memory difference is arguably more important than the bandwidth savings. It's what enables modern LLMs to have 100K+ context windows.


Handling Causal Masking

In autoregressive models, we use causal (triangular) masking: token i can only attend to tokens 0, 1, ..., i. This means the upper triangle of the attention matrix should be zero (or negative infinity before softmax).

FlashAttention handles this elegantly:

Full Causal Attention Matrix

K₀ K₁ K₂ K₃ K₄ K₅ K₆ K₇
Q₀
Q₁
Q₂
Q₃
Q₄
Q₅
Q₆
Q₇

■ = computed | ✗ = masked (skipped)

With Tiling (B=2)

K₀K₁ K₂K₃ K₄K₅ K₆K₇
Q₀Q₁ skip skip skip
Q₂Q₃ skip skip
Q₄Q₅ skip
Q₆Q₇

Legend:

  • ▣ = partial block (apply mask within)
  • ■ = full block (no mask needed)
  • skip = entirely masked, skip computation

Optimization: Skip Entirely Masked Blocks

  • Blocks where min(query_indices) > max(key_indices) → skip
  • Reduces computation by ~50% for causal attention

For causal masking, FlashAttention can skip entire blocks that would be completely masked. This gives nearly 2× speedup for causal attention compared to bidirectional attention.


FlashAttention and the KV Cache: What It Does and Doesn't Solve

A common misconception is that FlashAttention solves all memory problems. It doesn't. Understanding what FlashAttention addresses—and what it leaves unsolved—is crucial.

The Distinction: Input Data vs. Computed Intermediates

Data Type Can avoid storing in HBM? Why?
KV Cache Input/Source data ❌ No Must persist—it's the actual K, V vectors from previous tokens
Attention Matrix (S, P) Computed intermediate ✅ Yes Derived from Q, K—can recompute any tile on demand

Why the attention matrix doesn't need to fully exist:

The attention matrix S = QK^T is derived from Q and K. At any moment, you can recompute any tile: S_ij = Q_i @ K_j^T. FlashAttention exploits this—instead of computing all of S, storing it, then using it, you compute one tile, use it immediately, discard it, and move on.

Why the KV cache MUST exist:

The KV cache contains the actual K and V vectors computed during prefill or accumulated during decode. You cannot "recompute K_j on demand" without re-running the entire model on the original input tokens—defeating the purpose of caching.

Analogy:

  • KV cache = A book you're referencing (must exist on your shelf)
  • Attention matrix = Notes you derive from the book (can re-derive anytime by looking at the book again)

FlashAttention says: "Don't write out all your notes at once—just look at the book page by page and write your final summary directly." But the book itself must still exist.

What FlashAttention Helps With

Problem FlashAttention helps? Explanation
O(N²) attention matrix storage ✅ Yes Never materializes full matrix
O(N²) attention matrix bandwidth ✅ Yes Tiles stay in SRAM
KV cache storage in HBM ❌ No Full cache must still exist in HBM
KV cache bandwidth during attention ✅ Partially Streams blocks efficiently, L2 cache helps

When KV Cache Becomes the Bottleneck

For very long sequences, even with FlashAttention, KV cache storage can exhaust GPU memory:

KV Cache Size = 2 × layers × heads × seq_len × head_dim × bytes_per_element

Example: Llama 2 70B with 100K context
- 2 × 80 layers × 8 KV heads × 100,000 × 128 × 2 bytes
- = ~32 GB just for KV cache (per request!)

Solutions for KV Cache Growth

Since FlashAttention doesn't reduce KV cache storage, other techniques are needed:

1. Reduce KV cache size per token:

Technique How it helps Reduction
GQA/MQA Share K/V heads across query heads 4-8×
KV cache quantization INT8/INT4 instead of FP16 2-4×
Low-rank compression Compress K/V representations Variable

2. Reduce tokens stored:

Technique How it helps
Sliding window attention Only cache last W tokens (Mistral)
Eviction policies Drop low-attention tokens (H2O, Scissorhands)
Landmark attention Keep important "landmark" tokens

3. Memory management (same size, better utilization):

Technique How it helps
PagedAttention (vLLM) Virtual memory for KV cache—no fragmentation, efficient batching
Chunked prefill Interleave prefill/decode to balance memory

4. Distribute across devices:

Technique How it helps
Ring Attention Shard KV cache across GPUs, stream blocks
Tensor parallelism Split heads across GPUs
Offloading Spill KV cache to CPU/NVMe (slow but works)

The practical stack for production 100K+ context:

GQA (smaller cache per token)
  + PagedAttention (efficient memory management)
  + FlashAttention (efficient attention compute)
  + KV cache quantization (optional, 2× more capacity)

KV Cache Construction During Decode

A key distinction: FlashAttention's "blocks" are a compute-time concept, not a storage concept. The KV cache is appended one token at a time, not in blocks.

How KV Cache Grows During Decode

Decode step 101:
  1. New token embedding → passes through layers
  2. At each attention layer:
     - Compute k_new, v_new for this ONE token  [1, d]
     - Append to cache: K_cache[101] = k_new, V_cache[101] = v_new
     - Now cache has 102 entries [0:101]
  3. Compute attention over entire cache (FlashAttention streams in blocks)
  4. Continue to next layer...

The Separation of Concerns

Concern Granularity What handles it
KV cache storage/append Token-by-token Memory management (continuous or paged)
KV cache reading for attention Block-by-block FlashAttention

Visualization

Decode step: generating token 257

KV Cache state in HBM:
┌─────────────────────────────────────────────────────────────┐
│ K: [k₀, k₁, k₂, ..., k₂₅₅, k₂₅₆]  ← 257 vectors            │
│ V: [v₀, v₁, v₂, ..., v₂₅₅, v₂₅₆]                           │
└─────────────────────────────────────────────────────────────┘
                                      ↑
                              Just appended k₂₅₆, v₂₅₆

FlashAttention reads this as blocks (B=128):
┌──────────────┬──────────────┬──────────┐
│ Block 0      │ Block 1      │ Block 2  │
│ k₀...k₁₂₇   │ k₁₂₈...k₂₅₅ │ k₂₅₆     │ (partial block)
└──────────────┴──────────────┴──────────┘

Why Not Append in Blocks?

You generate one token at a time. You can't wait for 128 tokens before appending—each output depends on all previous tokens. The block structure is purely for efficient reading, not writing.

PagedAttention Nuance

With PagedAttention (vLLM), the KV cache IS stored in fixed-size "pages" (blocks), but:

  • You still append one token at a time
  • Pages are just the memory allocation unit
  • A page fills up gradually (token by token) until full, then a new page is allocated

This solves memory fragmentation, not the append granularity.


Summary: What Tiling Achieves

The tiling strategy gives us three major benefits:

1. O(N) Memory Instead of O(N²)

By never materializing the full attention matrix, we only need storage proportional to sequence length, not its square. This is what enables 100K+ context windows.

2. Reduced HBM Traffic

By keeping intermediate computations in SRAM, we avoid the O(N²) reads and writes of the attention matrix. The traffic reduction factor grows with sequence length.

3. Better Cache Utilization

The K and V re-reads are sequential and can benefit from L2 cache. The computation pattern has good locality.

The Remaining Challenge:

The tiling strategy requires computing softmax incrementally. Each tile only sees a portion of the attention scores, but softmax needs global information (the sum over all scores). How do we handle this?

The answer is the online softmax algorithm, which we'll cover in the next artifact. It's the mathematical trick that makes the whole approach work.


Check Your Understanding

Before moving to online softmax:

1. If SRAM is 192KB and we need to store Q_block, K_block, V_block, S_block, and O_block (all in FP16) with d=128, what's the maximum block size B we can use? (Assume B_q = B_k = B)

Answer: We need: Q_block (B×128×2) + K_block (B×128×2) + V_block (B×128×2) + S_block (B×B×2) + O_block (B×128×2) + stats (~1KB)

= 4×B×128×2 + 2B² = 1024B + 2B² bytes

For 192KB = 196,608 bytes: 2B² + 1024B ≤ 196,608

Solving: B ≈ 185. In practice, B=128 or B=256 (powers of 2) are used. B=128 uses ~161KB, B=256 would exceed capacity.

2. For N=8192 and B=128, how many times is each K block loaded from HBM? How does this compare to standard attention's single load of K?

Answer: T_q = 8192/128 = 64 query blocks. Each K block is loaded 64 times (once per query block iteration).

Standard attention loads K once. But standard attention also writes/reads the 8192×8192 attention matrix (~134MB per head), which dwarfs the K reload cost. The trade-off favors FlashAttention.

3. Why can't we use very small blocks (e.g., B=16) to minimize re-reads of K and V? What's the tradeoff?

Answer: Smaller blocks mean more iterations (T_q = N/B increases), so K/V get reloaded more times. Also, very small blocks underutilize GPU parallelism—matrix multiplications are most efficient with larger tiles. There's also fixed overhead per block (kernel launch, memory management). The sweet spot balances SRAM capacity, compute efficiency, and reload cost.

4. In causal attention, approximately what fraction of the blocks can be skipped entirely? (Hint: think about the triangular structure)

Answer: Roughly half. The causal mask is lower-triangular, so approximately half the blocks (upper triangle) are entirely masked and can be skipped. This gives ~2× speedup for causal vs bidirectional attention.

5. If FlashAttention has O(N) memory complexity, why do we still see "out of memory" errors with very long sequences? What else consumes memory? (Hint: think about what else a transformer has besides attention)

Answer: The KV cache grows as O(N × layers × heads × d)—linear in N but multiplied by model depth. Model weights are constant but large. Activations for other layers (FFN, LayerNorm) also consume memory. For very long sequences, KV cache alone can exhaust GPU memory even without the N² attention bottleneck.

6. Why are K and V blocks reloaded multiple times while Q blocks are loaded only once?

Answer: Each output row i depends on Q_i attending to ALL keys/values: O_i = f(Q_i, K_all, V_all). Query blocks are independent—Q_0 doesn't need Q_1 to compute its output. But every query block needs every K/V block. The outer loop iterates over Q blocks, inner loop streams through K/V blocks. This asymmetry means K/V are reloaded T_q times while each Q block is loaded once.

7. When is the output written to HBM? Is each tile's output written, or only the final result?

Answer: Only the final result per query block. The accumulator O_i stays in SRAM throughout the entire inner loop, being updated as each K/V tile is processed. Only after ALL T_k key/value blocks are processed does O_i get written to HBM at rows [i×B_q : (i+1)×B_q]. Tile outputs are never written—they're accumulated in-place.

8. Is FlashAttention used in decode phase? When is it most beneficial?

Answer: FlashAttention is most beneficial during prefill where you have N queries × N keys creating the N² bottleneck. In single-token decode, attention is only [1, N]—no N² problem exists, and decode is memory-bound anyway (dominated by KV cache reads). FlashAttention is still used in decode (unified codepath) but provides minimal speedup. Batched decode benefits more—B concurrent requests means B queries, increasing arithmetic intensity.

9. What's the difference between reloading "the K matrix" vs "a K block"?

Answer: The entire K matrix [N, d] is never in SRAM at once. Each inner loop iteration loads one block K_j of size [B_k, d], uses it, then discards it before loading the next. The "reload cost" is the sum of many small [B_k, d] transfers, not repeated [N, d] transfers. This is why the computation fits in limited SRAM.

10. Why must the KV cache be fully stored in HBM, while the attention matrix doesn't need to be?

Answer: The attention matrix S = QK^T is a computed intermediate—you can recompute any tile on demand from Q and K. FlashAttention exploits this by computing tiles, using them, and discarding them. The KV cache, however, is source data—the actual K and V vectors from previous tokens. You cannot "recompute" the KV cache without re-running the entire model on the original inputs, which defeats the purpose of caching.

11. Does FlashAttention help with KV cache memory problems? What solutions exist for KV cache growth?

Answer: FlashAttention does NOT reduce KV cache storage—the full cache must still exist in HBM. It only helps with efficient reading (streaming in blocks). For KV cache growth, use: (1) Size reduction: GQA/MQA, KV quantization; (2) Token reduction: sliding window, eviction policies; (3) Memory management: PagedAttention; (4) Distribution: Ring Attention, tensor parallelism. Production systems typically combine GQA + PagedAttention + FlashAttention.

12. During decode, is the KV cache appended in blocks or token-by-token? How does FlashAttention's blocking relate to this?

Answer: The KV cache is appended token-by-token. Each decode step computes one k, v pair and appends it to the cache. FlashAttention's "blocks" are a compute-time concept for efficiently reading the cache, not a storage concept. You can't wait for 128 tokens before appending because each output depends on all previous tokens. With PagedAttention, pages are the memory allocation unit, but you still append one token at a time within each page.

13. If FlashAttention streams K/V in blocks rather than loading the entire cache, isn't it helping with KV cache memory?

Answer: It helps with bandwidth (efficient reading), not storage. The full KV cache must still exist somewhere in HBM—FlashAttention just doesn't need it all in SRAM simultaneously. Think of it like reading a book: you can read one page at a time (block streaming), but the whole book must still exist on your shelf (HBM). FlashAttention prevents needing a giant notepad (N×N attention matrix), but doesn't shrink the book itself.

Community

Sign up or log in to comment