shubhasz's picture
small bug
be0ec28
##########################################
# This file will be copy pasted in the HuggingFace
# Model repo for doing inference.
##########################################
from transformers.pipelines import TokenClassificationPipeline, AggregationStrategy
from typing import Any, Union, List, Optional, Tuple, Dict
from optimum.onnxruntime import ORTModelForTokenClassification
from transformers import AutoTokenizer
class MyTokenClassificationPipeline(TokenClassificationPipeline):
def _sanitize_parameters(
self,
ignore_labels=None,
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
offset_mapping: Optional[List[Tuple[int, int]]] = None,
stride: Optional[int] = None,
):
preprocess_params, other, postprocess_params = super()._sanitize_parameters(
ignore_labels,
grouped_entities,
ignore_subwords,
aggregation_strategy,
offset_mapping,
stride
)
preprocess_params['tokenizer_params'] = {'return_token_type_ids': False}
return preprocess_params, other, postprocess_params
class EndpointHandler():
def __init__(self, path="") -> None:
model = ORTModelForTokenClassification.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
self.pipeline = MyTokenClassificationPipeline(model=model,
framework='pt',
task='ner',
tokenizer=tokenizer,
aggregation_strategy='simple')
def combine_sentences(self, text, context_len=2):
sentences = text.split(".")
if len(sentences) == 1: # edge case
return text
combined = []
for i in range(0, len(sentences), context_len):
combined.append(".".join(sentences[i:i+context_len]))
return combined
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
inner_len = 0
final_list = []
final_sents = self.combine_sentences(inputs, context_len=4)
for i, m in enumerate(final_sents):
n = m.replace(",", " ")
res = self.pipeline(n)
if len(res) > 0:
l = [{'word': d['word'],
'score': d['score'].item(),
'class': 'skill',
'start': inner_len + d['start'],
'end': (inner_len + d['start'])+(d['end'] - d['start'])}
for d in res
]
final_list.extend(l)
inner_len += len(m) + 1
return final_list