Needle โ€” ONNX export for in-browser inference

Browser-ready ONNX export of Cactus-Compute/needle, a 26M-parameter function-calling model. Designed to run entirely client-side via onnxruntime-web (WASM backend) โ€” no server required.

Files

File Description Size
encoder.onnx Needle encoder. Input input_ids:(B,T), output encoder_out:(B,T,512). Single-pass. ~55 MB
decoder_step.onnx One decoder step with explicit past-KV in / present-KV out. Run in a JS loop. ~85 MB
needle.model SentencePiece BPE protobuf (vocab=8192, byte_fallback=True, identity normalization). Loadable by sentencepiece-js / @huggingface/transformers. 125 KB
tokenizer-specials.json {"pad":0,"eos":1,"bos":2,"tool_call":4,"tools":5} tiny

Origin

The upstream Cactus Needle is implemented in JAX/Flax, not PyTorch โ€” torch.onnx.export cannot run against the upstream model directly. This ONNX export was produced via a "port-and-copy" pipeline:

  1. Reimplemented the Simple Attention Network in PyTorch (parametric on TransformerConfig)
  2. Copied weights tensor-by-tensor from the upstream Flax checkpoint (handling Flax (in, out) โ†’ PyTorch (out, in) transposition for Linear kernels and the nn.scan layer-stacking convention)
  3. Verified Flaxโ†”PyTorch parity at <1e-3 max-abs-diff
  4. Exported encoder + decoder-step to ONNX via legacy TorchScript-based torch.onnx.export
  5. Verified PyTorchโ†”ONNX parity at <1e-3
  6. Verified end-to-end: Cactus's native generate() and a hand-rolled onnxruntime KV-cache loop produce byte-identical output token sequences

Parity numbers (against Cactus's native generate(constrained=False))

Stage max-abs-diff
Flax encoder โ†” PyTorch port 0.000010
Flax decoder step-0 โ†” PyTorch port 0.000029
PyTorch encoder โ†” ONNX 0.000004
PyTorch decoder step โ†” ONNX 0.000014 (logits)
End-to-end token sequence byte-identical

Example: query="set a 5 min timer" produces ' [{"name":"set_timer","arguments":{"time_human":"5 minutes"}}]' in both Cactus native and the browser via these artifacts.

Usage in the browser

Load both .onnx files via onnxruntime-web (WASM backend), load needle.model via sentencepiece-js, and run the encoder once + decoder-step in a JS loop with the KV cache passed through.

Architecture

Per the upstream model card: encoder-decoder "Simple Attention Network", d_model=512, GQA 8/4 heads, 12 encoder layers, 8 decoder layers, no FFN, ZCRMSNorm ((1+ฮณ)ยทx/RMS(x), ฮณ init zero), RoPE on Q and K.

The decoder is exported as a single step with past/present KV as graph I/O โ€” the JS side calls it in a loop, allowing streaming token output and avoiding ONNX symbolic control flow.

Reproduce / port your own Cactus-trained model

The full pipeline that produced these artifacts is checked in alongside the .onnx files (see PORTING.md for the step-by-step). The scripts are parametric on the source HF repo, so if you've finetuned Needle (or trained a Simple-Attention-Network variant with the upstream Cactus codebase), you can produce a browser-ready ONNX export with the same recipe:

# 1. Convert your Cactus checkpoint โ†’ PyTorch state_dict
uv run python convert_weights.py --ckpt-repo YOUR_USER/your-finetune --ckpt-file weights.pkl

# 2. Verify the port matches your upstream model bit-for-bit (< 1e-3)
uv run python verify_port_parity.py

# 3. Export to ONNX (reads config back from step 1's saved JSON; no edits needed)
uv run python export_onnx.py

# 4. Verify ONNX matches PyTorch AND matches native Cactus generate() token-for-token
uv run python verify_parity.py --ckpt-repo YOUR_USER/your-finetune --ckpt-file weights.pkl

# 5. Push your ONNX artifacts to HF
uv run python upload_to_hf.py --repo YOUR_USER/your-finetune-onnx

The PyTorch port (needle_torch/) is parametric on TransformerConfig โ€” it reads the config straight out of your checkpoint's payload, so dim changes (d_model, layer counts, GQA ratios) are picked up automatically. The same pipeline works for the 26M production Needle, the 1.35M iteration config, and anything in between.

Files included for reproduction:

needle_torch/         โ€” PyTorch port of the Simple Attention Network
convert_weights.py    โ€” Flax checkpoint โ†’ PyTorch state_dict (parametric on --ckpt-repo)
export_onnx.py        โ€” torch.onnx.export of encoder + decoder-step
verify_port_parity.py โ€” Flax โ†” PyTorch parity check (load-bearing)
verify_parity.py      โ€” PyTorch โ†” ONNX + end-to-end vs native generate()
dump_tokenizer.py     โ€” Copy SentencePiece .model + emit parity goldens for the JS port
upload_to_hf.py       โ€” This script (push artifacts to HF Hub)
inspect_needle.py     โ€” Dump Flax arch / tokenizer / prompt notes (useful when porting a variant)
pyproject.toml        โ€” uv-managed env spec
PORTING.md            โ€” Full step-by-step guide

License

MIT, matching the upstream Cactus Needle license.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for onnx-community/needle-onnx

Quantized
(1)
this model

Space using onnx-community/needle-onnx 1