Spaces:
Sleeping
Sleeping
Add BMO RDT-MoE: Recurrent-Depth Transformer with Chain-of-Experts latent simmering engine
Browse filesPapers: RDT (2603.21676) + CoE (2506.18945) + PonderNet (2107.05407) + TRCΒ² (2602.22479)
- GRU-gated recurrence with identity bias (b_z=-2.0, 88% retain)
- Chain-of-Experts with per-iteration routers and shared experts
- PonderNet dynamic halting with geometric prior KL regularization
- Thalamic query modulation from limbic state (TRCΒ²)
- LayerScale initialization (1e-4) for stable deep recurrence
- Depth embeddings for loop-step awareness
- Full BMO integration: limbic modulation, entropy noise, PFC grit"
- project_bmo/bmo_rdt_moe.py +1200 -0
project_bmo/bmo_rdt_moe.py
ADDED
|
@@ -0,0 +1,1200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BMO Recurrent-Depth MoE β Latent Simmering Logic Engine
|
| 3 |
+
==========================================================
|
| 4 |
+
Implements a Recurrent Depth Transformer with Chain-of-Experts
|
| 5 |
+
for iterative latent reasoning in BMO's cognitive loop.
|
| 6 |
+
|
| 7 |
+
Paper foundations (every equation cited):
|
| 8 |
+
|
| 9 |
+
1. Depth-Recurrent Transformer (arxiv:2603.21676)
|
| 10 |
+
- Shared-weight transformer block applied T times in latent space
|
| 11 |
+
- Identity-biased GRU gating (Eq. 2-3, b_z = -2.0 β 88% retain)
|
| 12 |
+
- LayerScale vectors Ξ_init = 1e-4 (Eq. 6-8)
|
| 13 |
+
- Depth embeddings e_t added before each loop (Appendix B)
|
| 14 |
+
- Silent thinking: loss only on FINAL output, not intermediates
|
| 15 |
+
|
| 16 |
+
2. Chain-of-Experts (arxiv:2506.18945)
|
| 17 |
+
- Sequential expert chaining with per-iteration routers (Eq. 5-7)
|
| 18 |
+
- x^(t) = Ξ£ g_{t,i} Β· E_i(x^{t-1}) + x^{t-1) (inner residual)
|
| 19 |
+
- Iteration-specific gating: different TopK selection per step
|
| 20 |
+
- Shared experts always active + routed experts selected per step
|
| 21 |
+
|
| 22 |
+
3. PonderNet (arxiv:2107.05407)
|
| 23 |
+
- Dynamic halting: Ξ»_n per step, geometric distribution (Eq. 1-2)
|
| 24 |
+
- p_n = Ξ»_n Β· Ξ _{j<n}(1-Ξ»_j) β generalized geometric
|
| 25 |
+
- KL regularization against geometric prior (Eq. 3)
|
| 26 |
+
- Evaluation: sample Bernoulli(Ξ»_n) at each step, halt on 1
|
| 27 |
+
|
| 28 |
+
4. Coconut (arxiv:2412.06769)
|
| 29 |
+
- Hidden state fed directly as next input (latent mode)
|
| 30 |
+
- No token generation during thinking β pure state refinement
|
| 31 |
+
- Final layer norm keeps magnitudes reasonable
|
| 32 |
+
|
| 33 |
+
Architecture:
|
| 34 |
+
|
| 35 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
β BMO RDT-MoE Architecture β
|
| 37 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
|
| 38 |
+
β β
|
| 39 |
+
β INPUT: h β R^{LΓd} (from model hidden states) β
|
| 40 |
+
β β
|
| 41 |
+
β ββββ PRELUDE (2 unique layers) ββββββββββββββββββββββββββββ β
|
| 42 |
+
β β Pre-LayerNorm β MHSA β FFN (no weight sharing) β β
|
| 43 |
+
β β Converts raw embeddings to "thinking-ready" latents β β
|
| 44 |
+
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
|
| 45 |
+
β β β
|
| 46 |
+
β βΌ β
|
| 47 |
+
β ββββ RECURRENT LOOP (1-T iterations) ββββββββββββββββββββββ β
|
| 48 |
+
β β β β
|
| 49 |
+
β β For t = 1..T: β β
|
| 50 |
+
β β 1. Add depth embedding: Δ€ = H + e_t β β
|
| 51 |
+
β β 2. Thalamic modulation (from limbic state) β β
|
| 52 |
+
β β 3. Shared MHSA: H' = Δ€ + Ξ_attn β MHSA(LN(Δ€)) β β
|
| 53 |
+
β β 4. Chain-of-Experts FFN: β β
|
| 54 |
+
β β - Shared experts: always active β β
|
| 55 |
+
β β - Routed experts: TopK per iteration router β β
|
| 56 |
+
β β H'' = H' + Ξ_ffn β CoE(LN(H')) β β
|
| 57 |
+
β β 5. GRU gate (identity-biased): β β
|
| 58 |
+
β β z = Ο([HΜ;H^{t-1}]Β·W_z + b_z) [b_z = -2.0] β β
|
| 59 |
+
β β H^{t} = z β HΜ + (1-z) β H^{t-1} β β
|
| 60 |
+
β β 6. Halt head: Ξ»_t = Ο(MLP(mean(H^{t}))) β β
|
| 61 |
+
β β If Bernoulli(Ξ»_t) = 1 β break (eval only) β β
|
| 62 |
+
β β β β
|
| 63 |
+
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
|
| 64 |
+
β β β
|
| 65 |
+
β βΌ β
|
| 66 |
+
β ββββ CODA (2 unique layers) βββββββββββββββββββββββββββββββ β
|
| 67 |
+
β β Post-thinking refinement β output projection β β
|
| 68 |
+
β β Unique weights (not shared with prelude or loop) β β
|
| 69 |
+
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
|
| 70 |
+
β β
|
| 71 |
+
β OUTPUT: h_out β R^{LΓd} β
|
| 72 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
|
| 74 |
+
Integration points with BMO:
|
| 75 |
+
- Limbic state β thalamic query modulation inside the loop
|
| 76 |
+
- Entropy layer β noise on expert routing weights
|
| 77 |
+
- Probabilistic gating β halt decision is STOCHASTIC
|
| 78 |
+
- PFC grit β affects maximum loop count under stress
|
| 79 |
+
|
| 80 |
+
HONESTY: This is a neural network module with real gradient flow.
|
| 81 |
+
The "simmering" metaphor describes iterative state refinement.
|
| 82 |
+
It is NOT consciousness. It IS genuine multi-step computation
|
| 83 |
+
where the model can allocate more compute to harder problems.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
from __future__ import annotations
|
| 87 |
+
|
| 88 |
+
import math
|
| 89 |
+
import random
|
| 90 |
+
from typing import Optional, Tuple, Dict, List
|
| 91 |
+
|
| 92 |
+
import torch
|
| 93 |
+
import torch.nn as nn
|
| 94 |
+
import torch.nn.functional as F
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
# Β§1 β EXPERT MODULES (Fine-Grained MoE building blocks)
|
| 99 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
|
| 101 |
+
class Expert(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
A single FFN expert (SwiGLU variant matching Qwen3's architecture).
|
| 104 |
+
|
| 105 |
+
Each expert is a small FFN: x β W_gate(x) * silu(W_up(x)) β W_down
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, d_model: int, d_expert: int):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.gate_proj = nn.Linear(d_model, d_expert, bias=False)
|
| 111 |
+
self.up_proj = nn.Linear(d_model, d_expert, bias=False)
|
| 112 |
+
self.down_proj = nn.Linear(d_expert, d_model, bias=False)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class SharedExpert(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
Shared expert β always active during every iteration.
|
| 121 |
+
Provides stable, common-sense processing baseline.
|
| 122 |
+
|
| 123 |
+
From CoE paper (Appendix B): shared experts improve stability
|
| 124 |
+
across both CoE and MoE variants.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, d_model: int, d_expert: int):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.expert = Expert(d_model, d_expert)
|
| 130 |
+
|
| 131 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
return self.expert(x)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# Β§2 β CHAIN-OF-EXPERTS LAYER (per-iteration routing)
|
| 137 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
|
| 139 |
+
class ChainOfExpertsFFN(nn.Module):
|
| 140 |
+
"""
|
| 141 |
+
Chain-of-Experts FFN layer with iteration-specific routing.
|
| 142 |
+
|
| 143 |
+
From CoE (arxiv:2506.18945), Eq. 5-7:
|
| 144 |
+
x^(t) = Ξ£ g_{t,i} Β· E_i(x^{t-1}) + x^{t-1}
|
| 145 |
+
g_{t,i} = TopK(Softmax(e_{t,i}^T Β· x^{t-1}))
|
| 146 |
+
|
| 147 |
+
Architecture:
|
| 148 |
+
- n_shared shared experts (always active, every iteration)
|
| 149 |
+
- n_routed routed experts (TopK selected per iteration)
|
| 150 |
+
- Each iteration t has its own router weights e_{t,i}
|
| 151 |
+
- Inner residual connection preserves information
|
| 152 |
+
|
| 153 |
+
REAL: This is actual sparse expert routing with real gradient flow.
|
| 154 |
+
The router learns which experts to fire at each depth step.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
d_model: int,
|
| 160 |
+
d_expert: int,
|
| 161 |
+
n_shared: int = 2,
|
| 162 |
+
n_routed: int = 8,
|
| 163 |
+
top_k: int = 2,
|
| 164 |
+
max_iterations: int = 16,
|
| 165 |
+
entropy_noise: float = 0.05,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.d_model = d_model
|
| 169 |
+
self.n_shared = n_shared
|
| 170 |
+
self.n_routed = n_routed
|
| 171 |
+
self.top_k = top_k
|
| 172 |
+
self.max_iterations = max_iterations
|
| 173 |
+
self.entropy_noise = entropy_noise
|
| 174 |
+
|
| 175 |
+
# Shared experts (always active)
|
| 176 |
+
self.shared_experts = nn.ModuleList([
|
| 177 |
+
SharedExpert(d_model, d_expert) for _ in range(n_shared)
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
# Routed experts (selected via TopK)
|
| 181 |
+
self.routed_experts = nn.ModuleList([
|
| 182 |
+
Expert(d_model, d_expert) for _ in range(n_routed)
|
| 183 |
+
])
|
| 184 |
+
|
| 185 |
+
# Per-iteration routers (CoE Eq. 7: iteration-specific gating)
|
| 186 |
+
# Each router is a linear projection: d_model β n_routed
|
| 187 |
+
self.routers = nn.ModuleList([
|
| 188 |
+
nn.Linear(d_model, n_routed, bias=False)
|
| 189 |
+
for _ in range(max_iterations)
|
| 190 |
+
])
|
| 191 |
+
|
| 192 |
+
def forward(
|
| 193 |
+
self,
|
| 194 |
+
x: torch.Tensor,
|
| 195 |
+
iteration: int,
|
| 196 |
+
limbic_entropy_sigma: float = 0.0,
|
| 197 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 198 |
+
"""
|
| 199 |
+
Forward pass for one iteration of the CoE.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x: [batch, seq_len, d_model]
|
| 203 |
+
iteration: current loop iteration (0-indexed)
|
| 204 |
+
limbic_entropy_sigma: noise from limbic state (genome layer)
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
output: [batch, seq_len, d_model]
|
| 208 |
+
diagnostics: routing statistics
|
| 209 |
+
"""
|
| 210 |
+
B, L, D = x.shape
|
| 211 |
+
|
| 212 |
+
# ββ Shared experts (always on) ββ
|
| 213 |
+
shared_out = torch.zeros_like(x)
|
| 214 |
+
for expert in self.shared_experts:
|
| 215 |
+
shared_out = shared_out + expert(x)
|
| 216 |
+
# Average shared experts
|
| 217 |
+
shared_out = shared_out / max(1, self.n_shared)
|
| 218 |
+
|
| 219 |
+
# ββ Routed experts (TopK per iteration) ββ
|
| 220 |
+
# Get router for this iteration (clamped if iteration > max)
|
| 221 |
+
router_idx = min(iteration, self.max_iterations - 1)
|
| 222 |
+
router = self.routers[router_idx]
|
| 223 |
+
|
| 224 |
+
# Compute routing scores: [B, L, n_routed]
|
| 225 |
+
# Use mean-pooled token representation for routing decision
|
| 226 |
+
router_input = x.mean(dim=1) # [B, D]
|
| 227 |
+
logits = router(router_input) # [B, n_routed]
|
| 228 |
+
|
| 229 |
+
# Entropy noise injection (BMO genome layer integration)
|
| 230 |
+
# "No two routing decisions are identical"
|
| 231 |
+
total_noise = self.entropy_noise + limbic_entropy_sigma
|
| 232 |
+
if total_noise > 0 and self.training:
|
| 233 |
+
noise = torch.randn_like(logits) * total_noise
|
| 234 |
+
logits = logits + noise
|
| 235 |
+
|
| 236 |
+
# Softmax + TopK (CoE Eq. 7)
|
| 237 |
+
scores = F.softmax(logits, dim=-1) # [B, n_routed]
|
| 238 |
+
|
| 239 |
+
# TopK selection
|
| 240 |
+
topk_vals, topk_idx = torch.topk(scores, self.top_k, dim=-1) # [B, top_k]
|
| 241 |
+
# Re-normalize selected weights
|
| 242 |
+
topk_weights = topk_vals / (topk_vals.sum(dim=-1, keepdim=True) + 1e-8)
|
| 243 |
+
|
| 244 |
+
# Compute routed expert outputs
|
| 245 |
+
routed_out = torch.zeros_like(x) # [B, L, D]
|
| 246 |
+
for k in range(self.top_k):
|
| 247 |
+
expert_indices = topk_idx[:, k] # [B]
|
| 248 |
+
weights = topk_weights[:, k] # [B]
|
| 249 |
+
|
| 250 |
+
# For each batch element, route to the selected expert
|
| 251 |
+
for b in range(B):
|
| 252 |
+
eidx = expert_indices[b].item()
|
| 253 |
+
expert_output = self.routed_experts[eidx](x[b:b+1]) # [1, L, D]
|
| 254 |
+
routed_out[b:b+1] = routed_out[b:b+1] + weights[b] * expert_output
|
| 255 |
+
|
| 256 |
+
# Combine: shared + routed + inner residual (CoE Eq. 5)
|
| 257 |
+
output = shared_out + routed_out + x # inner residual
|
| 258 |
+
|
| 259 |
+
diagnostics = {
|
| 260 |
+
"iteration": iteration,
|
| 261 |
+
"router_idx": router_idx,
|
| 262 |
+
"top_experts": topk_idx.detach().cpu().tolist(),
|
| 263 |
+
"expert_weights": topk_weights.detach().cpu().tolist(),
|
| 264 |
+
"routing_entropy": -(scores * (scores + 1e-10).log()).sum(-1).mean().item(),
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
return output, diagnostics
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 271 |
+
# Β§3 β LAYERSCALE (from arxiv:2603.21676, Appendix A)
|
| 272 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 273 |
+
|
| 274 |
+
class LayerScale(nn.Module):
|
| 275 |
+
"""
|
| 276 |
+
Per-channel learnable scaling (Touvron et al., 2021).
|
| 277 |
+
Initialized to 1e-4 so early training acts as near-identity.
|
| 278 |
+
As training progresses, network selectively scales up.
|
| 279 |
+
|
| 280 |
+
From RDT paper: "LayerScale forces early-training dynamics to
|
| 281 |
+
act almost perfectly as identity mapping, protecting fragile
|
| 282 |
+
reasoning states from untrained layer noise."
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
def __init__(self, d_model: int, init_value: float = 1e-4):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.gamma = nn.Parameter(torch.full((d_model,), init_value))
|
| 288 |
+
|
| 289 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 290 |
+
return x * self.gamma
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 294 |
+
# Β§4 β IDENTITY-BIASED GRU GATE (from arxiv:2603.21676, Eq. 2-3)
|
| 295 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 296 |
+
|
| 297 |
+
class IdentityBiasedGate(nn.Module):
|
| 298 |
+
"""
|
| 299 |
+
GRU-style gate for blending new thought with old memory.
|
| 300 |
+
|
| 301 |
+
From RDT paper Eq. 2-3:
|
| 302 |
+
z = Ο([HΜ; H^{t-1}] Β· W_z + b_z)
|
| 303 |
+
H^{t} = z β HΜ + (1-z) β H^{t-1}
|
| 304 |
+
|
| 305 |
+
CRITICAL: b_z initialized to -2.0
|
| 306 |
+
β Ο(-2.0) β 0.12 β model retains 88% of previous state
|
| 307 |
+
β Guarantees stable signal propagation through 20+ steps
|
| 308 |
+
β Creates "gradient highway" preventing vanishing gradients
|
| 309 |
+
|
| 310 |
+
REAL: This is actual gradient-stabilizing recurrence math.
|
| 311 |
+
The 88% retention means the model defaults to "remembering"
|
| 312 |
+
and must actively learn when to "update" β biological analogy
|
| 313 |
+
is working memory gating in prefrontal cortex.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(self, d_model: int, bias_init: float = -2.0):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.gate = nn.Linear(2 * d_model, d_model)
|
| 319 |
+
# Initialize bias to -2.0 (identity-biased)
|
| 320 |
+
nn.init.constant_(self.gate.bias, bias_init)
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self,
|
| 324 |
+
h_new: torch.Tensor,
|
| 325 |
+
h_prev: torch.Tensor,
|
| 326 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 327 |
+
"""
|
| 328 |
+
Blend new candidate with previous state.
|
| 329 |
+
|
| 330 |
+
Returns: (blended_state, gate_values)
|
| 331 |
+
"""
|
| 332 |
+
# Concatenate along feature dimension
|
| 333 |
+
combined = torch.cat([h_new, h_prev], dim=-1) # [B, L, 2D]
|
| 334 |
+
z = torch.sigmoid(self.gate(combined)) # [B, L, D]
|
| 335 |
+
h_out = z * h_new + (1 - z) * h_prev
|
| 336 |
+
return h_out, z
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
+
# Β§5 β DYNAMIC HALTING HEAD (from PonderNet, arxiv:2107.05407)
|
| 341 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 342 |
+
|
| 343 |
+
class DynamicHaltingHead(nn.Module):
|
| 344 |
+
"""
|
| 345 |
+
PonderNet-style halting mechanism.
|
| 346 |
+
|
| 347 |
+
At each recurrence step, predicts Ξ»_n = P(halt at step n).
|
| 348 |
+
The model learns to halt early on easy inputs and think longer
|
| 349 |
+
on hard ones.
|
| 350 |
+
|
| 351 |
+
From PonderNet Eq. 1-2:
|
| 352 |
+
P(Ξ_n = 1 | Ξ_{n-1} = 0) = Ξ»_n
|
| 353 |
+
p_n = Ξ»_n Β· Ξ _{j=1}^{n-1} (1 - Ξ»_j) β generalized geometric
|
| 354 |
+
|
| 355 |
+
BMO integration: halting probability is ALSO modulated by:
|
| 356 |
+
- Limbic arousal (high arousal β think longer)
|
| 357 |
+
- PFC grit (high grit β don't halt under stress)
|
| 358 |
+
- Probabilistic gating (no fixed threshold)
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
def __init__(self, d_model: int, bias_init: float = 1.0):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.halt_proj = nn.Sequential(
|
| 364 |
+
nn.Linear(d_model, d_model // 4),
|
| 365 |
+
nn.GELU(),
|
| 366 |
+
nn.Linear(d_model // 4, 1),
|
| 367 |
+
)
|
| 368 |
+
# Positive bias β initially unlikely to halt (explore deeper)
|
| 369 |
+
nn.init.constant_(self.halt_proj[-1].bias, bias_init)
|
| 370 |
+
|
| 371 |
+
def forward(
|
| 372 |
+
self,
|
| 373 |
+
h: torch.Tensor,
|
| 374 |
+
limbic_arousal: float = 0.0,
|
| 375 |
+
pfc_grit: float = 0.5,
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
"""
|
| 378 |
+
Compute halting probability for current recurrence step.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
h: [B, L, D] hidden state
|
| 382 |
+
limbic_arousal: 0-1, high β less likely to halt (think more)
|
| 383 |
+
pfc_grit: 0-1, high β less likely to halt (persist)
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
lambda_n: [B] halting probability per batch element
|
| 387 |
+
"""
|
| 388 |
+
# Pool across sequence dimension
|
| 389 |
+
h_pooled = h.mean(dim=1) # [B, D]
|
| 390 |
+
raw_logit = self.halt_proj(h_pooled).squeeze(-1) # [B]
|
| 391 |
+
|
| 392 |
+
# Limbic modulation: arousal reduces halt probability
|
| 393 |
+
# (excited BMO thinks longer β like how arousal sharpens attention)
|
| 394 |
+
arousal_shift = -limbic_arousal * 0.5 # negative β less likely to halt
|
| 395 |
+
grit_shift = -pfc_grit * 0.3 # grit β persist
|
| 396 |
+
|
| 397 |
+
modulated_logit = raw_logit + arousal_shift + grit_shift
|
| 398 |
+
|
| 399 |
+
lambda_n = torch.sigmoid(modulated_logit) # [B]
|
| 400 |
+
return lambda_n
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 404 |
+
# Β§6 β DEPTH EMBEDDINGS (from arxiv:2603.21676, Appendix B)
|
| 405 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 406 |
+
|
| 407 |
+
class DepthEmbedding(nn.Module):
|
| 408 |
+
"""
|
| 409 |
+
Learned per-step embeddings so the shared-weight block
|
| 410 |
+
knows which iteration it's on.
|
| 411 |
+
|
| 412 |
+
From RDT Appendix B:
|
| 413 |
+
Δ€^(t) = H^{t-1} + e_t
|
| 414 |
+
|
| 415 |
+
Gives the model a "sense of time" within the recurrent loop.
|
| 416 |
+
Adds almost zero parameters (T Γ d_model) but is critical
|
| 417 |
+
for the model to distinguish early vs. late thinking steps.
|
| 418 |
+
|
| 419 |
+
REAL: This is just a learnable lookup table indexed by loop step.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
def __init__(self, max_steps: int, d_model: int):
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.embeddings = nn.Embedding(max_steps, d_model)
|
| 425 |
+
# Initialize small (don't disrupt early training)
|
| 426 |
+
nn.init.normal_(self.embeddings.weight, std=0.02)
|
| 427 |
+
|
| 428 |
+
def forward(self, step: int) -> torch.Tensor:
|
| 429 |
+
"""Returns embedding for step t: [d_model]"""
|
| 430 |
+
idx = torch.tensor(step, device=self.embeddings.weight.device)
|
| 431 |
+
return self.embeddings(idx)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 435 |
+
# Β§7 β THALAMIC QUERY MODULATION (from TRCΒ², arxiv:2602.22479)
|
| 436 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 437 |
+
|
| 438 |
+
class ThalamicModulation(nn.Module):
|
| 439 |
+
"""
|
| 440 |
+
Thalamus-inspired modulation of attention queries.
|
| 441 |
+
|
| 442 |
+
From TRCΒ² (arxiv:2602.22479), Eq. 4:
|
| 443 |
+
Q = U Β· W_Q + Z Β· W_Q_thal
|
| 444 |
+
|
| 445 |
+
Where Z is the "thalamic signal" β in BMO, this comes from
|
| 446 |
+
the limbic state (valence, arousal, dominant emotion).
|
| 447 |
+
|
| 448 |
+
The thalamus modulates QUERIES only (not keys/values).
|
| 449 |
+
This means it controls WHAT the model attends to,
|
| 450 |
+
not what information is available.
|
| 451 |
+
|
| 452 |
+
REAL: This is a learned linear projection from limbic state
|
| 453 |
+
to query-space bias. It genuinely changes attention patterns.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def __init__(self, d_model: int, limbic_dim: int = 8):
|
| 457 |
+
super().__init__()
|
| 458 |
+
# Project limbic state to thalamic signal
|
| 459 |
+
self.limbic_to_thalamic = nn.Linear(limbic_dim, d_model, bias=False)
|
| 460 |
+
# Thalamic query projection (separate from main W_Q)
|
| 461 |
+
self.W_Q_thal = nn.Linear(d_model, d_model, bias=False)
|
| 462 |
+
# Gating: how much thalamic signal influences queries
|
| 463 |
+
self.gate = nn.Parameter(torch.tensor(0.1)) # start small
|
| 464 |
+
|
| 465 |
+
def forward(
|
| 466 |
+
self,
|
| 467 |
+
Q: torch.Tensor,
|
| 468 |
+
limbic_vector: torch.Tensor,
|
| 469 |
+
) -> torch.Tensor:
|
| 470 |
+
"""
|
| 471 |
+
Modulate queries with thalamic signal.
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
Q: [B, L, D] query vectors
|
| 475 |
+
limbic_vector: [B, limbic_dim] limbic state
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
Q_modulated: [B, L, D]
|
| 479 |
+
"""
|
| 480 |
+
# Compute thalamic signal
|
| 481 |
+
Z = self.limbic_to_thalamic(limbic_vector) # [B, D]
|
| 482 |
+
Z = Z.unsqueeze(1).expand_as(Q) # [B, L, D]
|
| 483 |
+
|
| 484 |
+
# Thalamic query contribution
|
| 485 |
+
Q_thal = self.W_Q_thal(Z) # [B, L, D]
|
| 486 |
+
|
| 487 |
+
# Gated combination (TRCΒ² Eq. 4 adapted)
|
| 488 |
+
Q_modulated = Q + self.gate * Q_thal
|
| 489 |
+
|
| 490 |
+
return Q_modulated
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 494 |
+
# Β§8 β RECURRENT REASONING BLOCK (single shared-weight block)
|
| 495 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 496 |
+
|
| 497 |
+
class RecurrentReasoningBlock(nn.Module):
|
| 498 |
+
"""
|
| 499 |
+
The shared-weight transformer block applied at each recurrence step.
|
| 500 |
+
|
| 501 |
+
From RDT (arxiv:2603.21676), Appendix A:
|
| 502 |
+
Sub-layer 1: H' = Δ€ + Ξ_attn β MHSA(LN(Δ€))
|
| 503 |
+
Sub-layer 2: H'' = H' + Ξ_ffn β CoE(LN(H'))
|
| 504 |
+
|
| 505 |
+
Key design: weights are SHARED across all iterations.
|
| 506 |
+
Only the depth embedding, LayerNorm stats, and router weights
|
| 507 |
+
differ per iteration β this is "adaptive weight sharing"
|
| 508 |
+
per the user's spec.
|
| 509 |
+
|
| 510 |
+
We replace the standard FFN with Chain-of-Experts.
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
def __init__(
|
| 514 |
+
self,
|
| 515 |
+
d_model: int,
|
| 516 |
+
n_heads: int,
|
| 517 |
+
d_expert: int,
|
| 518 |
+
n_shared_experts: int = 2,
|
| 519 |
+
n_routed_experts: int = 8,
|
| 520 |
+
top_k_experts: int = 2,
|
| 521 |
+
max_iterations: int = 16,
|
| 522 |
+
dropout: float = 0.0,
|
| 523 |
+
limbic_dim: int = 8,
|
| 524 |
+
):
|
| 525 |
+
super().__init__()
|
| 526 |
+
self.d_model = d_model
|
| 527 |
+
self.n_heads = n_heads
|
| 528 |
+
assert d_model % n_heads == 0
|
| 529 |
+
self.d_k = d_model // n_heads
|
| 530 |
+
|
| 531 |
+
# ββ Sub-layer 1: Multi-Head Self-Attention ββ
|
| 532 |
+
self.ln_attn = nn.LayerNorm(d_model)
|
| 533 |
+
self.W_Q = nn.Linear(d_model, d_model, bias=False)
|
| 534 |
+
self.W_K = nn.Linear(d_model, d_model, bias=False)
|
| 535 |
+
self.W_V = nn.Linear(d_model, d_model, bias=False)
|
| 536 |
+
self.W_O = nn.Linear(d_model, d_model, bias=False)
|
| 537 |
+
self.attn_scale = LayerScale(d_model)
|
| 538 |
+
|
| 539 |
+
# ββ Thalamic modulation (on queries only) ββ
|
| 540 |
+
self.thalamic = ThalamicModulation(d_model, limbic_dim)
|
| 541 |
+
|
| 542 |
+
# ββ Sub-layer 2: Chain-of-Experts FFN ββ
|
| 543 |
+
self.ln_ffn = nn.LayerNorm(d_model)
|
| 544 |
+
self.coe = ChainOfExpertsFFN(
|
| 545 |
+
d_model=d_model,
|
| 546 |
+
d_expert=d_expert,
|
| 547 |
+
n_shared=n_shared_experts,
|
| 548 |
+
n_routed=n_routed_experts,
|
| 549 |
+
top_k=top_k_experts,
|
| 550 |
+
max_iterations=max_iterations,
|
| 551 |
+
)
|
| 552 |
+
self.ffn_scale = LayerScale(d_model)
|
| 553 |
+
|
| 554 |
+
self.dropout = nn.Dropout(dropout)
|
| 555 |
+
|
| 556 |
+
def forward(
|
| 557 |
+
self,
|
| 558 |
+
h: torch.Tensor,
|
| 559 |
+
iteration: int,
|
| 560 |
+
limbic_vector: Optional[torch.Tensor] = None,
|
| 561 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 562 |
+
limbic_entropy_sigma: float = 0.0,
|
| 563 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 564 |
+
"""
|
| 565 |
+
One step of the recurrent reasoning block.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
h: [B, L, D] hidden state
|
| 569 |
+
iteration: current recurrence step
|
| 570 |
+
limbic_vector: [B, limbic_dim] for thalamic modulation
|
| 571 |
+
attention_mask: [B, L] or [B, 1, L, L]
|
| 572 |
+
limbic_entropy_sigma: noise from genome layer
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
h_out: [B, L, D] processed hidden state (candidate HΜ)
|
| 576 |
+
diagnostics: attention and routing stats
|
| 577 |
+
"""
|
| 578 |
+
B, L, D = h.shape
|
| 579 |
+
diagnostics = {}
|
| 580 |
+
|
| 581 |
+
# ββ Sub-layer 1: MHSA with thalamic modulation ββ
|
| 582 |
+
h_norm = self.ln_attn(h)
|
| 583 |
+
|
| 584 |
+
Q = self.W_Q(h_norm) # [B, L, D]
|
| 585 |
+
K = self.W_K(h_norm)
|
| 586 |
+
V = self.W_V(h_norm)
|
| 587 |
+
|
| 588 |
+
# Thalamic modulation on queries (TRCΒ² integration)
|
| 589 |
+
if limbic_vector is not None:
|
| 590 |
+
Q = self.thalamic(Q, limbic_vector)
|
| 591 |
+
|
| 592 |
+
# Reshape for multi-head attention
|
| 593 |
+
Q = Q.view(B, L, self.n_heads, self.d_k).transpose(1, 2) # [B, H, L, D/H]
|
| 594 |
+
K = K.view(B, L, self.n_heads, self.d_k).transpose(1, 2)
|
| 595 |
+
V = V.view(B, L, self.n_heads, self.d_k).transpose(1, 2)
|
| 596 |
+
|
| 597 |
+
# Scaled dot-product attention
|
| 598 |
+
attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 599 |
+
|
| 600 |
+
if attention_mask is not None:
|
| 601 |
+
if attention_mask.dim() == 2:
|
| 602 |
+
# [B, L] β [B, 1, 1, L]
|
| 603 |
+
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 604 |
+
attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf'))
|
| 605 |
+
else:
|
| 606 |
+
attn_weights = attn_weights + attention_mask
|
| 607 |
+
|
| 608 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 609 |
+
attn_weights = self.dropout(attn_weights)
|
| 610 |
+
|
| 611 |
+
attn_output = torch.matmul(attn_weights, V) # [B, H, L, D/H]
|
| 612 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, D)
|
| 613 |
+
attn_output = self.W_O(attn_output)
|
| 614 |
+
|
| 615 |
+
# Residual + LayerScale (RDT Eq. 6)
|
| 616 |
+
h = h + self.attn_scale(attn_output)
|
| 617 |
+
|
| 618 |
+
# ββ Sub-layer 2: Chain-of-Experts FFN ββ
|
| 619 |
+
h_norm = self.ln_ffn(h)
|
| 620 |
+
coe_output, coe_diag = self.coe(h_norm, iteration, limbic_entropy_sigma)
|
| 621 |
+
|
| 622 |
+
# Residual + LayerScale (RDT Eq. 8)
|
| 623 |
+
h = h + self.ffn_scale(coe_output)
|
| 624 |
+
|
| 625 |
+
diagnostics["coe"] = coe_diag
|
| 626 |
+
diagnostics["attn_entropy"] = -(
|
| 627 |
+
attn_weights * (attn_weights + 1e-10).log()
|
| 628 |
+
).sum(-1).mean().item()
|
| 629 |
+
|
| 630 |
+
return h, diagnostics
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 634 |
+
# Β§9 β PRELUDE / CODA BLOCKS (unique weights, no sharing)
|
| 635 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 636 |
+
|
| 637 |
+
class PreludeBlock(nn.Module):
|
| 638 |
+
"""
|
| 639 |
+
Unique (non-shared) transformer block for prelude/coda.
|
| 640 |
+
Simpler than RecurrentReasoningBlock β standard FFN, no MoE.
|
| 641 |
+
Converts raw embeddings to "thinking-ready" latent representations.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float = 0.0):
|
| 645 |
+
super().__init__()
|
| 646 |
+
self.ln_attn = nn.LayerNorm(d_model)
|
| 647 |
+
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
|
| 648 |
+
self.ln_ffn = nn.LayerNorm(d_model)
|
| 649 |
+
self.ffn = nn.Sequential(
|
| 650 |
+
nn.Linear(d_model, d_ffn),
|
| 651 |
+
nn.GELU(),
|
| 652 |
+
nn.Linear(d_ffn, d_model),
|
| 653 |
+
)
|
| 654 |
+
self.dropout = nn.Dropout(dropout)
|
| 655 |
+
|
| 656 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 657 |
+
# Self-attention with pre-norm
|
| 658 |
+
x_norm = self.ln_attn(x)
|
| 659 |
+
key_padding_mask = None
|
| 660 |
+
if attention_mask is not None and attention_mask.dim() == 2:
|
| 661 |
+
key_padding_mask = (attention_mask == 0)
|
| 662 |
+
attn_out, _ = self.attn(x_norm, x_norm, x_norm, key_padding_mask=key_padding_mask)
|
| 663 |
+
x = x + self.dropout(attn_out)
|
| 664 |
+
|
| 665 |
+
# FFN with pre-norm
|
| 666 |
+
x_norm = self.ln_ffn(x)
|
| 667 |
+
x = x + self.dropout(self.ffn(x_norm))
|
| 668 |
+
|
| 669 |
+
return x
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 673 |
+
# Β§10 β THE FULL RDT-MoE ENGINE
|
| 674 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 675 |
+
|
| 676 |
+
class BMORDTMoE(nn.Module):
|
| 677 |
+
"""
|
| 678 |
+
BMO's Recurrent-Depth MoE β the complete latent simmering engine.
|
| 679 |
+
|
| 680 |
+
Architecture: 2 Prelude + 1 Looped MoE Block (1-T iterations) + 2 Coda
|
| 681 |
+
|
| 682 |
+
Combines:
|
| 683 |
+
- RDT's gated recurrence + silent thinking (arxiv:2603.21676)
|
| 684 |
+
- CoE's chain-of-experts with per-iteration routing (arxiv:2506.18945)
|
| 685 |
+
- PonderNet's dynamic halting (arxiv:2107.05407)
|
| 686 |
+
- TRCΒ² thalamic query modulation (arxiv:2602.22479)
|
| 687 |
+
- BMO's limbic modulation + probabilistic gating + entropy
|
| 688 |
+
|
| 689 |
+
At inference:
|
| 690 |
+
- Easy inputs (1+1) β halts at step 2-3 (fast)
|
| 691 |
+
- Hard reasoning β runs 12-16 steps (thoughtful)
|
| 692 |
+
- Limbic arousal β extends thinking (excited BMO explores more)
|
| 693 |
+
- PFC grit β persists through cognitive difficulty
|
| 694 |
+
|
| 695 |
+
HONESTY: This is a neural network that can dynamically allocate
|
| 696 |
+
compute. "Thinking deeper" = more iterations of the same block.
|
| 697 |
+
Not consciousness. Real computation with genuine adaptive depth.
|
| 698 |
+
"""
|
| 699 |
+
|
| 700 |
+
def __init__(
|
| 701 |
+
self,
|
| 702 |
+
d_model: int = 256,
|
| 703 |
+
n_heads: int = 8,
|
| 704 |
+
d_expert: int = 512,
|
| 705 |
+
d_ffn: int = 1024,
|
| 706 |
+
n_shared_experts: int = 2,
|
| 707 |
+
n_routed_experts: int = 8,
|
| 708 |
+
top_k_experts: int = 2,
|
| 709 |
+
max_thinking_steps: int = 16,
|
| 710 |
+
n_prelude: int = 2,
|
| 711 |
+
n_coda: int = 2,
|
| 712 |
+
limbic_dim: int = 8,
|
| 713 |
+
dropout: float = 0.0,
|
| 714 |
+
halt_prior_lambda: float = 0.2,
|
| 715 |
+
halt_kl_beta: float = 0.01,
|
| 716 |
+
):
|
| 717 |
+
super().__init__()
|
| 718 |
+
|
| 719 |
+
self.d_model = d_model
|
| 720 |
+
self.max_thinking_steps = max_thinking_steps
|
| 721 |
+
self.halt_prior_lambda = halt_prior_lambda
|
| 722 |
+
self.halt_kl_beta = halt_kl_beta
|
| 723 |
+
|
| 724 |
+
# ββ Prelude (unique layers) ββ
|
| 725 |
+
self.prelude = nn.ModuleList([
|
| 726 |
+
PreludeBlock(d_model, n_heads, d_ffn, dropout)
|
| 727 |
+
for _ in range(n_prelude)
|
| 728 |
+
])
|
| 729 |
+
|
| 730 |
+
# ββ Recurrent reasoning block (shared weights, applied T times) ββ
|
| 731 |
+
self.reasoning_block = RecurrentReasoningBlock(
|
| 732 |
+
d_model=d_model,
|
| 733 |
+
n_heads=n_heads,
|
| 734 |
+
d_expert=d_expert,
|
| 735 |
+
n_shared_experts=n_shared_experts,
|
| 736 |
+
n_routed_experts=n_routed_experts,
|
| 737 |
+
top_k_experts=top_k_experts,
|
| 738 |
+
max_iterations=max_thinking_steps,
|
| 739 |
+
dropout=dropout,
|
| 740 |
+
limbic_dim=limbic_dim,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# ββ Depth embeddings (one per thinking step) ββ
|
| 744 |
+
self.depth_embeddings = DepthEmbedding(max_thinking_steps, d_model)
|
| 745 |
+
|
| 746 |
+
# ββ Identity-biased GRU gate ββ
|
| 747 |
+
self.gate = IdentityBiasedGate(d_model, bias_init=-2.0)
|
| 748 |
+
|
| 749 |
+
# ββ Dynamic halting head ββ
|
| 750 |
+
self.halt_head = DynamicHaltingHead(d_model, bias_init=1.0)
|
| 751 |
+
|
| 752 |
+
# ββ Coda (unique layers) ββ
|
| 753 |
+
self.coda = nn.ModuleList([
|
| 754 |
+
PreludeBlock(d_model, n_heads, d_ffn, dropout)
|
| 755 |
+
for _ in range(n_coda)
|
| 756 |
+
])
|
| 757 |
+
|
| 758 |
+
# ββ Final layer norm ββ
|
| 759 |
+
self.final_norm = nn.LayerNorm(d_model)
|
| 760 |
+
|
| 761 |
+
# ββ Input projection (maps from external model dim to our dim) ββ
|
| 762 |
+
self.input_proj = None # set dynamically if needed
|
| 763 |
+
self.output_proj = None # set dynamically if needed
|
| 764 |
+
|
| 765 |
+
def set_projection(self, external_dim: int):
|
| 766 |
+
"""Set up projections if external model dim != our d_model."""
|
| 767 |
+
if external_dim != self.d_model:
|
| 768 |
+
self.input_proj = nn.Linear(external_dim, self.d_model)
|
| 769 |
+
self.output_proj = nn.Linear(self.d_model, external_dim)
|
| 770 |
+
|
| 771 |
+
def forward(
|
| 772 |
+
self,
|
| 773 |
+
h: torch.Tensor,
|
| 774 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 775 |
+
limbic_vector: Optional[torch.Tensor] = None,
|
| 776 |
+
limbic_arousal: float = 0.0,
|
| 777 |
+
pfc_grit: float = 0.5,
|
| 778 |
+
limbic_entropy_sigma: float = 0.0,
|
| 779 |
+
force_steps: Optional[int] = None,
|
| 780 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 781 |
+
"""
|
| 782 |
+
Full forward pass of the RDT-MoE engine.
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
h: [B, L, D_ext] hidden states from base model
|
| 786 |
+
attention_mask: [B, L] attention mask
|
| 787 |
+
limbic_vector: [B, limbic_dim] limbic state for thalamic modulation
|
| 788 |
+
limbic_arousal: 0-1, affects halting (high β think longer)
|
| 789 |
+
pfc_grit: 0-1, affects halting (high β persist)
|
| 790 |
+
limbic_entropy_sigma: noise injection from genome layer
|
| 791 |
+
force_steps: if set, force exactly this many thinking steps
|
| 792 |
+
(useful for training with silent thinking objective)
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
h_out: [B, L, D_ext] refined hidden states
|
| 796 |
+
report: detailed diagnostics of the thinking process
|
| 797 |
+
"""
|
| 798 |
+
report = {
|
| 799 |
+
"thinking_steps": 0,
|
| 800 |
+
"halt_probabilities": [],
|
| 801 |
+
"gate_retain_ratios": [],
|
| 802 |
+
"per_step_diagnostics": [],
|
| 803 |
+
"halted_early": False,
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
# ββ Projection ββ
|
| 807 |
+
if self.input_proj is not None:
|
| 808 |
+
h = self.input_proj(h)
|
| 809 |
+
|
| 810 |
+
# ββ Prelude ββ
|
| 811 |
+
for block in self.prelude:
|
| 812 |
+
h = block(h, attention_mask)
|
| 813 |
+
|
| 814 |
+
# ββ Recurrent loop ββ
|
| 815 |
+
# For training: sample T from uniform range (silent thinking)
|
| 816 |
+
if force_steps is not None:
|
| 817 |
+
T = force_steps
|
| 818 |
+
elif self.training:
|
| 819 |
+
# Random depth per training step (RDT: T ~ U(1, max))
|
| 820 |
+
T = random.randint(1, self.max_thinking_steps)
|
| 821 |
+
else:
|
| 822 |
+
T = self.max_thinking_steps # max, let halting decide
|
| 823 |
+
|
| 824 |
+
# Initialize: keep reference to initial state for skip connection
|
| 825 |
+
h_prev = h.clone()
|
| 826 |
+
cumulative_halt_prob = torch.zeros(h.shape[0], device=h.device)
|
| 827 |
+
|
| 828 |
+
# Accumulators for PonderNet loss
|
| 829 |
+
all_halt_probs = []
|
| 830 |
+
|
| 831 |
+
for t in range(T):
|
| 832 |
+
# Step 1: Add depth embedding (RDT Appendix B)
|
| 833 |
+
depth_emb = self.depth_embeddings(t) # [D]
|
| 834 |
+
h_input = h_prev + depth_emb.unsqueeze(0).unsqueeze(0)
|
| 835 |
+
|
| 836 |
+
# Step 2-4: Reasoning block (MHSA + CoE)
|
| 837 |
+
h_candidate, step_diag = self.reasoning_block(
|
| 838 |
+
h_input, t, limbic_vector, attention_mask, limbic_entropy_sigma
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
# Step 5: GRU gate (identity-biased, RDT Eq. 2-3)
|
| 842 |
+
h_new, gate_values = self.gate(h_candidate, h_prev)
|
| 843 |
+
|
| 844 |
+
# Step 6: Halting probability (PonderNet)
|
| 845 |
+
lambda_n = self.halt_head(h_new, limbic_arousal, pfc_grit) # [B]
|
| 846 |
+
all_halt_probs.append(lambda_n)
|
| 847 |
+
|
| 848 |
+
# Track
|
| 849 |
+
retain_ratio = (1 - gate_values).mean().item()
|
| 850 |
+
report["gate_retain_ratios"].append(retain_ratio)
|
| 851 |
+
report["halt_probabilities"].append(lambda_n.mean().item())
|
| 852 |
+
report["per_step_diagnostics"].append(step_diag)
|
| 853 |
+
|
| 854 |
+
h_prev = h_new
|
| 855 |
+
report["thinking_steps"] = t + 1
|
| 856 |
+
|
| 857 |
+
# Dynamic halting (eval only β training uses silent thinking)
|
| 858 |
+
if not self.training and force_steps is None:
|
| 859 |
+
# Sample halt decision (PonderNet: Bernoulli(Ξ»_n))
|
| 860 |
+
halt_samples = torch.bernoulli(lambda_n)
|
| 861 |
+
if halt_samples.all():
|
| 862 |
+
report["halted_early"] = True
|
| 863 |
+
break
|
| 864 |
+
|
| 865 |
+
# ββ Compute PonderNet regularization loss ββ
|
| 866 |
+
if self.training and all_halt_probs:
|
| 867 |
+
report["ponder_loss"] = self._compute_ponder_loss(all_halt_probs)
|
| 868 |
+
|
| 869 |
+
h_out = h_prev
|
| 870 |
+
|
| 871 |
+
# ββ Coda ββ
|
| 872 |
+
for block in self.coda:
|
| 873 |
+
h_out = block(h_out, attention_mask)
|
| 874 |
+
|
| 875 |
+
# ββ Final norm ββ
|
| 876 |
+
h_out = self.final_norm(h_out)
|
| 877 |
+
|
| 878 |
+
# ββ Output projection ββ
|
| 879 |
+
if self.output_proj is not None:
|
| 880 |
+
h_out = self.output_proj(h_out)
|
| 881 |
+
|
| 882 |
+
return h_out, report
|
| 883 |
+
|
| 884 |
+
def _compute_ponder_loss(
|
| 885 |
+
self,
|
| 886 |
+
halt_probs: List[torch.Tensor],
|
| 887 |
+
) -> torch.Tensor:
|
| 888 |
+
"""
|
| 889 |
+
PonderNet KL regularization loss (arxiv:2107.05407, Eq. 3).
|
| 890 |
+
|
| 891 |
+
L_reg = Ξ² Β· KL(p_n || p_G(Ξ»_p))
|
| 892 |
+
|
| 893 |
+
Where p_G is geometric distribution with parameter Ξ»_p.
|
| 894 |
+
Encourages the model to:
|
| 895 |
+
1. Not always halt at the same step
|
| 896 |
+
2. Give non-zero probability to all possible step counts
|
| 897 |
+
3. Bias toward expected prior number of steps 1/Ξ»_p
|
| 898 |
+
"""
|
| 899 |
+
N = len(halt_probs)
|
| 900 |
+
device = halt_probs[0].device
|
| 901 |
+
|
| 902 |
+
# Compute p_n (generalized geometric) β PonderNet Eq. 2
|
| 903 |
+
# p_n = Ξ»_n Β· Ξ _{j<n} (1 - Ξ»_j)
|
| 904 |
+
p_n_list = []
|
| 905 |
+
running_continue = torch.ones_like(halt_probs[0])
|
| 906 |
+
|
| 907 |
+
for n in range(N):
|
| 908 |
+
p_n = halt_probs[n] * running_continue
|
| 909 |
+
p_n_list.append(p_n)
|
| 910 |
+
running_continue = running_continue * (1 - halt_probs[n])
|
| 911 |
+
|
| 912 |
+
# Assign remaining probability to last step
|
| 913 |
+
p_n_list[-1] = p_n_list[-1] + running_continue
|
| 914 |
+
|
| 915 |
+
p_dist = torch.stack(p_n_list, dim=-1) # [B, N]
|
| 916 |
+
p_dist = p_dist / (p_dist.sum(dim=-1, keepdim=True) + 1e-8)
|
| 917 |
+
|
| 918 |
+
# Geometric prior
|
| 919 |
+
prior = torch.zeros(N, device=device)
|
| 920 |
+
for n in range(N):
|
| 921 |
+
prior[n] = self.halt_prior_lambda * (
|
| 922 |
+
(1 - self.halt_prior_lambda) ** n
|
| 923 |
+
)
|
| 924 |
+
prior = prior / (prior.sum() + 1e-8)
|
| 925 |
+
prior = prior.unsqueeze(0).expand_as(p_dist)
|
| 926 |
+
|
| 927 |
+
# KL divergence
|
| 928 |
+
kl = F.kl_div(
|
| 929 |
+
(p_dist + 1e-10).log(),
|
| 930 |
+
prior,
|
| 931 |
+
reduction='batchmean',
|
| 932 |
+
log_target=False,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
return self.halt_kl_beta * kl
|
| 936 |
+
|
| 937 |
+
def get_parameter_summary(self) -> dict:
|
| 938 |
+
"""Report parameter counts by component."""
|
| 939 |
+
def count(module):
|
| 940 |
+
return sum(p.numel() for p in module.parameters())
|
| 941 |
+
|
| 942 |
+
total = count(self)
|
| 943 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 944 |
+
|
| 945 |
+
summary = {
|
| 946 |
+
"total_params": total,
|
| 947 |
+
"trainable_params": trainable,
|
| 948 |
+
"prelude": sum(count(b) for b in self.prelude),
|
| 949 |
+
"reasoning_block": count(self.reasoning_block),
|
| 950 |
+
"depth_embeddings": count(self.depth_embeddings),
|
| 951 |
+
"gate": count(self.gate),
|
| 952 |
+
"halt_head": count(self.halt_head),
|
| 953 |
+
"coda": sum(count(b) for b in self.coda),
|
| 954 |
+
"final_norm": count(self.final_norm),
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
if self.input_proj:
|
| 958 |
+
summary["input_proj"] = count(self.input_proj)
|
| 959 |
+
if self.output_proj:
|
| 960 |
+
summary["output_proj"] = count(self.output_proj)
|
| 961 |
+
|
| 962 |
+
# The key metric: how much is SHARED (looped) vs UNIQUE
|
| 963 |
+
shared = summary["reasoning_block"]
|
| 964 |
+
unique = total - shared
|
| 965 |
+
summary["shared_pct"] = f"{100 * shared / total:.1f}%"
|
| 966 |
+
summary["unique_pct"] = f"{100 * unique / total:.1f}%"
|
| 967 |
+
summary["effective_depth"] = f"{self.max_thinking_steps} loops Γ 1 shared block = {self.max_thinking_steps}Γ effective depth"
|
| 968 |
+
|
| 969 |
+
return summary
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 973 |
+
# Β§11 β INTEGRATION HELPERS (bridge to BMO's existing systems)
|
| 974 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 975 |
+
|
| 976 |
+
def limbic_state_to_vector(limbic_state: dict, device: torch.device = None) -> torch.Tensor:
|
| 977 |
+
"""
|
| 978 |
+
Convert BMO's limbic state dict to a tensor for thalamic modulation.
|
| 979 |
+
|
| 980 |
+
Limbic state has: valence, arousal, dominant, fear, seeking, care, panic
|
| 981 |
+
We encode this as a fixed-size vector.
|
| 982 |
+
"""
|
| 983 |
+
if device is None:
|
| 984 |
+
device = torch.device("cpu")
|
| 985 |
+
|
| 986 |
+
vec = torch.tensor([
|
| 987 |
+
limbic_state.get("valence", 0.0),
|
| 988 |
+
limbic_state.get("arousal", 0.5),
|
| 989 |
+
limbic_state.get("fear", 0.0),
|
| 990 |
+
limbic_state.get("seeking", 0.2),
|
| 991 |
+
limbic_state.get("care", 0.0),
|
| 992 |
+
limbic_state.get("panic", 0.0),
|
| 993 |
+
# Additional slots for future modalities
|
| 994 |
+
limbic_state.get("surprise", 0.0),
|
| 995 |
+
limbic_state.get("stress", 0.0),
|
| 996 |
+
], dtype=torch.float32, device=device)
|
| 997 |
+
|
| 998 |
+
return vec
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
def create_bmo_rdt_moe(
|
| 1002 |
+
d_model: int = 256,
|
| 1003 |
+
config: str = "tiny",
|
| 1004 |
+
) -> BMORDTMoE:
|
| 1005 |
+
"""
|
| 1006 |
+
Factory function for creating BMO RDT-MoE with preset configs.
|
| 1007 |
+
|
| 1008 |
+
Configs:
|
| 1009 |
+
tiny: d=256, 4 heads, 4 experts β for testing/sandbox
|
| 1010 |
+
small: d=512, 8 heads, 8 experts β for Qwen3-1.7B
|
| 1011 |
+
medium: d=1024, 8 heads, 8 experts β for Qwen3-4B
|
| 1012 |
+
large: d=2048, 16 heads, 16 experts β for Qwen3-8B
|
| 1013 |
+
|
| 1014 |
+
NOTE: For integration with a pretrained model, call
|
| 1015 |
+
set_projection(model_dim) after creation if d_model differs.
|
| 1016 |
+
"""
|
| 1017 |
+
configs = {
|
| 1018 |
+
"tiny": dict(
|
| 1019 |
+
d_model=256, n_heads=4, d_expert=512, d_ffn=512,
|
| 1020 |
+
n_shared_experts=1, n_routed_experts=4, top_k_experts=2,
|
| 1021 |
+
max_thinking_steps=8,
|
| 1022 |
+
),
|
| 1023 |
+
"small": dict(
|
| 1024 |
+
d_model=512, n_heads=8, d_expert=1024, d_ffn=1024,
|
| 1025 |
+
n_shared_experts=2, n_routed_experts=8, top_k_experts=2,
|
| 1026 |
+
max_thinking_steps=12,
|
| 1027 |
+
),
|
| 1028 |
+
"medium": dict(
|
| 1029 |
+
d_model=1024, n_heads=8, d_expert=2048, d_ffn=2048,
|
| 1030 |
+
n_shared_experts=2, n_routed_experts=8, top_k_experts=2,
|
| 1031 |
+
max_thinking_steps=16,
|
| 1032 |
+
),
|
| 1033 |
+
"large": dict(
|
| 1034 |
+
d_model=2048, n_heads=16, d_expert=4096, d_ffn=4096,
|
| 1035 |
+
n_shared_experts=2, n_routed_experts=16, top_k_experts=2,
|
| 1036 |
+
max_thinking_steps=16,
|
| 1037 |
+
),
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
if config not in configs:
|
| 1041 |
+
raise ValueError(f"Unknown config '{config}'. Available: {list(configs.keys())}")
|
| 1042 |
+
|
| 1043 |
+
params = configs[config]
|
| 1044 |
+
if d_model != params["d_model"]:
|
| 1045 |
+
params["d_model"] = d_model
|
| 1046 |
+
|
| 1047 |
+
return BMORDTMoE(**params)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1051 |
+
# Β§12 β SELF-TEST (run this file directly to verify)
|
| 1052 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1053 |
+
|
| 1054 |
+
def _self_test():
|
| 1055 |
+
"""Comprehensive self-test of the RDT-MoE engine."""
|
| 1056 |
+
import sys
|
| 1057 |
+
print("=" * 70)
|
| 1058 |
+
print(" BMO RDT-MoE Self-Test")
|
| 1059 |
+
print(" Papers: RDT (2603.21676) + CoE (2506.18945) +")
|
| 1060 |
+
print(" PonderNet (2107.05407) + TRCΒ² (2602.22479)")
|
| 1061 |
+
print("=" * 70)
|
| 1062 |
+
|
| 1063 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1064 |
+
print(f"\nDevice: {device}")
|
| 1065 |
+
|
| 1066 |
+
# ββ Create tiny model for testing ββ
|
| 1067 |
+
model = create_bmo_rdt_moe(d_model=256, config="tiny").to(device)
|
| 1068 |
+
|
| 1069 |
+
# Parameter summary
|
| 1070 |
+
summary = model.get_parameter_summary()
|
| 1071 |
+
print(f"\nπ Parameter Summary:")
|
| 1072 |
+
for k, v in summary.items():
|
| 1073 |
+
if isinstance(v, int):
|
| 1074 |
+
print(f" {k}: {v:,}")
|
| 1075 |
+
else:
|
| 1076 |
+
print(f" {k}: {v}")
|
| 1077 |
+
|
| 1078 |
+
# ββ Test 1: Forward pass (training mode) ββ
|
| 1079 |
+
print(f"\nπ§ͺ Test 1: Training forward pass")
|
| 1080 |
+
model.train()
|
| 1081 |
+
B, L, D = 2, 16, 256
|
| 1082 |
+
h = torch.randn(B, L, D, device=device, requires_grad=True)
|
| 1083 |
+
limbic = limbic_state_to_vector(
|
| 1084 |
+
{"valence": 0.3, "arousal": 0.7, "seeking": 0.8, "fear": 0.1}
|
| 1085 |
+
).to(device).unsqueeze(0).expand(B, -1)
|
| 1086 |
+
|
| 1087 |
+
h_out, report = model(
|
| 1088 |
+
h, limbic_vector=limbic,
|
| 1089 |
+
limbic_arousal=0.7, pfc_grit=0.6,
|
| 1090 |
+
limbic_entropy_sigma=0.03,
|
| 1091 |
+
)
|
| 1092 |
+
|
| 1093 |
+
print(f" Input: {h.shape}")
|
| 1094 |
+
print(f" Output: {h_out.shape}")
|
| 1095 |
+
print(f" Thinking steps: {report['thinking_steps']}")
|
| 1096 |
+
print(f" Gate retain ratios: {[f'{r:.3f}' for r in report['gate_retain_ratios']]}")
|
| 1097 |
+
print(f" Halt probs: {[f'{p:.3f}' for p in report['halt_probabilities']]}")
|
| 1098 |
+
if 'ponder_loss' in report:
|
| 1099 |
+
print(f" PonderNet loss: {report['ponder_loss']:.6f}")
|
| 1100 |
+
print(f" β Training forward OK")
|
| 1101 |
+
|
| 1102 |
+
# ββ Test 2: Backward pass (gradient flow) ββ
|
| 1103 |
+
print(f"\nπ§ͺ Test 2: Gradient flow")
|
| 1104 |
+
loss = h_out.sum() + report.get('ponder_loss', 0)
|
| 1105 |
+
loss.backward()
|
| 1106 |
+
|
| 1107 |
+
n_grad = sum(1 for p in model.parameters() if p.grad is not None and p.grad.abs().sum() > 0)
|
| 1108 |
+
n_total = sum(1 for p in model.parameters() if p.requires_grad)
|
| 1109 |
+
print(f" {n_grad}/{n_total} parameters have non-zero gradients")
|
| 1110 |
+
assert n_grad > 0, "No gradients!"
|
| 1111 |
+
print(f" β Gradients flow correctly")
|
| 1112 |
+
|
| 1113 |
+
# ββ Test 3: Eval mode (dynamic halting) ββ
|
| 1114 |
+
print(f"\nπ§ͺ Test 3: Eval mode (dynamic halting)")
|
| 1115 |
+
model.eval()
|
| 1116 |
+
with torch.no_grad():
|
| 1117 |
+
h_eval = torch.randn(1, 8, 256, device=device)
|
| 1118 |
+
h_out_eval, report_eval = model(
|
| 1119 |
+
h_eval, limbic_arousal=0.2, pfc_grit=0.3,
|
| 1120 |
+
)
|
| 1121 |
+
print(f" Thinking steps: {report_eval['thinking_steps']}")
|
| 1122 |
+
print(f" Halted early: {report_eval['halted_early']}")
|
| 1123 |
+
print(f" Halt probs: {[f'{p:.3f}' for p in report_eval['halt_probabilities']]}")
|
| 1124 |
+
print(f" β Dynamic halting OK")
|
| 1125 |
+
|
| 1126 |
+
# ββ Test 4: Stochastic expert routing ββ
|
| 1127 |
+
print(f"\nπ§ͺ Test 4: Expert routing stochasticity")
|
| 1128 |
+
model.train()
|
| 1129 |
+
routes = []
|
| 1130 |
+
for _ in range(5):
|
| 1131 |
+
with torch.no_grad():
|
| 1132 |
+
_, r = model(h[:1], force_steps=3, limbic_entropy_sigma=0.1)
|
| 1133 |
+
route = [d["coe"]["top_experts"] for d in r["per_step_diagnostics"]]
|
| 1134 |
+
routes.append(route)
|
| 1135 |
+
|
| 1136 |
+
# Check that not all routing decisions are identical
|
| 1137 |
+
all_same = all(r == routes[0] for r in routes)
|
| 1138 |
+
print(f" 5 routing runs: {'IDENTICAL β' if all_same else 'VARIED β'}")
|
| 1139 |
+
for i, r in enumerate(routes):
|
| 1140 |
+
print(f" Run {i+1}: {r}")
|
| 1141 |
+
if not all_same:
|
| 1142 |
+
print(f" β Expert routing is stochastic")
|
| 1143 |
+
else:
|
| 1144 |
+
print(f" β Routes identical (noise may be too small)")
|
| 1145 |
+
|
| 1146 |
+
# ββ Test 5: Limbic arousal affects thinking depth ββ
|
| 1147 |
+
print(f"\nπ§ͺ Test 5: Limbic arousal affects halting")
|
| 1148 |
+
model.eval()
|
| 1149 |
+
steps_low = []
|
| 1150 |
+
steps_high = []
|
| 1151 |
+
for _ in range(10):
|
| 1152 |
+
with torch.no_grad():
|
| 1153 |
+
_, r_low = model(h_eval, limbic_arousal=0.1, pfc_grit=0.2)
|
| 1154 |
+
_, r_high = model(h_eval, limbic_arousal=0.9, pfc_grit=0.9)
|
| 1155 |
+
steps_low.append(r_low["thinking_steps"])
|
| 1156 |
+
steps_high.append(r_high["thinking_steps"])
|
| 1157 |
+
|
| 1158 |
+
avg_low = sum(steps_low) / len(steps_low)
|
| 1159 |
+
avg_high = sum(steps_high) / len(steps_high)
|
| 1160 |
+
print(f" Low arousal/grit β avg {avg_low:.1f} steps {steps_low}")
|
| 1161 |
+
print(f" High arousal/grit β avg {avg_high:.1f} steps {steps_high}")
|
| 1162 |
+
if avg_high >= avg_low:
|
| 1163 |
+
print(f" β Higher arousal/grit β deeper thinking")
|
| 1164 |
+
else:
|
| 1165 |
+
print(f" β Effect not yet learned (expected β needs training)")
|
| 1166 |
+
|
| 1167 |
+
# ββ Test 6: Identity-biased gate ββ
|
| 1168 |
+
print(f"\nπ§ͺ Test 6: Identity-biased gate (88% retention at init)")
|
| 1169 |
+
gate = IdentityBiasedGate(256).to(device)
|
| 1170 |
+
h1 = torch.randn(1, 8, 256, device=device)
|
| 1171 |
+
h2 = torch.randn(1, 8, 256, device=device)
|
| 1172 |
+
h_blended, z = gate(h1, h2)
|
| 1173 |
+
mean_z = z.mean().item()
|
| 1174 |
+
retain = 1 - mean_z
|
| 1175 |
+
print(f" Mean gate value z: {mean_z:.4f}")
|
| 1176 |
+
print(f" Retention ratio: {retain:.4f} (expected ~0.88)")
|
| 1177 |
+
assert 0.7 < retain < 0.98, f"Gate retention {retain} outside expected range"
|
| 1178 |
+
print(f" β Identity bias working (Ο(-2.0) β 0.12)")
|
| 1179 |
+
|
| 1180 |
+
# ββ Test 7: External dimension projection ββ
|
| 1181 |
+
print(f"\nπ§ͺ Test 7: External model dimension projection")
|
| 1182 |
+
model_ext = create_bmo_rdt_moe(d_model=256, config="tiny").to(device)
|
| 1183 |
+
model_ext.set_projection(external_dim=4096)
|
| 1184 |
+
model_ext = model_ext.to(device)
|
| 1185 |
+
h_ext = torch.randn(1, 8, 4096, device=device)
|
| 1186 |
+
model_ext.eval()
|
| 1187 |
+
with torch.no_grad():
|
| 1188 |
+
h_ext_out, _ = model_ext(h_ext)
|
| 1189 |
+
print(f" Input: {h_ext.shape} β Output: {h_ext_out.shape}")
|
| 1190 |
+
assert h_ext_out.shape == h_ext.shape
|
| 1191 |
+
print(f" β Projection works (4096 β 256 β loop β 256 β 4096)")
|
| 1192 |
+
|
| 1193 |
+
print(f"\n{'='*70}")
|
| 1194 |
+
print(f" ALL TESTS PASSED β")
|
| 1195 |
+
print(f" BMO's thinking engine is ready for latent simmering")
|
| 1196 |
+
print(f"{'='*70}")
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
if __name__ == "__main__":
|
| 1200 |
+
_self_test()
|