BiBert-MultiTask-2 / bibert_multitask_classification.py
Kalaoke's picture
add custom handler and modify pipeline
2c5bd8b
raw
history blame
1.79 kB
from transformers import Pipeline
import numpy as np
import torch
def softmax(_outputs):
maxes = np.max(_outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(_outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class BiBert_MultiTaskPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "task_id" in kwargs:
preprocess_kwargs["task_id"] = kwargs["task_id"]
forward_kwargs = {}
if "task_id" in kwargs:
forward_kwargs["task_id"] = kwargs["task_id"]
postprocess_kwargs = {}
if "top_k" in kwargs:
postprocess_kwargs["top_k"] = kwargs["top_k"]
postprocess_kwargs["_legacy"] = False
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs, task_id):
return_tensors = self.framework
feature = self.tokenizer(inputs, padding = True, return_tensors=return_tensors).to(self.device)
task_ids = np.full(shape=1,fill_value=task_id, dtype=int)
feature["task_ids"] = torch.IntTensor(task_ids)
return feature
def _forward(self, model_inputs, task_id):
return self.model(**model_inputs)
def postprocess(self, model_outputs, top_k=1, _legacy=True):
outputs = model_outputs["logits"][0]
outputs = outputs.numpy()
scores = softmax(outputs)
if top_k == 1 and _legacy:
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
dict_scores = [
{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
]
if not _legacy:
dict_scores.sort(key=lambda x: x["score"], reverse=True)
if top_k is not None:
dict_scores = dict_scores[:top_k]
return dict_scores