########################################## # This file will be copy pasted in the HuggingFace # Model repo for doing inference. ########################################## from transformers.pipelines import TokenClassificationPipeline, AggregationStrategy from typing import Any, Union, List, Optional, Tuple, Dict from optimum.onnxruntime import ORTModelForTokenClassification from transformers import AutoTokenizer class MyTokenClassificationPipeline(TokenClassificationPipeline): def _sanitize_parameters( self, ignore_labels=None, grouped_entities: Optional[bool] = None, ignore_subwords: Optional[bool] = None, aggregation_strategy: Optional[AggregationStrategy] = None, offset_mapping: Optional[List[Tuple[int, int]]] = None, stride: Optional[int] = None, ): preprocess_params, other, postprocess_params = super()._sanitize_parameters( ignore_labels, grouped_entities, ignore_subwords, aggregation_strategy, offset_mapping, stride ) preprocess_params['tokenizer_params'] = {'return_token_type_ids': False} return preprocess_params, other, postprocess_params class EndpointHandler(): def __init__(self, path="") -> None: model = ORTModelForTokenClassification.from_pretrained(path) tokenizer = AutoTokenizer.from_pretrained(path) self.pipeline = MyTokenClassificationPipeline(model=model, framework='pt', task='ner', tokenizer=tokenizer, aggregation_strategy='simple') def combine_sentences(self, text, context_len=2): sentences = text.split(".") if len(sentences) == 1: # edge case return text combined = [] for i in range(0, len(sentences), context_len): combined.append(".".join(sentences[i:i+context_len])) return combined def __call__(self, data: Any) -> List[List[Dict[str, float]]]: inputs = data.pop("inputs", data) inner_len = 0 final_list = [] final_sents = self.combine_sentences(inputs, context_len=4) for i, m in enumerate(final_sents): n = m.replace(",", " ") res = self.pipeline(n) if len(res) > 0: l = [{'word': d['word'], 'score': d['score'].item(), 'class': 'skill', 'start': inner_len + d['start'], 'end': (inner_len + d['start'])+(d['end'] - d['start'])} for d in res ] final_list.extend(l) inner_len += len(m) + 1 return final_list