spawn-router / README.md
pranavkarthik10's picture
Upload README.md with huggingface_hub
b9c1500 verified
|
Raw
History Blame Contribute Delete
10.6 kB
---
license: apache-2.0
language:
- en
library_name: onnxruntime
pipeline_tag: text-classification
base_model: microsoft/deberta-v3-small
tags:
- coding-agent
- routing
- multi-head-classifier
- onnx
- deberta-v3
- model-router
metrics:
- accuracy
- f1
---
<!-- This file is the Hugging Face model card: published to
huggingface.co/afterbuild/spawn-router as README.md. Frontmatter must stay
at the very top of the file or HF won't parse the metadata. -->
# spawn-router
A compact, fast, **local-first** multi-head classifier for **coding-agent task
routing**. Given a task prompt at kickoff, it predicts stable task properties; a
downstream policy/config then maps those properties to a model, provider, and
execution behavior. The classifier predicts the *ontology*; your config owns the
*orchestration*.
This is the model component of **spawn** — see
[`Afterbuild/spawn-router`](https://github.com/Afterbuild/spawn-router) for the
training code and [`spawn-gateway`](https://github.com/Afterbuild) for the local
gateway that wraps Claude Code / Codex and routes with these weights.
- **Backbone:** `microsoft/deberta-v3-small` (multi-head)
- **Checkpoint:** v6 (final text-only training run)
- **Inference:** torch-free ONNX path, ~7 ms/prompt CPU, ~140 MB deps
- **Input:** text only (`current_text`), 256-token max
## What it predicts
```
complexity: easy | medium | hard
+ sub-dims (0..1 regression): reasoning_depth, scope_breadth,
domain_knowledge, spec_completeness (inverted: low spec = harder)
task_type: bugfix | feature | refactor | test | design | docs | migration | exploration
risk: low | medium | high
+ sub-dims (0..1 regression): security_surface, data_sensitivity,
production_exposure, reversal_cost
+ per-head confidences (post-hoc temperature-scaled) and overall_confidence
```
- `complexity` → capability tier (small / mid / large model)
- `task_type` → model specialty (e.g. design → Claude, systems → GPT, docs → small)
- `risk` → tier bumper (easy + high-risk still routes capable) and confirmation gate
The ONNX graphs emit the three classification logits plus all eight regression
sub-dimension scores. Output names are:
- `complexity_logits`, `task_type_logits`, `risk_logits`
- `complexity_sub_reasoning_depth`, `complexity_sub_scope_breadth`,
`complexity_sub_domain_knowledge`, `complexity_sub_spec_completeness`
- `risk_sub_security_surface`, `risk_sub_data_sensitivity`,
`risk_sub_production_exposure`, `risk_sub_reversal_cost`
Routing is **kickoff-only**: classify once at task start and lock the model for
the whole task cycle (no per-turn re-routing → no context thrash).
## Files
| File | What |
|---|---|
| `spawn_router.int8.onnx` | int8-quantized graph — **recommended for serving** (~164 MB) |
| `spawn_router.onnx` | fp32 graph (~540 MB) |
| `model.pt` | PyTorch state dict — for fine-tuning / sub-dim outputs (~565 MB) |
| `spm.model` + `*tokenizer*.json` | SentencePiece (DeBERTa-v2/spm) tokenizer |
| `model_config.json` | architecture + label maps |
| `temperature_scaling.json` | per-head calibration temperatures |
| `*_metrics.json`, `battery_results.json` | evaluation results |
## Usage (ONNX, torch-free)
Needs only `onnxruntime`, `numpy`, and `sentencepiece` — no torch, no transformers.
```python
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
MODEL_DIR = "." # dir containing spawn_router.int8.onnx, spm.model, *.json
MAX_LEN = 256
LABELS = {
"complexity_logits": ["easy", "medium", "hard"],
"task_type_logits": ["bugfix", "feature", "refactor", "test",
"design", "docs", "migration", "exploration"],
"risk_logits": ["low", "medium", "high"],
}
TEMPS = { # from temperature_scaling.json; output name -> head temperature
"complexity_logits": 0.891251,
"task_type_logits": 0.707946,
"risk_logits": 1.059254,
}
sp = spm.SentencePieceProcessor(model_file=f"{MODEL_DIR}/spm.model")
sess = ort.InferenceSession(f"{MODEL_DIR}/spawn_router.int8.onnx",
providers=["CPUExecutionProvider"])
def classify(text: str) -> dict:
# DeBERTa-v3 spm tokenizer: [CLS]=1 + pieces (truncated) + [SEP]=2
pieces = sp.encode(f"Current: {text}", out_type=int)[: MAX_LEN - 2]
ids = [1, *pieces, 2]
feeds = { # structured inputs are text-only sentinels (no interaction context)
"input_ids": np.array([ids], dtype=np.int64),
"attention_mask": np.ones((1, len(ids)), dtype=np.int64),
"previous_action_id": np.array([0], dtype=np.int64), # "none"
"previous_outcome_id": np.array([4], dtype=np.int64), # "unknown"
"log_recency_seconds": np.array([0.0], dtype=np.float32),
"has_interaction": np.array([0], dtype=np.int64),
"has_recency": np.array([0], dtype=np.int64),
}
out = {o.name: v for o, v in zip(sess.get_outputs(), sess.run(None, feeds))}
result = {}
for name, labels in LABELS.items():
logits = out[name][0] / TEMPS[name]
p = np.exp(logits - logits.max()); p /= p.sum()
i = int(p.argmax())
result[name.replace("_logits", "")] = {
"label": labels[i], "confidence": round(float(p[i]), 4),
}
result["complexity_sub"] = {
name.replace("complexity_sub_", ""): round(float(out[name][0]), 4)
for name in out if name.startswith("complexity_sub_")
}
result["risk_sub"] = {
name.replace("risk_sub_", ""): round(float(out[name][0]), 4)
for name in out if name.startswith("risk_sub_")
}
return result
print(classify("refactor JWT key rotation in prod"))
# {'complexity': {'label': ...}, 'task_type': {'label': ...}, 'risk': {'label': ...},
# 'complexity_sub': {'reasoning_depth': ...}, 'risk_sub': {'security_surface': ...}}
```
## Evaluation
Two complementary measures (eval scripts in
[`Afterbuild/spawn-router`](https://github.com/Afterbuild/spawn-router):
`scripts/eval_battery.py`, `eval.py`):
**Locked kickoff battery** (83 hand-labeled probes, never in training — the
canonical cross-version benchmark):
| Metric | v6 |
|---|---|
| Unified kickoff score | **69.5%** |
| Exact match (all 3 heads) | 37.4% |
| complexity | 65.1% |
| task_type | 78.3% |
| risk | 65.1% |
**Held-out test split** (n=174, mirrors the training distribution):
| Head | Accuracy | Macro F1 |
|---|---|---|
| complexity | 67.8% | 68.1% |
| task_type | 86.8% | 87.2% |
| risk | 66.7% | 62.2% |
| **Exact match** | **39.1%** | — |
Sub-dimension regression R² (PyTorch model): reasoning_depth 0.51, scope_breadth
0.47, spec_completeness 0.34, domain_knowledge 0.30; reversal_cost 0.55,
production_exposure 0.51, data_sensitivity 0.25, security_surface 0.19.
Calibration: per-head temperature scaling fit on validation. ECE on the held-out
split is ~0.37 at the 0.8 automation threshold — **confidence is not yet
well-calibrated for aggressive automation**; gate on it conservatively.
## Intended use
- Pick a capability tier / provider for a coding task **at kickoff**, before the
first expensive agent call.
- Drive a confirmation gate for high-blast-radius work (risk/security/prod).
- Spread work across tiers to reduce rate-limit pressure.
**Out of scope:** per-turn routing; non-coding prompts; high-stakes autonomous
action without a human gate; languages other than English (trained on English).
## Limitations
- **Cold-start ceiling.** Effort/blast-radius isn't fully derivable from prompt
text — `complexity=medium` and `risk=high` are the weakest bands, especially on
short imperatives. Production signals (overrides, retries, session duration) are
the intended path past this; this checkpoint predates that loop.
- **Synthetic-label ceiling.** Much training data is LLM-labeled; expect a
~75–80% ceiling per head until real disagreement signals are mixed in.
- **Quantized serving tradeoff.** The fp32 ONNX graph matches the PyTorch model
on the locked battery, including sub-dimension scores. The int8 graph is the
recommended low-dependency serving artifact and preserves the established v6
serving behavior, but dynamic quantization can move borderline labels and
regression values.
## Training
- Backbone `microsoft/deberta-v3-small`, attention pooling, head dependencies,
3 softmax heads + 8 regression heads; `current_text_only` feature mode.
- 5 epochs, batch 16, encoder LR 2e-5, head LR 1e-4, weight decay 0.01, warmup
0.1, seed 13; post-hoc per-head temperature scaling on validation.
- Data: v6 mixed set (train 1141 / val 174 / test 174) — a mix of synthetic
coding-task prompts and real coding-agent kickoff prompts. **The merged
training set is not distributed** (it embeds third-party trace text and
personal usage traces); the synthetic seed data and the full data pipeline
are in the code repo. See "Training data provenance" below.
## Training data provenance
Disclosed in full so downstream users can do their own diligence:
- **Synthetic coding-task prompts** (majority of the mix) — written by Claude
sub-agents and hand-labeled; included in the code repo.
- **SWE-bench problem statements** — used only as Claude-paraphrased
short-imperative prompts (no code, patches, or full issue text). The
SWE-bench benchmark code is MIT; the aggregated issue text is owned by its
authors and the HF dataset card carries no license tag.
- **Public coding-agent trace datasets** (`badlogicgames/pi-mono`,
`armand0e/gpt-5.5-agent`, `lewtun/ml-intern-sessions`) — kickoff prompts
extracted and labeled. These carry `license: other` or no license; their raw
text is **not redistributed** here.
- **The author's own local agent traces** — first-task prompts only; not
redistributed.
- Labels and paraphrases were produced with **Anthropic Claude**; per
Anthropic's Commercial Terms, outputs are customer-owned. No other
provider's models were used for generation or labeling.
The model is a non-generative classifier (three softmax heads over 256-token
inputs); it emits logits, not text, and cannot reproduce training data.
## Credits & provenance
- Scaffolding began as a fork of **[tiny-router](https://github.com/UdaraJay/tiny-router)
by Udara Jay** (MIT); the ontology, data, heads, and serving path were rebuilt
for coding-agent routing.
- Backbone: **DeBERTa-v3** (He et al.; `microsoft/deberta-v3-small`, MIT).
- Related prior art: Vercel v0 Auto and NVIDIA's
prompt-task-and-complexity-classifier.
- **License: Apache-2.0** (weights), with the training-data provenance
disclosed above; the training/serving code repo is MIT.