HRM-Text-1B-MLX / README.md
RockTalk's picture
Add files using upload-large-folder tool
b6cdd64 verified
---
license: apache-2.0
language:
- en
library_name: mlx
pipeline_tag: text-generation
tags:
- mlx
- apple-silicon
- hrm
- hierarchical-reasoning
- prefix-lm
- pre-alignment
- non-chat
- non-instruction-tuned
base_model: sapientinc/HRM-Text-1B
---
# HRM-Text-1B-MLX
MLX port of [sapientinc/HRM-Text-1B](https://huggingface.co/sapientinc/HRM-Text-1B) β€” Sapient Intelligence's Hierarchical Reasoning Model adapted to language modeling. Optimized for Apple Silicon with a recurrent KV cache, `mx.fast` kernels, and a clean `load`/`generate` API.
## What this is
A 1 B-parameter "pre-alignment" language model that replaces the standard "scale parameters, pretrain on internet-scale text" recipe with **scale compute depth at fixed parameters via recurrence, train only on instruction-response pairs**. Two transformer modules β€” H (slow / strategic) and L (fast / execution) β€” iterate over the same input embeddings for `H_cycles Γ— (L_cycles + 1) = 8` stack passes per token, with additive state injection between modules.
From the paper: 60.7 % MMLU / 81.9 % ARC-C / 82.2 % DROP / 84.5 % GSM8K / 56.2 % MATH β€” competitive with 2-7 B open models at **100-900Γ— fewer training tokens** and **96-432Γ— less compute**.
> **Disclaimer (carries over from upstream):** This is a **pre-alignment** checkpoint, not a chat or instruction-following assistant. It was pre-trained on a PrefixLM objective with condition prefix tokens. No multi-turn SFT, no RLHF, no chat coating. Use it as a research substrate, not a finished assistant.
## Install
```bash
pip install mlx safetensors transformers
```
Then drop `hrm_mlx.py` into your project (the implementation is a single file).
## Usage
```python
import mlx.core as mx
from hrm_mlx import load, generate
model, tokenizer = load("RockTalk/HRM-Text-1B-MLX") # or a local path
# Reasoning / chain-of-thought (default condition)
print(generate(
model, tokenizer,
"Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes "
"muffins with 4 every day. She sells the remainder at $2 per egg. "
"How much does she make daily?",
condition="synth,cot",
max_tokens=300,
))
# Direct answer (few-shot extraction / multi-choice)
print(generate(
model, tokenizer,
"Q: What is the capital of Argentina?\nA:",
condition="direct",
max_tokens=10,
))
# Streaming
for chunk in generate(
model, tokenizer, "Explain why the sky is blue.",
condition="synth", max_tokens=200, stream=True,
):
print(chunk, end="", flush=True)
# Sampled
print(generate(
model, tokenizer, "Write a short poem about silicon.",
condition="synth", temperature=0.8, top_p=0.9, max_tokens=200,
))
```
## Condition modes
HRM-Text routes between training distributions via composite condition prefix tokens. Combine with comma-separated tags (order matters):
| Condition | Use for | Notes |
|---|---|---|
| `synth,cot` | Math, multi-step reasoning | Most verbose, formal step-by-step with LaTeX |
| `synth` | Factual explanations | Cleanest factual answers; boxed final answer |
| `direct` | Few-shot extraction, MCQ | Terse, 1-5 tokens |
| `cot` alone | Free-form chain-of-thought | Moderate verbosity |
| `noisy` | Web-crawl-style instructions | Often terse; can leak training-data formatting |
**Practical guidance:**
- For NLP tasks (classification, extraction, structured output), use `direct` with 2-8 few-shot examples. Zero-shot `direct` is noticeably weaker.
- For math / reasoning, use `synth,cot`.
- The model is **not** a base LM β€” it cannot continue raw text ("Once upon a time, …" gets interpreted as a question to answer).
## Architecture
```
z_H = embed(input_ids) * embedding_scale # 1 / initializer_range β‰ˆ 39.19
z_L = zeros
for h in range(H_cycles=2):
for l in range(L_cycles=3):
z_L = L_module(z_L + z_H) # 16-layer transformer stack
z_H = H_module(z_H + z_L) # same architecture, separate weights
logits = lm_head(z_H)
```
Per forward pass: `L_module` runs 6 times and `H_module` runs 2 times β€” **8 stack invocations Γ— 16 layers = 128 transformer-layer-equivalents of compute**, all at 1.18 B parameters.
| Field | Value |
|---|---|
| Parameters | 1.18 B |
| Hidden size | 1536 |
| Layers per stack | 16 |
| Attention heads | 12 (MHA, head_dim 128) |
| Intermediate size | 4096 |
| H_cycles Γ— L_cycles | 2 Γ— 3 |
| Max sequence | 4096 |
| Vocabulary | 65,536 |
| Position encoding | RoPE (ΞΈ = 10,000) |
| Activation | SwiGLU |
| Normalization | Parameterless Pre-RMSNorm |
| Attention | Gated (sigmoid output gate) |
| Objective | Instruction-only PrefixLM (40 B unique tokens) |
## MLX-specific implementation notes
- **128-slot recurrent KV cache** β€” one slot per `(H_cycle, L_cycle | trailing_H, layer)`, indexed as `(h * (L_cycles+1) + l) * num_layers_per_stack + layer_idx`. Each slot uses chunked-grow allocation (256-token chunks) Γ  la `mlx-lm`.
- **`mx.fast.scaled_dot_product_attention`** for both prefill and step.
- **`mx.fast.rope` with `offset`** parameter so cached and new positions stay aligned without a precomputed cos/sin table.
- **`mx.fast.rms_norm`** with `weight=None` (parameterless RMSNorm).
- **Fused projections** β€” the safetensors store `gqkv_proj` (gate, q, k, v concatenated) and `gate_up_proj` (gate, up concatenated). The MLX port keeps the fused layout, splitting on dim 0 inside the forward (one big matmul beats four small ones on the Metal backend).
- **PrefixLM mask** β€” when `token_type_ids` is all-ones on prefill (the canonical usage), we pass `mask=None` to the fast SDPA kernel and let the bidirectional attention happen implicitly. Mixed prefix/causal masks are supported via an explicit additive mask.
## Benchmarks (M3 Ultra, bf16, greedy)
| Prompt | MLX (this port) | PyTorch MPS (with cache) | Speedup |
|---|---|---|---|
| Sky-blue (~15-tok prompt β†’ 50 tok) | **36.3 tok/s** | 19.3 tok/s | 1.89Γ— |
| Janet GSM8K (~75-tok prompt β†’ ~150 tok) | **37.5 tok/s** | 14.3 tok/s | 2.62Γ— |
| Train meeting (~50-tok prompt β†’ 400 tok) | **36.5 tok/s** | 14.7 tok/s | 2.47Γ— |
| Capitals 8-shot (~115-tok prompt β†’ 10 tok) | **24.5 tok/s** | 11.8 tok/s | 2.07Γ— |
Numerical parity: **token-for-token match** with the PyTorch reference under greedy decoding for the first ~30 tokens on every prompt, occasional late-stage divergence on near-tie logits (bf16 precision limit) thereafter β€” both implementations are correct, just picking different tokens out of statistical ties.
## Files
```
config.json β€” model config (unchanged from upstream)
tokenizer.json β€” tokenizer (unchanged)
tokenizer_config.json β€” special tokens map (EOS = <|box_end|>, id 11)
model_mlx.safetensors β€” bf16 weights in MLX-native layout (2.2 GB)
README.md β€” this file
hrm_mlx.py β€” the implementation (~350 lines)
LICENSE β€” Apache 2.0
```
## License
[Apache License 2.0](LICENSE) β€” same as upstream.
## Citation
The original HRM-Text paper:
```bibtex
@misc{wang2026hrmtextefficientpretrainingscaling,
title={HRM-Text: Efficient Pretraining Beyond Scaling},
author={Guan Wang and Changling Liu and Chenyu Wang and Cai Zhou and Yuhao Sun and Yifei Wu and Shuai Zhen and Luca Scimeca and Yasin Abbasi Yadkori},
year={2026},
eprint={2605.20613},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2605.20613},
}
```
## Acknowledgments
Upstream model and weights: [Sapient Intelligence](https://huggingface.co/sapientinc).
MLX port by [@RockTalk](https://huggingface.co/RockTalk).