| from transformers import AutoModel | |
| import torch | |
| max_seq_length = 384 | |
| model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
| model.eval() | |
| inputs = { | |
| "input_ids": torch.ones(1, max_seq_length, dtype=torch.int64), | |
| "attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64), | |
| } | |
| symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} | |
| torch.onnx.export( | |
| model,args=tuple(inputs.values()), | |
| f="model.onnx", | |
| export_params=True, | |
| input_names=["input_ids", "attention_mask"], output_names=["last_hidden_state"], | |
| dynamic_axes={"input_ids": symbolic_names, "attention_mask": symbolic_names} | |
| ) |