the-dev-kumar commited on
Commit
c7f839a
Β·
verified Β·
1 Parent(s): e32098d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/architecture.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/reasoning_loop.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/vram_comparison.png filter=lfs diff=lfs merge=lfs -text
39
+ default/d/40b6413d8d045553235110cf8a3113dc filter=lfs diff=lfs merge=lfs -text
40
+ default/ocdbt.process_0/d/2655f27744aa28bc57a54732ca8aa17f filter=lfs diff=lfs merge=lfs -text
41
+ default/ocdbt.process_0/d/309bfd1f96632d6760dd55dea979babf filter=lfs diff=lfs merge=lfs -text
42
+ default/ocdbt.process_0/d/37b57931b1bb0df657d81dd245946279 filter=lfs diff=lfs merge=lfs -text
43
+ default/ocdbt.process_0/d/7d3a8dd28172f4fc4fe186eaa73f2843 filter=lfs diff=lfs merge=lfs -text
44
+ default/ocdbt.process_0/d/cca1e2cb6509e2bae33156603f3ff2de filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,306 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - jax
7
+ - flax
8
+ - language-model
9
+ - text-generation
10
+ - retrieval-augmented
11
+ - custom-architecture
12
+ - research
13
+ library_name: flax
14
+ pipeline_tag: text-generation
15
+ model_type: dpsnr
16
+ datasets:
17
+ - fineweb
18
+ metrics:
19
+ - perplexity
20
+ inference: false
21
+ widget:
22
+ - text: "The future of artificial intelligence"
23
+ example_title: "AI Future"
24
+ - text: "Once upon a time in a land"
25
+ example_title: "Story"
26
+ - text: "The key insight of this paper is"
27
+ example_title: "Research"
28
+ model-index:
29
+ - name: DPSNR-Large
30
+ results: []
31
+ ---
32
+
33
+ # DPSNR β€” Dynamic Parameter Selection Network with Reasoning
34
+
35
+ > **A JAX/Flax language model that separates *what it knows* from *how it thinks* β€” so the knowledge can grow to 100B+ vectors while inference stays fast and cheap.**
36
+
37
+ ---
38
+
39
+ ## What Is DPSNR?
40
+
41
+ Normal large language models (GPT, Llama, etc.) mix logic and facts together inside the same transformer weights. When you want more knowledge, you need more parameters, which means more GPU VRAM, more compute, more cost β€” the **VRAM Wall**.
42
+
43
+ DPSNR breaks that wall. It splits the model into two parts:
44
+
45
+ | Part | Role | Size |
46
+ |------|------|------|
47
+ | **TinyController** | Does the thinking / reasoning | ~350M params on GPU |
48
+ | **CoordinateMassivePool** | Stores world knowledge as vectors | 262K–1T+ vectors, can live on disk |
49
+
50
+ The controller *queries* the pool each reasoning step instead of storing facts in its weights. Pool size can grow arbitrarily; inference cost stays **O(1)**.
51
+
52
+ ---
53
+
54
+ ## Architecture Overview
55
+
56
+ ![DPSNR Architecture](assets/architecture.png)
57
+
58
+ The model has **4 components** that work together:
59
+
60
+ ```mermaid
61
+ flowchart LR
62
+ Input["πŸ—’οΈ Input Tokens"] --> TC
63
+
64
+ subgraph TC["β‘  TinyController"]
65
+ direction TB
66
+ E["Token + Position\nEmbedding"] --> TL["12Γ— Transformer\nLayers (768-dim)"] --> H["Hidden States\n(B, T, 768)"]
67
+ end
68
+
69
+ TC --> LI
70
+
71
+ subgraph LI["β‘‘ LearnedIndexer"]
72
+ direction TB
73
+ AP["Attention Pooling\n(learn which token to query from)"] --> MH["Multi-Head Dense\n→ μ coordinate\n→ σ bandwidth"]
74
+ end
75
+
76
+ LI -->|"ΞΌ, Οƒ"| Pool
77
+
78
+ subgraph Pool["β‘’ CoordinateMassivePool"]
79
+ direction TB
80
+ PS["262,144 Γ— 768\nlearned vectors"] --> GW["Gaussian window\naround ΞΌ Β± K vectors\nweighted by Οƒ"] --> AV["Aggregated\nKnowledge Vector\n(B, 768)"]
81
+ end
82
+
83
+ Pool --> ACC
84
+
85
+ subgraph ACC["β‘£ Adaptive Compute Controller"]
86
+ direction TB
87
+ RI["Integrate knowledge\ninto hidden state"] --> HN["Halt Network\n(should we stop?)"]
88
+ HN -->|"halt < 0.99"| RI
89
+ end
90
+
91
+ ACC -->|"Final hidden state"| Out["πŸ“ Output Logits\n(B, T, vocab)"]
92
+ ACC -->|"loop back\n(up to 6 times)"| TC
93
+ ```
94
+
95
+ ---
96
+
97
+ ## How the Reasoning Loop Works
98
+
99
+ Instead of doing one pass like most LLMs, DPSNR thinks iteratively β€” like a human reading and re-reading a hard problem.
100
+
101
+ ![Reasoning Loop](assets/reasoning_loop.png)
102
+
103
+ Each loop:
104
+ 1. **TinyController** encodes the input β†’ produces a hidden state
105
+ 2. **LearnedIndexer** converts the hidden state into a *coordinate* (ΞΌ) and *uncertainty* (Οƒ)
106
+ 3. **CoordinateMassivePool** retrieves K=32 knowledge vectors near ΞΌ, weighted by a Gaussian of width Οƒ
107
+ 4. Retrieved knowledge is fused into the hidden state
108
+ 5. **ACC** decides: confident enough? β†’ output. Unsure? β†’ loop again
109
+
110
+ Simple questions finish in 1–2 loops. Hard questions use all 6. Compute is spent where it's needed.
111
+
112
+ ---
113
+
114
+ ## Breaking the VRAM Wall
115
+
116
+ ![VRAM Comparison](assets/vram_comparison.png)
117
+
118
+ A 70B dense model requires 80GB+ of expensive HBM VRAM just to load. Because DPSNR stores knowledge as a flat array of vectors (not entangled with transformer weights), the pool can live in:
119
+
120
+ - **System RAM** β€” 64GB RAM holds ~130M vectors Γ— 768-dim at float32
121
+ - **NVMe SSD** β€” mmap'd; only the retrieved window is paged in
122
+ - **GPU VRAM** β€” only the TinyController (~1.3GB at bf16) needs the GPU
123
+
124
+ ```
125
+ Dense 70B: [GPU|β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“ 80GB VRAM β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“]
126
+ DPSNR: [GPU|β–“ 4GB] + [RAM|β–“β–“β–“β–“β–“β–“ Pool β–“β–“β–“β–“β–“β–“] ← no problem
127
+ ```
128
+
129
+ ---
130
+
131
+ ## Quick Start (Inference)
132
+
133
+ ### 1. Activate the virtualenv
134
+
135
+ ```bash
136
+ cd /path/to/dpsn
137
+ source .venv/bin/activate
138
+ ```
139
+
140
+ ### 2. Verify GPU is available
141
+
142
+ ```bash
143
+ python -c "import jax; print(jax.devices())"
144
+ # β†’ [CudaDevice(id=0)]
145
+ ```
146
+
147
+ ### 3. Run inference
148
+
149
+ ```bash
150
+ # Single prompt
151
+ python infer.py --prompt "The future of artificial intelligence"
152
+
153
+ # Interactive chat mode
154
+ python infer.py
155
+
156
+ # All options
157
+ python infer.py \
158
+ --prompt "Once upon a time" \
159
+ --max_tokens 200 \
160
+ --temp 0.8 \
161
+ --top_k 50 \
162
+ --penalty 1.3
163
+ ```
164
+
165
+ The first run takes ~20–30s to JIT-compile the forward pass. Subsequent prompts in the same session are fast.
166
+
167
+ ---
168
+
169
+ ## Inference Script β€” `infer.py`
170
+
171
+ The file `infer.py` is **fully self-contained** β€” it has the entire model architecture, checkpoint loading, and generation logic in one file. No dependency on the `dpsn_r_jax` package.
172
+
173
+ ```
174
+ infer.py
175
+ β”œβ”€β”€ DPSNRConfig ← Large model config, hardcoded
176
+ β”œβ”€β”€ FlashCausalSelfAttention
177
+ β”œβ”€β”€ TinyFFN / TinyTransformerLayer
178
+ β”œβ”€β”€ TinyController ← 12-layer transformer encoder + LM head
179
+ β”œβ”€β”€ LearnedIndexer ← ΞΌ, Οƒ coordinate predictor
180
+ β”œβ”€β”€ CoordinateMassivePool ← 1D flat pool (used by large config)
181
+ β”œβ”€β”€ CoordinateMassivePool2D ← 2D grid pool (use_2d_pool=True)
182
+ β”œβ”€β”€ AdaptiveComputeController ← halt/loop decision
183
+ β”œβ”€β”€ DPSNR ← full forward pass, reasoning scan
184
+ β”œβ”€β”€ TrainState ← pytree-compatible state for orbax restore
185
+ β”œβ”€β”€ load_checkpoint() ← restores params only (no optimizer bloat)
186
+ β”œβ”€β”€ _forward() ← @jax.jit compiled forward pass
187
+ └── generate() ← autoregressive sampling, fixed-size buffers
188
+ ```
189
+
190
+ ### CLI arguments
191
+
192
+ | Argument | Default | Description |
193
+ |---|---|---|
194
+ | `--prompt` | None | Text prompt. Omit to enter interactive mode |
195
+ | `--max_tokens` | 100 | Maximum new tokens to generate |
196
+ | `--temp` | 0.7 | Sampling temperature. Lower = more focused |
197
+ | `--top_k` | 40 | Only sample from top-K most likely tokens |
198
+ | `--penalty` | 1.2 | Repetition penalty. >1 discourages repeats |
199
+ | `--checkpoint_dir` | `./checkpoints_dir` | Override checkpoint path |
200
+
201
+ ---
202
+
203
+ ## Model Configuration (Large)
204
+
205
+ The `large` config is hardcoded in `infer.py`:
206
+
207
+ ```python
208
+ DPSNRConfig(
209
+ vocab_size = 50257, # GPT-Neo tokenizer vocab
210
+ controller_hidden_dim = 768, # transformer width
211
+ controller_num_layers = 12, # transformer depth
212
+ controller_num_heads = 12, # attention heads
213
+ max_seq_len = 1024, # max context window
214
+ pool_total_vectors = 262144, # 2^18 knowledge vectors
215
+ pool_hidden_dim = 768, # vector dimension
216
+ max_reasoning_loops = 6, # max iterations of the loop
217
+ )
218
+ ```
219
+
220
+ ### Model size breakdown
221
+
222
+ ```mermaid
223
+ pie title DPSNR Large β€” Parameter Distribution (~350M total)
224
+ "CoordinateMassivePool (262K Γ— 768)" : 201
225
+ "TinyController (12L Γ— 768d)" : 85
226
+ "LearnedIndexer" : 3
227
+ "AdaptiveComputeController" : 2
228
+ "Retrieval Integrator" : 9
229
+ ```
230
+
231
+ ---
232
+
233
+ ## Tokenizer
234
+
235
+ Uses **`EleutherAI/gpt-neo-125M`** tokenizer β€” GPT-2 compatible BPE with 50,257 tokens. Downloaded automatically via HuggingFace on first use.
236
+
237
+ ---
238
+
239
+
240
+ ## Key Ideas Explained Simply
241
+
242
+ ### Why the pool doesn't slow things down
243
+
244
+ Every retrieval fetches exactly `K=32` vectors regardless of pool size. Going from 10K to 100B pool vectors doesn't add a single FLOP β€” only the storage grows.
245
+
246
+ ```mermaid
247
+ xychart-beta
248
+ title "Inference Latency vs Pool Size"
249
+ x-axis ["10K vectors", "100K", "262K", "1M", "1B", "100B"]
250
+ y-axis "Relative Latency" 0 --> 2
251
+ line [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
252
+ ```
253
+
254
+ ### Why Gaussian retrieval is better than nearest-neighbour
255
+
256
+ Nearest-neighbour lookup (like a typical vector database) must search the entire pool. DPSNR uses a **coordinate** approach: the pool is arranged in a continuous 1D (or 2D grid) space. The indexer predicts a *position* ΞΌ and *width* Οƒ, and we simply slice a window. No search required β€” it's a direct lookup with `jax.lax.dynamic_slice`.
257
+
258
+ ### Why Οƒ matters
259
+
260
+ - **Small Οƒ** β†’ sharp, precise retrieval (good for exact facts, code syntax)
261
+ - **Large Οƒ** β†’ broad, averaged retrieval (good for general context)
262
+
263
+ Οƒ is learned per token, per reasoning step β€” the model naturally figures out how precise to be.
264
+
265
+ ---
266
+
267
+ ## Performance
268
+
269
+ | Metric | Value | Notes |
270
+ |---|---|---|
271
+ | Training platform | TPU v5e-8 | 8-chip pod slice |
272
+ | Throughput | **240–250K tokens/sec** | HBM bandwidth bound |
273
+ | Sustained compute | **260–270 TFLOPS** | Below 393 TFLOPS peak |
274
+ | Bottleneck | Memory bandwidth | Pool gather ops, not MXU |
275
+ | Optimizer speedup vs dense | **590Γ—** | Sparse Adam on retrieved indices only |
276
+ | Checkpoint step | 31,000 | |
277
+ | GPU VRAM (inference) | ~1.3GB (params only, bf16) | Pool can live off-device |
278
+ | Inference tested on | NVIDIA RTX 2050 (4GB) | Consumer GPU confirmed |
279
+
280
+ ---
281
+
282
+ ## Dependencies
283
+
284
+ ```
285
+ jax + jaxlib ← Core ML framework (GPU/TPU backend)
286
+ flax ← Neural network layers and module API
287
+ optax ← Optimizers (used for checkpoint structure only)
288
+ orbax ← Checkpoint save/restore
289
+ transformers ← Tokenizer (HuggingFace)
290
+ ```
291
+
292
+ Install:
293
+ ```bash
294
+ pip install jax jaxlib flax optax orbax-checkpoint transformers
295
+ ```
296
+
297
+ ---
298
+
299
+ ## Citation / Reference
300
+
301
+ ```
302
+ DPSNR: Disaggregated Parameter Selection Network with Reasoning
303
+ Architecture: TinyController + CoordinateMassivePool + LearnedIndexer + ACC
304
+ Implementation: JAX/Flax
305
+ Checkpoint: step 31,000
306
+ ```
_CHECKPOINT_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"item_handlers": {"default": "orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1772564676365471524, "commit_timestamp_nsecs": 1772564878381366150, "custom_metadata": {}}
assets/architecture.png ADDED

Git LFS Details

  • SHA256: 128751f593e700c87cc47c9f6bd241152c8b820629fad5909d041dde38937c7d
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
assets/reasoning_loop.png ADDED

Git LFS Details

  • SHA256: d06b2746c834c2329907f1ba2037a8949b465c30858fb06bc6926c1f4d8ddaf1
  • Pointer size: 131 Bytes
  • Size of remote file: 442 kB
assets/vram_comparison.png ADDED

Git LFS Details

  • SHA256: 14af82cca6a7d58f755762198324b8cba91e0bad379cd2a9b82b0b8aee7fbae7
  • Pointer size: 131 Bytes
  • Size of remote file: 446 kB
default/_METADATA ADDED
The diff for this file is too large to render. See raw diff
 
default/_sharding ADDED
The diff for this file is too large to render. See raw diff
 
default/array_metadatas/process_0 ADDED
@@ -0,0 +1 @@
 
 
1
+ {"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.halt_net.layers_0.bias", "write_shape": [24], "chunk_shape": [24], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.halt_net.layers_0.kernel", "write_shape": [96, 192], "chunk_shape": [96, 192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.halt_net.layers_2.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.halt_net.layers_2.kernel", "write_shape": [24, 1], "chunk_shape": [24, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.loop_embed.embedding", "write_shape": [4, 768], "chunk_shape": [4, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.state_gate.layers_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.state_gate.layers_0.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.state_norm.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.state_norm.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.state_transform.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.acc.state_transform.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.embedding.embedding", "write_shape": [50257, 96], "chunk_shape": [50257, 96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.final_norm.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.final_norm.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_0.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_1.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_10.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_11.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_2.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_3.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_4.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_5.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_6.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_7.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_8.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.layers_9.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.lm_head.kernel", "write_shape": [96, 50257], "chunk_shape": [96, 50257], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.controller.pos_encoding.embedding", "write_shape": [128, 768], "chunk_shape": [128, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_0.kernel", "write_shape": [96, 1], "chunk_shape": [96, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_2.bias", "write_shape": [48], "chunk_shape": [48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_2.kernel", "write_shape": [96, 384], "chunk_shape": [96, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_3.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_3.kernel", "write_shape": [48, 1], "chunk_shape": [48, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_4.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.indexer.Dense_4.kernel", "write_shape": [48, 1], "chunk_shape": [48, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.pool.params_storage", "write_shape": [32768, 768], "chunk_shape": [32768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.retrieval_integrator.layers_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.retrieval_integrator.layers_0.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.retrieval_integrator.layers_2.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.retrieval_integrator.layers_2.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.retrieval_integrator.layers_3.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.retrieval_integrator.layers_3.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.halt_net.layers_0.bias", "write_shape": [24], "chunk_shape": [24], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.halt_net.layers_0.kernel", "write_shape": [96, 192], "chunk_shape": [96, 192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.halt_net.layers_2.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.halt_net.layers_2.kernel", "write_shape": [24, 1], "chunk_shape": [24, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.loop_embed.embedding", "write_shape": [4, 768], "chunk_shape": [4, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.state_gate.layers_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.state_gate.layers_0.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.state_norm.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.state_norm.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.state_transform.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.acc.state_transform.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.embedding.embedding", "write_shape": [50257, 96], "chunk_shape": [50257, 96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.final_norm.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.final_norm.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_0.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_1.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_10.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_11.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_2.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_3.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_4.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_5.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_6.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_7.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_8.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.layers_9.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.lm_head.kernel", "write_shape": [96, 50257], "chunk_shape": [96, 50257], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.controller.pos_encoding.embedding", "write_shape": [128, 768], "chunk_shape": [128, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_0.kernel", "write_shape": [96, 1], "chunk_shape": [96, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_2.bias", "write_shape": [48], "chunk_shape": [48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_2.kernel", "write_shape": [96, 384], "chunk_shape": [96, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_3.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_3.kernel", "write_shape": [48, 1], "chunk_shape": [48, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_4.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.indexer.Dense_4.kernel", "write_shape": [48, 1], "chunk_shape": [48, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.retrieval_integrator.layers_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.retrieval_integrator.layers_0.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.retrieval_integrator.layers_2.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.retrieval_integrator.layers_2.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.retrieval_integrator.layers_3.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.retrieval_integrator.layers_3.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.halt_net.layers_0.bias", "write_shape": [24], "chunk_shape": [24], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.halt_net.layers_0.kernel", "write_shape": [96, 192], "chunk_shape": [96, 192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.halt_net.layers_2.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.halt_net.layers_2.kernel", "write_shape": [24, 1], "chunk_shape": [24, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.loop_embed.embedding", "write_shape": [4, 768], "chunk_shape": [4, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.state_gate.layers_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.state_gate.layers_0.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.state_norm.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.state_norm.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.state_transform.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.acc.state_transform.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.embedding.embedding", "write_shape": [50257, 96], "chunk_shape": [50257, 96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.final_norm.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.final_norm.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_0.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_1.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_10.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_11.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_2.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_3.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_4.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_5.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_6.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_7.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_8.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.FlashCausalSelfAttention_0.Dense_0.kernel", "write_shape": [96, 2304], "chunk_shape": [96, 2304], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.FlashCausalSelfAttention_0.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.LayerNorm_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.LayerNorm_0.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.LayerNorm_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.LayerNorm_1.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.TinyFFN_0.Dense_0.bias", "write_shape": [192], "chunk_shape": [192], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.TinyFFN_0.Dense_0.kernel", "write_shape": [96, 1536], "chunk_shape": [96, 1536], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.TinyFFN_0.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.layers_9.TinyFFN_0.Dense_1.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.lm_head.kernel", "write_shape": [96, 50257], "chunk_shape": [96, 50257], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.controller.pos_encoding.embedding", "write_shape": [128, 768], "chunk_shape": [128, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_0.kernel", "write_shape": [96, 1], "chunk_shape": [96, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_1.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_1.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_2.bias", "write_shape": [48], "chunk_shape": [48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_2.kernel", "write_shape": [96, 384], "chunk_shape": [96, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_3.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_3.kernel", "write_shape": [48, 1], "chunk_shape": [48, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_4.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.indexer.Dense_4.kernel", "write_shape": [48, 1], "chunk_shape": [48, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.retrieval_integrator.layers_0.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.retrieval_integrator.layers_0.kernel", "write_shape": [192, 768], "chunk_shape": [192, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.retrieval_integrator.layers_2.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.retrieval_integrator.layers_2.kernel", "write_shape": [96, 768], "chunk_shape": [96, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.retrieval_integrator.layers_3.bias", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.retrieval_integrator.layers_3.scale", "write_shape": [96], "chunk_shape": [96], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.2.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "rng", "write_shape": [2], "chunk_shape": [2], "ext_metadata": null}}, {"array_metadata": {"param_name": "pool_m", "write_shape": [32768, 768], "chunk_shape": [32768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "pool_v", "write_shape": [32768, 768], "chunk_shape": [32768, 768], "ext_metadata": null}}]}
default/d/40b6413d8d045553235110cf8a3113dc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ab1acca67161a769f0a2c74c17105b0c0940d21fef602908981624325175d88
3
+ size 865274
default/manifest.ocdbt ADDED
Binary file (120 Bytes). View file
 
default/ocdbt.process_0/d/2655f27744aa28bc57a54732ca8aa17f ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f891a2d70042b5fb12f778dcda445e55fdf47afbe900bc41f2ad67ba2cf2c892
3
+ size 18268160
default/ocdbt.process_0/d/309bfd1f96632d6760dd55dea979babf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a6225c169f137eff87d4374fb90deafe3a09e4f768669a0e40fcaa0c33085dd
3
+ size 865238
default/ocdbt.process_0/d/37b57931b1bb0df657d81dd245946279 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3742f0bdd4d5e21c32bf96f4091d9012ce808d10267013b3866f2a92baea767
3
+ size 1458647040
default/ocdbt.process_0/d/4319cb1782ea3890b355c019a87bd8c1 ADDED
Binary file (1 kB). View file
 
default/ocdbt.process_0/d/7d3a8dd28172f4fc4fe186eaa73f2843 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9347314243a2c765d8656ed631af5ff405617703c08190ed52a53bc3a88361a6
3
+ size 2152075264
default/ocdbt.process_0/d/b5804de1cb04f30d793b9b8393d7f599 ADDED
Binary file (171 Bytes). View file
 
default/ocdbt.process_0/d/cca1e2cb6509e2bae33156603f3ff2de ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a887f4933dfab73175adeb17d3037574d5c643c8bab23cfed58f6ca66b8e992
3
+ size 71282688
default/ocdbt.process_0/d/e6677b886d30c14daa1f696397120498 ADDED
Binary file (1.05 kB). View file
 
default/ocdbt.process_0/manifest.ocdbt ADDED
Binary file (402 Bytes). View file
 
infer.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DPSNR Inference β€” Fully self-contained single-file GPU inference for the Large model.
4
+
5
+ This file contains the ENTIRE model architecture, checkpoint loading, and generation
6
+ logic. It has ZERO dependencies on the dpsn_r_jax package.
7
+
8
+ Usage:
9
+ source .venv/bin/activate
10
+
11
+ # Single prompt
12
+ python infer.py --prompt "Once upon a time"
13
+
14
+ # Interactive mode (default)
15
+ python infer.py
16
+
17
+ # Adjust generation parameters
18
+ python infer.py --prompt "The future of AI" --max_tokens 200 --temp 0.8 --top_k 50
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import time
24
+ import argparse
25
+ from dataclasses import dataclass, field
26
+ from collections import namedtuple
27
+ from typing import Any, Callable, Optional
28
+
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import flax.linen as nn
34
+ from jax import lax
35
+ from flax.training import train_state
36
+ from flax import struct, traverse_util
37
+ import optax
38
+ import orbax.checkpoint
39
+ from functools import partial
40
+ from transformers import AutoTokenizer
41
+
42
+
43
+ # ═══════════════════════════════════════════════════════════════════════════════
44
+ # DEVICE
45
+ # ═══════════════════════════════════════════════════════════════════════════════
46
+ DEVICE = jax.devices()[0]
47
+ PLATFORM = DEVICE.platform
48
+ print(f"[Device] {DEVICE} (platform: {PLATFORM})")
49
+
50
+
51
+ # ═══════════════════════════════════════════════════════════════════════════════
52
+ # CONFIG β€” Large model, hardcoded
53
+ # ═══════════════════════════════════════════════════════════════════════════════
54
+ TOKENIZER_NAME = "EleutherAI/gpt-neo-125M"
55
+ CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints_dir")
56
+
57
+
58
+ @dataclass
59
+ class PoolConfig:
60
+ total_vectors: int
61
+ hidden_dim: int
62
+
63
+
64
+ @dataclass
65
+ class DPSNRConfig:
66
+ vocab_size: int = 50257
67
+ controller_hidden_dim: int = 768
68
+ controller_num_layers: int = 12
69
+ controller_num_heads: int = 12
70
+ controller_ff_multiplier: float = 2.0
71
+ max_seq_len: int = 1024
72
+ dropout: float = 0.0
73
+ pool_total_vectors: int = 262144
74
+ pool_hidden_dim: int = 768
75
+ librarian_hidden_dim: int = 32
76
+ max_reasoning_loops: int = 6
77
+ min_reasoning_loops: int = 1
78
+ halt_threshold: float = 0.99
79
+ min_k: int = 4
80
+ max_k: int = 32
81
+ num_clusters_to_search: int = 4
82
+ pad_token_id: int = 0
83
+ learning_rate: float = 3e-4
84
+ gradient_checkpointing: bool = False
85
+ use_bf16: bool = False
86
+ num_indexer_heads: int = 1
87
+ sigma_min: float = 0.01
88
+ sigma_max: float = 5.0
89
+ use_2d_pool: bool = False
90
+ pool_grid_rows: int = 512
91
+ pool_grid_cols: int = 512
92
+ sigma_anneal_steps: int = 0
93
+ sigma_target: float = 0.05
94
+ precision_loss_weight: float = 0.0
95
+ # Fields needed by create_train_state but unused for inference
96
+ streaming: bool = True
97
+ hf_dataset_name: Optional[str] = None
98
+ hf_tokenizer_name: Optional[str] = None
99
+ max_steps: Optional[int] = None
100
+ generation_steps: Optional[int] = None
101
+ generation_max_tokens: int = 20
102
+ generation_prompts: Optional[list] = None
103
+ num_workers: int = 4
104
+ loss_chunk_size: int = 0
105
+ finetune: Optional[Any] = None
106
+
107
+
108
+ CONFIG = DPSNRConfig()
109
+
110
+
111
+ # ═══════════════════════════════════════════════════════════════════════════════
112
+ # MODEL LAYERS
113
+ # ═══════════════════════════════════════════════════════════════════════════════
114
+
115
+ class FlashCausalSelfAttention(nn.Module):
116
+ hidden_dim: int
117
+ num_heads: int
118
+ dropout_rate: float = 0.0
119
+
120
+ @nn.compact
121
+ def __call__(self, x, mask=None, deterministic=True):
122
+ head_dim = self.hidden_dim // self.num_heads
123
+
124
+ qkv = nn.Dense(3 * self.hidden_dim, use_bias=False)(x)
125
+ q, k, v = jnp.split(qkv, 3, axis=-1)
126
+
127
+ q = q.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim)
128
+ k = k.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim)
129
+ v = v.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim)
130
+
131
+ dropout_rng = (
132
+ self.make_rng("dropout")
133
+ if not deterministic and self.dropout_rate > 0
134
+ else None
135
+ )
136
+
137
+ y = nn.dot_product_attention(
138
+ q, k, v,
139
+ bias=mask,
140
+ dropout_rate=self.dropout_rate,
141
+ deterministic=deterministic,
142
+ dropout_rng=dropout_rng,
143
+ )
144
+
145
+ y = y.reshape(x.shape[0], x.shape[1], self.hidden_dim)
146
+ y = nn.Dense(self.hidden_dim, use_bias=False)(y)
147
+
148
+ if not deterministic:
149
+ y = nn.Dropout(self.dropout_rate)(y, deterministic=deterministic)
150
+
151
+ return y
152
+
153
+
154
+ class TinyFFN(nn.Module):
155
+ hidden_dim: int
156
+ ff_dim: int
157
+ dropout_rate: float = 0.0
158
+
159
+ @nn.compact
160
+ def __call__(self, x, deterministic=True):
161
+ x = nn.Dense(self.ff_dim)(x)
162
+ x = nn.gelu(x)
163
+ if not deterministic:
164
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic)
165
+ x = nn.Dense(self.hidden_dim)(x)
166
+ if not deterministic:
167
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic)
168
+ return x
169
+
170
+
171
+ class TinyTransformerLayer(nn.Module):
172
+ hidden_dim: int
173
+ num_heads: int
174
+ ff_dim: int
175
+ dropout_rate: float = 0.0
176
+
177
+ @nn.compact
178
+ def __call__(self, x, mask=None, deterministic=True):
179
+ norm1 = nn.LayerNorm()(x)
180
+ attn_out = FlashCausalSelfAttention(
181
+ self.hidden_dim, self.num_heads, self.dropout_rate
182
+ )(norm1, mask=mask, deterministic=deterministic)
183
+ x = x + attn_out
184
+
185
+ norm2 = nn.LayerNorm()(x)
186
+ ffn_out = TinyFFN(self.hidden_dim, self.ff_dim, self.dropout_rate)(
187
+ norm2, deterministic=deterministic
188
+ )
189
+ x = x + ffn_out
190
+ return x
191
+
192
+
193
+ # ═══════════════════════════════════════════════════════════════════════════════
194
+ # CONTROLLER
195
+ # ═══════════════════════════════════════════════════════════════════════════════
196
+
197
+ class TinyController(nn.Module):
198
+ config: DPSNRConfig
199
+
200
+ def setup(self):
201
+ self.embedding = nn.Embed(
202
+ self.config.vocab_size, self.config.controller_hidden_dim
203
+ )
204
+ self.pos_encoding = nn.Embed(
205
+ self.config.max_seq_len, self.config.controller_hidden_dim
206
+ )
207
+
208
+ ff_dim = int(
209
+ self.config.controller_hidden_dim * self.config.controller_ff_multiplier
210
+ )
211
+ layer_cls = TinyTransformerLayer
212
+ if self.config.gradient_checkpointing:
213
+ layer_cls = nn.remat(TinyTransformerLayer, static_argnums=(3,))
214
+
215
+ self.layers = [
216
+ layer_cls(
217
+ self.config.controller_hidden_dim,
218
+ self.config.controller_num_heads,
219
+ ff_dim,
220
+ self.config.dropout,
221
+ )
222
+ for _ in range(self.config.controller_num_layers)
223
+ ]
224
+
225
+ self.final_norm = nn.LayerNorm()
226
+ self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)
227
+
228
+ def __call__(self, input_ids, deterministic=True):
229
+ return self.encode(input_ids, deterministic)
230
+
231
+ def encode(self, input_ids, deterministic=True):
232
+ B, T = input_ids.shape
233
+ embed = self.embedding(input_ids)
234
+ pos_ids = jnp.arange(T)[None, :]
235
+ pos_embed = self.pos_encoding(pos_ids)
236
+ x = embed + pos_embed
237
+
238
+ mask = nn.make_causal_mask(input_ids)
239
+ mask = jnp.where(mask, 0, -1e4)
240
+
241
+ for layer in self.layers:
242
+ x = layer(x, mask, deterministic)
243
+
244
+ return x
245
+
246
+ def decode(self, hidden):
247
+ x = self.final_norm(hidden)
248
+ logits = self.lm_head(x)
249
+ return logits
250
+
251
+
252
+ # ═══════════════════════════════════════════════════════════════════════════════
253
+ # MEMORY β€” Learned Indexer + 1D/2D Pool
254
+ # ═══════════════════════════════════════════════════════════════════════════════
255
+
256
+ class LearnedIndexer(nn.Module):
257
+ hidden_dim: int
258
+ num_heads: int = 1
259
+ sigma_min: float = 0.01
260
+ sigma_max: float = 5.0
261
+
262
+ @nn.compact
263
+ def __call__(self, hidden_states, sigma_max_scale: float = 1.0):
264
+ attn_logits = nn.Dense(1, use_bias=False)(hidden_states)
265
+ attn_weights = jax.nn.softmax(attn_logits, axis=1)
266
+ pooled = jnp.sum(attn_weights * hidden_states, axis=1)
267
+
268
+ x = nn.Dense(self.hidden_dim)(pooled)
269
+ x = nn.gelu(x)
270
+ x = nn.Dense(self.hidden_dim // 2)(x)
271
+ x = nn.gelu(x)
272
+
273
+ mu_raw = nn.Dense(self.num_heads)(x)
274
+ sigma_raw = nn.Dense(self.num_heads)(x)
275
+
276
+ mu = jax.nn.sigmoid(mu_raw)
277
+
278
+ effective_sigma_max = self.sigma_max * sigma_max_scale
279
+ sigma = (
280
+ self.sigma_min
281
+ + (effective_sigma_max - self.sigma_min) * jax.nn.sigmoid(sigma_raw)
282
+ )
283
+
284
+ return mu, sigma
285
+
286
+
287
+ class CoordinateMassivePool(nn.Module):
288
+ config: PoolConfig
289
+ window_size: int
290
+
291
+ def setup(self):
292
+ self.params_storage = self.param(
293
+ "params_storage",
294
+ nn.initializers.normal(),
295
+ (self.config.total_vectors, self.config.hidden_dim),
296
+ )
297
+
298
+ def __call__(self, mu, sigma):
299
+ B = mu.shape[0]
300
+ Total = self.config.total_vectors
301
+ D = self.config.hidden_dim
302
+ W = self.window_size
303
+
304
+ center_idx = mu * (Total - 1)
305
+ start_indices = jnp.clip(center_idx - W // 2, 0, Total - W).astype(jnp.int32)
306
+
307
+ def slice_fn(start):
308
+ return lax.dynamic_slice(self.params_storage, (start, 0), (W, D))
309
+
310
+ selected = jax.vmap(slice_fn)(start_indices)
311
+ relative_indices = jnp.arange(W)[None, :] + start_indices[:, None]
312
+ distances = relative_indices - center_idx[:, None]
313
+ weights = jnp.exp(-(distances**2) / (2 * (sigma[:, None] + 1e-6) ** 2)) + 1e-6
314
+ weights = weights / jnp.sum(weights, axis=-1, keepdims=True)
315
+ aggregated = jnp.einsum("bw,bwd->bd", weights, selected)
316
+
317
+ return aggregated, start_indices
318
+
319
+
320
+ class CoordinateMassivePool2D(nn.Module):
321
+ rows: int
322
+ cols: int
323
+ hidden_dim: int
324
+ window_size: int
325
+
326
+ def setup(self):
327
+ self.params_storage = self.param(
328
+ "params_storage",
329
+ nn.initializers.normal(),
330
+ (self.rows, self.cols, self.hidden_dim),
331
+ )
332
+
333
+ def __call__(self, mu_row, mu_col, sigma):
334
+ B = mu_row.shape[0]
335
+ R = self.rows
336
+ C = self.cols
337
+ D = self.hidden_dim
338
+ W = self.window_size
339
+
340
+ r_center = mu_row * (R - 1)
341
+ r_start = jnp.clip(r_center - W // 2, 0, R - W).astype(jnp.int32)
342
+ c_center = mu_col * (C - 1)
343
+ c_start = jnp.clip(c_center - W // 2, 0, C - W).astype(jnp.int32)
344
+
345
+ def fetch_window(r_s, c_s):
346
+ return lax.dynamic_slice(self.params_storage, (r_s, c_s, 0), (W, W, D))
347
+
348
+ windows = jax.vmap(fetch_window)(r_start, c_start)
349
+
350
+ r_idx = jnp.arange(W)[None, :] + r_start[:, None]
351
+ c_idx = jnp.arange(W)[None, :] + c_start[:, None]
352
+ r_dist = r_idx - r_center[:, None]
353
+ c_dist = c_idx - c_center[:, None]
354
+
355
+ sigma_sq = (sigma + 1e-6) ** 2
356
+ r_w = jnp.exp(-r_dist ** 2 / (2 * sigma_sq[:, None]))
357
+ c_w = jnp.exp(-c_dist ** 2 / (2 * sigma_sq[:, None]))
358
+
359
+ w_2d = jnp.einsum("bi,bj->bij", r_w, c_w) + 1e-6
360
+ w_2d = w_2d / jnp.sum(w_2d, axis=(-2, -1), keepdims=True)
361
+
362
+ aggregated = jnp.einsum("bij,bijd->bd", w_2d, windows)
363
+ flat_start = r_start * C + c_start
364
+
365
+ return aggregated, flat_start
366
+
367
+
368
+ # ═══════════════════════════════════════════════════════════════════════════════
369
+ # REASONING β€” Adaptive Compute Controller
370
+ # ═══════════════════════════════════════════════════════════════════════════════
371
+
372
+ class AdaptiveComputeController(nn.Module):
373
+ hidden_dim: int
374
+ max_loops: int = 8
375
+ halt_threshold: float = 0.99
376
+
377
+ def setup(self):
378
+ self.halt_net = nn.Sequential(
379
+ [nn.Dense(self.hidden_dim // 4), nn.gelu, nn.Dense(1), nn.sigmoid]
380
+ )
381
+ self.state_gate = nn.Sequential([nn.Dense(self.hidden_dim), nn.sigmoid])
382
+ self.state_transform = nn.Dense(self.hidden_dim)
383
+ self.state_norm = nn.LayerNorm()
384
+ self.loop_embed = nn.Embed(32, self.hidden_dim)
385
+
386
+ def __call__(self, state_hidden, step_output, loop_count, current_halt_prob, halted_mask):
387
+ loop_idx = jnp.array([loop_count], dtype=jnp.int32)
388
+ emb = self.loop_embed(loop_idx)
389
+ step_output = step_output + emb
390
+
391
+ combined = jnp.concatenate([step_output, state_hidden], axis=-1)
392
+ g = self.state_gate(combined)
393
+
394
+ candidate_state = g * self.state_transform(step_output) + (1 - g) * state_hidden
395
+ candidate_state = self.state_norm(candidate_state)
396
+
397
+ hp = self.halt_net(candidate_state)
398
+
399
+ still_running_mask = 1.0 - halted_mask
400
+ new_halt_prob = current_halt_prob + hp * still_running_mask
401
+
402
+ is_halted_now = (new_halt_prob >= self.halt_threshold).astype(jnp.float32)
403
+ final_halted_mask = jnp.maximum(halted_mask, is_halted_now)
404
+
405
+ return candidate_state, new_halt_prob, final_halted_mask
406
+
407
+
408
+ # ═══════════════════════════════════════════════════════════════════════════════
409
+ # DPSNR β€” Full model
410
+ # ═══════════════════════════════════════════════════════════════════════════════
411
+
412
+ class DPSNR(nn.Module):
413
+ config: DPSNRConfig
414
+
415
+ def setup(self):
416
+ self.controller = TinyController(self.config)
417
+ self.indexer = LearnedIndexer(
418
+ self.config.controller_hidden_dim,
419
+ num_heads=self.config.num_indexer_heads,
420
+ sigma_min=self.config.sigma_min,
421
+ sigma_max=self.config.sigma_max,
422
+ )
423
+
424
+ if self.config.use_2d_pool:
425
+ axis_window = max(2, int(self.config.max_k ** 0.5))
426
+ self.pool = CoordinateMassivePool2D(
427
+ rows=self.config.pool_grid_rows,
428
+ cols=self.config.pool_grid_cols,
429
+ hidden_dim=self.config.controller_hidden_dim,
430
+ window_size=axis_window,
431
+ )
432
+ else:
433
+ self.pool = CoordinateMassivePool(
434
+ PoolConfig(
435
+ self.config.pool_total_vectors,
436
+ self.config.controller_hidden_dim,
437
+ ),
438
+ window_size=self.config.max_k,
439
+ )
440
+
441
+ self.acc = AdaptiveComputeController(
442
+ self.config.controller_hidden_dim,
443
+ self.config.max_reasoning_loops,
444
+ self.config.halt_threshold,
445
+ )
446
+ self.retrieval_integrator = nn.Sequential(
447
+ [
448
+ nn.Dense(self.config.controller_hidden_dim),
449
+ nn.gelu,
450
+ nn.Dense(self.config.controller_hidden_dim),
451
+ nn.LayerNorm(),
452
+ ]
453
+ )
454
+
455
+ def __call__(self, input_ids, deterministic=True, sigma_max_scale: float = 1.0):
456
+ state_hidden, all_indices, mean_sigma = self._encode_hidden(
457
+ input_ids, deterministic, sigma_max_scale
458
+ )
459
+ logits = self.controller.decode(state_hidden)
460
+ return logits, (self.config.max_reasoning_loops, all_indices, mean_sigma)
461
+
462
+ def encode_to_hidden(self, input_ids, deterministic=True, sigma_max_scale: float = 1.0):
463
+ state_hidden, all_indices, mean_sigma = self._encode_hidden(
464
+ input_ids, deterministic, sigma_max_scale
465
+ )
466
+ return state_hidden, (self.config.max_reasoning_loops, all_indices, mean_sigma)
467
+
468
+ def _encode_hidden(self, input_ids, deterministic=True, sigma_max_scale: float = 1.0):
469
+ hidden = self.controller(input_ids, deterministic)
470
+ state_hidden = hidden
471
+ B, T, D = hidden.shape
472
+
473
+ halt_prob = jnp.zeros((B, T, 1), dtype=hidden.dtype)
474
+ halted_mask = jnp.zeros((B, T, 1), dtype=hidden.dtype)
475
+
476
+ # Warm-up calls: force Flax to trace all sub-modules before scan
477
+ _mu, _sigma = self.indexer(
478
+ jnp.zeros((B, T, D)), sigma_max_scale=sigma_max_scale
479
+ )
480
+ if self.config.use_2d_pool:
481
+ H = self.config.num_indexer_heads
482
+ h_per_dim = max(1, H // 2)
483
+ _ = self.pool(jnp.zeros((B,)), jnp.zeros((B,)), jnp.zeros((B,)))
484
+ else:
485
+ _ = self.pool(jnp.zeros((B,)), jnp.zeros((B,)))
486
+ _ = self.retrieval_integrator(
487
+ jnp.zeros((B, T, D + self.config.controller_hidden_dim))
488
+ )
489
+ _ = self.acc(state_hidden, state_hidden, 0, halt_prob, halted_mask)
490
+
491
+ use_2d = self.config.use_2d_pool
492
+ H = self.config.num_indexer_heads
493
+
494
+ def reasoning_step(carry, i):
495
+ s_hidden, h_prob, h_mask = carry
496
+ prev_s_hidden = s_hidden
497
+
498
+ mu, sigma = self.indexer(s_hidden, sigma_max_scale=sigma_max_scale)
499
+
500
+ all_retrieved = []
501
+ all_start_indices = []
502
+
503
+ if use_2d:
504
+ heads_per_dim = max(1, H // 2)
505
+ for h in range(heads_per_dim):
506
+ h_row = h
507
+ h_col = min(h + heads_per_dim, H - 1)
508
+ sigma_h = (sigma[:, h_row] + sigma[:, h_col]) / 2.0
509
+ retrieved_h, start_idx_h = self.pool(
510
+ mu[:, h_row], mu[:, h_col], sigma_h
511
+ )
512
+ all_retrieved.append(retrieved_h)
513
+ all_start_indices.append(start_idx_h)
514
+ else:
515
+ for h in range(H):
516
+ retrieved_h, start_idx_h = self.pool(mu[:, h], sigma[:, h])
517
+ all_retrieved.append(retrieved_h)
518
+ all_start_indices.append(start_idx_h)
519
+
520
+ retrieved = jnp.mean(jnp.stack(all_retrieved, axis=1), axis=1)
521
+ start_indices = jnp.concatenate(all_start_indices, axis=0)
522
+ mean_sigma_step = jnp.mean(sigma)
523
+
524
+ retrieved_expanded = jnp.expand_dims(retrieved, 1).repeat(T, axis=1)
525
+ combined = jnp.concatenate([s_hidden, retrieved_expanded], axis=-1)
526
+ integrated = self.retrieval_integrator(combined)
527
+
528
+ new_s_hidden, h_prob, new_h_mask = self.acc(
529
+ s_hidden, s_hidden + integrated, i, h_prob, h_mask,
530
+ )
531
+
532
+ update_mask = 1.0 - h_mask
533
+ s_hidden = update_mask * new_s_hidden + h_mask * prev_s_hidden
534
+
535
+ carry_dtype = prev_s_hidden.dtype
536
+ s_hidden = s_hidden.astype(carry_dtype)
537
+ h_prob = h_prob.astype(carry_dtype)
538
+ new_h_mask = new_h_mask.astype(carry_dtype)
539
+
540
+ return (s_hidden, h_prob, new_h_mask), (start_indices, mean_sigma_step)
541
+
542
+ _scan_fn = reasoning_step
543
+ if self.config.gradient_checkpointing:
544
+ _scan_fn = jax.checkpoint(reasoning_step)
545
+
546
+ init_carry = (state_hidden, halt_prob, halted_mask)
547
+ (state_hidden, halt_prob, halted_mask), (all_indices, sigma_per_loop) = (
548
+ jax.lax.scan(
549
+ _scan_fn,
550
+ init_carry,
551
+ jnp.arange(self.config.max_reasoning_loops),
552
+ )
553
+ )
554
+
555
+ all_indices = jnp.transpose(all_indices, (1, 0))
556
+ mean_sigma = jnp.mean(sigma_per_loop)
557
+
558
+ return state_hidden, all_indices, mean_sigma
559
+
560
+
561
+ # ═══════════════════════════════════════════════════════════════════════════════
562
+ # TRAIN STATE β€” Minimal, just enough to restore the checkpoint pytree
563
+ # ═══════════════════════════════════════════════════════════════════════════════
564
+
565
+ class TrainState(train_state.TrainState):
566
+ rng: Any
567
+ pool_m: jnp.ndarray
568
+ pool_v: jnp.ndarray
569
+ window_size: int = struct.field(pytree_node=False)
570
+ learning_rate_fn: Callable[[int], float] = struct.field(pytree_node=False)
571
+ sigma_anneal_fn: Callable[[int], float] = struct.field(pytree_node=False)
572
+
573
+
574
+ def _create_dummy_state(rng, config):
575
+ """Create a dummy TrainState with the correct pytree structure for checkpoint restore."""
576
+ model = DPSNR(config)
577
+ dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
578
+ variables = model.init(rng, dummy_input)
579
+ params = variables["params"]
580
+
581
+ flat_params = traverse_util.flatten_dict(params)
582
+ pool_key = ("pool", "params_storage")
583
+ pool_params = flat_params[pool_key]
584
+
585
+ dense_flat_params = {k: v for k, v in flat_params.items() if k != pool_key}
586
+ dense_params = traverse_util.unflatten_dict(dense_flat_params)
587
+
588
+ learning_rate_fn = lambda step: config.learning_rate
589
+
590
+ tx = optax.chain(
591
+ optax.clip_by_global_norm(1.0),
592
+ optax.adamw(learning_rate=learning_rate_fn),
593
+ )
594
+ opt_state = tx.init(dense_params)
595
+
596
+ pool_m = jnp.zeros_like(pool_params)
597
+ pool_v = jnp.zeros_like(pool_params)
598
+
599
+ sigma_anneal_fn = lambda step: 1.0
600
+
601
+ return TrainState(
602
+ step=jnp.array(0, dtype=jnp.int32),
603
+ apply_fn=model.apply,
604
+ params=params,
605
+ tx=tx,
606
+ opt_state=opt_state,
607
+ rng=rng,
608
+ pool_m=pool_m,
609
+ pool_v=pool_v,
610
+ window_size=config.max_k,
611
+ learning_rate_fn=learning_rate_fn,
612
+ sigma_anneal_fn=sigma_anneal_fn,
613
+ )
614
+
615
+
616
+ # ═══════════════════════════════════════════════════════════════════════════════
617
+ # INFERENCE CONTAINER
618
+ # ═══════════════════════════════════════════════════════════════════════════════
619
+ InferenceModel = namedtuple("InferenceModel", ["apply_fn", "params", "step"])
620
+
621
+
622
+ # ═══════════════════════════════════════════════════════════════════════════════
623
+ # TOKENIZER
624
+ # ═══════════════════════════════════════════════════════════════════════════════
625
+ def load_tokenizer():
626
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
627
+ if tokenizer.pad_token is None:
628
+ tokenizer.pad_token = tokenizer.eos_token
629
+ return tokenizer
630
+
631
+
632
+ # ═══════════════════════════════════════════════════════════════════════════════
633
+ # CHECKPOINT LOADING
634
+ # ═══════════════════════════════════════════════════════════════════════════════
635
+ def load_checkpoint():
636
+ """Load trained weights from checkpoint. Returns only params + apply_fn."""
637
+ rng = jax.random.PRNGKey(0)
638
+
639
+ cpu = jax.devices("cpu")[0]
640
+ print("[Init] Creating model skeleton on CPU...")
641
+ with jax.default_device(cpu):
642
+ dummy_state = _create_dummy_state(rng, CONFIG)
643
+ dummy_state = jax.device_put(dummy_state, cpu)
644
+
645
+ abs_ckpt = os.path.abspath(CHECKPOINT_DIR)
646
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
647
+ restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(dummy_state)
648
+
649
+ mgr = orbax.checkpoint.CheckpointManager(abs_ckpt, checkpointer)
650
+ latest_step = mgr.latest_step()
651
+
652
+ if latest_step is not None:
653
+ print(f"[Checkpoint] Restoring step {latest_step} from {abs_ckpt}")
654
+ state = mgr.restore(
655
+ latest_step,
656
+ items=dummy_state,
657
+ restore_kwargs={"restore_args": restore_args},
658
+ )
659
+ else:
660
+ target = None
661
+ for sub in ("default", ""):
662
+ p = os.path.join(abs_ckpt, sub) if sub else abs_ckpt
663
+ if os.path.exists(os.path.join(p, "_METADATA")):
664
+ target = p
665
+ break
666
+ if target is None:
667
+ raise FileNotFoundError(f"No valid checkpoint found in {abs_ckpt}")
668
+ print(f"[Checkpoint] Restoring directly from {target}")
669
+ state = checkpointer.restore(target, item=dummy_state, restore_args=restore_args)
670
+
671
+ step = int(state.step)
672
+ apply_fn = state.apply_fn
673
+ params = state.params
674
+
675
+ del dummy_state, state
676
+
677
+ if PLATFORM != "cpu":
678
+ print(f"[Device] Moving model params to {DEVICE}...")
679
+ params = jax.device_put(params, DEVICE)
680
+
681
+ print(f"[Checkpoint] Loaded at training step {step}")
682
+ return InferenceModel(apply_fn=apply_fn, params=params, step=step)
683
+
684
+
685
+ # ═══════════════════════════════════════════════════════════════════════════════
686
+ # JIT FORWARD PASS
687
+ # ═══════════════════════════════════════════════════════════════════════════════
688
+ @partial(jax.jit, static_argnums=(0,))
689
+ def _forward(apply_fn, params, input_ids):
690
+ logits, _ = apply_fn({"params": params}, input_ids, deterministic=True)
691
+ return logits
692
+
693
+
694
+ # ═══════════════════════════════════════════════════════════════════════════════
695
+ # TEXT GENERATION
696
+ # ═══════════════════════════════════════════════════════════════════════════════
697
+ def generate(
698
+ model: InferenceModel,
699
+ prompt: str,
700
+ tokenizer,
701
+ rng,
702
+ max_tokens: int = 100,
703
+ temperature: float = 0.7,
704
+ top_k: int = 40,
705
+ repetition_penalty: float = 1.2,
706
+ ):
707
+ """Autoregressive generation with fixed-size buffers (no XLA recompilation)."""
708
+ input_ids = tokenizer.encode(prompt, return_tensors="np")
709
+ eos_id = tokenizer.eos_token_id
710
+ prompt_len = input_ids.shape[1]
711
+ max_seq = CONFIG.max_seq_len
712
+
713
+ if prompt_len > max_seq:
714
+ input_ids = input_ids[:, :max_seq]
715
+ prompt_len = max_seq
716
+
717
+ buf = jnp.zeros((1, max_seq), dtype=jnp.int32)
718
+ buf = buf.at[:, :prompt_len].set(input_ids)
719
+
720
+ gen_buf = jnp.zeros((max_tokens,), dtype=jnp.int32)
721
+ n_gen = 0
722
+
723
+ for step in range(max_tokens):
724
+ pos = prompt_len + step
725
+ if pos >= max_seq:
726
+ break
727
+
728
+ logits = _forward(model.apply_fn, model.params, buf)
729
+ next_logits = logits[0, pos - 1, :]
730
+
731
+ # Repetition penalty
732
+ if n_gen > 0:
733
+ prev = gen_buf[:n_gen]
734
+ vocab = next_logits.shape[-1]
735
+ mask = jnp.zeros(vocab, dtype=jnp.bool_)
736
+ mask = mask.at[prev].set(True)
737
+ penalized = jnp.where(
738
+ next_logits > 0,
739
+ next_logits / repetition_penalty,
740
+ next_logits * repetition_penalty,
741
+ )
742
+ next_logits = jnp.where(mask, penalized, next_logits)
743
+
744
+ # Top-k filtering
745
+ k = min(top_k, next_logits.shape[-1])
746
+ vals, _ = jax.lax.top_k(next_logits, k=k)
747
+ threshold = vals[-1]
748
+ next_logits = jnp.where(next_logits < threshold, -1e10, next_logits)
749
+
750
+ # Temperature sampling
751
+ rng, key = jax.random.split(rng)
752
+ token = jax.random.categorical(key, next_logits / max(temperature, 1e-8))
753
+ token_int = int(token)
754
+
755
+ buf = buf.at[0, pos].set(token_int)
756
+ gen_buf = gen_buf.at[n_gen].set(token_int)
757
+ n_gen += 1
758
+
759
+ if token_int == eos_id:
760
+ break
761
+
762
+ return tokenizer.decode(
763
+ buf[0, prompt_len : prompt_len + n_gen].tolist(),
764
+ skip_special_tokens=True,
765
+ )
766
+
767
+
768
+ # ═══════════════════════════════════════════════════════════════════════════════
769
+ # MAIN
770
+ # ═══════════════════════════════════════════════════════════════════════════════
771
+ def main():
772
+ parser = argparse.ArgumentParser(description="DPSNR Large β€” Inference")
773
+ parser.add_argument("--prompt", type=str, default=None, help="Input prompt (omit for interactive mode)")
774
+ parser.add_argument("--max_tokens", type=int, default=100, help="Max tokens to generate (default: 100)")
775
+ parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature (default: 0.7)")
776
+ parser.add_argument("--top_k", type=int, default=40, help="Top-k sampling (default: 40)")
777
+ parser.add_argument("--penalty", type=float, default=1.2, help="Repetition penalty (default: 1.2)")
778
+ parser.add_argument("--checkpoint_dir", type=str, default=None, help="Override checkpoint path")
779
+ args = parser.parse_args()
780
+
781
+ if args.checkpoint_dir:
782
+ global CHECKPOINT_DIR
783
+ CHECKPOINT_DIR = args.checkpoint_dir
784
+
785
+ print("=" * 60)
786
+ print(" DPSNR Large β€” Loading Model")
787
+ print("=" * 60)
788
+ tokenizer = load_tokenizer()
789
+ model = load_checkpoint()
790
+
791
+ # Warmup: compile forward pass once
792
+ print("[Warmup] Compiling forward pass...")
793
+ t0 = time.time()
794
+ warmup_ids = jnp.zeros((1, CONFIG.max_seq_len), dtype=jnp.int32)
795
+ _ = _forward(model.apply_fn, model.params, warmup_ids)
796
+ jax.effects_barrier()
797
+ print(f"[Warmup] Done in {time.time() - t0:.1f}s")
798
+
799
+ rng = jax.random.PRNGKey(42)
800
+
801
+ def run(prompt: str):
802
+ nonlocal rng
803
+ rng, key = jax.random.split(rng)
804
+ t0 = time.time()
805
+ output = generate(
806
+ model, prompt, tokenizer, key,
807
+ max_tokens=args.max_tokens,
808
+ temperature=args.temp,
809
+ top_k=args.top_k,
810
+ repetition_penalty=args.penalty,
811
+ )
812
+ elapsed = time.time() - t0
813
+ print(f"\n{'─' * 50}")
814
+ print(f"Prompt: {prompt}")
815
+ print(f"Generated: {output}")
816
+ print(f"Time: {elapsed:.2f}s")
817
+ print(f"{'─' * 50}\n")
818
+
819
+ if args.prompt:
820
+ run(args.prompt)
821
+ else:
822
+ print("\n╔══════════════════════════════════════════════════╗")
823
+ print("β•‘ DPSNR Interactive Inference β•‘")
824
+ print("β•‘ Type 'exit' or 'quit' to stop β•‘")
825
+ print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•\n")
826
+ while True:
827
+ try:
828
+ user_input = input(">>> ")
829
+ if user_input.strip().lower() in ("exit", "quit"):
830
+ break
831
+ if not user_input.strip():
832
+ continue
833
+ run(user_input)
834
+ except (EOFError, KeyboardInterrupt):
835
+ print("\nExiting...")
836
+ break
837
+
838
+
839
+ if __name__ == "__main__":
840
+ main()