File size: 2,998 Bytes
daa2f46 2c5bd8b daa2f46 2c5bd8b daa2f46 ff0b14b 1c4ad1e ff0b14b 05f6b94 ff0b14b 05f6b94 2c5bd8b 05f6b94 d6df1bf 794d04c 2c5bd8b 1c4ad1e 2c5bd8b 1c4ad1e 2c5bd8b 1c4ad1e 2c5bd8b c268aa4 58c3574 1c4ad1e c268aa4 2c5bd8b 1c4ad1e 1c87edc 1c4ad1e 2c5bd8b 1c87edc 794d04c 2c5bd8b 1c87edc 794d04c 1c4ad1e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | from typing import Dict, List, Any
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from bibert_multitask_classification import BiBert_MultiTaskPipeline
from bert_for_sequence_classification import BertForSequenceClassification
from transformers.utils import logging
from time import perf_counter
PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@dataclass
class Task:
id: int
name: str
type: str
num_labels: int
tasks = [
Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
]
idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
logger.info("The device is %s.", device)
t0 = perf_counter()
tokenizer = AutoTokenizer.from_pretrained(path)
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
elapsed = 1000 * (perf_counter() - t0)
logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
lang = data.pop("lang", None)
logger.info("The language of Verbatim is %s.", lang)
if isinstance(inputs, str):
inputs = [inputs]
t0 = perf_counter()
prediction_res = []
classifier_pol = self.classifier_p(inputs)
label = classifier_pol[0]['label']
score = classifier_pol[0]['score']
if label == '0' and score >= 0.75:
logger.info("Prediction polarity %s", classifier_pol)
prediction_res = [{"label":"Neutral"}]
else:
classifier_subj = self.classifier_s(inputs)
logger.info("Prediction subjective %s", classifier_subj)
label = classifier_subj[0]['label']
for key in idtolabel.keys():
label = label.replace(key, idtolabel[key])
prediction_res = [{"label":label}]
elapsed = 1000 * (perf_counter() - t0)
logger.info("Model prediction time: %d ms.", elapsed)
return prediction_res
|