File size: 628 Bytes
399941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from functools import lru_cache

from transformers import pipeline, Pipeline


@lru_cache
def init_model( task: str, model: str = None,  aggregation_strategy: str = None) -> Pipeline:
    ner_pipeline = pipeline(
        task, model=model, aggregation_strategy=aggregation_strategy
    )
    return ner_pipeline


def custom_predict(text: str, pipe: str):
    result = pipe(text, aggregation_strategy="simple")
    ents = [
        {"start": dic['start'],
         "end": dic['end'],
         "label": dic['entity_group']}
        for dic in result]
    return {"text": text,
            "ents": ents,
            "title": None}