🧬 NEXUS v4.1 β€” Depth-Recurrent Language Model

A novel architecture that achieves deep reasoning through iteration, not size.

NEXUS applies the same tiny block 48 times on a single input β€” giving a 40M parameter model the effective depth of a 2B parameter transformer, while fitting in <100MB with ternary weights.

x = embed(tokens)                    # [B, T, 512]
z = zeros_like(x)                    # latent state

for sup in range(4):                  # deep supervision
    for t in range(6):                # depth recurrence  
        z = f(z + x)                  # 2-layer block, applied 24 times total
    loss += CE(head(z), labels)
    z = detach(z)                     # 1-step IFT β†’ O(1) memory

Why This Architecture?

Most small LMs are weak because they're shallow β€” a 40M transformer has ~6 layers, giving it only 6 chances to transform each token. NEXUS has 2 physical layers but applies them 24 times per supervision step Γ— 4 supervision steps = effectively 48+ layers deep, all with O(1) training memory via the Implicit Function Theorem gradient trick.

This is based on a key finding from the research:

  • TRM (Samsung, 2025): A 7M parameter model scored 45% on ARC-AGI-1, beating most billion-parameter LLMs β€” using depth recurrence, not model size
  • HRM (2025): A 27M parameter model was competitive with DeepSeek R1 and o3-mini on ARC-AGI β€” same principle
Approach Params ARC-AGI-1 How
GPT-4o ~1.8T ~5% Massive scale
DeepSeek R1 671B ~21% Chain-of-thought
TRM 7M 45% Depth recurrence
HRM 27M competitive Depth recurrence

NEXUS brings this principle to language modeling.


Architecture

Core Block f()

Each physical block is tiny β€” just CausalConv1d + SwiGLU MLP:

f(x):
    x = x + CausalConv1d(RMSNorm(x))    # local token mixing (kernel=4)
    x = x + SwiGLU(RMSNorm(x))          # channel mixing (ternary weights)
    return x
  • CausalConv1d: depthwise causal convolution, O(TΒ·D) β€” mixes nearby tokens without attention
  • SwiGLU: gated MLP with SiLU activation, all weights ternary {-1, 0, +1}
  • Pre-norm residual: standard GPT-2/LLaMA pattern for stable deep networks

Depth Recurrence

The block is applied T_inner times (default 6) per supervision step:

z = f(f(f(f(f(f(z + x)))))) + x)   # 6 iterations, same weights

With 2 physical layers in f(), that's 12 layer applications β€” and this is repeated across N_sup=4 supervision steps for a total of 48 effective layers.

Deep Supervision (1-step IFT)

From HRM paper (arxiv:2506.21734): only the last iteration per supervision step has gradients. All prior iterations run in torch.no_grad():

for sup in range(N_sup):
    with torch.no_grad():
        for _ in range(T_inner - 1):   # 5 iterations, no grad
            z = f(z + x)
    z = f(z + x)                        # 1 iteration WITH grad
    loss += CE(head(z), labels)
    z = z.detach()                       # break gradient chain

This gives O(1) training memory regardless of recurrence depth β€” only the last iteration's activations are stored for backward.

BitNet b1.58 Ternary Quantization

All linear layers (except embeddings) use ternary weights {-1, 0, +1}:

Ξ³ = mean(|W|)
W_ternary = round(W / Ξ³).clamp(-1, 1)   # only -1, 0, +1
# At inference: matmul becomes pure addition/subtraction

Impact: 40M params Γ— 1.58 bits β‰ˆ 8MB model (vs 80MB in FP16).


Model Configurations

Config Params FP16 Ternary N_sup T_inner Eff. Depth
nexus_tiny() ~10M 20MB 2MB 4 4 32
nexus_small() ~40M 80MB 8MB 4 6 48
nexus_base() ~100M 200MB 20MB 6 8 144
nexus_medium() ~200M 400MB 40MB 8 10 240

All fit in 2-3GB RAM at inference with ternary quantization.


Quick Start

Train on Colab (Free T4 GPU)

# In a Colab cell:
!pip install -q torch transformers datasets huggingface_hub trackio
!wget -q -O nexus_model.py https://huggingface.co/krystv/nexus-small-v1/resolve/main/nexus_model.py
!wget -q -O train_nexus.py https://huggingface.co/krystv/nexus-small-v1/resolve/main/train_nexus.py

import os
os.environ['MODEL_SIZE']   = 'small'   # tiny/small/base/medium
os.environ['BATCH_SIZE']   = '4'       # 4 for T4, 8+ for A100
os.environ['GRAD_ACCUM']   = '8'       # Effective batch = 32
os.environ['SEQ_LEN']      = '256'     # 256 for T4, 512+ for A100
os.environ['TOTAL_STEPS']  = '20000'
os.environ['SAVE_EVERY']   = '1000'    # Checkpoint frequency (crash protection)
os.environ['EVAL_EVERY']   = '500'     # Generate sample text to see progress

!python train_nexus.py

All Environment Variables

Variable Default Description
MODEL_SIZE small Model preset: tiny, small, base, medium
BATCH_SIZE 4 Micro batch size
GRAD_ACCUM 8 Gradient accumulation (effective batch = BS Γ— GA)
SEQ_LEN 256 Sequence length (256 for T4, 512 for A100)
LR 3e-4 Peak learning rate
WARMUP_STEPS 500 LR warmup steps
TOTAL_STEPS 20000 Total training steps
LOG_EVERY 10 Print loss every N steps
SAVE_EVERY 1000 Save checkpoint every N steps
EVAL_EVERY 500 Generate sample text every N steps
N_SUP 0 (preset) Override deep supervision steps
T_INNER 0 (preset) Override depth recurrence iterations
PHASE pretrain pretrain (FineWeb-Edu) or sft (smol-smoltalk)
OUTPUT_DIR ./nexus-out Where to save checkpoints
HUB_MODEL_ID "" Push to HF Hub (e.g. username/nexus-small)
MAX_SAMPLES 0 Limit dataset size (for testing)

Use the Model

import torch, json
from nexus_model import NexusModel, NexusConfig
from transformers import AutoTokenizer

# Load
with open('config.json') as f:
    cfg = NexusConfig.from_dict(json.load(f))
model = NexusModel(cfg)
model.load_state_dict(torch.load('model.pt', map_location='cpu'))
model.eval()

tok = AutoTokenizer.from_pretrained('.')

# Generate
ids = tok.encode("The future of AI is", return_tensors='pt')
with torch.no_grad():
    out = model.generate(ids, max_new_tokens=200, temperature=0.8, top_p=0.9)
print(tok.decode(out[0], skip_special_tokens=True))

Training Details

Dataset

Phase 1 β€” Pretraining: FineWeb-Edu (sample-10BT)

  • 10B tokens of educationally filtered web content
  • Quality filter: int_score β‰₯ 3 (top 60%)
  • Streaming: no full download needed

Phase 2 β€” Conversation SFT: smol-smoltalk

  • 500K conversations designed for small models (<400M params)
  • Multi-turn user/assistant format

Hyperparameters

Setting Pretrain SFT
Learning Rate 3e-4 1e-4
Schedule Cosine + warmup Cosine + warmup
Warmup 500 steps 100 steps
Effective Batch 32 (4Γ—8) 32 (4Γ—8)
Sequence Length 256 256
Optimizer AdamW (Ξ²=0.9, 0.95) AdamW
Weight Decay 0.01 0.01
Grad Clip 1.0 1.0

Training Features

  • Periodic sample generation (EVAL_EVERY): prints 4 text completions so you can see quality improving
  • Frequent checkpoints (SAVE_EVERY): crash protection for Colab timeouts
  • Checkpoints include optimizer state: enables proper training resume
  • GPU memory logging: every log line shows allocated/reserved VRAM
  • Trackio integration: automatic experiment tracking with GPU metrics

How It Compares

vs Standard Transformers (same param count)

Property 40M Transformer NEXUS-Small (40M)
Physical layers 6 2
Effective depth 6 48
Attention O(TΒ²) None (O(T))
KV cache at seq 1K ~50MB 0
KV cache at seq 100K ~5GB 0
Training memory O(layers Γ— T) O(1) per depth
Weights (ternary) N/A 8MB

vs Other Efficient Architectures

Model Type Context Memory/tok Reasoning Depth
RWKV-7 Linear recurrence ∞ O(1) state Fixed (= layers)
Mamba-2 SSM ∞ O(1) state Fixed (= layers)
Huginn Recurrent depth Sliding window O(window) Variable (test-time)
NEXUS Depth recurrence max_seq_len O(T) Variable (T_inner Γ— N_sup)

NEXUS trades infinite context (like RWKV/Mamba) for variable reasoning depth (like Huginn/TRM). The depth can be increased at inference time for harder problems.

Test-Time Scaling

At inference, you can increase T_inner beyond the training value for deeper reasoning:

# Training: T_inner=6
# Inference on hard problem: T_inner=20 (3x more thinking)
cfg.T_inner = 20
model = NexusModel(cfg)

This is the same principle behind "thinking tokens" in o1/R1, but built into the architecture rather than as a prompting hack.


Technical Deep Dive

Why CausalConv1d Instead of Attention?

For a depth-recurrent model, the block is applied 24-48 times per input. Attention would cost O(TΒ² Γ— D Γ— 48) β€” catastrophically expensive. CausalConv1d costs O(T Γ— D Γ— k Γ— 48) where k=4 (kernel size), which is orders of magnitude cheaper.

The TRM paper (arxiv:2510.04871) showed that even a simple transposed MLP (no attention at all) outperforms self-attention for reasoning tasks when combined with depth recurrence. The reasoning power comes from iteration depth, not from the mixing mechanism.

Why Ternary Weights Work

BitNet b1.58 (arxiv:2402.17764) showed that training from scratch with ternary weights matches FP16 quality at the same parameter count. The key is:

  1. Straight-Through Estimator for gradient flow through quantization
  2. RMSNorm before quantization (SubLN) to preserve variance
  3. Absmean scaling to adapt the ternary threshold per layer

At inference, matrix multiplication becomes pure addition/subtraction β€” no floating-point multiply needed. This enables efficient deployment on mobile/edge devices and custom hardware.

The 1-Step IFT Gradient Trick

The Implicit Function Theorem guarantees that at a fixed point z* = f(z*), the gradient can be computed from a single step of the iteration. By running many no-grad iterations first to approach the fixed point, then one gradient step, we get:

  • Correct gradients (to the extent the system is near a fixed point)
  • O(1) memory (only the last iteration's activations stored)
  • Biologically plausible (similar to equilibrium propagation)

This is why NEXUS can have 48 effective layers but use memory equivalent to a 2-layer model during training.


Repository Contents

File Description
nexus_model.py Complete architecture β€” self-contained, ~300 lines
train_nexus.py Training script with eval + checkpointing
train_nexus_colab.ipynb Ready-to-run Colab notebook
config.json Default NEXUS-Small configuration
ARCHITECTURE.md Extended architecture documentation

Research References

  1. HRM β€” Hierarchical Reasoning Model (Wang et al., 2025) arxiv:2506.21734 β€” Introduced depth recurrence with dual L/H modules. 27M params competitive with o3-mini on ARC-AGI.

  2. TRM β€” Tiny Recursive Models (Samsung SAIL Montreal, 2025) arxiv:2510.04871 β€” Simplified to single shared block. 7M params β†’ 45% ARC-AGI-1. Found transposed MLP > attention for reasoning.

  3. Huginn β€” Recurrent Depth (TΓ³masson et al., 2025) arxiv:2502.05171 β€” Applied depth recurrence to autoregressive LMs. 3.5B Huginn β‰ˆ 50B standard transformer with enough iterations.

  4. BitNet b1.58 β€” The Era of 1-bit LLMs (Ma et al., 2024) arxiv:2402.17764 β€” Ternary quantization matching FP16 quality. Matmul β†’ addition.

  5. SmolLM2 (Allal et al., 2025) arxiv:2502.02737 β€” Training recipe: FineWeb-Edu + smol-smoltalk for small models.


License

Apache 2.0

Downloads last month
174
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for krystv/nexus-small-v1