BiBert-MultiTask-2 / handler.py
Kalaoke's picture
add score to output
4f31466
from typing import Dict, List, Any
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from bibert_multitask_classification import BiBert_MultiTaskPipeline
from bert_for_sequence_classification import BertForSequenceClassification
from transformers.utils import logging
from time import perf_counter
PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@dataclass
class Task:
id: int
name: str
type: str
num_labels: int
tasks = [
Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
]
idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
idtoscore = {"0": -1, "1": -1, "2": 0, "3": 1, "4": 1 }
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
logger.info("The device is %s.", device)
t0 = perf_counter()
tokenizer = AutoTokenizer.from_pretrained(path)
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
elapsed = 1000 * (perf_counter() - t0)
logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
#lang = data.pop("lang", None)
#logger.info("The language of Verbatim is %s.", lang)
if isinstance(inputs, str):
inputs = [inputs]
t0 = perf_counter()
prediction_res = []
classifier_pol = self.classifier_p(inputs)
classifier_subj = self.classifier_s(inputs)
logger.info("Prediction polarity %s", classifier_pol)
logger.info("Prediction subjective %s", classifier_subj)
for idx, x in enumerate(inputs):
label = classifier_pol[idx]['label']
prob = classifier_pol[idx]['probability']
if label == '0' and prob >= 0.75:
prediction_res.append({"label":"Neutral", "score":0})
else:
prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label']), "score": idtoscore.get(classifier_subj[idx]['label'])})
elapsed = 1000 * (perf_counter() - t0)
logger.info("Model prediction time: %d ms.", elapsed)
return prediction_res