dex-neural-bake / scripts /export_onnx.py
dexifried
Replace with tiny-router trainer (ZeroGPU/H200)
3bfff54
from __future__ import annotations
import argparse
import importlib.util
from pathlib import Path
import torch
from tiny_router.constants import HEAD_LABELS
from tiny_router.io import load_checkpoint, load_temperature_scaling
from tiny_router.runtime import dump_json, get_device
class OnnxWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
input_ids,
attention_mask,
previous_action_id,
previous_outcome_id,
log_recency_seconds,
has_interaction,
has_recency,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
previous_action_id=previous_action_id,
previous_outcome_id=previous_outcome_id,
log_recency_seconds=log_recency_seconds,
has_interaction=has_interaction,
has_recency=has_recency,
)
return tuple(outputs["logits"][head] for head in HEAD_LABELS)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Export tiny-router to ONNX.")
parser.add_argument("--model-dir", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--opset", type=int, default=18)
parser.add_argument("--quantize", action="store_true")
return parser.parse_args()
def ensure_export_dependencies(quantize: bool) -> None:
missing = []
if importlib.util.find_spec("onnxscript") is None:
missing.append("onnxscript")
if quantize and importlib.util.find_spec("onnxruntime") is None:
missing.append("onnxruntime")
if not missing:
return
package_list = " ".join(missing)
raise RuntimeError(
"Missing export dependencies: "
f"{', '.join(missing)}. "
"Install them with `uv sync` "
f"or `uv add {package_list}` and rerun export."
)
def main() -> None:
args = parse_args()
ensure_export_dependencies(args.quantize)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
device = get_device(prefer_mps=False)
model, tokenizer, config = load_checkpoint(args.model_dir, device=device)
temperatures = load_temperature_scaling(args.model_dir)
wrapper = OnnxWrapper(model).to(device)
wrapper.eval()
sample = tokenizer(
"Current: Export is broken on Safari\nPrevious: None\nPrevious action: None\nPrevious outcome: None",
return_tensors="pt",
max_length=config.max_length,
truncation=True,
)
sample = {key: value.to(device) for key, value in sample.items()}
sample["previous_action_id"] = torch.tensor([0], dtype=torch.long, device=device)
sample["previous_outcome_id"] = torch.tensor([4], dtype=torch.long, device=device)
sample["log_recency_seconds"] = torch.tensor([0.0], dtype=torch.float32, device=device)
sample["has_interaction"] = torch.tensor([0], dtype=torch.long, device=device)
sample["has_recency"] = torch.tensor([0], dtype=torch.long, device=device)
output_path = output_dir / "tiny_router.onnx"
torch.onnx.export(
wrapper,
(
sample["input_ids"],
sample["attention_mask"],
sample["previous_action_id"],
sample["previous_outcome_id"],
sample["log_recency_seconds"],
sample["has_interaction"],
sample["has_recency"],
),
output_path,
input_names=[
"input_ids",
"attention_mask",
"previous_action_id",
"previous_outcome_id",
"log_recency_seconds",
"has_interaction",
"has_recency",
],
output_names=[f"{head}_logits" for head in HEAD_LABELS],
dynamic_axes={
"input_ids": {1: "sequence"},
"attention_mask": {1: "sequence"},
},
opset_version=args.opset,
dynamo=False,
)
metadata = {
"model_file": str(output_path.name),
"feature_mode": config.feature_mode,
"heads": list(HEAD_LABELS.keys()),
"max_length": config.max_length,
"label_maps": config.label_maps,
"temperature_scaling": temperatures,
}
tokenizer.save_pretrained(output_dir)
dump_json(output_dir / "model_config.json", config.to_dict())
dump_json(output_dir / "onnx_metadata.json", metadata)
if temperatures:
dump_json(
output_dir / "temperature_scaling.json",
{
"method": "per_head_temperature_scaling",
"source_split": "validation",
"per_head": temperatures,
},
)
if args.quantize:
from onnxruntime.quantization import QuantType, quantize_dynamic
quantized_path = output_dir / "tiny_router.int8.onnx"
quantize_dynamic(
model_input=str(output_path),
model_output=str(quantized_path),
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
main()