| | |
| | |
| | |
| | |
| |
|
| | from transformers.pipelines import TokenClassificationPipeline, AggregationStrategy |
| | from typing import Any, Union, List, Optional, Tuple, Dict |
| | from optimum.onnxruntime import ORTModelForTokenClassification |
| | from transformers import AutoTokenizer |
| |
|
| |
|
| | class MyTokenClassificationPipeline(TokenClassificationPipeline): |
| | def _sanitize_parameters( |
| | self, |
| | ignore_labels=None, |
| | grouped_entities: Optional[bool] = None, |
| | ignore_subwords: Optional[bool] = None, |
| | aggregation_strategy: Optional[AggregationStrategy] = None, |
| | offset_mapping: Optional[List[Tuple[int, int]]] = None, |
| | stride: Optional[int] = None, |
| | ): |
| | preprocess_params, other, postprocess_params = super()._sanitize_parameters( |
| | ignore_labels, |
| | grouped_entities, |
| | ignore_subwords, |
| | aggregation_strategy, |
| | offset_mapping, |
| | stride |
| | ) |
| | preprocess_params['tokenizer_params'] = {'return_token_type_ids': False} |
| | return preprocess_params, other, postprocess_params |
| |
|
| |
|
| | class EndpointHandler(): |
| |
|
| | def __init__(self, path="") -> None: |
| | model = ORTModelForTokenClassification.from_pretrained(path) |
| | tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.pipeline = MyTokenClassificationPipeline(model=model, |
| | framework='pt', |
| | task='ner', |
| | tokenizer=tokenizer, |
| | aggregation_strategy='simple') |
| | |
| | def combine_sentences(self, text, context_len=2): |
| | sentences = text.split(".") |
| | if len(sentences) == 1: |
| | return text |
| | combined = [] |
| | for i in range(0, len(sentences), context_len): |
| | combined.append(".".join(sentences[i:i+context_len])) |
| | return combined |
| | |
| | def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
| | inputs = data.pop("inputs", data) |
| |
|
| | inner_len = 0 |
| | final_list = [] |
| |
|
| | final_sents = self.combine_sentences(inputs, context_len=4) |
| |
|
| | for i, m in enumerate(final_sents): |
| | n = m.replace(",", " ") |
| | res = self.pipeline(n) |
| | if len(res) > 0: |
| | l = [{'word': d['word'], |
| | 'score': d['score'].item(), |
| | 'class': 'skill', |
| | 'start': inner_len + d['start'], |
| | 'end': (inner_len + d['start'])+(d['end'] - d['start'])} |
| | for d in res |
| | ] |
| | final_list.extend(l) |
| | inner_len += len(m) + 1 |
| | return final_list |