File size: 10,479 Bytes
c7f839a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14f6276
 
 
 
 
c7f839a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
---
license: apache-2.0
language:
  - en
tags:
  - jax
  - flax
  - language-model
  - text-generation
  - retrieval-augmented
  - custom-architecture
  - research
library_name: flax
pipeline_tag: text-generation
model_type: dpsnr
datasets:
  - fineweb
metrics:
  - perplexity
inference: false
widget:
  - text: "The future of artificial intelligence"
    example_title: "AI Future"
  - text: "Once upon a time in a land"
    example_title: "Story"
  - text: "The key insight of this paper is"
    example_title: "Research"
model-index:
  - name: DPSNR-Large
    results: []
---

# DPSNR β€” Dynamic Parameter Selection Network with Reasoning

> **A JAX/Flax language model that separates *what it knows* from *how it thinks* β€” so the knowledge can grow to 100B+ vectors while inference stays fast and cheap.**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VM64IOZHj5rDvxWPbqktC037LlyOJih3?usp=sharing)

> [!WARNING]
> **Disclaimer**: This repository and checkpoint are provided as a **research proof-of-concept to demonstrate the novel DPSNR architecture**. It is an experimental model trained on a limited compute budget (for ~31,000 steps) to validate theoretical claims (such as $O(1)$ retrieval scaling, Sparse Adam optimizer speedups, and memory-bandwidth properties). **It is NOT a fully-trained competitive model** and is not intended to compete with state-of-the-art open-source text models (like LLaMA or Mistral) on downstream benchmarks.

---

## What Is DPSNR?

Normal large language models (GPT, Llama, etc.) mix logic and facts together inside the same transformer weights. When you want more knowledge, you need more parameters, which means more GPU VRAM, more compute, more cost β€” the **VRAM Wall**.

DPSNR breaks that wall. It splits the model into two parts:

| Part | Role | Size |
|------|------|------|
| **TinyController** | Does the thinking / reasoning | ~350M params on GPU |
| **CoordinateMassivePool** | Stores world knowledge as vectors | 262K–1T+ vectors, can live on disk |

The controller *queries* the pool each reasoning step instead of storing facts in its weights. Pool size can grow arbitrarily; inference cost stays **O(1)**.

---

## Architecture Overview

![DPSNR Architecture](assets/architecture.png)

The model has **4 components** that work together:

```mermaid
flowchart LR
    Input["πŸ—’οΈ Input Tokens"] --> TC

    subgraph TC["β‘  TinyController"]
        direction TB
        E["Token + Position\nEmbedding"] --> TL["12Γ— Transformer\nLayers (768-dim)"] --> H["Hidden States\n(B, T, 768)"]
    end

    TC --> LI

    subgraph LI["β‘‘ LearnedIndexer"]
        direction TB
        AP["Attention Pooling\n(learn which token to query from)"] --> MH["Multi-Head Dense\n→ μ coordinate\n→ σ bandwidth"]
    end

    LI -->|"ΞΌ, Οƒ"| Pool

    subgraph Pool["β‘’ CoordinateMassivePool"]
        direction TB
        PS["262,144 Γ— 768\nlearned vectors"] --> GW["Gaussian window\naround ΞΌ Β± K vectors\nweighted by Οƒ"] --> AV["Aggregated\nKnowledge Vector\n(B, 768)"]
    end

    Pool --> ACC

    subgraph ACC["β‘£ Adaptive Compute Controller"]
        direction TB
        RI["Integrate knowledge\ninto hidden state"] --> HN["Halt Network\n(should we stop?)"]
        HN -->|"halt < 0.99"| RI
    end

    ACC -->|"Final hidden state"| Out["πŸ“ Output Logits\n(B, T, vocab)"]
    ACC -->|"loop back\n(up to 6 times)"| TC
```

---

## How the Reasoning Loop Works

Instead of doing one pass like most LLMs, DPSNR thinks iteratively β€” like a human reading and re-reading a hard problem.

![Reasoning Loop](assets/reasoning_loop.png)

Each loop:
1. **TinyController** encodes the input β†’ produces a hidden state
2. **LearnedIndexer** converts the hidden state into a *coordinate* (ΞΌ) and *uncertainty* (Οƒ)
3. **CoordinateMassivePool** retrieves K=32 knowledge vectors near ΞΌ, weighted by a Gaussian of width Οƒ
4. Retrieved knowledge is fused into the hidden state
5. **ACC** decides: confident enough? β†’ output. Unsure? β†’ loop again

Simple questions finish in 1–2 loops. Hard questions use all 6. Compute is spent where it's needed.

---

## Breaking the VRAM Wall

![VRAM Comparison](assets/vram_comparison.png)

A 70B dense model requires 80GB+ of expensive HBM VRAM just to load. Because DPSNR stores knowledge as a flat array of vectors (not entangled with transformer weights), the pool can live in:

- **System RAM** β€” 64GB RAM holds ~130M vectors Γ— 768-dim at float32
- **NVMe SSD** β€” mmap'd; only the retrieved window is paged in
- **GPU VRAM** β€” only the TinyController (~1.3GB at bf16) needs the GPU

```
Dense 70B:  [GPU|β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“ 80GB VRAM β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“]
DPSNR:      [GPU|β–“ 4GB]  +  [RAM|β–“β–“β–“β–“β–“β–“ Pool β–“β–“β–“β–“β–“β–“]  ← no problem
```

---

## Quick Start (Inference)

### 1. Activate the virtualenv

```bash
cd /path/to/dpsn
source .venv/bin/activate
```

### 2. Verify GPU is available

```bash
python -c "import jax; print(jax.devices())"
# β†’ [CudaDevice(id=0)]
```

### 3. Run inference

```bash
# Single prompt
python infer.py --prompt "The future of artificial intelligence"

# Interactive chat mode
python infer.py

# All options
python infer.py \
    --prompt    "Once upon a time" \
    --max_tokens 200 \
    --temp       0.8 \
    --top_k      50 \
    --penalty    1.3
```

The first run takes ~20–30s to JIT-compile the forward pass. Subsequent prompts in the same session are fast.

---

## Inference Script β€” `infer.py`

The file `infer.py` is **fully self-contained** β€” it has the entire model architecture, checkpoint loading, and generation logic in one file. No dependency on the `dpsn_r_jax` package.

```
infer.py
β”œβ”€β”€ DPSNRConfig          ← Large model config, hardcoded
β”œβ”€β”€ FlashCausalSelfAttention
β”œβ”€β”€ TinyFFN / TinyTransformerLayer
β”œβ”€β”€ TinyController       ← 12-layer transformer encoder + LM head
β”œβ”€β”€ LearnedIndexer       ← ΞΌ, Οƒ coordinate predictor
β”œβ”€β”€ CoordinateMassivePool  ← 1D flat pool (used by large config)
β”œβ”€β”€ CoordinateMassivePool2D ← 2D grid pool (use_2d_pool=True)
β”œβ”€β”€ AdaptiveComputeController ← halt/loop decision
β”œβ”€β”€ DPSNR                ← full forward pass, reasoning scan
β”œβ”€β”€ TrainState           ← pytree-compatible state for orbax restore
β”œβ”€β”€ load_checkpoint()    ← restores params only (no optimizer bloat)
β”œβ”€β”€ _forward()           ← @jax.jit compiled forward pass
└── generate()           ← autoregressive sampling, fixed-size buffers
```

### CLI arguments

| Argument | Default | Description |
|---|---|---|
| `--prompt` | None | Text prompt. Omit to enter interactive mode |
| `--max_tokens` | 100 | Maximum new tokens to generate |
| `--temp` | 0.7 | Sampling temperature. Lower = more focused |
| `--top_k` | 40 | Only sample from top-K most likely tokens |
| `--penalty` | 1.2 | Repetition penalty. >1 discourages repeats |
| `--checkpoint_dir` | `./checkpoints_dir` | Override checkpoint path |

---

## Model Configuration (Large)

The `large` config is hardcoded in `infer.py`:

```python
DPSNRConfig(
    vocab_size            = 50257,   # GPT-Neo tokenizer vocab
    controller_hidden_dim = 768,     # transformer width
    controller_num_layers = 12,      # transformer depth
    controller_num_heads  = 12,      # attention heads
    max_seq_len           = 1024,    # max context window
    pool_total_vectors    = 262144,  # 2^18 knowledge vectors
    pool_hidden_dim       = 768,     # vector dimension
    max_reasoning_loops   = 6,       # max iterations of the loop
)
```

### Model size breakdown

```mermaid
pie title DPSNR Large β€” Parameter Distribution (~350M total)
    "CoordinateMassivePool (262K Γ— 768)" : 201
    "TinyController (12L Γ— 768d)" : 85
    "LearnedIndexer" : 3
    "AdaptiveComputeController" : 2
    "Retrieval Integrator" : 9
```

---

## Tokenizer

Uses **`EleutherAI/gpt-neo-125M`** tokenizer β€” GPT-2 compatible BPE with 50,257 tokens. Downloaded automatically via HuggingFace on first use.

---


## Key Ideas Explained Simply

### Why the pool doesn't slow things down

Every retrieval fetches exactly `K=32` vectors regardless of pool size. Going from 10K to 100B pool vectors doesn't add a single FLOP β€” only the storage grows.

```mermaid
xychart-beta
    title "Inference Latency vs Pool Size"
    x-axis ["10K vectors", "100K", "262K", "1M", "1B", "100B"]
    y-axis "Relative Latency" 0 --> 2
    line [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
```

### Why Gaussian retrieval is better than nearest-neighbour

Nearest-neighbour lookup (like a typical vector database) must search the entire pool. DPSNR uses a **coordinate** approach: the pool is arranged in a continuous 1D (or 2D grid) space. The indexer predicts a *position* ΞΌ and *width* Οƒ, and we simply slice a window. No search required β€” it's a direct lookup with `jax.lax.dynamic_slice`.

### Why Οƒ matters

- **Small Οƒ** β†’ sharp, precise retrieval (good for exact facts, code syntax)
- **Large Οƒ** β†’ broad, averaged retrieval (good for general context)

Οƒ is learned per token, per reasoning step β€” the model naturally figures out how precise to be.

---

## Performance

| Metric | Value | Notes |
|---|---|---|
| Training platform | TPU v5e-8 | 8-chip pod slice |
| Throughput | **240–250K tokens/sec** | HBM bandwidth bound |
| Sustained compute | **260–270 TFLOPS** | Below 393 TFLOPS peak |
| Bottleneck | Memory bandwidth | Pool gather ops, not MXU |
| Optimizer speedup vs dense | **590Γ—** | Sparse Adam on retrieved indices only |
| Checkpoint step | 31,000 | |
| GPU VRAM (inference) | ~1.3GB (params only, bf16) | Pool can live off-device |
| Inference tested on | NVIDIA RTX 2050 (4GB) | Consumer GPU confirmed |

---

## Dependencies

```
jax + jaxlib   ← Core ML framework (GPU/TPU backend)
flax           ← Neural network layers and module API
optax          ← Optimizers (used for checkpoint structure only)
orbax          ← Checkpoint save/restore
transformers   ← Tokenizer (HuggingFace)
```

Install:
```bash
pip install jax jaxlib flax optax orbax-checkpoint transformers
```

---

## Citation / Reference

```
DPSNR: Disaggregated Parameter Selection Network with Reasoning
Architecture: TinyController + CoordinateMassivePool + LearnedIndexer + ACC
Implementation: JAX/Flax
Checkpoint: step 31,000
```