Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import onnxruntime as ort | |
| import torch | |
| from .config_train import onnx_path, tokenizer | |
| from .DataProcessing import read_input | |
| from .load_data import sorted_tags | |
| class Key_Ner_ONNX_Predictor: | |
| def __init__(self, model_path, tokenizer, tag_map): | |
| """ | |
| Initialize the ONNX predictor. | |
| Args: | |
| model_path (str): Path to the ONNX model. | |
| tokenizer (BertTokenizer): Tokenizer to process input sentences. | |
| tag_map (Dict[int, str]): Mapping of indices to tags. | |
| """ | |
| self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) | |
| self.tokenizer = tokenizer | |
| self.tag_map = tag_map | |
| def predict(self, sentence): | |
| """ | |
| Predict tags using the ONNX model. | |
| Args: | |
| sentence (str): Input sentence. | |
| Returns: | |
| Tuple[str, List[str]]: Original sentence and predicted tags. | |
| """ | |
| sentence = read_input(sentence) | |
| tokens = self.tokenizer(sentence, return_tensors="np", padding=True, truncation=True) | |
| # Convert to int64 (ONNX requirement) | |
| input_ids = tokens["input_ids"].astype(np.int64) | |
| attention_mask = tokens["attention_mask"].astype(np.int64) | |
| # Run inference | |
| outputs = self.session.run(None, { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask | |
| }) | |
| logits = outputs[0] | |
| predicted_tags = np.argmax(logits, axis=2)[0] | |
| # Convert indices to tags | |
| predicted_tags = [self.tag_map[idx] for idx in predicted_tags] | |
| predicted_tags = set(predicted_tags) | |
| predicted_tags.discard('<pad>') | |
| predicted_tags = [tag.replace(" ", "_") for tag in predicted_tags] | |
| return self.tokenizer.decode(input_ids[0], skip_special_tokens=True), predicted_tags | |
| # Initialize ONNX-based predictor | |
| onnx_predictor = Key_Ner_ONNX_Predictor( | |
| model_path=onnx_path, | |
| tokenizer=tokenizer, | |
| tag_map=dict(enumerate(sorted_tags)) | |
| ) | |