from fastapi import Request from loguru import logger from core.conf import settings class NER: def __init__(self, model_dir: str = settings.NER_MODEL_DIR): self.model_dir = model_dir self.model = None self.tokenizer = None self.pipeline = None self.load_model() def load_model(self): from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForTokenClassification from optimum.pipelines import pipeline self.tokenizer = AutoTokenizer.from_pretrained( self.model_dir, truncation=settings.TRUNCATE, max_length=settings.MAX_LENGTH ) self.model = ORTModelForTokenClassification.from_pretrained(self.model_dir) self.pipeline = pipeline( task=settings.TASK_NAME, model=self.model, tokenizer=self.tokenizer, device=settings.DEVICE, ) logger.info(f"Model loaded from {self.model_dir}") async def predict(self, text: str, entity_tag: str = None): if not text: return None if self.pipeline is None: raise ValueError("Model not loaded. Please call load_model() first.") pred = self.pipeline(text) if entity_tag: return self.extract_entities(pred, entity_tag) return pred def extract_entities( self, result_pred: list[dict[str, any]], entity: str ) -> list[str]: if self.pipeline is None: raise ValueError("Model not loaded. Please call load_model() first.") B_ENTITY = f"B-{entity}" I_ENTITY = f"I-{entity}" extracted_entities = [] current_entity_tokens = [] for item in result_pred: word = item["word"] entity_tag = item["entity"] if entity_tag == B_ENTITY: if current_entity_tokens: extracted_entities.append( self._combine_token(current_entity_tokens) ) current_entity_tokens = [word] elif entity_tag == I_ENTITY and current_entity_tokens: current_entity_tokens.append(word) else: if current_entity_tokens: extracted_entities.append( self._combine_token(current_entity_tokens) ) current_entity_tokens = [] if current_entity_tokens: extracted_entities.append(self._combine_token(current_entity_tokens)) return extracted_entities def _combine_token(self, tokens: list[str]) -> str: """Combines tokens into a single string, removing leading hashtags from the first token if present. Args: tokens (list[str]): List of tokens to combine. Returns: str: Combined string of tokens. """ if not tokens: return "" words = [] for token in tokens: if token.strip("#") != token: clean_token = token.strip("#") if words: words[-1] += clean_token else: words.append(clean_token) else: words.append(token) return " ".join(words) def get_ner_model(request: Request) -> NER: """ Dependency to get the NER model. This can be used to inject the NER model into the endpoint. """ return request.app.state.ner