Spaces:
Sleeping
Sleeping
File size: 3,511 Bytes
df37f6e cea87ce c524d8c df37f6e c524d8c 765f020 c524d8c df37f6e c524d8c 5816e57 c524d8c 9a28506 df37f6e 5816e57 df37f6e 3661274 df37f6e c524d8c df37f6e 9a28506 df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c 05b674b c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e c524d8c df37f6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from fastapi import Request
from loguru import logger
from core.conf import settings
class NER:
def __init__(self, model_dir: str = settings.NER_MODEL_DIR):
self.model_dir = model_dir
self.model = None
self.tokenizer = None
self.pipeline = None
self.load_model()
def load_model(self):
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForTokenClassification
from optimum.pipelines import pipeline
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir, truncation=settings.TRUNCATE, max_length=settings.MAX_LENGTH
)
self.model = ORTModelForTokenClassification.from_pretrained(self.model_dir)
self.pipeline = pipeline(
task=settings.TASK_NAME,
model=self.model,
tokenizer=self.tokenizer,
device=settings.DEVICE,
)
logger.info(f"Model loaded from {self.model_dir}")
async def predict(self, text: str, entity_tag: str = None):
if not text:
return None
if self.pipeline is None:
raise ValueError("Model not loaded. Please call load_model() first.")
pred = self.pipeline(text)
if entity_tag:
return self.extract_entities(pred, entity_tag)
return pred
def extract_entities(
self, result_pred: list[dict[str, any]], entity: str
) -> list[str]:
if self.pipeline is None:
raise ValueError("Model not loaded. Please call load_model() first.")
B_ENTITY = f"B-{entity}"
I_ENTITY = f"I-{entity}"
extracted_entities = []
current_entity_tokens = []
for item in result_pred:
word = item["word"]
entity_tag = item["entity"]
if entity_tag == B_ENTITY:
if current_entity_tokens:
extracted_entities.append(
self._combine_token(current_entity_tokens)
)
current_entity_tokens = [word]
elif entity_tag == I_ENTITY and current_entity_tokens:
current_entity_tokens.append(word)
else:
if current_entity_tokens:
extracted_entities.append(
self._combine_token(current_entity_tokens)
)
current_entity_tokens = []
if current_entity_tokens:
extracted_entities.append(self._combine_token(current_entity_tokens))
return extracted_entities
def _combine_token(self, tokens: list[str]) -> str:
"""Combines tokens into a single string, removing leading hashtags from the first token if present.
Args:
tokens (list[str]): List of tokens to combine.
Returns:
str: Combined string of tokens.
"""
if not tokens:
return ""
words = []
for token in tokens:
if token.strip("#") != token:
clean_token = token.strip("#")
if words:
words[-1] += clean_token
else:
words.append(clean_token)
else:
words.append(token)
return " ".join(words)
def get_ner_model(request: Request) -> NER:
"""
Dependency to get the NER model.
This can be used to inject the NER model into the endpoint.
"""
return request.app.state.ner
|