from typing import Dict, List, Any from dataclasses import dataclass import torch from transformers import AutoTokenizer from transformers import pipeline from transformers.pipelines import PIPELINE_REGISTRY from bibert_multitask_classification import BiBert_MultiTaskPipeline from bert_for_sequence_classification import BertForSequenceClassification from transformers.utils import logging from time import perf_counter PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification) logging.set_verbosity_info() logger = logging.get_logger("transformers") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @dataclass class Task: id: int name: str type: str num_labels: int tasks = [ Task(id=0, name='label_classification', type='seq_classification', num_labels=5), Task(id=1, name='binary_classification', type='seq_classification', num_labels=2) ] idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" } idtoscore = {"0": -1, "1": -1, "2": 0, "3": 1, "4": 1 } class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. logger.info("The device is %s.", device) t0 = perf_counter() tokenizer = AutoTokenizer.from_pretrained(path) model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device) self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device) self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device) elapsed = 1000 * (perf_counter() - t0) logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) #lang = data.pop("lang", None) #logger.info("The language of Verbatim is %s.", lang) if isinstance(inputs, str): inputs = [inputs] t0 = perf_counter() prediction_res = [] classifier_pol = self.classifier_p(inputs) classifier_subj = self.classifier_s(inputs) logger.info("Prediction polarity %s", classifier_pol) logger.info("Prediction subjective %s", classifier_subj) for idx, x in enumerate(inputs): label = classifier_pol[idx]['label'] prob = classifier_pol[idx]['probability'] if label == '0' and prob >= 0.75: prediction_res.append({"label":"Neutral", "score":0}) else: prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label']), "score": idtoscore.get(classifier_subj[idx]['label'])}) elapsed = 1000 * (perf_counter() - t0) logger.info("Model prediction time: %d ms.", elapsed) return prediction_res