Needle Transit — ONNX
ONNX export of rasaboun/needle-transit for
in-browser / on-device inference via onnxruntime and onnxruntime-web. Same 26M-parameter
transit extractor, runtime-portable.
Files
| File | Purpose |
|---|---|
encoder.onnx |
Encoder. Input input_ids:(B,T) → output encoder_out:(B,T,512). Single pass. |
decoder_step.onnx |
One decoder step with explicit past-KV in / present-KV out. Run in a loop. |
needle.model |
SentencePiece BPE tokenizer (vocab 8192, byte_fallback=True, identity normalization). Loadable by sentencepiece-js / @huggingface/transformers. |
tokenizer-specials.json |
Special-token ids: pad=0, eos=1, bos=2, tool_call=4, tools=5. |
needle_torch.config.json |
Model dims (d_model=512, heads 8/4kv, enc×12, dec×8, max_seq_len=1024). |
Inference flow
- Tokenize
query+toolswithneedle.model, framing with the special tokens above. - Run
encoder.onnxonce →encoder_out. - Loop
decoder_step.onnxgreedily frombos, passing the KV cache (present-KV → past-KV) and feeding back each emitted token, untileos. - Decode the token stream → JSON tool call (or empty = refusal).
The decoder is exported as a single step with past/present KV as graph I/O — the host calls it in a loop, enabling streaming and avoiding ONNX symbolic control flow.
Reference the JAX inference path in the official
cactus-compute/needle package for the exact
tokenization framing and decode loop.
Usage (Python / onnxruntime)
import onnxruntime as ort
import sentencepiece as spm
enc = ort.InferenceSession("encoder.onnx")
dec = ort.InferenceSession("decoder_step.onnx")
sp = spm.SentencePieceProcessor(model_file="needle.model")
# ... tokenize query+tools, run enc once, loop dec until eos ...
For the browser, swap onnxruntime for onnxruntime-web
(WASM/WebGPU backend); the file layout is identical.
Examples
Example query → tool call (verified — ONNX output is token-identical to the JAX checkpoint on these):
| Query | Output |
|---|---|
Itinéraire de Bastille à Nation |
[{"name":"search_itinerary","arguments":{"origin":"Bastille","destination":"Nation"}}] |
De Issy à Charles-de-Gaulle, départ 14h |
[{"name":"search_itinerary","arguments":{"origin":"Issy","destination":"Charles-de-Gaulle","time_human":"départ 14h","time_mode":"depart_at"}}] |
How do I get from Gare du Nord to La Défense? |
[{"name":"search_itinerary","arguments":{"origin":"Gare du Nord","destination":"La Défense"}}] |
Prochain métro à Bastille ligne 1 ? |
[{"name":"get_next_arrivals","arguments":{"station":"Bastille","line":"1"}}] |
prochains passages à Châtelet |
[{"name":"get_next_arrivals","arguments":{"station":"Châtelet"}}] |
cmt aller a chatelet depuis nation |
[{"name":"search_itinerary","arguments":{"origin":"nation","destination":"chatelet"}}] |
Quel temps fait-il ? |
[] |
Results
Functionally equivalent to the JAX checkpoint — the ONNX export produces token-identical tool calls on the examples above (verified end-to-end). Dataset and evaluation: coming soon.
Finetuning
Finetune on your own tools with the customized scripts in github.com/Rasaboun/needle-transit (tunable LR/Muon-LR, per-field loss weighting, metrics logging), then re-export to ONNX with the toolkit below.
Provenance & reproduction
Upstream Needle is JAX/Flax, so torch.onnx.export can't run against it directly. These artifacts
were produced via a port-and-copy pipeline: reimplement the Simple Attention Network in PyTorch,
copy weights tensor-by-tensor from the Flax checkpoint, verify Flax↔PyTorch↔ONNX parity, then export
encoder + decoder-step.
The conversion scripts live in onnx-community/needle-onnx
(convert_weights.py, export_onnx.py, verify_parity.py, the PyTorch port, and PORTING.md).
They are parametric on the source checkpoint — re-export any Needle finetune with the same recipe:
# get the conversion toolkit
hf download onnx-community/needle-onnx --local-dir needle-onnx && cd needle-onnx
# 1. Flax checkpoint → PyTorch state_dict
uv run python convert_weights.py --ckpt-repo rasaboun/needle-transit --ckpt-file needle-transit.pkl
# 2. verify the port matches upstream (< 1e-3)
uv run python verify_port_parity.py
# 3. export encoder + decoder-step to ONNX
uv run python export_onnx.py
# 4. verify ONNX ↔ PyTorch ↔ native generate()
uv run python verify_parity.py --ckpt-repo rasaboun/needle-transit --ckpt-file needle-transit.pkl
License & attribution
MIT. Exported from rasaboun/needle-transit, itself fine-tuned from Cactus-Compute/needle (© Cactus Compute, MIT).