Spaces:
Sleeping
Sleeping
| from fastapi import Request | |
| from loguru import logger | |
| from core.conf import settings | |
| class NER: | |
| def __init__(self, model_dir: str = settings.NER_MODEL_DIR): | |
| self.model_dir = model_dir | |
| self.model = None | |
| self.tokenizer = None | |
| self.pipeline = None | |
| self.load_model() | |
| def load_model(self): | |
| from transformers import AutoTokenizer | |
| from optimum.onnxruntime import ORTModelForTokenClassification | |
| from optimum.pipelines import pipeline | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_dir, truncation=settings.TRUNCATE, max_length=settings.MAX_LENGTH | |
| ) | |
| self.model = ORTModelForTokenClassification.from_pretrained(self.model_dir) | |
| self.pipeline = pipeline( | |
| task=settings.TASK_NAME, | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=settings.DEVICE, | |
| ) | |
| logger.info(f"Model loaded from {self.model_dir}") | |
| async def predict(self, text: str, entity_tag: str = None): | |
| if not text: | |
| return None | |
| if self.pipeline is None: | |
| raise ValueError("Model not loaded. Please call load_model() first.") | |
| pred = self.pipeline(text) | |
| if entity_tag: | |
| return self.extract_entities(pred, entity_tag) | |
| return pred | |
| def extract_entities( | |
| self, result_pred: list[dict[str, any]], entity: str | |
| ) -> list[str]: | |
| if self.pipeline is None: | |
| raise ValueError("Model not loaded. Please call load_model() first.") | |
| B_ENTITY = f"B-{entity}" | |
| I_ENTITY = f"I-{entity}" | |
| extracted_entities = [] | |
| current_entity_tokens = [] | |
| for item in result_pred: | |
| word = item["word"] | |
| entity_tag = item["entity"] | |
| if entity_tag == B_ENTITY: | |
| if current_entity_tokens: | |
| extracted_entities.append( | |
| self._combine_token(current_entity_tokens) | |
| ) | |
| current_entity_tokens = [word] | |
| elif entity_tag == I_ENTITY and current_entity_tokens: | |
| current_entity_tokens.append(word) | |
| else: | |
| if current_entity_tokens: | |
| extracted_entities.append( | |
| self._combine_token(current_entity_tokens) | |
| ) | |
| current_entity_tokens = [] | |
| if current_entity_tokens: | |
| extracted_entities.append(self._combine_token(current_entity_tokens)) | |
| return extracted_entities | |
| def _combine_token(self, tokens: list[str]) -> str: | |
| """Combines tokens into a single string, removing leading hashtags from the first token if present. | |
| Args: | |
| tokens (list[str]): List of tokens to combine. | |
| Returns: | |
| str: Combined string of tokens. | |
| """ | |
| if not tokens: | |
| return "" | |
| words = [] | |
| for token in tokens: | |
| if token.strip("#") != token: | |
| clean_token = token.strip("#") | |
| if words: | |
| words[-1] += clean_token | |
| else: | |
| words.append(clean_token) | |
| else: | |
| words.append(token) | |
| return " ".join(words) | |
| def get_ner_model(request: Request) -> NER: | |
| """ | |
| Dependency to get the NER model. | |
| This can be used to inject the NER model into the endpoint. | |
| """ | |
| return request.app.state.ner | |