import argparse from pathlib import Path import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer def export_to_onnx( model_dir: str, output_path: str | None = None, max_len: int = 256, opset_version: int = 17, ) -> Path: model_dir = Path(model_dir) if output_path is None: output_path = model_dir / "model.onnx" else: output_path = Path(output_path) tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) model.eval() dummy = tokenizer( ["dummy input", "another dummy"], max_length=max_len, truncation=True, padding="max_length", return_tensors="pt", ) batch = torch.export.Dim("batch", min=1, max=4096) torch.onnx.export( model, (dummy["input_ids"], dummy["attention_mask"]), str(output_path), input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_shapes={ "input_ids": {0: batch}, "attention_mask": {0: batch}, }, opset_version=opset_version, dynamo=True, external_data=False, ) print(f"Exported ONNX model to '{output_path}' ({output_path.stat().st_size / 1024 / 1024:.1f} MB)") return output_path def main(): parser = argparse.ArgumentParser(description="Export model to ONNX") parser.add_argument("--model-dir", required=True, help="Path to saved model directory") parser.add_argument("--output", default=None, help="Output ONNX path (default: model_dir/model.onnx)") parser.add_argument("--max-len", type=int, default=256) parser.add_argument("--opset", type=int, default=17) args = parser.parse_args() export_to_onnx(args.model_dir, args.output, args.max_len, args.opset) if __name__ == "__main__": main()