DPSNR β Dynamic Parameter Selection Network with Reasoning
A JAX/Flax language model that separates what it knows from how it thinks β so the knowledge can grow to 100B+ vectors while inference stays fast and cheap.
Disclaimer: This repository and checkpoint are provided as a research proof-of-concept to demonstrate the novel DPSNR architecture. It is an experimental model trained on a limited compute budget (for ~31,000 steps) to validate theoretical claims (such as $O(1)$ retrieval scaling, Sparse Adam optimizer speedups, and memory-bandwidth properties). It is NOT a fully-trained competitive model and is not intended to compete with state-of-the-art open-source text models (like LLaMA or Mistral) on downstream benchmarks.
What Is DPSNR?
Normal large language models (GPT, Llama, etc.) mix logic and facts together inside the same transformer weights. When you want more knowledge, you need more parameters, which means more GPU VRAM, more compute, more cost β the VRAM Wall.
DPSNR breaks that wall. It splits the model into two parts:
| Part | Role | Size |
|---|---|---|
| TinyController | Does the thinking / reasoning | ~350M params on GPU |
| CoordinateMassivePool | Stores world knowledge as vectors | 262Kβ1T+ vectors, can live on disk |
The controller queries the pool each reasoning step instead of storing facts in its weights. Pool size can grow arbitrarily; inference cost stays O(1).
Architecture Overview
The model has 4 components that work together:
flowchart LR
Input["ποΈ Input Tokens"] --> TC
subgraph TC["β TinyController"]
direction TB
E["Token + Position\nEmbedding"] --> TL["12Γ Transformer\nLayers (768-dim)"] --> H["Hidden States\n(B, T, 768)"]
end
TC --> LI
subgraph LI["β‘ LearnedIndexer"]
direction TB
AP["Attention Pooling\n(learn which token to query from)"] --> MH["Multi-Head Dense\nβ ΞΌ coordinate\nβ Ο bandwidth"]
end
LI -->|"ΞΌ, Ο"| Pool
subgraph Pool["β’ CoordinateMassivePool"]
direction TB
PS["262,144 Γ 768\nlearned vectors"] --> GW["Gaussian window\naround ΞΌ Β± K vectors\nweighted by Ο"] --> AV["Aggregated\nKnowledge Vector\n(B, 768)"]
end
Pool --> ACC
subgraph ACC["β£ Adaptive Compute Controller"]
direction TB
RI["Integrate knowledge\ninto hidden state"] --> HN["Halt Network\n(should we stop?)"]
HN -->|"halt < 0.99"| RI
end
ACC -->|"Final hidden state"| Out["π Output Logits\n(B, T, vocab)"]
ACC -->|"loop back\n(up to 6 times)"| TC
How the Reasoning Loop Works
Instead of doing one pass like most LLMs, DPSNR thinks iteratively β like a human reading and re-reading a hard problem.
Each loop:
- TinyController encodes the input β produces a hidden state
- LearnedIndexer converts the hidden state into a coordinate (ΞΌ) and uncertainty (Ο)
- CoordinateMassivePool retrieves K=32 knowledge vectors near ΞΌ, weighted by a Gaussian of width Ο
- Retrieved knowledge is fused into the hidden state
- ACC decides: confident enough? β output. Unsure? β loop again
Simple questions finish in 1β2 loops. Hard questions use all 6. Compute is spent where it's needed.
Breaking the VRAM Wall
A 70B dense model requires 80GB+ of expensive HBM VRAM just to load. Because DPSNR stores knowledge as a flat array of vectors (not entangled with transformer weights), the pool can live in:
- System RAM β 64GB RAM holds ~130M vectors Γ 768-dim at float32
- NVMe SSD β mmap'd; only the retrieved window is paged in
- GPU VRAM β only the TinyController (~1.3GB at bf16) needs the GPU
Dense 70B: [GPU|ββββββββββββββββ 80GB VRAM ββββββββββββββββ]
DPSNR: [GPU|β 4GB] + [RAM|ββββββ Pool ββββββ] β no problem
Quick Start (Inference)
1. Activate the virtualenv
cd /path/to/dpsn
source .venv/bin/activate
2. Verify GPU is available
python -c "import jax; print(jax.devices())"
# β [CudaDevice(id=0)]
3. Run inference
# Single prompt
python infer.py --prompt "The future of artificial intelligence"
# Interactive chat mode
python infer.py
# All options
python infer.py \
--prompt "Once upon a time" \
--max_tokens 200 \
--temp 0.8 \
--top_k 50 \
--penalty 1.3
The first run takes ~20β30s to JIT-compile the forward pass. Subsequent prompts in the same session are fast.
Inference Script β infer.py
The file infer.py is fully self-contained β it has the entire model architecture, checkpoint loading, and generation logic in one file. No dependency on the dpsn_r_jax package.
infer.py
βββ DPSNRConfig β Large model config, hardcoded
βββ FlashCausalSelfAttention
βββ TinyFFN / TinyTransformerLayer
βββ TinyController β 12-layer transformer encoder + LM head
βββ LearnedIndexer β ΞΌ, Ο coordinate predictor
βββ CoordinateMassivePool β 1D flat pool (used by large config)
βββ CoordinateMassivePool2D β 2D grid pool (use_2d_pool=True)
βββ AdaptiveComputeController β halt/loop decision
βββ DPSNR β full forward pass, reasoning scan
βββ TrainState β pytree-compatible state for orbax restore
βββ load_checkpoint() β restores params only (no optimizer bloat)
βββ _forward() β @jax.jit compiled forward pass
βββ generate() β autoregressive sampling, fixed-size buffers
CLI arguments
| Argument | Default | Description |
|---|---|---|
--prompt |
None | Text prompt. Omit to enter interactive mode |
--max_tokens |
100 | Maximum new tokens to generate |
--temp |
0.7 | Sampling temperature. Lower = more focused |
--top_k |
40 | Only sample from top-K most likely tokens |
--penalty |
1.2 | Repetition penalty. >1 discourages repeats |
--checkpoint_dir |
./checkpoints_dir |
Override checkpoint path |
Model Configuration (Large)
The large config is hardcoded in infer.py:
DPSNRConfig(
vocab_size = 50257, # GPT-Neo tokenizer vocab
controller_hidden_dim = 768, # transformer width
controller_num_layers = 12, # transformer depth
controller_num_heads = 12, # attention heads
max_seq_len = 1024, # max context window
pool_total_vectors = 262144, # 2^18 knowledge vectors
pool_hidden_dim = 768, # vector dimension
max_reasoning_loops = 6, # max iterations of the loop
)
Model size breakdown
pie title DPSNR Large β Parameter Distribution (~350M total)
"CoordinateMassivePool (262K Γ 768)" : 201
"TinyController (12L Γ 768d)" : 85
"LearnedIndexer" : 3
"AdaptiveComputeController" : 2
"Retrieval Integrator" : 9
Tokenizer
Uses EleutherAI/gpt-neo-125M tokenizer β GPT-2 compatible BPE with 50,257 tokens. Downloaded automatically via HuggingFace on first use.
Key Ideas Explained Simply
Why the pool doesn't slow things down
Every retrieval fetches exactly K=32 vectors regardless of pool size. Going from 10K to 100B pool vectors doesn't add a single FLOP β only the storage grows.
xychart-beta
title "Inference Latency vs Pool Size"
x-axis ["10K vectors", "100K", "262K", "1M", "1B", "100B"]
y-axis "Relative Latency" 0 --> 2
line [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Why Gaussian retrieval is better than nearest-neighbour
Nearest-neighbour lookup (like a typical vector database) must search the entire pool. DPSNR uses a coordinate approach: the pool is arranged in a continuous 1D (or 2D grid) space. The indexer predicts a position ΞΌ and width Ο, and we simply slice a window. No search required β it's a direct lookup with jax.lax.dynamic_slice.
Why Ο matters
- Small Ο β sharp, precise retrieval (good for exact facts, code syntax)
- Large Ο β broad, averaged retrieval (good for general context)
Ο is learned per token, per reasoning step β the model naturally figures out how precise to be.
Performance
| Metric | Value | Notes |
|---|---|---|
| Training platform | TPU v5e-8 | 8-chip pod slice |
| Throughput | 240β250K tokens/sec | HBM bandwidth bound |
| Sustained compute | 260β270 TFLOPS | Below 393 TFLOPS peak |
| Bottleneck | Memory bandwidth | Pool gather ops, not MXU |
| Optimizer speedup vs dense | 590Γ | Sparse Adam on retrieved indices only |
| Checkpoint step | 31,000 | |
| GPU VRAM (inference) | ~1.3GB (params only, bf16) | Pool can live off-device |
| Inference tested on | NVIDIA RTX 2050 (4GB) | Consumer GPU confirmed |
Dependencies
jax + jaxlib β Core ML framework (GPU/TPU backend)
flax β Neural network layers and module API
optax β Optimizers (used for checkpoint structure only)
orbax β Checkpoint save/restore
transformers β Tokenizer (HuggingFace)
Install:
pip install jax jaxlib flax optax orbax-checkpoint transformers
Citation / Reference
DPSNR: Disaggregated Parameter Selection Network with Reasoning
Architecture: TinyController + CoordinateMassivePool + LearnedIndexer + ACC
Implementation: JAX/Flax
Checkpoint: step 31,000


