sv-task / src /models /serialize.py
lamossta's picture
hf upload/download and onnx export
9f3aa4a
import argparse
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def export_to_onnx(
model_dir: str,
output_path: str | None = None,
max_len: int = 256,
opset_version: int = 17,
) -> Path:
model_dir = Path(model_dir)
if output_path is None:
output_path = model_dir / "model.onnx"
else:
output_path = Path(output_path)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.eval()
dummy = tokenizer(
["dummy input", "another dummy"],
max_length=max_len,
truncation=True,
padding="max_length",
return_tensors="pt",
)
batch = torch.export.Dim("batch", min=1, max=4096)
torch.onnx.export(
model,
(dummy["input_ids"], dummy["attention_mask"]),
str(output_path),
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_shapes={
"input_ids": {0: batch},
"attention_mask": {0: batch},
},
opset_version=opset_version,
dynamo=True,
external_data=False,
)
print(f"Exported ONNX model to '{output_path}' ({output_path.stat().st_size / 1024 / 1024:.1f} MB)")
return output_path
def main():
parser = argparse.ArgumentParser(description="Export model to ONNX")
parser.add_argument("--model-dir", required=True, help="Path to saved model directory")
parser.add_argument("--output", default=None, help="Output ONNX path (default: model_dir/model.onnx)")
parser.add_argument("--max-len", type=int, default=256)
parser.add_argument("--opset", type=int, default=17)
args = parser.parse_args()
export_to_onnx(args.model_dir, args.output, args.max_len, args.opset)
if __name__ == "__main__":
main()