| | from typing import Dict, List, Any |
| | from dataclasses import dataclass |
| | import torch |
| | from transformers import AutoTokenizer |
| | from transformers import pipeline |
| | from transformers.pipelines import PIPELINE_REGISTRY |
| | from bibert_multitask_classification import BiBert_MultiTaskPipeline |
| | from bert_for_sequence_classification import BertForSequenceClassification |
| | from transformers.utils import logging |
| | from time import perf_counter |
| |
|
| |
|
| | PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification) |
| |
|
| | logging.set_verbosity_info() |
| | logger = logging.get_logger("transformers") |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | @dataclass |
| | class Task: |
| | id: int |
| | name: str |
| | type: str |
| | num_labels: int |
| |
|
| | tasks = [ |
| | Task(id=0, name='label_classification', type='seq_classification', num_labels=5), |
| | Task(id=1, name='binary_classification', type='seq_classification', num_labels=2) |
| | ] |
| |
|
| |
|
| | idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" } |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | logger.info("The device is %s.", device) |
| |
|
| | t0 = perf_counter() |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(path) |
| | model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device) |
| | self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device) |
| | self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device) |
| | elapsed = 1000 * (perf_counter() - t0) |
| | logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed) |
| |
|
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| | kwargs |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | |
| | inputs = data.pop("inputs", data) |
| | |
| | |
| | if isinstance(inputs, str): |
| | inputs = [inputs] |
| | |
| | t0 = perf_counter() |
| | prediction_res = [] |
| | classifier_pol = self.classifier_p(inputs) |
| | classifier_subj = self.classifier_s(inputs) |
| | logger.info("Prediction polarity %s", classifier_pol) |
| | logger.info("Prediction subjective %s", classifier_subj) |
| |
|
| | for idx, x in enumerate(inputs): |
| | label = classifier_pol[idx]['label'] |
| | score = classifier_pol[idx]['score'] |
| |
|
| | if label == '0' and score >= 0.75: |
| | prediction_res.append({"label":"Neutral"}) |
| | else: |
| | prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label'])}) |
| | elapsed = 1000 * (perf_counter() - t0) |
| | logger.info("Model prediction time: %d ms.", elapsed) |
| | return prediction_res |
| |
|