Spaces:
Sleeping
Sleeping
| import onnx | |
| import onnxruntime | |
| import torch | |
| from transformers import BertForTokenClassification | |
| from .config_train import model_load_path, onnx_path, tokenizer | |
| # Convert Model to ONNX | |
| def convert_to_onnx(model_path, tokenizer): | |
| """Convert the fine-tuned BERT token classification model to ONNX.""" | |
| model = BertForTokenClassification.from_pretrained(model_path) | |
| model.eval() | |
| # Dummy input | |
| dummy_sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình" | |
| inputs = tokenizer(dummy_sentence, return_tensors="pt", padding=True, truncation=True) | |
| dummy_input_ids = inputs["input_ids"] | |
| dummy_attention_mask = inputs["attention_mask"] | |
| # Export ONNX model | |
| torch.onnx.export( | |
| model, | |
| (inputs["input_ids"], inputs["attention_mask"]), # Tuple of model inputs | |
| onnx_path, | |
| export_params=True, | |
| opset_version=14, # Use Opset 14 or higher | |
| input_names=["input_ids", "attention_mask"], | |
| output_names=["logits"], | |
| dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, | |
| "attention_mask": {0: "batch_size", 1: "sequence_length"}, | |
| "logits": {0: "batch_size", 1: "sequence_length"}}, | |
| ) | |
| print(f"✅ ONNX model saved to {onnx_path}") | |
| convert_to_onnx(model_load_path, tokenizer) | |