File size: 3,268 Bytes
daa2f46
2c5bd8b
 
daa2f46
 
2c5bd8b
daa2f46
 
ff0b14b
1c4ad1e
ff0b14b
05f6b94
 
 
ff0b14b
 
 
05f6b94
 
2c5bd8b
 
 
286854a
 
 
 
2c5bd8b
05f6b94
 
 
 
d6df1bf
794d04c
 
4f31466
794d04c
2c5bd8b
286854a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5bd8b
286854a
 
 
 
 
1c87edc
286854a
 
 
 
 
 
 
 
 
4f31466
286854a
4f31466
 
286854a
4f31466
286854a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Dict, List, Any
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from bibert_multitask_classification import BiBert_MultiTaskPipeline
from bert_for_sequence_classification import BertForSequenceClassification
from transformers.utils import logging
from time import perf_counter


PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)

logging.set_verbosity_info()
logger = logging.get_logger("transformers")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@dataclass
class Task:
    id: int
    name: str
    type: str
    num_labels: int

tasks = [
    Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
    Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
    ]


idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral",  "3":"Positive", "4": "Positive" }
idtoscore =  {"0": -1, "1": -1, "2": 0,  "3": 1, "4": 1 }

class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        logger.info("The device is %s.", device)

        t0 = perf_counter()

        tokenizer = AutoTokenizer.from_pretrained(path)
        model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
        self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
        self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
        elapsed = 1000 * (perf_counter() - t0)
        logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
    
        inputs = data.pop("inputs", data)
        #lang = data.pop("lang", None)
        #logger.info("The language of Verbatim is %s.", lang)
        if isinstance(inputs, str):
            inputs = [inputs]
    
        t0 = perf_counter()
        prediction_res = []
        classifier_pol = self.classifier_p(inputs)
        classifier_subj = self.classifier_s(inputs)
        logger.info("Prediction polarity %s", classifier_pol)
        logger.info("Prediction subjective %s", classifier_subj)

        for idx, x in enumerate(inputs):
            label = classifier_pol[idx]['label']
            prob = classifier_pol[idx]['probability']

            if label == '0' and prob >= 0.75:
                prediction_res.append({"label":"Neutral", "score":0}) 
            else: 
                prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label']), "score": idtoscore.get(classifier_subj[idx]['label'])})
        elapsed = 1000 * (perf_counter() - t0)
        logger.info("Model prediction time: %d ms.", elapsed)  
        return prediction_res