HRM-Text-1B-MLX

MLX port of sapientinc/HRM-Text-1B β€” Sapient Intelligence's Hierarchical Reasoning Model adapted to language modeling. Optimized for Apple Silicon with a recurrent KV cache, mx.fast kernels, and a clean load/generate API.

What this is

A 1 B-parameter "pre-alignment" language model that replaces the standard "scale parameters, pretrain on internet-scale text" recipe with scale compute depth at fixed parameters via recurrence, train only on instruction-response pairs. Two transformer modules β€” H (slow / strategic) and L (fast / execution) β€” iterate over the same input embeddings for H_cycles Γ— (L_cycles + 1) = 8 stack passes per token, with additive state injection between modules.

From the paper: 60.7 % MMLU / 81.9 % ARC-C / 82.2 % DROP / 84.5 % GSM8K / 56.2 % MATH β€” competitive with 2-7 B open models at 100-900Γ— fewer training tokens and 96-432Γ— less compute.

Disclaimer (carries over from upstream): This is a pre-alignment checkpoint, not a chat or instruction-following assistant. It was pre-trained on a PrefixLM objective with condition prefix tokens. No multi-turn SFT, no RLHF, no chat coating. Use it as a research substrate, not a finished assistant.

Install

pip install mlx safetensors transformers

Then drop hrm_mlx.py into your project (the implementation is a single file).

Usage

import mlx.core as mx
from hrm_mlx import load, generate

model, tokenizer = load("RockTalk/HRM-Text-1B-MLX")  # or a local path

# Reasoning / chain-of-thought (default condition)
print(generate(
    model, tokenizer,
    "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes "
    "muffins with 4 every day. She sells the remainder at $2 per egg. "
    "How much does she make daily?",
    condition="synth,cot",
    max_tokens=300,
))

# Direct answer (few-shot extraction / multi-choice)
print(generate(
    model, tokenizer,
    "Q: What is the capital of Argentina?\nA:",
    condition="direct",
    max_tokens=10,
))

# Streaming
for chunk in generate(
    model, tokenizer, "Explain why the sky is blue.",
    condition="synth", max_tokens=200, stream=True,
):
    print(chunk, end="", flush=True)

# Sampled
print(generate(
    model, tokenizer, "Write a short poem about silicon.",
    condition="synth", temperature=0.8, top_p=0.9, max_tokens=200,
))

Condition modes

HRM-Text routes between training distributions via composite condition prefix tokens. Combine with comma-separated tags (order matters):

Condition Use for Notes
synth,cot Math, multi-step reasoning Most verbose, formal step-by-step with LaTeX
synth Factual explanations Cleanest factual answers; boxed final answer
direct Few-shot extraction, MCQ Terse, 1-5 tokens
cot alone Free-form chain-of-thought Moderate verbosity
noisy Web-crawl-style instructions Often terse; can leak training-data formatting

Practical guidance:

  • For NLP tasks (classification, extraction, structured output), use direct with 2-8 few-shot examples. Zero-shot direct is noticeably weaker.
  • For math / reasoning, use synth,cot.
  • The model is not a base LM β€” it cannot continue raw text ("Once upon a time, …" gets interpreted as a question to answer).

Architecture

z_H = embed(input_ids) * embedding_scale      # 1 / initializer_range β‰ˆ 39.19
z_L = zeros

for h in range(H_cycles=2):
    for l in range(L_cycles=3):
        z_L = L_module(z_L + z_H)             # 16-layer transformer stack
    z_H = H_module(z_H + z_L)                 # same architecture, separate weights

logits = lm_head(z_H)

Per forward pass: L_module runs 6 times and H_module runs 2 times β€” 8 stack invocations Γ— 16 layers = 128 transformer-layer-equivalents of compute, all at 1.18 B parameters.

Field Value
Parameters 1.18 B
Hidden size 1536
Layers per stack 16
Attention heads 12 (MHA, head_dim 128)
Intermediate size 4096
H_cycles Γ— L_cycles 2 Γ— 3
Max sequence 4096
Vocabulary 65,536
Position encoding RoPE (ΞΈ = 10,000)
Activation SwiGLU
Normalization Parameterless Pre-RMSNorm
Attention Gated (sigmoid output gate)
Objective Instruction-only PrefixLM (40 B unique tokens)

MLX-specific implementation notes

  • 128-slot recurrent KV cache β€” one slot per (H_cycle, L_cycle | trailing_H, layer), indexed as (h * (L_cycles+1) + l) * num_layers_per_stack + layer_idx. Each slot uses chunked-grow allocation (256-token chunks) Γ  la mlx-lm.
  • mx.fast.scaled_dot_product_attention for both prefill and step.
  • mx.fast.rope with offset parameter so cached and new positions stay aligned without a precomputed cos/sin table.
  • mx.fast.rms_norm with weight=None (parameterless RMSNorm).
  • Fused projections β€” the safetensors store gqkv_proj (gate, q, k, v concatenated) and gate_up_proj (gate, up concatenated). The MLX port keeps the fused layout, splitting on dim 0 inside the forward (one big matmul beats four small ones on the Metal backend).
  • PrefixLM mask β€” when token_type_ids is all-ones on prefill (the canonical usage), we pass mask=None to the fast SDPA kernel and let the bidirectional attention happen implicitly. Mixed prefix/causal masks are supported via an explicit additive mask.

Benchmarks (M3 Ultra, bf16, greedy)

Prompt MLX (this port) PyTorch MPS (with cache) Speedup
Sky-blue (~15-tok prompt β†’ 50 tok) 36.3 tok/s 19.3 tok/s 1.89Γ—
Janet GSM8K (~75-tok prompt β†’ ~150 tok) 37.5 tok/s 14.3 tok/s 2.62Γ—
Train meeting (~50-tok prompt β†’ 400 tok) 36.5 tok/s 14.7 tok/s 2.47Γ—
Capitals 8-shot (~115-tok prompt β†’ 10 tok) 24.5 tok/s 11.8 tok/s 2.07Γ—

Numerical parity: token-for-token match with the PyTorch reference under greedy decoding for the first ~30 tokens on every prompt, occasional late-stage divergence on near-tie logits (bf16 precision limit) thereafter β€” both implementations are correct, just picking different tokens out of statistical ties.

Files

config.json                  β€” model config (unchanged from upstream)
tokenizer.json               β€” tokenizer (unchanged)
tokenizer_config.json        β€” special tokens map (EOS = <|box_end|>, id 11)
model_mlx.safetensors        β€” bf16 weights in MLX-native layout (2.2 GB)
README.md                    β€” this file
hrm_mlx.py                   β€” the implementation (~350 lines)
LICENSE                      β€” Apache 2.0

License

Apache License 2.0 β€” same as upstream.

Citation

The original HRM-Text paper:

@misc{wang2026hrmtextefficientpretrainingscaling,
      title={HRM-Text: Efficient Pretraining Beyond Scaling},
      author={Guan Wang and Changling Liu and Chenyu Wang and Cai Zhou and Yuhao Sun and Yifei Wu and Shuai Zhen and Luca Scimeca and Yasin Abbasi Yadkori},
      year={2026},
      eprint={2605.20613},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2605.20613},
}

Acknowledgments

Upstream model and weights: Sapient Intelligence.

MLX port by @RockTalk.

Downloads last month
35
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for RockTalk/HRM-Text-1B-MLX

Finetuned
(6)
this model

Paper for RockTalk/HRM-Text-1B-MLX