from typing import Dict, List, Any from dataclasses import dataclass import torch from transformers import AutoTokenizer from transformers import pipeline from transformers.pipelines import PIPELINE_REGISTRY from bibert_multitask_classification import BiBert_MultiTaskPipeline from bert_for_sequence_classification import BertForSequenceClassification from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger("transformers") @dataclass class Task: id: int name: str type: str num_labels: int class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(path) PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification) tasks = [ Task(id=0, name='label_classification', type='seq_classification', num_labels=5), Task(id=1, name='binary_classification', type='seq_classification', num_labels=2) ] model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device) self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device) self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ logger.info("INFO") inputs = data.pop("inputs", data) lang = data.pop("lang", None) prediction_p = self.classifier_p(inputs) label = prediction_p[0]['label'] score = prediction_p[0]['score'] if label == '0' and score >= 0.75: label = 2 return {"label":label, "score": score} else: prediction_s = self.classifier_s(inputs) label = prediction_s[0]['label'] score = prediction_s[0]['score'] return prediction_s