File size: 3,140 Bytes
daa2f46 2c5bd8b daa2f46 2c5bd8b daa2f46 ff0b14b 1c4ad1e ff0b14b 05f6b94 ff0b14b 05f6b94 2c5bd8b 286854a 2c5bd8b 05f6b94 d6df1bf 794d04c 2c5bd8b 286854a 2c5bd8b 286854a 1c87edc 286854a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | from typing import Dict, List, Any
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from bibert_multitask_classification import BiBert_MultiTaskPipeline
from bert_for_sequence_classification import BertForSequenceClassification
from transformers.utils import logging
from time import perf_counter
PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@dataclass
class Task:
id: int
name: str
type: str
num_labels: int
tasks = [
Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
]
idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
logger.info("The device is %s.", device)
t0 = perf_counter()
tokenizer = AutoTokenizer.from_pretrained(path)
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
elapsed = 1000 * (perf_counter() - t0)
logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
#lang = data.pop("lang", None)
#logger.info("The language of Verbatim is %s.", lang)
if isinstance(inputs, str):
inputs = [inputs]
t0 = perf_counter()
prediction_res = []
classifier_pol = self.classifier_p(inputs)
classifier_subj = self.classifier_s(inputs)
logger.info("Prediction polarity %s", classifier_pol)
logger.info("Prediction subjective %s", classifier_subj)
for idx, x in enumerate(inputs):
label = classifier_pol[idx]['label']
score = classifier_pol[idx]['score']
if label == '0' and score >= 0.75:
prediction_res.append({"label":"Neutral"})
else:
prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label'])})
elapsed = 1000 * (perf_counter() - t0)
logger.info("Model prediction time: %d ms.", elapsed)
return prediction_res
|