--- license: apache-2.0 language: - en library_name: mlx pipeline_tag: text-generation tags: - mlx - apple-silicon - hrm - hierarchical-reasoning - prefix-lm - pre-alignment - non-chat - non-instruction-tuned base_model: sapientinc/HRM-Text-1B --- # HRM-Text-1B-MLX MLX port of [sapientinc/HRM-Text-1B](https://huggingface.co/sapientinc/HRM-Text-1B) — Sapient Intelligence's Hierarchical Reasoning Model adapted to language modeling. Optimized for Apple Silicon with a recurrent KV cache, `mx.fast` kernels, and a clean `load`/`generate` API. ## What this is A 1 B-parameter "pre-alignment" language model that replaces the standard "scale parameters, pretrain on internet-scale text" recipe with **scale compute depth at fixed parameters via recurrence, train only on instruction-response pairs**. Two transformer modules — H (slow / strategic) and L (fast / execution) — iterate over the same input embeddings for `H_cycles × (L_cycles + 1) = 8` stack passes per token, with additive state injection between modules. From the paper: 60.7 % MMLU / 81.9 % ARC-C / 82.2 % DROP / 84.5 % GSM8K / 56.2 % MATH — competitive with 2-7 B open models at **100-900× fewer training tokens** and **96-432× less compute**. > **Disclaimer (carries over from upstream):** This is a **pre-alignment** checkpoint, not a chat or instruction-following assistant. It was pre-trained on a PrefixLM objective with condition prefix tokens. No multi-turn SFT, no RLHF, no chat coating. Use it as a research substrate, not a finished assistant. ## Install ```bash pip install mlx safetensors transformers ``` Then drop `hrm_mlx.py` into your project (the implementation is a single file). ## Usage ```python import mlx.core as mx from hrm_mlx import load, generate model, tokenizer = load("RockTalk/HRM-Text-1B-MLX") # or a local path # Reasoning / chain-of-thought (default condition) print(generate( model, tokenizer, "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes " "muffins with 4 every day. She sells the remainder at $2 per egg. " "How much does she make daily?", condition="synth,cot", max_tokens=300, )) # Direct answer (few-shot extraction / multi-choice) print(generate( model, tokenizer, "Q: What is the capital of Argentina?\nA:", condition="direct", max_tokens=10, )) # Streaming for chunk in generate( model, tokenizer, "Explain why the sky is blue.", condition="synth", max_tokens=200, stream=True, ): print(chunk, end="", flush=True) # Sampled print(generate( model, tokenizer, "Write a short poem about silicon.", condition="synth", temperature=0.8, top_p=0.9, max_tokens=200, )) ``` ## Condition modes HRM-Text routes between training distributions via composite condition prefix tokens. Combine with comma-separated tags (order matters): | Condition | Use for | Notes | |---|---|---| | `synth,cot` | Math, multi-step reasoning | Most verbose, formal step-by-step with LaTeX | | `synth` | Factual explanations | Cleanest factual answers; boxed final answer | | `direct` | Few-shot extraction, MCQ | Terse, 1-5 tokens | | `cot` alone | Free-form chain-of-thought | Moderate verbosity | | `noisy` | Web-crawl-style instructions | Often terse; can leak training-data formatting | **Practical guidance:** - For NLP tasks (classification, extraction, structured output), use `direct` with 2-8 few-shot examples. Zero-shot `direct` is noticeably weaker. - For math / reasoning, use `synth,cot`. - The model is **not** a base LM — it cannot continue raw text ("Once upon a time, …" gets interpreted as a question to answer). ## Architecture ``` z_H = embed(input_ids) * embedding_scale # 1 / initializer_range ≈ 39.19 z_L = zeros for h in range(H_cycles=2): for l in range(L_cycles=3): z_L = L_module(z_L + z_H) # 16-layer transformer stack z_H = H_module(z_H + z_L) # same architecture, separate weights logits = lm_head(z_H) ``` Per forward pass: `L_module` runs 6 times and `H_module` runs 2 times — **8 stack invocations × 16 layers = 128 transformer-layer-equivalents of compute**, all at 1.18 B parameters. | Field | Value | |---|---| | Parameters | 1.18 B | | Hidden size | 1536 | | Layers per stack | 16 | | Attention heads | 12 (MHA, head_dim 128) | | Intermediate size | 4096 | | H_cycles × L_cycles | 2 × 3 | | Max sequence | 4096 | | Vocabulary | 65,536 | | Position encoding | RoPE (θ = 10,000) | | Activation | SwiGLU | | Normalization | Parameterless Pre-RMSNorm | | Attention | Gated (sigmoid output gate) | | Objective | Instruction-only PrefixLM (40 B unique tokens) | ## MLX-specific implementation notes - **128-slot recurrent KV cache** — one slot per `(H_cycle, L_cycle | trailing_H, layer)`, indexed as `(h * (L_cycles+1) + l) * num_layers_per_stack + layer_idx`. Each slot uses chunked-grow allocation (256-token chunks) à la `mlx-lm`. - **`mx.fast.scaled_dot_product_attention`** for both prefill and step. - **`mx.fast.rope` with `offset`** parameter so cached and new positions stay aligned without a precomputed cos/sin table. - **`mx.fast.rms_norm`** with `weight=None` (parameterless RMSNorm). - **Fused projections** — the safetensors store `gqkv_proj` (gate, q, k, v concatenated) and `gate_up_proj` (gate, up concatenated). The MLX port keeps the fused layout, splitting on dim 0 inside the forward (one big matmul beats four small ones on the Metal backend). - **PrefixLM mask** — when `token_type_ids` is all-ones on prefill (the canonical usage), we pass `mask=None` to the fast SDPA kernel and let the bidirectional attention happen implicitly. Mixed prefix/causal masks are supported via an explicit additive mask. ## Benchmarks (M3 Ultra, bf16, greedy) | Prompt | MLX (this port) | PyTorch MPS (with cache) | Speedup | |---|---|---|---| | Sky-blue (~15-tok prompt → 50 tok) | **36.3 tok/s** | 19.3 tok/s | 1.89× | | Janet GSM8K (~75-tok prompt → ~150 tok) | **37.5 tok/s** | 14.3 tok/s | 2.62× | | Train meeting (~50-tok prompt → 400 tok) | **36.5 tok/s** | 14.7 tok/s | 2.47× | | Capitals 8-shot (~115-tok prompt → 10 tok) | **24.5 tok/s** | 11.8 tok/s | 2.07× | Numerical parity: **token-for-token match** with the PyTorch reference under greedy decoding for the first ~30 tokens on every prompt, occasional late-stage divergence on near-tie logits (bf16 precision limit) thereafter — both implementations are correct, just picking different tokens out of statistical ties. ## Files ``` config.json — model config (unchanged from upstream) tokenizer.json — tokenizer (unchanged) tokenizer_config.json — special tokens map (EOS = <|box_end|>, id 11) model_mlx.safetensors — bf16 weights in MLX-native layout (2.2 GB) README.md — this file hrm_mlx.py — the implementation (~350 lines) LICENSE — Apache 2.0 ``` ## License [Apache License 2.0](LICENSE) — same as upstream. ## Citation The original HRM-Text paper: ```bibtex @misc{wang2026hrmtextefficientpretrainingscaling, title={HRM-Text: Efficient Pretraining Beyond Scaling}, author={Guan Wang and Changling Liu and Chenyu Wang and Cai Zhou and Yuhao Sun and Yifei Wu and Shuai Zhen and Luca Scimeca and Yasin Abbasi Yadkori}, year={2026}, eprint={2605.20613}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2605.20613}, } ``` ## Acknowledgments Upstream model and weights: [Sapient Intelligence](https://huggingface.co/sapientinc). MLX port by [@RockTalk](https://huggingface.co/RockTalk).