File size: 2,914 Bytes
0584798 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
BASE_DIR = Path(__file__).resolve().parent.parent
if str(BASE_DIR) not in sys.path:
sys.path.insert(0, str(BASE_DIR))
from config import MULTITASK_INTENT_MODEL_DIR # noqa: E402
from multitask_runtime import get_multitask_runtime # noqa: E402
class _OnnxMultiTaskWrapper(torch.nn.Module):
def __init__(self, runtime):
super().__init__()
self.model = runtime.model
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
return (
outputs["intent_type_logits"],
outputs["intent_subtype_logits"],
outputs["decision_phase_logits"],
)
def main() -> None:
parser = argparse.ArgumentParser(description="Export multitask intent model to ONNX.")
parser.add_argument(
"--output-path",
default=str(MULTITASK_INTENT_MODEL_DIR / "multitask_intent.onnx"),
help="Output ONNX file path.",
)
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version.")
args = parser.parse_args()
runtime = get_multitask_runtime()
tokenizer = runtime.tokenizer
wrapper = _OnnxMultiTaskWrapper(runtime)
wrapper.eval()
sample = tokenizer(
["sample query for intent classification"],
return_tensors="pt",
truncation=True,
padding=True,
max_length=int(runtime.metadata.get("max_length", 96)),
)
output_path = Path(args.output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
torch.onnx.export(
wrapper,
(sample["input_ids"], sample["attention_mask"]),
str(output_path),
input_names=["input_ids", "attention_mask"],
output_names=[
"intent_type_logits",
"intent_subtype_logits",
"decision_phase_logits",
],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"intent_type_logits": {0: "batch_size"},
"intent_subtype_logits": {0: "batch_size"},
"decision_phase_logits": {0: "batch_size"},
},
opset_version=args.opset,
)
except ModuleNotFoundError as e:
# Newer torch ONNX exporter requires `onnxscript` (and `onnx`).
if e.name == "onnxscript" or "onnxscript" in str(e).lower():
print(
"Skipping ONNX export: missing dependency `onnxscript`.\n"
"Install with: `pip install onnx onnxscript`",
file=sys.stderr,
)
return
raise
print(f"Exported ONNX model: {output_path}")
if __name__ == "__main__":
main()
|