Needle โ ONNX export for in-browser inference
Browser-ready ONNX export of Cactus-Compute/needle, a 26M-parameter function-calling model. Designed to run entirely client-side via onnxruntime-web (WASM backend) โ no server required.
Files
| File | Description | Size |
|---|---|---|
encoder.onnx |
Needle encoder. Input input_ids:(B,T), output encoder_out:(B,T,512). Single-pass. |
~55 MB |
decoder_step.onnx |
One decoder step with explicit past-KV in / present-KV out. Run in a JS loop. | ~85 MB |
needle.model |
SentencePiece BPE protobuf (vocab=8192, byte_fallback=True, identity normalization). Loadable by sentencepiece-js / @huggingface/transformers. |
125 KB |
tokenizer-specials.json |
{"pad":0,"eos":1,"bos":2,"tool_call":4,"tools":5} |
tiny |
Origin
The upstream Cactus Needle is implemented in JAX/Flax, not PyTorch โ torch.onnx.export cannot run against the upstream model directly. This ONNX export was produced via a "port-and-copy" pipeline:
- Reimplemented the Simple Attention Network in PyTorch (parametric on
TransformerConfig) - Copied weights tensor-by-tensor from the upstream Flax checkpoint (handling Flax
(in, out)โ PyTorch(out, in)transposition for Linear kernels and thenn.scanlayer-stacking convention) - Verified FlaxโPyTorch parity at
<1e-3max-abs-diff - Exported encoder + decoder-step to ONNX via legacy TorchScript-based
torch.onnx.export - Verified PyTorchโONNX parity at
<1e-3 - Verified end-to-end: Cactus's native
generate()and a hand-rolledonnxruntimeKV-cache loop produce byte-identical output token sequences
Parity numbers (against Cactus's native generate(constrained=False))
| Stage | max-abs-diff |
|---|---|
| Flax encoder โ PyTorch port | 0.000010 |
| Flax decoder step-0 โ PyTorch port | 0.000029 |
| PyTorch encoder โ ONNX | 0.000004 |
| PyTorch decoder step โ ONNX | 0.000014 (logits) |
| End-to-end token sequence | byte-identical |
Example: query="set a 5 min timer" produces ' [{"name":"set_timer","arguments":{"time_human":"5 minutes"}}]' in both Cactus native and the browser via these artifacts.
Usage in the browser
Load both .onnx files via onnxruntime-web (WASM backend), load needle.model via sentencepiece-js, and run the encoder once + decoder-step in a JS loop with the KV cache passed through.
Architecture
Per the upstream model card: encoder-decoder "Simple Attention Network", d_model=512, GQA 8/4 heads, 12 encoder layers, 8 decoder layers, no FFN, ZCRMSNorm ((1+ฮณ)ยทx/RMS(x), ฮณ init zero), RoPE on Q and K.
The decoder is exported as a single step with past/present KV as graph I/O โ the JS side calls it in a loop, allowing streaming token output and avoiding ONNX symbolic control flow.
Reproduce / port your own Cactus-trained model
The full pipeline that produced these artifacts is checked in alongside the .onnx files (see PORTING.md for the step-by-step). The scripts are parametric on the source HF repo, so if you've finetuned Needle (or trained a Simple-Attention-Network variant with the upstream Cactus codebase), you can produce a browser-ready ONNX export with the same recipe:
# 1. Convert your Cactus checkpoint โ PyTorch state_dict
uv run python convert_weights.py --ckpt-repo YOUR_USER/your-finetune --ckpt-file weights.pkl
# 2. Verify the port matches your upstream model bit-for-bit (< 1e-3)
uv run python verify_port_parity.py
# 3. Export to ONNX (reads config back from step 1's saved JSON; no edits needed)
uv run python export_onnx.py
# 4. Verify ONNX matches PyTorch AND matches native Cactus generate() token-for-token
uv run python verify_parity.py --ckpt-repo YOUR_USER/your-finetune --ckpt-file weights.pkl
# 5. Push your ONNX artifacts to HF
uv run python upload_to_hf.py --repo YOUR_USER/your-finetune-onnx
The PyTorch port (needle_torch/) is parametric on TransformerConfig โ it reads the config straight out of your checkpoint's payload, so dim changes (d_model, layer counts, GQA ratios) are picked up automatically. The same pipeline works for the 26M production Needle, the 1.35M iteration config, and anything in between.
Files included for reproduction:
needle_torch/ โ PyTorch port of the Simple Attention Network
convert_weights.py โ Flax checkpoint โ PyTorch state_dict (parametric on --ckpt-repo)
export_onnx.py โ torch.onnx.export of encoder + decoder-step
verify_port_parity.py โ Flax โ PyTorch parity check (load-bearing)
verify_parity.py โ PyTorch โ ONNX + end-to-end vs native generate()
dump_tokenizer.py โ Copy SentencePiece .model + emit parity goldens for the JS port
upload_to_hf.py โ This script (push artifacts to HF Hub)
inspect_needle.py โ Dump Flax arch / tokenizer / prompt notes (useful when porting a variant)
pyproject.toml โ uv-managed env spec
PORTING.md โ Full step-by-step guide
License
MIT, matching the upstream Cactus Needle license.
Model tree for onnx-community/needle-onnx
Base model
Cactus-Compute/needle