RTPurbo Stage-1 indexer β€” d_idx=32, Qwen3.5-9B, full 32-131K context

Trained by distilling the full self-attention's top-p=0.9 token set into a 32-dim retrieval index via forward KL (paper Eq.6), using a hybrid SVD initialisation that seeds rows 0-15 from a prior d_idx=16 model and rows 16-31 from SVD principal components of Vq[16:32].

Configuration

  • base model: Qwen/Qwen3.5-9B (32 layers; 24 GDN + 8 GA at L3/7/11/15/19/23/27/31)
  • GQA 16:4, head_dim=256, max_pos=262144
  • seq_len=131072, chunked_teacher (cq=4096), teacher_query_sample=32768 (25%)
  • teacher_query_min_pos=8192, reuse_k=5
  • lr=1e-3 cosine, warmup=100, wd=0.01, batch=1, bf16
  • init_from=hybrid_svd_dim32 (rows 0-15 from d_idx=16 ckpt, rows 16-31 SVD of Vq)
  • steps=600, elapsed=341.4 min
  • data: emozilla/pg19 train split, min_chars=120000 (docs roughly 30K-131K tokens)

Final probe metrics (real source-code text @ seq=131072, block=64)

top_p teacher_tok teacher_blk indexer_blk idx_recall
0.50 0.001 0.009 0.032 0.487
0.70 0.003 0.023 0.075 0.700
0.80 0.007 0.041 0.117 0.784
0.90 0.018 0.083 0.200 0.874
0.95 0.037 0.140 0.288 0.918
  • final_mean_kl = 1.7587
  • final_max_kl = 4.4431

Compute reasoning (why d_idx=32, not 16)

Paper RTPurbo used head_dim=128 β†’ d_idx=16 (1/8 ratio). Qwen3.5-9B has head_dim=256, so the paper-equivalent compression is 256/8 = d_idx=32. The earlier d_idx=16 run on this model was under-provisioned (equivalent to paper's d_idx=8). Probe confirms recall is sufficient (87% at top_p=0.9, 92% at top_p=0.95), so the bottleneck for tightening the sparsity gap is the training recipe (teacher-query-sample 25% vs paper's full-query), not d_idx.

Files

  • indexer_final.pt β€” clean substate, 256 tensors of shape [32, 256] bf16
  • state_dict.pt β€” same content, raw state_dict format
  • stage1_summary.jsonβ€” full per-head KL breakdown for all 8 GA layers
  • loss_curve.csv β€” step,loss,lr,distill_kl,agree_kl,entropy,tok/s

Loading

import torch
from train.surgeries._rtpurbo_indexer import RetrievalIndexer

state = torch.load("indexer_final.pt", map_location="cpu")
# state is a flat dict of 256 tensors: q_heads.<layer_idx>.weight
# pass to RetrievalIndexer(d_idx=32, head_dim=256).load_state_dict(...)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support