| import argparse |
| from pathlib import Path |
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
| def export_to_onnx( |
| model_dir: str, |
| output_path: str | None = None, |
| max_len: int = 256, |
| opset_version: int = 17, |
| ) -> Path: |
| model_dir = Path(model_dir) |
| if output_path is None: |
| output_path = model_dir / "model.onnx" |
| else: |
| output_path = Path(output_path) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
| model.eval() |
|
|
| dummy = tokenizer( |
| ["dummy input", "another dummy"], |
| max_length=max_len, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| batch = torch.export.Dim("batch", min=1, max=4096) |
| torch.onnx.export( |
| model, |
| (dummy["input_ids"], dummy["attention_mask"]), |
| str(output_path), |
| input_names=["input_ids", "attention_mask"], |
| output_names=["logits"], |
| dynamic_shapes={ |
| "input_ids": {0: batch}, |
| "attention_mask": {0: batch}, |
| }, |
| opset_version=opset_version, |
| dynamo=True, |
| external_data=False, |
| ) |
|
|
| print(f"Exported ONNX model to '{output_path}' ({output_path.stat().st_size / 1024 / 1024:.1f} MB)") |
| return output_path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Export model to ONNX") |
| parser.add_argument("--model-dir", required=True, help="Path to saved model directory") |
| parser.add_argument("--output", default=None, help="Output ONNX path (default: model_dir/model.onnx)") |
| parser.add_argument("--max-len", type=int, default=256) |
| parser.add_argument("--opset", type=int, default=17) |
| args = parser.parse_args() |
|
|
| export_to_onnx(args.model_dir, args.output, args.max_len, args.opset) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|