RTPurbo Stage-1 indexer β d_idx=32, Qwen3.5-9B, full 32-131K context
Trained by distilling the full self-attention's top-p=0.9 token set into a 32-dim retrieval index via forward KL (paper Eq.6), using a hybrid SVD initialisation that seeds rows 0-15 from a prior d_idx=16 model and rows 16-31 from SVD principal components of Vq[16:32].
Configuration
- base model: Qwen/Qwen3.5-9B (32 layers; 24 GDN + 8 GA at L3/7/11/15/19/23/27/31)
- GQA 16:4, head_dim=256, max_pos=262144
- seq_len=131072, chunked_teacher (cq=4096), teacher_query_sample=32768 (25%)
- teacher_query_min_pos=8192, reuse_k=5
- lr=1e-3 cosine, warmup=100, wd=0.01, batch=1, bf16
- init_from=hybrid_svd_dim32 (rows 0-15 from d_idx=16 ckpt, rows 16-31 SVD of Vq)
- steps=600, elapsed=341.4 min
- data: emozilla/pg19 train split, min_chars=120000 (docs roughly 30K-131K tokens)
Final probe metrics (real source-code text @ seq=131072, block=64)
| top_p | teacher_tok | teacher_blk | indexer_blk | idx_recall |
|---|---|---|---|---|
| 0.50 | 0.001 | 0.009 | 0.032 | 0.487 |
| 0.70 | 0.003 | 0.023 | 0.075 | 0.700 |
| 0.80 | 0.007 | 0.041 | 0.117 | 0.784 |
| 0.90 | 0.018 | 0.083 | 0.200 | 0.874 |
| 0.95 | 0.037 | 0.140 | 0.288 | 0.918 |
- final_mean_kl = 1.7587
- final_max_kl = 4.4431
Compute reasoning (why d_idx=32, not 16)
Paper RTPurbo used head_dim=128 β d_idx=16 (1/8 ratio). Qwen3.5-9B has head_dim=256, so the paper-equivalent compression is 256/8 = d_idx=32. The earlier d_idx=16 run on this model was under-provisioned (equivalent to paper's d_idx=8). Probe confirms recall is sufficient (87% at top_p=0.9, 92% at top_p=0.95), so the bottleneck for tightening the sparsity gap is the training recipe (teacher-query-sample 25% vs paper's full-query), not d_idx.
Files
indexer_final.ptβ clean substate, 256 tensors of shape [32, 256] bf16state_dict.ptβ same content, raw state_dict formatstage1_summary.jsonβ full per-head KL breakdown for all 8 GA layersloss_curve.csvβ step,loss,lr,distill_kl,agree_kl,entropy,tok/s
Loading
import torch
from train.surgeries._rtpurbo_indexer import RetrievalIndexer
state = torch.load("indexer_final.pt", map_location="cpu")
# state is a flat dict of 256 tensors: q_heads.<layer_idx>.weight
# pass to RetrievalIndexer(d_idx=32, head_dim=256).load_state_dict(...)
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support