𧬠NEXUS v4.1 β Depth-Recurrent Language Model
A novel architecture that achieves deep reasoning through iteration, not size.
NEXUS applies the same tiny block 48 times on a single input β giving a 40M parameter model the effective depth of a 2B parameter transformer, while fitting in <100MB with ternary weights.
x = embed(tokens) # [B, T, 512]
z = zeros_like(x) # latent state
for sup in range(4): # deep supervision
for t in range(6): # depth recurrence
z = f(z + x) # 2-layer block, applied 24 times total
loss += CE(head(z), labels)
z = detach(z) # 1-step IFT β O(1) memory
Why This Architecture?
Most small LMs are weak because they're shallow β a 40M transformer has ~6 layers, giving it only 6 chances to transform each token. NEXUS has 2 physical layers but applies them 24 times per supervision step Γ 4 supervision steps = effectively 48+ layers deep, all with O(1) training memory via the Implicit Function Theorem gradient trick.
This is based on a key finding from the research:
- TRM (Samsung, 2025): A 7M parameter model scored 45% on ARC-AGI-1, beating most billion-parameter LLMs β using depth recurrence, not model size
- HRM (2025): A 27M parameter model was competitive with DeepSeek R1 and o3-mini on ARC-AGI β same principle
| Approach | Params | ARC-AGI-1 | How |
|---|---|---|---|
| GPT-4o | ~1.8T | ~5% | Massive scale |
| DeepSeek R1 | 671B | ~21% | Chain-of-thought |
| TRM | 7M | 45% | Depth recurrence |
| HRM | 27M | competitive | Depth recurrence |
NEXUS brings this principle to language modeling.
Architecture
Core Block f()
Each physical block is tiny β just CausalConv1d + SwiGLU MLP:
f(x):
x = x + CausalConv1d(RMSNorm(x)) # local token mixing (kernel=4)
x = x + SwiGLU(RMSNorm(x)) # channel mixing (ternary weights)
return x
- CausalConv1d: depthwise causal convolution, O(TΒ·D) β mixes nearby tokens without attention
- SwiGLU: gated MLP with SiLU activation, all weights ternary {-1, 0, +1}
- Pre-norm residual: standard GPT-2/LLaMA pattern for stable deep networks
Depth Recurrence
The block is applied T_inner times (default 6) per supervision step:
z = f(f(f(f(f(f(z + x)))))) + x) # 6 iterations, same weights
With 2 physical layers in f(), that's 12 layer applications β and this is repeated across N_sup=4 supervision steps for a total of 48 effective layers.
Deep Supervision (1-step IFT)
From HRM paper (arxiv:2506.21734): only the last iteration per supervision step has gradients. All prior iterations run in torch.no_grad():
for sup in range(N_sup):
with torch.no_grad():
for _ in range(T_inner - 1): # 5 iterations, no grad
z = f(z + x)
z = f(z + x) # 1 iteration WITH grad
loss += CE(head(z), labels)
z = z.detach() # break gradient chain
This gives O(1) training memory regardless of recurrence depth β only the last iteration's activations are stored for backward.
BitNet b1.58 Ternary Quantization
All linear layers (except embeddings) use ternary weights {-1, 0, +1}:
Ξ³ = mean(|W|)
W_ternary = round(W / Ξ³).clamp(-1, 1) # only -1, 0, +1
# At inference: matmul becomes pure addition/subtraction
Impact: 40M params Γ 1.58 bits β 8MB model (vs 80MB in FP16).
Model Configurations
| Config | Params | FP16 | Ternary | N_sup | T_inner | Eff. Depth |
|---|---|---|---|---|---|---|
nexus_tiny() |
~10M | 20MB | 2MB | 4 | 4 | 32 |
nexus_small() |
~40M | 80MB | 8MB | 4 | 6 | 48 |
nexus_base() |
~100M | 200MB | 20MB | 6 | 8 | 144 |
nexus_medium() |
~200M | 400MB | 40MB | 8 | 10 | 240 |
All fit in 2-3GB RAM at inference with ternary quantization.
Quick Start
Train on Colab (Free T4 GPU)
# In a Colab cell:
!pip install -q torch transformers datasets huggingface_hub trackio
!wget -q -O nexus_model.py https://huggingface.co/krystv/nexus-small-v1/resolve/main/nexus_model.py
!wget -q -O train_nexus.py https://huggingface.co/krystv/nexus-small-v1/resolve/main/train_nexus.py
import os
os.environ['MODEL_SIZE'] = 'small' # tiny/small/base/medium
os.environ['BATCH_SIZE'] = '4' # 4 for T4, 8+ for A100
os.environ['GRAD_ACCUM'] = '8' # Effective batch = 32
os.environ['SEQ_LEN'] = '256' # 256 for T4, 512+ for A100
os.environ['TOTAL_STEPS'] = '20000'
os.environ['SAVE_EVERY'] = '1000' # Checkpoint frequency (crash protection)
os.environ['EVAL_EVERY'] = '500' # Generate sample text to see progress
!python train_nexus.py
All Environment Variables
| Variable | Default | Description |
|---|---|---|
MODEL_SIZE |
small |
Model preset: tiny, small, base, medium |
BATCH_SIZE |
4 |
Micro batch size |
GRAD_ACCUM |
8 |
Gradient accumulation (effective batch = BS Γ GA) |
SEQ_LEN |
256 |
Sequence length (256 for T4, 512 for A100) |
LR |
3e-4 |
Peak learning rate |
WARMUP_STEPS |
500 |
LR warmup steps |
TOTAL_STEPS |
20000 |
Total training steps |
LOG_EVERY |
10 |
Print loss every N steps |
SAVE_EVERY |
1000 |
Save checkpoint every N steps |
EVAL_EVERY |
500 |
Generate sample text every N steps |
N_SUP |
0 (preset) |
Override deep supervision steps |
T_INNER |
0 (preset) |
Override depth recurrence iterations |
PHASE |
pretrain |
pretrain (FineWeb-Edu) or sft (smol-smoltalk) |
OUTPUT_DIR |
./nexus-out |
Where to save checkpoints |
HUB_MODEL_ID |
"" |
Push to HF Hub (e.g. username/nexus-small) |
MAX_SAMPLES |
0 |
Limit dataset size (for testing) |
Use the Model
import torch, json
from nexus_model import NexusModel, NexusConfig
from transformers import AutoTokenizer
# Load
with open('config.json') as f:
cfg = NexusConfig.from_dict(json.load(f))
model = NexusModel(cfg)
model.load_state_dict(torch.load('model.pt', map_location='cpu'))
model.eval()
tok = AutoTokenizer.from_pretrained('.')
# Generate
ids = tok.encode("The future of AI is", return_tensors='pt')
with torch.no_grad():
out = model.generate(ids, max_new_tokens=200, temperature=0.8, top_p=0.9)
print(tok.decode(out[0], skip_special_tokens=True))
Training Details
Dataset
Phase 1 β Pretraining: FineWeb-Edu (sample-10BT)
- 10B tokens of educationally filtered web content
- Quality filter:
int_score β₯ 3(top 60%) - Streaming: no full download needed
Phase 2 β Conversation SFT: smol-smoltalk
- 500K conversations designed for small models (<400M params)
- Multi-turn user/assistant format
Hyperparameters
| Setting | Pretrain | SFT |
|---|---|---|
| Learning Rate | 3e-4 | 1e-4 |
| Schedule | Cosine + warmup | Cosine + warmup |
| Warmup | 500 steps | 100 steps |
| Effective Batch | 32 (4Γ8) | 32 (4Γ8) |
| Sequence Length | 256 | 256 |
| Optimizer | AdamW (Ξ²=0.9, 0.95) | AdamW |
| Weight Decay | 0.01 | 0.01 |
| Grad Clip | 1.0 | 1.0 |
Training Features
- Periodic sample generation (
EVAL_EVERY): prints 4 text completions so you can see quality improving - Frequent checkpoints (
SAVE_EVERY): crash protection for Colab timeouts - Checkpoints include optimizer state: enables proper training resume
- GPU memory logging: every log line shows allocated/reserved VRAM
- Trackio integration: automatic experiment tracking with GPU metrics
How It Compares
vs Standard Transformers (same param count)
| Property | 40M Transformer | NEXUS-Small (40M) |
|---|---|---|
| Physical layers | 6 | 2 |
| Effective depth | 6 | 48 |
| Attention | O(TΒ²) | None (O(T)) |
| KV cache at seq 1K | ~50MB | 0 |
| KV cache at seq 100K | ~5GB | 0 |
| Training memory | O(layers Γ T) | O(1) per depth |
| Weights (ternary) | N/A | 8MB |
vs Other Efficient Architectures
| Model | Type | Context | Memory/tok | Reasoning Depth |
|---|---|---|---|---|
| RWKV-7 | Linear recurrence | β | O(1) state | Fixed (= layers) |
| Mamba-2 | SSM | β | O(1) state | Fixed (= layers) |
| Huginn | Recurrent depth | Sliding window | O(window) | Variable (test-time) |
| NEXUS | Depth recurrence | max_seq_len | O(T) | Variable (T_inner Γ N_sup) |
NEXUS trades infinite context (like RWKV/Mamba) for variable reasoning depth (like Huginn/TRM). The depth can be increased at inference time for harder problems.
Test-Time Scaling
At inference, you can increase T_inner beyond the training value for deeper reasoning:
# Training: T_inner=6
# Inference on hard problem: T_inner=20 (3x more thinking)
cfg.T_inner = 20
model = NexusModel(cfg)
This is the same principle behind "thinking tokens" in o1/R1, but built into the architecture rather than as a prompting hack.
Technical Deep Dive
Why CausalConv1d Instead of Attention?
For a depth-recurrent model, the block is applied 24-48 times per input. Attention would cost O(TΒ² Γ D Γ 48) β catastrophically expensive. CausalConv1d costs O(T Γ D Γ k Γ 48) where k=4 (kernel size), which is orders of magnitude cheaper.
The TRM paper (arxiv:2510.04871) showed that even a simple transposed MLP (no attention at all) outperforms self-attention for reasoning tasks when combined with depth recurrence. The reasoning power comes from iteration depth, not from the mixing mechanism.
Why Ternary Weights Work
BitNet b1.58 (arxiv:2402.17764) showed that training from scratch with ternary weights matches FP16 quality at the same parameter count. The key is:
- Straight-Through Estimator for gradient flow through quantization
- RMSNorm before quantization (SubLN) to preserve variance
- Absmean scaling to adapt the ternary threshold per layer
At inference, matrix multiplication becomes pure addition/subtraction β no floating-point multiply needed. This enables efficient deployment on mobile/edge devices and custom hardware.
The 1-Step IFT Gradient Trick
The Implicit Function Theorem guarantees that at a fixed point z* = f(z*), the gradient can be computed from a single step of the iteration. By running many no-grad iterations first to approach the fixed point, then one gradient step, we get:
- Correct gradients (to the extent the system is near a fixed point)
- O(1) memory (only the last iteration's activations stored)
- Biologically plausible (similar to equilibrium propagation)
This is why NEXUS can have 48 effective layers but use memory equivalent to a 2-layer model during training.
Repository Contents
| File | Description |
|---|---|
nexus_model.py |
Complete architecture β self-contained, ~300 lines |
train_nexus.py |
Training script with eval + checkpointing |
train_nexus_colab.ipynb |
Ready-to-run Colab notebook |
config.json |
Default NEXUS-Small configuration |
ARCHITECTURE.md |
Extended architecture documentation |
Research References
HRM β Hierarchical Reasoning Model (Wang et al., 2025) arxiv:2506.21734 β Introduced depth recurrence with dual L/H modules. 27M params competitive with o3-mini on ARC-AGI.
TRM β Tiny Recursive Models (Samsung SAIL Montreal, 2025) arxiv:2510.04871 β Simplified to single shared block. 7M params β 45% ARC-AGI-1. Found transposed MLP > attention for reasoning.
Huginn β Recurrent Depth (TΓ³masson et al., 2025) arxiv:2502.05171 β Applied depth recurrence to autoregressive LMs. 3.5B Huginn β 50B standard transformer with enough iterations.
BitNet b1.58 β The Era of 1-bit LLMs (Ma et al., 2024) arxiv:2402.17764 β Ternary quantization matching FP16 quality. Matmul β addition.
SmolLM2 (Allal et al., 2025) arxiv:2502.02737 β Training recipe: FineWeb-Edu + smol-smoltalk for small models.
License
Apache 2.0
- Downloads last month
- 174