Mistral-7B-v0.3 + CoDA-GQA-L

Mistral 7B with standard attention replaced by CoDA-GQA-L (Constrained Orthogonal Differential Attention with Value-Routed Landmark Banks).

The model uses a fixed-size three-segment KV cache instead of the standard O(L) cache:

Segment Size Function
Recent window W=256 Ring buffer of latest tokens
Exact landmark bank Me=64 Novelty-filtered LRU of important tokens
Summary landmark bank Ms=64 EMA prototypes compressing older context

Total: 384 slots/layer (54 MB across 32 layers) regardless of sequence length. Standard Mistral at 4096 tokens uses 512 MB. At 128K tokens, the standard cache grows to 64 GB while CoDA stays at 54 MB (1,185x compression).

Quick start

Install

pip install coda-gqa-l transformers accelerate

Bounded mode (constant-memory inference)

Bounded mode requires a manual generation loop because CoDA manages its own KV state internally. HF's model.generate() cannot drive this.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from coda_gqa_l import LlamaCoDAAdapter

# 1. Load base model + tokenizer
model = AutoModelForCausalLM.from_pretrained(
    'mistralai/Mistral-7B-v0.3',
    torch_dtype=torch.bfloat16,
    device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.3')

# 2. Swap attention layers to bounded CoDA and load trained weights
adapter_path = hf_hub_download(
    'anthonym21/Mistral-7B-v0.3-CoDA-GQA-L', 'coda_adapters.pt'
)
adapters = torch.load(adapter_path, map_location='cpu', weights_only=True)

for i, layer in enumerate(model.model.layers):
    device = next(layer.parameters()).device
    adapter = LlamaCoDAAdapter.from_llama_attention(
        layer.self_attn,
        bounded=True,
        head_norm_mode='identity',
        rope_interleaved=False,
    )
    adapter.load_state_dict(adapters[f'layer_{i}'], strict=False)
    adapter = adapter.to(device=device, dtype=torch.bfloat16)
    layer.self_attn = adapter

# 3. CRITICAL: call eval() AFTER installing adapters
#    New modules default to training=True, which uses a stateless
#    code path (fresh empty state every call). eval() switches to
#    the persistent stateful path needed for generation.
model.eval()

# 4. Manual generation loop
prompt = 'The future of AI is'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
temperature = 0.7
generated = input_ids[0].tolist()

with torch.no_grad():
    # Prefill: full prompt in one pass (adapters use prefill_chunked)
    outputs = model(input_ids=input_ids, use_cache=False)
    logits = outputs.logits[:, -1, :]

    next_token = torch.multinomial(
        torch.softmax(logits / temperature, dim=-1), 1
    )
    generated.append(next_token.item())

    # Decode: one token at a time (adapters use step())
    for _ in range(199):
        if next_token.item() == tokenizer.eos_token_id:
            break
        outputs = model(input_ids=next_token, use_cache=False)
        logits = outputs.logits[:, -1, :]
        next_token = torch.multinomial(
            torch.softmax(logits / temperature, dim=-1), 1
        )
        generated.append(next_token.item())

print(tokenizer.decode(generated, skip_special_tokens=True))

Unbounded mode (standard causal attention with differential attention)

For standard generation without memory banks, use bounded=False with HF's model.generate():

for i, layer in enumerate(model.model.layers):
    device = next(layer.parameters()).device
    adapter = LlamaCoDAAdapter.from_llama_attention(
        layer.self_attn,
        bounded=False,  # <-- unbounded
        head_norm_mode='identity',
        rope_interleaved=False,
    )
    adapter.load_state_dict(adapters[f'layer_{i}'], strict=False)
    adapter = adapter.to(device=device, dtype=torch.bfloat16)
    layer.self_attn = adapter

model.eval()

# HF generate() works in unbounded mode
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=200, use_cache=False)
print(tokenizer.decode(output[0], skip_special_tokens=True))

Common pitfalls

model.eval() must come AFTER adapter installation. New PyTorch modules default to training=True. The bounded forward path branches on self.training: in training mode it allocates a fresh empty state every call (for gradient checkpointing safety), so decode tokens have zero context. This produces deterministic garbage. Always call model.eval() after the adapter swap loop.

Use the base Mistral tokenizer. The tokenizer config in this repo has a broken tokenizer_class field. Load the tokenizer from mistralai/Mistral-7B-v0.3 instead.

Always pass use_cache=False. CoDA manages its own KV cache internally. HF's cache system conflicts with it.

Use strict=False when loading weights. Bounded adapters have extra parameters (write_proj, summary_eta_logit) not present in the unbounded-trained checkpoint. These use their default initialization, which works well for inference.

How bounded mode works

During generation, each LlamaCoDAAdapter manages an internal state machine:

  1. Prefill (prompt processing): The full prompt passes through the model. Each adapter receives hidden_states with L > 1 tokens and routes through prefill_chunked(), which processes the prompt in blocks and populates the bounded KV buffer.

  2. Decode (token generation): Each new token passes through the model individually. Adapters receive L == 1 and route through step(), which:

    • Writes the new token into the recent window ring buffer
    • If the ring buffer is full, evicts the oldest token
    • Evicted tokens pass through a write gate
    • Tokens above threshold are routed to memory banks via value cosine similarity
    • The exact bank stores novel tokens (LRU eviction when full)
    • The summary bank blends similar tokens via EMA

The KV buffer layout is [recent W | exact Me | summary Ms] = 384 slots, constant regardless of how many tokens have been generated.

Architecture

CoDA-GQA-L replaces standard attention with constrained orthogonal differential attention:

x -> q_proj -> q_signal -> RoPE(q, pos) -> q_roped
                                        \-> R(theta) -> q_noise

SDPA(q_roped, K_buf, V_buf) -> out_signal
SDPA(q_noise, K_buf, V_buf) -> out_noise

output = RMSNorm(out_signal - lambda * out_noise)

The noise query is produced by rotating the signal query through learnable orthogonal angles (no second Wq projection). Lambda is a learned per-token gate initialized near zero (sigmoid(-6) ~ 0.0025) so the model starts as near-standard attention and the differential mechanism activates gradually during training.

Value-routing is a design decision worth explaining: memory banks match tokens by cosine similarity on Values, not Keys. Keys have RoPE rotation applied, so identical tokens at different positions have different key vectors. Values are RoPE-free, making their similarity position-invariant -- the right property for deduplication (exact bank) and clustering (summary bank).

Training details

Phase 1 (unbounded) Phase 2 (bounded)
Attention CoDAGQA (full KV) CoDAGQALandmarkPerf2 (384 slots)
Steps 2,000 2,000
Dataset WikiText-103 WikiText-103
Sequence length 2,048 2,048
Learning rate (projections) 5e-5 2.5e-5
Learning rate (CoDA params) 1e-3 5e-4
Batch size 1 x 8 grad accum 1 x 8 grad accum
Trainable params ~1.3B / 7.2B (attention only) ~1.3B / 7.2B
Best unbounded PPL 5.94 --
Gradient checkpointing Yes No (incompatible with grad-through-banks)
detach_evicted N/A False (gradients flow through bank updates)

Phase 1 teaches the differential attention mechanism with full context available. Phase 2 adapts the model to work with bounded memory by training the write gate and bank parameters. Both phases freeze all non-attention parameters (MLP, embeddings, layer norms).

Memory budget

Configuration Per layer 32 layers total vs unbounded at 4K
medium-cache (default) 1.7 MB 54 MB 9.5x smaller
tiny-cache (W=128, Me=32, Ms=32) 865 KB 27 MB 19x smaller
window-only (W=256, Me=0, Ms=0) 1.0 MB 32 MB 16x smaller

At 128K context, the savings reach 1,185x (54 MB vs 64 GB).

Benchmark numbers (H200, bf16)

From the paper, single-layer throughput at 7B scale:

Config Prefill L=2048 Prefill L=8192 Per-layer KV
Baseline GQA 3,096K tok/s 2,286K tok/s 32.0 MB
CoDA unbounded 1,852K tok/s 1,283K tok/s 32.0 MB
CoDA medium-cache 160K tok/s 158K tok/s 1.7 MB
CoDA window-only 392K tok/s 397K tok/s 1.0 MB

Bounded throughput is flat across sequence lengths (bank updates operate on fixed-size buffers). The 2x SDPA cost from differential attention is the constant overhead; bank updates account for the remaining gap.

Stateful Neural Database pattern

The bounded state is a fixed-size serializable artifact. You can ingest a document once, save the compressed state, and query it later without re-processing:

# Ingest: process document into bounded state
for layer in model.model.layers:
    layer.self_attn.reset_state()

with torch.no_grad():
    model(input_ids=document_tokens, use_cache=False)

# Save all layer states (54 MB total, constant regardless of doc length)
states = {}
for i, layer in enumerate(model.model.layers):
    states[i] = layer.self_attn.get_state()
torch.save(states, "document_state.pt")

# Later: load and query without re-reading the document
states = torch.load("document_state.pt")
for i, layer in enumerate(model.model.layers):
    layer.self_attn.set_state(states[i])

with torch.no_grad():
    outputs = model(input_ids=question_tokens, use_cache=False)
# Decode answer from outputs.logits

100 documents at 7B = 5.4 GB of state files. Each query is a decode-phase forward pass with sub-second latency.

Files in this repo

  • coda_adapters.pt -- trained CoDA adapter weights for all 32 layers
  • config.json, generation_config.json -- Mistral model configs
  • model-00001-of-00003.safetensors etc. -- base Mistral weights (identical to mistralai/Mistral-7B-v0.3)
  • tokenizer.model, tokenizer.json, tokenizer_config.json -- tokenizer files (note: tokenizer_config.json has a broken tokenizer_class; use the base Mistral tokenizer instead)
  • special_tokens_map.json -- special token mappings

Requirements

  • PyTorch >= 2.0 (2.5+ recommended for FlashAttention with causal_lower_right)
  • CUDA GPU with bf16 support
  • ~15 GB VRAM for bf16 inference on single GPU, or ~24 GB across 2 GPUs with device_map='auto'

Links

  • Code: github.com/anthony-maio/CoDA-GQA-L
  • Package: pip install coda-gqa-l
  • Paper: CoDA-GQA-L: Bounded-Memory Differential Attention with Value-Routed Landmark Banks (Maio, 2026)

Citation

@article{maio2026coda,
  title={CoDA-GQA-L: Bounded-Memory Differential Attention with Value-Routed Landmark Banks},
  author={Maio, Anthony},
  year={2026}
}
Downloads last month
52
Safetensors
Model size
7B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for anthonym21/Mistral-7B-v0.3-CoDA-GQA-L

Finetuned
(337)
this model