OnlyBiggg
refactor code
df37f6e
from fastapi import Request
from loguru import logger
from core.conf import settings
class NER:
def __init__(self, model_dir: str = settings.NER_MODEL_DIR):
self.model_dir = model_dir
self.model = None
self.tokenizer = None
self.pipeline = None
self.load_model()
def load_model(self):
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForTokenClassification
from optimum.pipelines import pipeline
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir, truncation=settings.TRUNCATE, max_length=settings.MAX_LENGTH
)
self.model = ORTModelForTokenClassification.from_pretrained(self.model_dir)
self.pipeline = pipeline(
task=settings.TASK_NAME,
model=self.model,
tokenizer=self.tokenizer,
device=settings.DEVICE,
)
logger.info(f"Model loaded from {self.model_dir}")
async def predict(self, text: str, entity_tag: str = None):
if not text:
return None
if self.pipeline is None:
raise ValueError("Model not loaded. Please call load_model() first.")
pred = self.pipeline(text)
if entity_tag:
return self.extract_entities(pred, entity_tag)
return pred
def extract_entities(
self, result_pred: list[dict[str, any]], entity: str
) -> list[str]:
if self.pipeline is None:
raise ValueError("Model not loaded. Please call load_model() first.")
B_ENTITY = f"B-{entity}"
I_ENTITY = f"I-{entity}"
extracted_entities = []
current_entity_tokens = []
for item in result_pred:
word = item["word"]
entity_tag = item["entity"]
if entity_tag == B_ENTITY:
if current_entity_tokens:
extracted_entities.append(
self._combine_token(current_entity_tokens)
)
current_entity_tokens = [word]
elif entity_tag == I_ENTITY and current_entity_tokens:
current_entity_tokens.append(word)
else:
if current_entity_tokens:
extracted_entities.append(
self._combine_token(current_entity_tokens)
)
current_entity_tokens = []
if current_entity_tokens:
extracted_entities.append(self._combine_token(current_entity_tokens))
return extracted_entities
def _combine_token(self, tokens: list[str]) -> str:
"""Combines tokens into a single string, removing leading hashtags from the first token if present.
Args:
tokens (list[str]): List of tokens to combine.
Returns:
str: Combined string of tokens.
"""
if not tokens:
return ""
words = []
for token in tokens:
if token.strip("#") != token:
clean_token = token.strip("#")
if words:
words[-1] += clean_token
else:
words.append(clean_token)
else:
words.append(token)
return " ".join(words)
def get_ner_model(request: Request) -> NER:
"""
Dependency to get the NER model.
This can be used to inject the NER model into the endpoint.
"""
return request.app.state.ner