File size: 3,511 Bytes
df37f6e
 
cea87ce
c524d8c
df37f6e
c524d8c
765f020
 
c524d8c
 
 
df37f6e
c524d8c
 
5816e57
 
 
c524d8c
9a28506
df37f6e
 
5816e57
df37f6e
 
 
 
 
 
 
 
3661274
df37f6e
 
 
 
c524d8c
 
df37f6e
9a28506
df37f6e
c524d8c
 
 
df37f6e
 
 
 
c524d8c
 
 
 
df37f6e
c524d8c
 
df37f6e
c524d8c
 
 
df37f6e
c524d8c
 
df37f6e
 
 
c524d8c
 
 
 
 
df37f6e
 
 
c524d8c
df37f6e
c524d8c
 
df37f6e
c524d8c
 
05b674b
c524d8c
 
 
df37f6e
c524d8c
 
 
 
 
df37f6e
c524d8c
df37f6e
c524d8c
 
 
 
 
 
 
 
 
df37f6e
c524d8c
df37f6e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import Request
from loguru import logger
from core.conf import settings


class NER:
    def __init__(self, model_dir: str = settings.NER_MODEL_DIR):
        self.model_dir = model_dir
        self.model = None
        self.tokenizer = None
        self.pipeline = None
        self.load_model()

    def load_model(self):
        from transformers import AutoTokenizer
        from optimum.onnxruntime import ORTModelForTokenClassification
        from optimum.pipelines import pipeline

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_dir, truncation=settings.TRUNCATE, max_length=settings.MAX_LENGTH
        )
        self.model = ORTModelForTokenClassification.from_pretrained(self.model_dir)
        self.pipeline = pipeline(
            task=settings.TASK_NAME,
            model=self.model,
            tokenizer=self.tokenizer,
            device=settings.DEVICE,
        )
        logger.info(f"Model loaded from {self.model_dir}")

    async def predict(self, text: str, entity_tag: str = None):

        if not text:
            return None

        if self.pipeline is None:
            raise ValueError("Model not loaded. Please call load_model() first.")

        pred = self.pipeline(text)

        if entity_tag:
            return self.extract_entities(pred, entity_tag)
        return pred

    def extract_entities(
        self, result_pred: list[dict[str, any]], entity: str
    ) -> list[str]:
        if self.pipeline is None:
            raise ValueError("Model not loaded. Please call load_model() first.")
        B_ENTITY = f"B-{entity}"
        I_ENTITY = f"I-{entity}"

        extracted_entities = []
        current_entity_tokens = []

        for item in result_pred:
            word = item["word"]
            entity_tag = item["entity"]

            if entity_tag == B_ENTITY:
                if current_entity_tokens:
                    extracted_entities.append(
                        self._combine_token(current_entity_tokens)
                    )
                current_entity_tokens = [word]
            elif entity_tag == I_ENTITY and current_entity_tokens:
                current_entity_tokens.append(word)
            else:
                if current_entity_tokens:
                    extracted_entities.append(
                        self._combine_token(current_entity_tokens)
                    )
                    current_entity_tokens = []

        if current_entity_tokens:
            extracted_entities.append(self._combine_token(current_entity_tokens))

        return extracted_entities

    def _combine_token(self, tokens: list[str]) -> str:
        """Combines tokens into a single string, removing leading hashtags from the first token if present.
        Args:
            tokens (list[str]): List of tokens to combine.

        Returns:
            str: Combined string of tokens.
        """
        if not tokens:
            return ""

        words = []

        for token in tokens:
            if token.strip("#") != token:
                clean_token = token.strip("#")
                if words:
                    words[-1] += clean_token
                else:
                    words.append(clean_token)
            else:
                words.append(token)

        return " ".join(words)


def get_ner_model(request: Request) -> NER:
    """
    Dependency to get the NER model.
    This can be used to inject the NER model into the endpoint.
    """
    return request.app.state.ner