DPSN-R / README.md
the-dev-kumar's picture
Upload folder using huggingface_hub
c7f839a verified
|
raw
history blame
9.77 kB
metadata
license: apache-2.0
language:
  - en
tags:
  - jax
  - flax
  - language-model
  - text-generation
  - retrieval-augmented
  - custom-architecture
  - research
library_name: flax
pipeline_tag: text-generation
model_type: dpsnr
datasets:
  - fineweb
metrics:
  - perplexity
inference: false
widget:
  - text: The future of artificial intelligence
    example_title: AI Future
  - text: Once upon a time in a land
    example_title: Story
  - text: The key insight of this paper is
    example_title: Research
model-index:
  - name: DPSNR-Large
    results: []

DPSNR β€” Dynamic Parameter Selection Network with Reasoning

A JAX/Flax language model that separates what it knows from how it thinks β€” so the knowledge can grow to 100B+ vectors while inference stays fast and cheap.


What Is DPSNR?

Normal large language models (GPT, Llama, etc.) mix logic and facts together inside the same transformer weights. When you want more knowledge, you need more parameters, which means more GPU VRAM, more compute, more cost β€” the VRAM Wall.

DPSNR breaks that wall. It splits the model into two parts:

Part Role Size
TinyController Does the thinking / reasoning ~350M params on GPU
CoordinateMassivePool Stores world knowledge as vectors 262K–1T+ vectors, can live on disk

The controller queries the pool each reasoning step instead of storing facts in its weights. Pool size can grow arbitrarily; inference cost stays O(1).


Architecture Overview

DPSNR Architecture

The model has 4 components that work together:

flowchart LR
    Input["πŸ—’οΈ Input Tokens"] --> TC

    subgraph TC["β‘  TinyController"]
        direction TB
        E["Token + Position\nEmbedding"] --> TL["12Γ— Transformer\nLayers (768-dim)"] --> H["Hidden States\n(B, T, 768)"]
    end

    TC --> LI

    subgraph LI["β‘‘ LearnedIndexer"]
        direction TB
        AP["Attention Pooling\n(learn which token to query from)"] --> MH["Multi-Head Dense\n→ μ coordinate\n→ σ bandwidth"]
    end

    LI -->|"ΞΌ, Οƒ"| Pool

    subgraph Pool["β‘’ CoordinateMassivePool"]
        direction TB
        PS["262,144 Γ— 768\nlearned vectors"] --> GW["Gaussian window\naround ΞΌ Β± K vectors\nweighted by Οƒ"] --> AV["Aggregated\nKnowledge Vector\n(B, 768)"]
    end

    Pool --> ACC

    subgraph ACC["β‘£ Adaptive Compute Controller"]
        direction TB
        RI["Integrate knowledge\ninto hidden state"] --> HN["Halt Network\n(should we stop?)"]
        HN -->|"halt < 0.99"| RI
    end

    ACC -->|"Final hidden state"| Out["πŸ“ Output Logits\n(B, T, vocab)"]
    ACC -->|"loop back\n(up to 6 times)"| TC

How the Reasoning Loop Works

Instead of doing one pass like most LLMs, DPSNR thinks iteratively β€” like a human reading and re-reading a hard problem.

Reasoning Loop

Each loop:

  1. TinyController encodes the input β†’ produces a hidden state
  2. LearnedIndexer converts the hidden state into a coordinate (ΞΌ) and uncertainty (Οƒ)
  3. CoordinateMassivePool retrieves K=32 knowledge vectors near ΞΌ, weighted by a Gaussian of width Οƒ
  4. Retrieved knowledge is fused into the hidden state
  5. ACC decides: confident enough? β†’ output. Unsure? β†’ loop again

Simple questions finish in 1–2 loops. Hard questions use all 6. Compute is spent where it's needed.


Breaking the VRAM Wall

VRAM Comparison

A 70B dense model requires 80GB+ of expensive HBM VRAM just to load. Because DPSNR stores knowledge as a flat array of vectors (not entangled with transformer weights), the pool can live in:

  • System RAM β€” 64GB RAM holds ~130M vectors Γ— 768-dim at float32
  • NVMe SSD β€” mmap'd; only the retrieved window is paged in
  • GPU VRAM β€” only the TinyController (~1.3GB at bf16) needs the GPU
Dense 70B:  [GPU|β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“ 80GB VRAM β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“]
DPSNR:      [GPU|β–“ 4GB]  +  [RAM|β–“β–“β–“β–“β–“β–“ Pool β–“β–“β–“β–“β–“β–“]  ← no problem

Quick Start (Inference)

1. Activate the virtualenv

cd /path/to/dpsn
source .venv/bin/activate

2. Verify GPU is available

python -c "import jax; print(jax.devices())"
# β†’ [CudaDevice(id=0)]

3. Run inference

# Single prompt
python infer.py --prompt "The future of artificial intelligence"

# Interactive chat mode
python infer.py

# All options
python infer.py \
    --prompt    "Once upon a time" \
    --max_tokens 200 \
    --temp       0.8 \
    --top_k      50 \
    --penalty    1.3

The first run takes ~20–30s to JIT-compile the forward pass. Subsequent prompts in the same session are fast.


Inference Script β€” infer.py

The file infer.py is fully self-contained β€” it has the entire model architecture, checkpoint loading, and generation logic in one file. No dependency on the dpsn_r_jax package.

infer.py
β”œβ”€β”€ DPSNRConfig          ← Large model config, hardcoded
β”œβ”€β”€ FlashCausalSelfAttention
β”œβ”€β”€ TinyFFN / TinyTransformerLayer
β”œβ”€β”€ TinyController       ← 12-layer transformer encoder + LM head
β”œβ”€β”€ LearnedIndexer       ← ΞΌ, Οƒ coordinate predictor
β”œβ”€β”€ CoordinateMassivePool  ← 1D flat pool (used by large config)
β”œβ”€β”€ CoordinateMassivePool2D ← 2D grid pool (use_2d_pool=True)
β”œβ”€β”€ AdaptiveComputeController ← halt/loop decision
β”œβ”€β”€ DPSNR                ← full forward pass, reasoning scan
β”œβ”€β”€ TrainState           ← pytree-compatible state for orbax restore
β”œβ”€β”€ load_checkpoint()    ← restores params only (no optimizer bloat)
β”œβ”€β”€ _forward()           ← @jax.jit compiled forward pass
└── generate()           ← autoregressive sampling, fixed-size buffers

CLI arguments

Argument Default Description
--prompt None Text prompt. Omit to enter interactive mode
--max_tokens 100 Maximum new tokens to generate
--temp 0.7 Sampling temperature. Lower = more focused
--top_k 40 Only sample from top-K most likely tokens
--penalty 1.2 Repetition penalty. >1 discourages repeats
--checkpoint_dir ./checkpoints_dir Override checkpoint path

Model Configuration (Large)

The large config is hardcoded in infer.py:

DPSNRConfig(
    vocab_size            = 50257,   # GPT-Neo tokenizer vocab
    controller_hidden_dim = 768,     # transformer width
    controller_num_layers = 12,      # transformer depth
    controller_num_heads  = 12,      # attention heads
    max_seq_len           = 1024,    # max context window
    pool_total_vectors    = 262144,  # 2^18 knowledge vectors
    pool_hidden_dim       = 768,     # vector dimension
    max_reasoning_loops   = 6,       # max iterations of the loop
)

Model size breakdown

pie title DPSNR Large β€” Parameter Distribution (~350M total)
    "CoordinateMassivePool (262K Γ— 768)" : 201
    "TinyController (12L Γ— 768d)" : 85
    "LearnedIndexer" : 3
    "AdaptiveComputeController" : 2
    "Retrieval Integrator" : 9

Tokenizer

Uses EleutherAI/gpt-neo-125M tokenizer β€” GPT-2 compatible BPE with 50,257 tokens. Downloaded automatically via HuggingFace on first use.


Key Ideas Explained Simply

Why the pool doesn't slow things down

Every retrieval fetches exactly K=32 vectors regardless of pool size. Going from 10K to 100B pool vectors doesn't add a single FLOP β€” only the storage grows.

xychart-beta
    title "Inference Latency vs Pool Size"
    x-axis ["10K vectors", "100K", "262K", "1M", "1B", "100B"]
    y-axis "Relative Latency" 0 --> 2
    line [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

Why Gaussian retrieval is better than nearest-neighbour

Nearest-neighbour lookup (like a typical vector database) must search the entire pool. DPSNR uses a coordinate approach: the pool is arranged in a continuous 1D (or 2D grid) space. The indexer predicts a position ΞΌ and width Οƒ, and we simply slice a window. No search required β€” it's a direct lookup with jax.lax.dynamic_slice.

Why Οƒ matters

  • Small Οƒ β†’ sharp, precise retrieval (good for exact facts, code syntax)
  • Large Οƒ β†’ broad, averaged retrieval (good for general context)

Οƒ is learned per token, per reasoning step β€” the model naturally figures out how precise to be.


Performance

Metric Value Notes
Training platform TPU v5e-8 8-chip pod slice
Throughput 240–250K tokens/sec HBM bandwidth bound
Sustained compute 260–270 TFLOPS Below 393 TFLOPS peak
Bottleneck Memory bandwidth Pool gather ops, not MXU
Optimizer speedup vs dense 590Γ— Sparse Adam on retrieved indices only
Checkpoint step 31,000
GPU VRAM (inference) ~1.3GB (params only, bf16) Pool can live off-device
Inference tested on NVIDIA RTX 2050 (4GB) Consumer GPU confirmed

Dependencies

jax + jaxlib   ← Core ML framework (GPU/TPU backend)
flax           ← Neural network layers and module API
optax          ← Optimizers (used for checkpoint structure only)
orbax          ← Checkpoint save/restore
transformers   ← Tokenizer (HuggingFace)

Install:

pip install jax jaxlib flax optax orbax-checkpoint transformers

Citation / Reference

DPSNR: Disaggregated Parameter Selection Network with Reasoning
Architecture: TinyController + CoordinateMassivePool + LearnedIndexer + ACC
Implementation: JAX/Flax
Checkpoint: step 31,000