from __future__ import annotations import argparse import importlib.util from pathlib import Path import torch from tiny_router.constants import HEAD_LABELS from tiny_router.io import load_checkpoint, load_temperature_scaling from tiny_router.runtime import dump_json, get_device class OnnxWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward( self, input_ids, attention_mask, previous_action_id, previous_outcome_id, log_recency_seconds, has_interaction, has_recency, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, previous_action_id=previous_action_id, previous_outcome_id=previous_outcome_id, log_recency_seconds=log_recency_seconds, has_interaction=has_interaction, has_recency=has_recency, ) return tuple(outputs["logits"][head] for head in HEAD_LABELS) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Export tiny-router to ONNX.") parser.add_argument("--model-dir", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument("--opset", type=int, default=18) parser.add_argument("--quantize", action="store_true") return parser.parse_args() def ensure_export_dependencies(quantize: bool) -> None: missing = [] if importlib.util.find_spec("onnxscript") is None: missing.append("onnxscript") if quantize and importlib.util.find_spec("onnxruntime") is None: missing.append("onnxruntime") if not missing: return package_list = " ".join(missing) raise RuntimeError( "Missing export dependencies: " f"{', '.join(missing)}. " "Install them with `uv sync` " f"or `uv add {package_list}` and rerun export." ) def main() -> None: args = parse_args() ensure_export_dependencies(args.quantize) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) device = get_device(prefer_mps=False) model, tokenizer, config = load_checkpoint(args.model_dir, device=device) temperatures = load_temperature_scaling(args.model_dir) wrapper = OnnxWrapper(model).to(device) wrapper.eval() sample = tokenizer( "Current: Export is broken on Safari\nPrevious: None\nPrevious action: None\nPrevious outcome: None", return_tensors="pt", max_length=config.max_length, truncation=True, ) sample = {key: value.to(device) for key, value in sample.items()} sample["previous_action_id"] = torch.tensor([0], dtype=torch.long, device=device) sample["previous_outcome_id"] = torch.tensor([4], dtype=torch.long, device=device) sample["log_recency_seconds"] = torch.tensor([0.0], dtype=torch.float32, device=device) sample["has_interaction"] = torch.tensor([0], dtype=torch.long, device=device) sample["has_recency"] = torch.tensor([0], dtype=torch.long, device=device) output_path = output_dir / "tiny_router.onnx" torch.onnx.export( wrapper, ( sample["input_ids"], sample["attention_mask"], sample["previous_action_id"], sample["previous_outcome_id"], sample["log_recency_seconds"], sample["has_interaction"], sample["has_recency"], ), output_path, input_names=[ "input_ids", "attention_mask", "previous_action_id", "previous_outcome_id", "log_recency_seconds", "has_interaction", "has_recency", ], output_names=[f"{head}_logits" for head in HEAD_LABELS], dynamic_axes={ "input_ids": {1: "sequence"}, "attention_mask": {1: "sequence"}, }, opset_version=args.opset, dynamo=False, ) metadata = { "model_file": str(output_path.name), "feature_mode": config.feature_mode, "heads": list(HEAD_LABELS.keys()), "max_length": config.max_length, "label_maps": config.label_maps, "temperature_scaling": temperatures, } tokenizer.save_pretrained(output_dir) dump_json(output_dir / "model_config.json", config.to_dict()) dump_json(output_dir / "onnx_metadata.json", metadata) if temperatures: dump_json( output_dir / "temperature_scaling.json", { "method": "per_head_temperature_scaling", "source_split": "validation", "per_head": temperatures, }, ) if args.quantize: from onnxruntime.quantization import QuantType, quantize_dynamic quantized_path = output_dir / "tiny_router.int8.onnx" quantize_dynamic( model_input=str(output_path), model_output=str(quantized_path), weight_type=QuantType.QInt8, ) if __name__ == "__main__": main()