| from sentence_transformers import SentenceTransformer
|
| import torch
|
| from transformers import AutoTokenizer
|
| def convert_onnx():
|
|
|
| model = SentenceTransformer('model/distiluse-base-multilingual-cased-v2')
|
| torch.save(model, "model/distiluse-base-multilingual-cased-v2.pt")
|
|
|
| model = torch.load("model/distiluse-base-multilingual-cased-v2.pt")
|
| tokenizer = AutoTokenizer.from_pretrained('model/distiluse-base-multilingual-cased-v2')
|
|
|
|
|
| lst_input = ["Pham Minh Chinh is Vietnam's Prime Minister"]
|
|
|
| x = tokenizer(lst_input, padding="max_length", truncation=True)
|
|
|
| print(x)
|
| save_path = 'tensorRT/models/distiluse-base-multilingual-cased-v2.onnx'
|
|
|
| torch.onnx.export(model, (torch.tensor(x['input_ids'], dtype=torch.long),torch.tensor(x['attention_mask'], dtype=torch.long)), save_path, export_params=True, opset_version=13, do_constant_folding=True,
|
| input_names = ['input_ids', 'attention_mask'],
|
| output_names = ['output'],
|
| dynamic_axes={'input_ids' : {0 : 'batch_size'}, 'attention_mask': {0 : 'batch_size'},
|
| 'output' : {0 : 'batch_size'}}
|
| )
|
|
|
|
|
| def convert_onnx_(model_name= "model/model-sup-simcse-vn", pt_model = 'model-sup-simcse-vn', max_length = 256, save_path = "tensorRT/models/tensorRT/models/model-sup-simcse-vn.onnx"):
|
| model = SentenceTransformer(model_name)
|
| torch.save(model, f"model/{pt_model}.pt")
|
|
|
| model = torch.load(f"model/{pt_model}.pt")
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
| lst_input = ["Pham Minh Chinh is Vietnam's Prime Minister"]
|
|
|
| x = tokenizer(lst_input, padding="max_length", truncation=True, max_length=256)
|
|
|
| print(x)
|
|
|
| torch.onnx.export(model, (torch.tensor(x['input_ids'], dtype=torch.long),torch.tensor(x['attention_mask'], dtype=torch.long), torch.tensor(x['token_type_ids'], dtype=torch.long)), save_path, export_params=True, opset_version=13, do_constant_folding=True,
|
| input_names = ['input_ids', 'attention_mask','token_type_ids'],
|
| output_names = ['output'],
|
| dynamic_axes={'input_ids' : {0 : 'batch_size'}, 'attention_mask': {0 : 'batch_size'},
|
| 'output' : {0 : 'batch_size'}}
|
| )
|
| if __name__ == '__main__':
|
| convert_onnx_() |