CoDA-GQA-L: Bounded-Memory Differential Attention with Value-Routed Landmark Banks

Community Article Published February 16, 2026

TL;DR

CoDA-GQA-L replaces the O(L) KV cache in transformer attention with a fixed-size three-segment buffer -- a recent window, an exact landmark bank, and an EMA summary bank -- bounding per-layer memory to O(W + Me + Ms) regardless of sequence length. At 70B scale with 128K context, this reduces KV cache from 160 GB to 136 MB (1,176x). The mechanism uses differential attention via orthogonal query rotation (no second W_q), value-routed semantic matching to sidestep RoPE position dependence, and a two-phase training protocol to bridge the gap between unbounded training and bounded inference.

fig_graphical_abstract

Author's Note: I'm looking for a role where I can work on things like this full time. I want to work in AI on things that matter, I do not have a fancy pedigree from Stanford or FAANG - my public work is my best job application as traditional routes have proven unproductive. If you are doing things no one else has done with AI, reach out - anthony@making-minds.ai and https://linkedin.com/in/anthony-maio


1. The problem: KV cache memory wall

Every transformer layer caches K and V vectors for every token it has ever seen. The arithmetic is unforgiving.

For a 70B-class model (8 KV heads, 128 head dim, 80 layers) serving 128K context in bf16:

2 (K+V) * 8 (KV heads) * 128,000 (tokens) * 128 (head dim) * 2 (bytes/bf16) * 80 (layers)
= 167,772,160,000 bytes
~ 160 GB

That is 160 GB of memory for KV cache alone, for a single sequence. The model weights themselves occupy roughly the same. You cannot serve long-context 70B models on commodity hardware without addressing this.

Existing solutions each sacrifice something important:

Approach What It Keeps What It Loses
Sliding window (Mistral) Last W tokens exactly Everything beyond the window, hard boundary
StreamingLLM (Xiao et al., 2023) Attention sinks + recent No selective retention of important mid-context tokens
H2O (Zhang et al., 2023) Heavy-hitter tokens by attention score Greedy eviction, no semantic clustering, no learned gating
Scissorhands (Liu et al., 2023) Pivotal tokens No learned write gate, no compressed summaries
InfLLM (2024) Everything (offloaded) Still O(L) total storage, block-level retrieval latency

CoDA-GQA-L takes a different approach: keep a small, fixed-size cache and make every slot count through learned gating, semantic routing, and dual-bank compression.


2. Architecture overview

The KV buffer for each layer has shape (B, H_kv, W + Me + Ms, D_h) and is divided into three contiguous segments:

Buffer layout:

  slot index:  [0 ........... W-1]  [W ...... W+Me-1]  [W+Me ... W+Me+Ms-1]
                  Recent Window        Exact Bank         Summary Bank
                  (ring buffer)     (novelty-filtered    (EMA prototypes)
                  exact FIFO           LRU cache)

  Total slots: W + Me + Ms  (fixed, independent of sequence length L)
``

![fig_bounded_memory](https://cdn-uploads.huggingface.co/production/uploads/64ac64386e9a4384cf6fa5b2/f-oE2LvO4UBt6UznycEO5.jpeg)

`

**Recent window (W slots)**: A ring buffer holding the exact K/V of the most recent W tokens. Pure FIFO -- the oldest token gets evicted when a new one arrives and the window is full.

**Exact landmark bank (Me slots)**: A novelty-filtered LRU cache. When a token is evicted from the recent window, it is compared to existing bank entries using value-space cosine similarity. If it is semantically novel (below a novelty threshold), it is inserted at full fidelity. If the bank is full, it replaces the least-recently-used entry. This bank is designed for needle retention -- preserving rare, important tokens that appear once and must not be lost.

**Summary landmark bank (Ms slots)**: An EMA prototype bank. Evicted tokens are routed to their best-matching slot and blended via exponential moving average. Each slot becomes a compressed centroid representing a cluster of semantically similar past tokens. This bank provides background context compression.

Both banks are gated: a learned write gate `g = sigmoid(W_write * x + b)` decides whether an evicted token is important enough to update long-term memory at all.

**Memory bound**: Per layer, the cache size is `2 * H_kv * (W + Me + Ms) * D_h * sizeof(dtype)`, completely independent of sequence length.

---

## 3. Deep dive: constrained orthogonal differential attention (CoDA)

### The idea

Standard differential attention (Ye et al., 2024) sharpens focus by computing two attention distributions and subtracting them:

out = Attn(q1, K, V) - lambda * Attn(q2, K, V)


`q1` attends to signal, `q2` attends to noise (common-mode patterns), and subtraction cancels the noise floor, sharpening the attention distribution. The catch is that two separate query projections (`W_q1`, `W_q2`) double the query parameter count.

### CoDA: one projection, orthogonal rotation

CoDA achieves the same signal/noise decomposition with a single query projection plus a per-head orthogonal rotation:

```python
# Single query projection
q = W_q @ x                              # (B, H, L, D_h)

# Signal query: apply RoPE for positional encoding
q_signal = RoPE(q, position)              # (B, H, L, D_h)

# Noise query: rotate the signal query orthogonally
q_noise = R(theta) @ q_signal             # (B, H, L, D_h)

# Per-token cancellation gate
lambda_ = sigmoid(W_lambda @ x + b)      # (B, H, L, 1)

# Differential output with normalization
out = HeadwiseRMSNorm(Attn(q_signal, K, V) - lambda_ * Attn(q_noise, K, V))

How R(theta) works

R(theta) is a block-diagonal orthogonal matrix that applies independent 2D rotations to each pair of adjacent feature dimensions:

For feature pair (x_even, x_odd) and learned angle theta_i:

  x'_even = x_even * cos(theta_i) - x_odd * sin(theta_i)
  x'_odd  = x_even * sin(theta_i) + x_odd * cos(theta_i)

Each head has D_h / 2 learned angles, giving H * D_h / 2 total parameters. This is the same mathematical structure as RoPE, but with learned angles instead of position-dependent frequencies.

Why orthogonal? An orthogonal rotation preserves the norm of the query vector (||R(theta) * q|| = ||q||), which means the noise stream has the same energy as the signal stream. The only difference is the direction of attention -- the noise query looks at a different part of the key space, determined by the learned angles.

Initialization strategy

The initialization is deliberately chosen so that CoDA starts near-transparent and gradually learns to differentiate:

  • theta = pi/2: At initialization, the noise query is exactly orthogonal to the signal query. This maximizes decorrelation between the two streams from the start.
  • lambda_bias = -6.0: sigmoid(-6) = 0.0025. The cancellation gate starts near zero, meaning the output is almost entirely the signal stream. The model begins as near-standard attention and learns to subtract noise only where it helps.

This initialization means you can cold-swap CoDA into a pretrained model and get roughly the same output (though not exactly -- the HeadwiseRMSNorm still reshapes activations, which is why fine-tuning is required; see Section 7).

Single SDPA call via head-stacking

Rather than making two separate SDPA calls (which doubles kernel launch overhead), CoDA stacks the signal and noise queries along the head dimension:

q_cat = torch.cat([q_signal, q_noise], dim=1)   # (B, 2*H, L, D_h)
k_rep = torch.cat([k_gqa, k_gqa], dim=1)         # (B, 2*H, L_k, D_h)
v_rep = torch.cat([v_gqa, v_gqa], dim=1)         # (B, 2*H, L_k, D_h)

out_cat = F.scaled_dot_product_attention(q_cat, k_rep, v_rep, ...)

# Split and subtract
out_signal = out_cat[:, :H, :, :]
out_noise  = out_cat[:, H:, :, :]
out = HeadwiseRMSNorm(out_signal - lambda_ * out_noise)

This is not a truly fused differential kernel (the compute is still 2x), but it reduces kernel launches from 2 to 1 and improves memory locality. The GQA head expansion (repeat_kv) is done before stacking so that each signal head and its noise counterpart attend to the same KV head.

HeadwiseRMSNorm

After subtraction, activations can have unusual magnitudes (near-zero when lambda is small, doubled when it is large). HeadwiseRMSNorm normalizes each (batch, head, token) independently across the head dimension with a learned scale:

class HeadwiseRMSNorm(nn.Module):
    def forward(self, x):  # x: (B, H, L, D_h)
        var = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(var + self.eps)
        return x * self.weight  # weight: (D_h,) learned per-head scale

4. Deep dive: value-routing (V-routing)

The RoPE problem

CoDA-GQA-L stores keys with RoPE already applied ("RoPE-at-write"). This is a performance optimization: during decode, you avoid re-rotating the entire KV cache at every step (an O(L_buf) operation per layer per token). The key is rotated once at write time and stored in its final form.

But RoPE makes key vectors position-dependent by design. The rotary embedding applies position-dependent 2D rotations to pairs of key dimensions:

k_rope[pos] = RoPE(k_raw, pos)

Two identical tokens at positions 100 and 5000 will have near-orthogonal key vectors because the high-frequency RoPE components rotate rapidly with position. Their cosine similarity in key space will be low, even though they encode the same content.

If you use key-space similarity to route memory banks, the exact bank fills with duplicates -- the same semantic token at different positions appears "novel" because its key looks different each time. The summary bank EMA blending produces garbage because you are averaging vectors that are rotated to incompatible phases.

Why routing on values fixes this

Values have no RoPE applied. They encode pure semantic content, position-invariant. So we compute all memory bank routing similarity using values instead of keys.

# V-routing for exact bank
v_normalized = F.normalize(v_evicted, dim=-1, eps=1e-6)
mem_normalized = state._exact_v_norm  # cached normalized bank values
similarity = einsum("bhd,bhmd->bhm", v_normalized, mem_normalized)
similarity = similarity.mean(dim=1)  # average across KV heads

This makes:

  • Exact bank deduplication correct: The same word at position 100 and position 5000 produces near-identical values, so the bank correctly recognizes it as a hit rather than a novel token.
  • Summary bank EMA coherent: Blending values that represent the same semantic concept produces a meaningful centroid. There is no destructive phase interference.

The implementation maintains normalized value caches (state._exact_v_norm) that are updated incrementally on writes, making the cosine computation efficient.

Summary bank: LF-K routing and phase-safe EMA

The summary bank uses a more nuanced routing strategy. While exact bank routing uses values directly, the summary bank routes using the low-frequency band of keys (LF-K routing).

RoPE dimensions are ordered by frequency. The first half of the head dimension contains high-frequency (HF) pairs that rotate rapidly with position. The second half contains low-frequency (LF) pairs that are nearly position-invariant -- they change so slowly that tokens thousands of positions apart still have similar LF components.

# LF-K routing for summary bank
lf_start = head_dim // 2
k_lf_normalized = F.normalize(k_evicted[..., lf_start:], dim=-1, eps=1e-6)
mem_lf_normalized = state._sum_lf_k_norm  # cached LF key norms
similarity = matmul(k_lf_normalized, mem_lf_normalized.transpose(-1, -2))

Phase-Safe EMA: When blending keys via EMA, only the LF band is updated. The HF band is kept at zero. Blending HF key components from different positions would cause destructive interference -- imagine averaging sin(100 * freq) and sin(5000 * freq) -- you get noise, not a meaningful representation. By zeroing the HF band and only blending LF, summary bank keys remain coherent. An energy calibration factor sqrt(D_h / D_h_lf) compensates for the reduced dimensionality so that attention scores are properly scaled.

# Phase-Safe EMA: only blend LF key band
delta_k = torch.zeros_like(mem_k)
delta_k[..., lf_start:] = incoming_lf * energy_scale - mem_k[..., lf_start:]
mem_k = mem_k + eta_eff * delta_k
# HF band stays zero -- no destructive interference

# Values: full EMA (no RoPE, no phase issue)
mem_v = mem_v + eta_eff * (v_evict - mem_v)

5. Deep dive: dense packing for FlashAttention

The mask problem

PyTorch's scaled_dot_product_attention selects its backend kernel based on the arguments you pass. FlashAttention and the memory-efficient backend are fast but have restrictions -- notably, they do not support arbitrary boolean attention masks. When you pass a custom boolean mask (like causal_mask & allowed_mask), SDPA falls back to the Math backend, which is significantly slower.

CoDA-GQA-L needs an allowed mask because not all buffer slots are valid. Early in inference, most exact and summary bank slots are empty. Without masking, attention would attend to uninitialized memory.

Dense packing for B=1

For the common single-sequence inference case (B=1), the implementation gathers only valid slots into a contiguous tensor before calling SDPA:

# Decode path (attend_step)
if B == 1 and not allowed.all():
    valid_idx = allowed[0].nonzero(as_tuple=True)[0]
    k_attend = state.k_buf[:, :, valid_idx, :]  # only valid slots
    v_attend = state.v_buf[:, :, valid_idx, :]
    attn_mask = None  # no mask needed -- all slots are valid
    # -> SDPA can select FlashAttention or MemEfficient backend

# Prefill path (prefill_chunked)
if B == 1 and not allowed_prev.all():
    valid_idx = allowed_prev[0].nonzero(as_tuple=True)[0]
    k_prefix = k_prev[:, :, valid_idx, :]
    v_prefix = v_prev[:, :, valid_idx, :]
    k_all = torch.cat([k_prefix, k_blk], dim=2)
    # Use explicit causal mask (no allowed mask needed)

For batched inference (B > 1), different sequences may have different valid slots, so the implementation falls back to an explicit mask:

else:
    attn_mask = allowed[:, None, None, :]  # (B, 1, 1, L_buf)

This is a pure inference optimization. The semantics are identical -- dense packing just eliminates the boolean mask so that faster SDPA backends can be selected.


6. Memory bank update mechanics

Write gate

Every token gets a write gate score at projection time:

g = sigmoid(W_write @ x + b_write)   # scalar per token

The gate is stored in the recent window alongside K/V. When a token is evicted, its stored gate determines whether and how strongly it updates the memory banks:

  • Exact bank: Token is considered only if g >= threshold_exact (default 0.10)
  • Summary bank: Token is considered only if g >= threshold_summary (default 0.05), and the EMA learning rate is scaled by g

Gate initialization (b_write = -2.0, W_write = 0) starts with sigmoid(-2) = 0.12, putting most tokens just above the exact threshold. The gate learns during fine-tuning which tokens are worth remembering.

Exact bank: novelty-filtered LRU

When a token is evicted from the recent window with sufficient gate value:

  1. Compute cosine similarity of its value vector against all existing bank entries (V-routing)
  2. Average similarity across KV heads
  3. Classify:
    • Novel (best similarity < 0.70): This is new information. Insert into the first free slot, or if full, replace the LRU entry.
    • Hit (best similarity >= 0.90): This token is already represented in the bank. Update the LRU timestamp (optionally refresh the stored K/V).
    • Between thresholds: Neither novel enough to insert nor similar enough to refresh. Ignored.

Summary bank: EMA prototypes

When a token is evicted with sufficient gate value:

  1. Compute cosine similarity of its low-frequency key band against all summary slots (LF-K routing)
  2. Route to the best-matching slot
  3. Blend via EMA:
eta_eff = sigmoid(eta_logit) * gate
mem_v[slot] += eta_eff * (v_evict - mem_v[slot])
mem_k[slot, lf_band] += eta_eff * (k_evict[lf_band] * energy_scale - mem_k[slot, lf_band])

For the first write to an empty slot, eta = 1.0 (the slot takes on the token's value directly -- fast warmup). After that, eta is typically 0.05-0.27 (sigmoid(-3) to sigmoid(-1)), producing slow blending that builds stable prototypes.

Vectorized block updates (prefill path)

During prefill, blocks of tokens are evicted at once. The implementation avoids Python loops entirely:

  1. Select top-T candidates by write gate score (torch.topk)
  2. Compute all pairwise similarities in one batched matmul
  3. Assign tokens to slots using argmax / cumulative novelty ranking
  4. Resolve collisions with winner-take-all per slot: when multiple tokens map to the same slot, the one with the highest gate score wins. Tie-breaking uses float32 arithmetic to survive bf16 precision loss.
  5. Apply updates via scatter_reduce and scatter_add -- one kernel call per operation, no iteration
# Winner-take-all: float32 for tie-breaking precision
score_tok = gate_sel.float() + torch.arange(T) * 1e-6  # tiny tiebreak
max_score.scatter_reduce_(1, idx_tok, score_tok, reduce="amax")
winner = overwrite_tok & (score_tok == max_score.gather(1, idx_tok))

7. Deep dive: two-phase training

Why two phases are necessary

Training with unbounded attention and then evaluating with bounded memory is catastrophic. The numbers speak for themselves:

Model Unbounded PPL Bounded Cold-Swap PPL Ratio
Mistral 7B (Phase 1 trained) 5.62 2,464 438x worse
SmolLM2-135M (Phase 1 only) 22.0 31+ 41%+ gap

The root cause: the memory banks have untrained parameters. The write gate, EMA learning rate, and routing decisions are all at their initialization values. The model has never experienced context loss during training, so it has no mechanism to compensate.

For Mistral 7B: training used full 2048-token context, but evaluation used W=256, meaning 87.5% of context was lost. The per-layer quality degradation compounds across 32 layers, producing nonsensical output.

Phase 1: unbounded (learn differential attention)

Phase 1 uses CoDAGQA with standard O(L) KV cache:

  • Goal: Teach the model the signal/noise decomposition -- learn theta (rotation angles), lambda (cancellation gate), and HeadwiseRMSNorm parameters
  • Duration: ~2000 steps
  • Learning rate: 1e-3 for CoDA parameters (theta, lambda_proj, head_norm), 5e-5 for projection weights (q_proj, k_proj, v_proj, o_proj)
  • Gradient checkpointing: Enabled (memory efficient, safe with standard attention)

Typical dynamics on SmolLM2-135M: PPL drops from ~70 to ~22 at ~36K tokens/second throughput.

Phase 2: bounded (learn memory banks)

Phase 2 switches to CoDAGQALandmarkPerf2 with fixed-size KV:

  • Goal: Train the write gate, EMA eta, and routing to compensate for bounded memory
  • Duration: ~2000 steps
  • Learning rate: 0.5x Phase 1 rates (fine-tuning, not relearning)
  • Key setting: detach_evicted=False -- gradients flow through bank updates so the write gate and EMA parameters receive gradient signal
  • Key setting: block_size=128 with window=256 -- the block size is smaller than the window, forcing evictions during every prefill block. This ensures the memory banks are exercised during training.
  • Gradient checkpointing: Disabled (required when detach_evicted=False -- checkpointing replays forward passes, which would trigger in-place mutations on tensors saved for backward, causing version conflicts)

Typical dynamics on SmolLM2-135M:

  • Bounded PPL starts at ~35.75 (immediate improvement over cold-swap)
  • Drops to ~31.12 over ~1200 steps, then plateaus
  • Throughput: ~1.8K tokens/second (20x slower than Phase 1 due to bank update overhead and gradient flow)

fig_bounded_memory

The remaining gap

After two-phase training on SmolLM2-135M:

  • Unbounded: PPL 22.0
  • Bounded: PPL 31.12
  • Gap: 41.5%

This gap reflects the fundamental information loss from 5.3x context compression (W + Me + Ms << L). The gap is expected to narrow with larger models and longer training, but bounded attention is not free -- it is a quality/memory tradeoff.


8. Benchmark results

All benchmarks on H200 NVL, bf16, PyTorch 2.x.

Memory savings

Scale Config KV Cache/Layer vs Baseline
70B (8 KV heads, 128 dim, 80 layers) Baseline @ 128K 32.0 MB 1x
70B Medium-cache (W=512, Me=128, Ms=128) 1.7 MB 18.8x
70B Total across 80 layers @ 128K
Baseline 160 GB
Medium-cache 136 MB 1,176x
7B (8 KV heads, 128 dim, 32 layers) Baseline @ 128K 64 GB total 1x
7B Medium-cache 54 MB total 1,185x

fig_compression

Throughput (70B-scale config, tokens/second)

Config Prefill 2K Prefill 8K Decode Peak VRAM
Baseline GQA 1,336,349 966,464 4,676 1.1 GB
CoDA unbounded 889,203 598,314 2,914 1.6 GB
Medium-cache (bounded) 149,832 153,716 1,753 568.6 MB
Window-only (no banks) 359,417 356,026 1,773 546.0 MB

What the numbers show

  1. Bounded prefill throughput is constant: ~150K tok/s regardless of sequence length (2K and 8K are nearly identical). Baseline drops 28% from 2K to 8K due to O(L^2) attention.

  2. Differential overhead: CoDA-unbounded runs at 62-67% of baseline throughput due to 2x SDPA cost. This is the price of differential attention without a fused Triton kernel.

  3. Bank update overhead dominates bounded cost: Medium-cache (with banks) runs at 11-15% of baseline, while window-only (no banks) runs at 27-31%. The memory bank routing, similarity computation, and scatter updates consume significant time.

  4. VRAM savings are real: Bounded configs use 1.9x less peak VRAM than baseline.

  5. Decode throughput: Bounded decode (1,753 tok/s) is 37% of baseline (4,676 tok/s). In a production setting, the memory savings enable serving more concurrent sequences, which can offset the per-sequence throughput loss.

Weight transfer validation

To verify that the base attention implementation is correct before adding CoDA:

  • BaselineGQA swap on SmolLM2-135M: 100% top-1 logit agreement, mean logit diff 0.156 (bf16 rounding), PPL match 12.31 vs 12.31.
  • Bounded correctness (W >= L, no evictions): max diff 1.8e-7 in fp32. When no evictions occur, bounded reduces to standard causal attention exactly.
  • CoDA cold-swap (no training): 0% agreement, PPL 3,371,482. This is expected -- the HeadwiseRMSNorm and differential subtraction reshape the activation manifold, confirming that fine-tuning is not optional.

9. Implementation bugs worth documenting

Bug 1: RoPE position overflow

Symptom: Perplexity evaluation on WikiText-2 returned PPL of 2,624 for the unbounded model (expected ~5-6).

Root cause: The evaluation loop processed multiple chunks of text, and the RoPE position counter accumulated across chunks. By the final chunks, positions reached ~50,000 -- far beyond the 8,192 max position the model was trained on. RoPE at extreme positions produces near-random rotations, destroying attention patterns.

Fix: Reset the state (and position counter) at each evaluation chunk boundary.

Result: PPL dropped from 2,624 to 87 (still high due to other issues at the time, but the position overflow was the dominant factor).

The broader takeaway: stateful KV caches require careful position management. In standard transformers, position is implicit in the KV cache length. In bounded systems, position is an explicit counter that must be managed across document boundaries.

Bug 2: in-place mutation with gradient flow

Symptom: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation during Phase 2 training.

Root cause: When detach_evicted=False, gradients flow through the bank update path. The bank update modifies state.k_buf and state.v_buf in-place (via slice assignment), which invalidates tensors that autograd saved during the forward pass for the backward pass.

Fix: Three-layer clone defense:

  1. Clone state.k_buf and state.v_buf at the start of _write_block_fast when detach_evicted=False
  2. Clone bank views (mem_k, mem_v) inside the vectorized update methods to decouple them from the base buffer
  3. Allocate fresh scratch buffers every call (no reuse) when gradients flow, because zero_() + scatter_add_() on graph-connected tensors increments version counters
# From _write_block_fast in memory_banks.py:
if not self.detach_evicted:
    state.k_buf = state.k_buf.clone()
    state.v_buf = state.v_buf.clone()
    state.g_recent = state.g_recent.clone()

Result: Training runs without version conflicts. The cloning overhead is acceptable because Phase 2 is already 20x slower than Phase 1.

Bug 3: gradient checkpointing incompatibility

Symptom: Silent numerical divergence during Phase 2 with gradient checkpointing enabled.

Root cause: Gradient checkpointing replays the forward pass during backward. But the bank update path contains in-place state mutations. On replay, these mutations execute a second time, corrupting the state.

Fix: Auto-disable gradient checkpointing when entering Phase 2 (detach_evicted=False). The implementation detects this configuration and raises an error if checkpointing is manually forced.


10. Application: stateful neural databases

fig_neural_db

The bounded state is a fixed-size, serializable tensor bundle. Because its size is constant regardless of how many tokens produced it, you can save, load, and swap these states like database entries -- a pattern we call stateful neural databases:

import torch
from coda_gqa_l import CoDAGQALandmarkPerf2

model = CoDAGQALandmarkPerf2(
    embed_dim=4096, num_heads=32, num_kv_heads=8,
    window=256, num_landmarks_exact=64, num_landmarks_summary=64,
)

# Ingest a document
state = model.init_state(batch_size=1, device="cuda", dtype=torch.bfloat16)
_, state = model.prefill_chunked(document_embeddings, state, block_size=256)

# Save the compressed state (fixed size regardless of document length)
torch.save(state, "document_42_state.pt")

# Later: load and query
state = torch.load("document_42_state.pt")
answer, state = model.step(query_embedding, state)

At 7B scale, each state is ~54 MB. A hundred documents cost 5.4 GB of state storage -- compared to 640 GB for 100 full 128K-token KV caches.

Agentic RAG pattern:

  1. Ingest: Process documents through the model, saving bounded states
  2. Route: Use a lightweight classifier to select relevant document states
  3. Load: Load the selected state from disk/cache
  4. Generate: Continue generation from the loaded state

The state is a lossy compression of the full document context, but it retains exact needle tokens (via the exact bank) and semantic summaries (via the summary bank).


11. Quick start

Installation

pip install coda-gqa-l

Requires PyTorch >= 2.0. CUDA with bf16 recommended; falls back to fp32 on CPU.

Basic usage

import torch
from coda_gqa_l import CoDAGQALandmarkPerf2, CoDAGQALandmarkStatePerf2

# Create the attention layer
attn = CoDAGQALandmarkPerf2(
    embed_dim=512,
    num_heads=8,
    num_kv_heads=2,         # GQA: 4 query heads per KV head
    window=256,
    num_landmarks_exact=64,
    num_landmarks_summary=64,
)

# Initialize state
state = attn.init_state(batch_size=1, device="cuda", dtype=torch.bfloat16)

# Prefill a prompt
x_prompt = torch.randn(1, 1024, 512, device="cuda", dtype=torch.bfloat16)
y_prompt, state = attn.prefill_chunked(x_prompt, state, block_size=256)

# Decode tokens autoregressively
x_token = torch.randn(1, 1, 512, device="cuda", dtype=torch.bfloat16)
y_token, state = attn.step(x_token, state)

# Check cache size
print(f"Cache: {attn.cache_bytes(batch_size=1, dtype=torch.bfloat16) / 1024:.1f} KB")

Drop-in adapter for Eve-2

from coda_gqa_l import EveCoDAAdapter

# Replace Eve-2 attention in-place
adapter = EveCoDAAdapter.from_eve_attention(eve_block.attn, bounded=True)
eve_block.attn = adapter  # Block.forward(x, freqs_cis) works unchanged

Metrics

attn = CoDAGQALandmarkPerf2(..., collect_metrics=True)
state = attn.init_state(...)

# After processing...
print(state.metrics)
# {'exact_hits': 42, 'exact_inserts': 18, 'exact_overwrites': 18,
#  'exact_fill_ratio': 0.28, 'summary_updates': 156, 'summary_inserts': 64,
#  'summary_fill_ratio': 1.0, 'tokens_gated_out': 89, 'total_evictions': 768}

Benchmarks

# Quick single-config timing
python benchmarks/bench.py

# Full 5-config suite with JSON output
python benchmarks/run_suite.py
python benchmarks/render_tables.py

# Eve-2 integration
python benchmarks/eval_eve.py --experiment all

12. What's next

The biggest bottleneck right now is the fused Triton kernel. The 2x differential SDPA overhead dominates the profile. A fused kernel that computes Attn(q_sig, K, V) - lambda * Attn(q_noise, K, V) in a single pass over K/V would eliminate redundant memory reads and roughly halve attention FLOPS. This is where most of the engineering effort is going next.

After that, 7B+ training. Current results are on SmolLM2-135M. We need to see whether the PPL gap narrows on Mistral 7B or Llama 3 8B (we expect it will -- larger models are more robust to lossy compression, and 135M is a harsh testbed).

Quantized KV storage is straightforward but not done yet. The bounded buffer is a natural target for per-channel INT4/INT8 quantization. At 4-bit, the 54 MB state at 7B scale would shrink to ~13.5 MB.

Distributed cache sharding is needed for tensor-parallel serving. The three-segment layout maps naturally to this: each device holds its own KV heads' segments.

One speculative idea: state composition. Can you merge two document states to create a joint context? If the summary banks use compatible routing, EMA-merging two states might produce a valid combined representation. We haven't tested this yet but the architecture doesn't prevent it.


13. Citation

@software{coda_gqa_l_2026,
  title  = {CoDA-GQA-L: Bounded-Memory Differential Attention
            with Value-Routed Landmark Banks},
  url    = {https://github.com/coda-gqa-l/coda-gqa-l},
  year   = {2026},
}

Key equations reference

Differential attention:

q_sig   = RoPE(W_q @ x, pos)
q_noise = R(theta) @ q_sig
out     = HeadwiseRMSNorm(SDPA(q_sig, K, V) - sigmoid(W_lambda @ x) * SDPA(q_noise, K, V))

Memory bound per layer:

cache_size = 2 * H_kv * (W + Me + Ms) * D_h * sizeof(dtype)

EMA update (summary bank):

eta_eff     = sigmoid(eta_logit) * gate
mem_v[slot] += eta_eff * (v_evict - mem_v[slot])

V-routing similarity:

sim = cosine(v_evicted, v_bank_slot)

Phase-Safe EMA (summary bank keys):

delta_k[HF_band] = 0                                                  # no blending
delta_k[LF_band] = k_evict[LF_band] * sqrt(D_h / D_lf) - mem_k[LF_band]
mem_k[slot]      += eta_eff * delta_k

Pairwise orthogonal rotation (noise query):

x'_even = x_even * cos(theta_i) - x_odd * sin(theta_i)
x'_odd  = x_even * sin(theta_i) + x_odd * cos(theta_i)

Paper: zenodo Code: github.com/anthony-maio/CoDA-GQA-L Contact: anthony@making-minds.ai

Community

Sign up or log in to comment