BiBert-MultiTask-2 / handler.py
Kalaoke's picture
add custom handler and modify pipeline
05f6b94
raw
history blame
2.31 kB
from typing import Dict, List, Any
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from bibert_multitask_classification import BiBert_MultiTaskPipeline
from bert_for_sequence_classification import BertForSequenceClassification
from transformers.utils import logging
PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@dataclass
class Task:
id: int
name: str
type: str
num_labels: int
tasks = [
Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
]
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
tokenizer = AutoTokenizer.from_pretrained(path)
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
lang = data.pop("lang", None)
logger.info(inputs)
logger.info(lang)
if isinstance(inputs, str):
inputs = [inputs]
prediction_p = self.classifier_p(inputs)
label = prediction_p[0]['label']
score = prediction_p[0]['score']
if label == '0' and score >= 0.75:
label = 2
return {"label":label, "score": score}
else:
prediction_s = self.classifier_s(inputs)
label = prediction_s[0]['label']
score = prediction_s[0]['score']
return prediction_s