Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import importlib.util | |
| from pathlib import Path | |
| import torch | |
| from tiny_router.constants import HEAD_LABELS | |
| from tiny_router.io import load_checkpoint, load_temperature_scaling | |
| from tiny_router.runtime import dump_json, get_device | |
| class OnnxWrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward( | |
| self, | |
| input_ids, | |
| attention_mask, | |
| previous_action_id, | |
| previous_outcome_id, | |
| log_recency_seconds, | |
| has_interaction, | |
| has_recency, | |
| ): | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| previous_action_id=previous_action_id, | |
| previous_outcome_id=previous_outcome_id, | |
| log_recency_seconds=log_recency_seconds, | |
| has_interaction=has_interaction, | |
| has_recency=has_recency, | |
| ) | |
| return tuple(outputs["logits"][head] for head in HEAD_LABELS) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Export tiny-router to ONNX.") | |
| parser.add_argument("--model-dir", required=True) | |
| parser.add_argument("--output-dir", required=True) | |
| parser.add_argument("--opset", type=int, default=18) | |
| parser.add_argument("--quantize", action="store_true") | |
| return parser.parse_args() | |
| def ensure_export_dependencies(quantize: bool) -> None: | |
| missing = [] | |
| if importlib.util.find_spec("onnxscript") is None: | |
| missing.append("onnxscript") | |
| if quantize and importlib.util.find_spec("onnxruntime") is None: | |
| missing.append("onnxruntime") | |
| if not missing: | |
| return | |
| package_list = " ".join(missing) | |
| raise RuntimeError( | |
| "Missing export dependencies: " | |
| f"{', '.join(missing)}. " | |
| "Install them with `uv sync` " | |
| f"or `uv add {package_list}` and rerun export." | |
| ) | |
| def main() -> None: | |
| args = parse_args() | |
| ensure_export_dependencies(args.quantize) | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| device = get_device(prefer_mps=False) | |
| model, tokenizer, config = load_checkpoint(args.model_dir, device=device) | |
| temperatures = load_temperature_scaling(args.model_dir) | |
| wrapper = OnnxWrapper(model).to(device) | |
| wrapper.eval() | |
| sample = tokenizer( | |
| "Current: Export is broken on Safari\nPrevious: None\nPrevious action: None\nPrevious outcome: None", | |
| return_tensors="pt", | |
| max_length=config.max_length, | |
| truncation=True, | |
| ) | |
| sample = {key: value.to(device) for key, value in sample.items()} | |
| sample["previous_action_id"] = torch.tensor([0], dtype=torch.long, device=device) | |
| sample["previous_outcome_id"] = torch.tensor([4], dtype=torch.long, device=device) | |
| sample["log_recency_seconds"] = torch.tensor([0.0], dtype=torch.float32, device=device) | |
| sample["has_interaction"] = torch.tensor([0], dtype=torch.long, device=device) | |
| sample["has_recency"] = torch.tensor([0], dtype=torch.long, device=device) | |
| output_path = output_dir / "tiny_router.onnx" | |
| torch.onnx.export( | |
| wrapper, | |
| ( | |
| sample["input_ids"], | |
| sample["attention_mask"], | |
| sample["previous_action_id"], | |
| sample["previous_outcome_id"], | |
| sample["log_recency_seconds"], | |
| sample["has_interaction"], | |
| sample["has_recency"], | |
| ), | |
| output_path, | |
| input_names=[ | |
| "input_ids", | |
| "attention_mask", | |
| "previous_action_id", | |
| "previous_outcome_id", | |
| "log_recency_seconds", | |
| "has_interaction", | |
| "has_recency", | |
| ], | |
| output_names=[f"{head}_logits" for head in HEAD_LABELS], | |
| dynamic_axes={ | |
| "input_ids": {1: "sequence"}, | |
| "attention_mask": {1: "sequence"}, | |
| }, | |
| opset_version=args.opset, | |
| dynamo=False, | |
| ) | |
| metadata = { | |
| "model_file": str(output_path.name), | |
| "feature_mode": config.feature_mode, | |
| "heads": list(HEAD_LABELS.keys()), | |
| "max_length": config.max_length, | |
| "label_maps": config.label_maps, | |
| "temperature_scaling": temperatures, | |
| } | |
| tokenizer.save_pretrained(output_dir) | |
| dump_json(output_dir / "model_config.json", config.to_dict()) | |
| dump_json(output_dir / "onnx_metadata.json", metadata) | |
| if temperatures: | |
| dump_json( | |
| output_dir / "temperature_scaling.json", | |
| { | |
| "method": "per_head_temperature_scaling", | |
| "source_split": "validation", | |
| "per_head": temperatures, | |
| }, | |
| ) | |
| if args.quantize: | |
| from onnxruntime.quantization import QuantType, quantize_dynamic | |
| quantized_path = output_dir / "tiny_router.int8.onnx" | |
| quantize_dynamic( | |
| model_input=str(output_path), | |
| model_output=str(quantized_path), | |
| weight_type=QuantType.QInt8, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |