Mistral-7B-v0.3 + CoDA-GQA-L
Mistral 7B with standard attention replaced by CoDA-GQA-L (Constrained Orthogonal Differential Attention with Value-Routed Landmark Banks).
The model uses a fixed-size three-segment KV cache instead of the standard O(L) cache:
| Segment | Size | Function |
|---|---|---|
| Recent window | W=256 | Ring buffer of latest tokens |
| Exact landmark bank | Me=64 | Novelty-filtered LRU of important tokens |
| Summary landmark bank | Ms=64 | EMA prototypes compressing older context |
Total: 384 slots/layer (54 MB across 32 layers) regardless of sequence length. Standard Mistral at 4096 tokens uses 512 MB. At 128K tokens, the standard cache grows to 64 GB while CoDA stays at 54 MB (1,185x compression).
Quick start
Install
pip install coda-gqa-l transformers accelerate
Bounded mode (constant-memory inference)
Bounded mode requires a manual generation loop because CoDA manages its own KV state internally. HF's model.generate() cannot drive this.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from coda_gqa_l import LlamaCoDAAdapter
# 1. Load base model + tokenizer
model = AutoModelForCausalLM.from_pretrained(
'mistralai/Mistral-7B-v0.3',
torch_dtype=torch.bfloat16,
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.3')
# 2. Swap attention layers to bounded CoDA and load trained weights
adapter_path = hf_hub_download(
'anthonym21/Mistral-7B-v0.3-CoDA-GQA-L', 'coda_adapters.pt'
)
adapters = torch.load(adapter_path, map_location='cpu', weights_only=True)
for i, layer in enumerate(model.model.layers):
device = next(layer.parameters()).device
adapter = LlamaCoDAAdapter.from_llama_attention(
layer.self_attn,
bounded=True,
head_norm_mode='identity',
rope_interleaved=False,
)
adapter.load_state_dict(adapters[f'layer_{i}'], strict=False)
adapter = adapter.to(device=device, dtype=torch.bfloat16)
layer.self_attn = adapter
# 3. CRITICAL: call eval() AFTER installing adapters
# New modules default to training=True, which uses a stateless
# code path (fresh empty state every call). eval() switches to
# the persistent stateful path needed for generation.
model.eval()
# 4. Manual generation loop
prompt = 'The future of AI is'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
temperature = 0.7
generated = input_ids[0].tolist()
with torch.no_grad():
# Prefill: full prompt in one pass (adapters use prefill_chunked)
outputs = model(input_ids=input_ids, use_cache=False)
logits = outputs.logits[:, -1, :]
next_token = torch.multinomial(
torch.softmax(logits / temperature, dim=-1), 1
)
generated.append(next_token.item())
# Decode: one token at a time (adapters use step())
for _ in range(199):
if next_token.item() == tokenizer.eos_token_id:
break
outputs = model(input_ids=next_token, use_cache=False)
logits = outputs.logits[:, -1, :]
next_token = torch.multinomial(
torch.softmax(logits / temperature, dim=-1), 1
)
generated.append(next_token.item())
print(tokenizer.decode(generated, skip_special_tokens=True))
Unbounded mode (standard causal attention with differential attention)
For standard generation without memory banks, use bounded=False with HF's model.generate():
for i, layer in enumerate(model.model.layers):
device = next(layer.parameters()).device
adapter = LlamaCoDAAdapter.from_llama_attention(
layer.self_attn,
bounded=False, # <-- unbounded
head_norm_mode='identity',
rope_interleaved=False,
)
adapter.load_state_dict(adapters[f'layer_{i}'], strict=False)
adapter = adapter.to(device=device, dtype=torch.bfloat16)
layer.self_attn = adapter
model.eval()
# HF generate() works in unbounded mode
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=200, use_cache=False)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Common pitfalls
model.eval() must come AFTER adapter installation. New PyTorch modules default to training=True. The bounded forward path branches on self.training: in training mode it allocates a fresh empty state every call (for gradient checkpointing safety), so decode tokens have zero context. This produces deterministic garbage. Always call model.eval() after the adapter swap loop.
Use the base Mistral tokenizer. The tokenizer config in this repo has a broken tokenizer_class field. Load the tokenizer from mistralai/Mistral-7B-v0.3 instead.
Always pass use_cache=False. CoDA manages its own KV cache internally. HF's cache system conflicts with it.
Use strict=False when loading weights. Bounded adapters have extra parameters (write_proj, summary_eta_logit) not present in the unbounded-trained checkpoint. These use their default initialization, which works well for inference.
How bounded mode works
During generation, each LlamaCoDAAdapter manages an internal state machine:
Prefill (prompt processing): The full prompt passes through the model. Each adapter receives
hidden_stateswithL > 1tokens and routes throughprefill_chunked(), which processes the prompt in blocks and populates the bounded KV buffer.Decode (token generation): Each new token passes through the model individually. Adapters receive
L == 1and route throughstep(), which:- Writes the new token into the recent window ring buffer
- If the ring buffer is full, evicts the oldest token
- Evicted tokens pass through a write gate
- Tokens above threshold are routed to memory banks via value cosine similarity
- The exact bank stores novel tokens (LRU eviction when full)
- The summary bank blends similar tokens via EMA
The KV buffer layout is [recent W | exact Me | summary Ms] = 384 slots, constant regardless of how many tokens have been generated.
Architecture
CoDA-GQA-L replaces standard attention with constrained orthogonal differential attention:
x -> q_proj -> q_signal -> RoPE(q, pos) -> q_roped
\-> R(theta) -> q_noise
SDPA(q_roped, K_buf, V_buf) -> out_signal
SDPA(q_noise, K_buf, V_buf) -> out_noise
output = RMSNorm(out_signal - lambda * out_noise)
The noise query is produced by rotating the signal query through learnable orthogonal angles (no second Wq projection). Lambda is a learned per-token gate initialized near zero (sigmoid(-6) ~ 0.0025) so the model starts as near-standard attention and the differential mechanism activates gradually during training.
Value-routing is a design decision worth explaining: memory banks match tokens by cosine similarity on Values, not Keys. Keys have RoPE rotation applied, so identical tokens at different positions have different key vectors. Values are RoPE-free, making their similarity position-invariant -- the right property for deduplication (exact bank) and clustering (summary bank).
Training details
| Phase 1 (unbounded) | Phase 2 (bounded) | |
|---|---|---|
| Attention | CoDAGQA (full KV) | CoDAGQALandmarkPerf2 (384 slots) |
| Steps | 2,000 | 2,000 |
| Dataset | WikiText-103 | WikiText-103 |
| Sequence length | 2,048 | 2,048 |
| Learning rate (projections) | 5e-5 | 2.5e-5 |
| Learning rate (CoDA params) | 1e-3 | 5e-4 |
| Batch size | 1 x 8 grad accum | 1 x 8 grad accum |
| Trainable params | ~1.3B / 7.2B (attention only) | ~1.3B / 7.2B |
| Best unbounded PPL | 5.94 | -- |
| Gradient checkpointing | Yes | No (incompatible with grad-through-banks) |
| detach_evicted | N/A | False (gradients flow through bank updates) |
Phase 1 teaches the differential attention mechanism with full context available. Phase 2 adapts the model to work with bounded memory by training the write gate and bank parameters. Both phases freeze all non-attention parameters (MLP, embeddings, layer norms).
Memory budget
| Configuration | Per layer | 32 layers total | vs unbounded at 4K |
|---|---|---|---|
| medium-cache (default) | 1.7 MB | 54 MB | 9.5x smaller |
| tiny-cache (W=128, Me=32, Ms=32) | 865 KB | 27 MB | 19x smaller |
| window-only (W=256, Me=0, Ms=0) | 1.0 MB | 32 MB | 16x smaller |
At 128K context, the savings reach 1,185x (54 MB vs 64 GB).
Benchmark numbers (H200, bf16)
From the paper, single-layer throughput at 7B scale:
| Config | Prefill L=2048 | Prefill L=8192 | Per-layer KV |
|---|---|---|---|
| Baseline GQA | 3,096K tok/s | 2,286K tok/s | 32.0 MB |
| CoDA unbounded | 1,852K tok/s | 1,283K tok/s | 32.0 MB |
| CoDA medium-cache | 160K tok/s | 158K tok/s | 1.7 MB |
| CoDA window-only | 392K tok/s | 397K tok/s | 1.0 MB |
Bounded throughput is flat across sequence lengths (bank updates operate on fixed-size buffers). The 2x SDPA cost from differential attention is the constant overhead; bank updates account for the remaining gap.
Stateful Neural Database pattern
The bounded state is a fixed-size serializable artifact. You can ingest a document once, save the compressed state, and query it later without re-processing:
# Ingest: process document into bounded state
for layer in model.model.layers:
layer.self_attn.reset_state()
with torch.no_grad():
model(input_ids=document_tokens, use_cache=False)
# Save all layer states (54 MB total, constant regardless of doc length)
states = {}
for i, layer in enumerate(model.model.layers):
states[i] = layer.self_attn.get_state()
torch.save(states, "document_state.pt")
# Later: load and query without re-reading the document
states = torch.load("document_state.pt")
for i, layer in enumerate(model.model.layers):
layer.self_attn.set_state(states[i])
with torch.no_grad():
outputs = model(input_ids=question_tokens, use_cache=False)
# Decode answer from outputs.logits
100 documents at 7B = 5.4 GB of state files. Each query is a decode-phase forward pass with sub-second latency.
Files in this repo
coda_adapters.pt-- trained CoDA adapter weights for all 32 layersconfig.json,generation_config.json-- Mistral model configsmodel-00001-of-00003.safetensorsetc. -- base Mistral weights (identical tomistralai/Mistral-7B-v0.3)tokenizer.model,tokenizer.json,tokenizer_config.json-- tokenizer files (note:tokenizer_config.jsonhas a broken tokenizer_class; use the base Mistral tokenizer instead)special_tokens_map.json-- special token mappings
Requirements
- PyTorch >= 2.0 (2.5+ recommended for FlashAttention with causal_lower_right)
- CUDA GPU with bf16 support
- ~15 GB VRAM for bf16 inference on single GPU, or ~24 GB across 2 GPUs with device_map='auto'
Links
- Code: github.com/anthony-maio/CoDA-GQA-L
- Package:
pip install coda-gqa-l - Paper: CoDA-GQA-L: Bounded-Memory Differential Attention with Value-Routed Landmark Banks (Maio, 2026)
Citation
@article{maio2026coda,
title={CoDA-GQA-L: Bounded-Memory Differential Attention with Value-Routed Landmark Banks},
author={Maio, Anthony},
year={2026}
}
- Downloads last month
- 52
Model tree for anthonym21/Mistral-7B-v0.3-CoDA-GQA-L
Base model
mistralai/Mistral-7B-v0.3