File size: 3,388 Bytes
be80ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0ec28
be80ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
##########################################
# This file will be copy pasted in the HuggingFace
# Model repo for doing inference.
##########################################

from transformers.pipelines import TokenClassificationPipeline, AggregationStrategy
from typing import Any, Union, List, Optional, Tuple, Dict
from optimum.onnxruntime import ORTModelForTokenClassification
from transformers import AutoTokenizer


class MyTokenClassificationPipeline(TokenClassificationPipeline):
    def _sanitize_parameters(
        self,
        ignore_labels=None,
        grouped_entities: Optional[bool] = None,
        ignore_subwords: Optional[bool] = None,
        aggregation_strategy: Optional[AggregationStrategy] = None,
        offset_mapping: Optional[List[Tuple[int, int]]] = None,
        stride: Optional[int] = None,
    ):
        preprocess_params, other, postprocess_params = super()._sanitize_parameters(
                                                                                ignore_labels,
                                                                                grouped_entities,
                                                                                ignore_subwords,
                                                                                aggregation_strategy,
                                                                                offset_mapping,
                                                                                stride
                                                                                )
        preprocess_params['tokenizer_params'] = {'return_token_type_ids': False}
        return preprocess_params, other, postprocess_params


class EndpointHandler():

    def __init__(self, path="") -> None:
        model = ORTModelForTokenClassification.from_pretrained(path)
        tokenizer = AutoTokenizer.from_pretrained(path)
        self.pipeline = MyTokenClassificationPipeline(model=model,
                                                    framework='pt',
                                                    task='ner',
                                                    tokenizer=tokenizer,
                                                    aggregation_strategy='simple')
    
    def combine_sentences(self, text, context_len=2):
        sentences = text.split(".")
        if len(sentences) == 1:  # edge case
            return text
        combined = []
        for i in range(0, len(sentences), context_len):
            combined.append(".".join(sentences[i:i+context_len]))
        return combined
    
    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        inputs = data.pop("inputs", data)

        inner_len = 0
        final_list = []

        final_sents = self.combine_sentences(inputs, context_len=4)

        for i, m in enumerate(final_sents):
            n = m.replace(",", " ")
            res = self.pipeline(n)
            if len(res) > 0:
                l = [{'word': d['word'], 
                    'score': d['score'].item(),
                    'class': 'skill',
                    'start':  inner_len + d['start'], 
                    'end':  (inner_len + d['start'])+(d['end'] - d['start'])}
                        for d in res
                    ]
                final_list.extend(l)
            inner_len += len(m) + 1
        return final_list