File size: 3,299 Bytes
df37f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from loguru import logger
from ...core.conf import settings


class ElectraModel:
    def __init__(self, model_dir: str = settings.NER_MODEL_DIR):
        self.model_dir = model_dir
        self.model = None
        self.tokenizer = None
        self.pipeline = None
        self.load_model()

    def load_model(self):
        from transformers import AutoTokenizer
        from optimum.onnxruntime import ORTModelForTokenClassification
        from optimum.pipelines import pipeline

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_dir, truncation=settings.TRUNCATE, max_length=settings.MAX_LENGTH
        )
        self.model = ORTModelForTokenClassification.from_pretrained(self.model_dir)
        self.pipeline = pipeline(
            task=settings.TASK_NAME,
            model=self.model,
            tokenizer=self.tokenizer,
            device=settings.DEVICE,
        )
        logger.info(f"Model loaded from {self.model_dir}")

    async def predict(self, text: str, entity_tag: str = None):

        if not text:
            return None

        if self.pipeline is None:
            raise ValueError("Model not loaded. Please call load_model() first.")

        pred = self.pipeline(text)

        if entity_tag:
            return self.extract_entities(pred, entity_tag)
        return pred

    def extract_entities(
        self, result_pred: list[dict[str, any]], entity: str
    ) -> list[str]:
        if self.pipeline is None:
            raise ValueError("Model not loaded. Please call load_model() first.")
        B_ENTITY = f"B-{entity}"
        I_ENTITY = f"I-{entity}"

        extracted_entities = []
        current_entity_tokens = []

        for item in result_pred:
            word = item["word"]
            entity_tag = item["entity"]

            if entity_tag == B_ENTITY:
                if current_entity_tokens:
                    extracted_entities.append(
                        self._combine_token(current_entity_tokens)
                    )
                current_entity_tokens = [word]
            elif entity_tag == I_ENTITY and current_entity_tokens:
                current_entity_tokens.append(word)
            else:
                if current_entity_tokens:
                    extracted_entities.append(
                        self._combine_token(current_entity_tokens)
                    )
                    current_entity_tokens = []

        if current_entity_tokens:
            extracted_entities.append(self._combine_token(current_entity_tokens))

        return extracted_entities

    def _combine_token(self, tokens: list[str]) -> str:
        """Combines tokens into a single string, removing leading hashtags from the first token if present.
        Args:
            tokens (list[str]): List of tokens to combine.

        Returns:
            str: Combined string of tokens.
        """
        if not tokens:
            return ""

        words = []

        for token in tokens:
            if token.strip("#") != token:
                clean_token = token.strip("#")
                if words:
                    words[-1] += clean_token
                else:
                    words.append(clean_token)
            else:
                words.append(token)

        return " ".join(words)