from functools import lru_cache from transformers import pipeline, Pipeline @lru_cache def init_model( task: str, model: str = None, aggregation_strategy: str = None) -> Pipeline: ner_pipeline = pipeline( task, model=model, aggregation_strategy=aggregation_strategy ) return ner_pipeline def custom_predict(text: str, pipe: str): result = pipe(text, aggregation_strategy="simple") ents = [ {"start": dic['start'], "end": dic['end'], "label": dic['entity_group']} for dic in result] return {"text": text, "ents": ents, "title": None}