File size: 1,786 Bytes
eb39c91
 
 
 
 
 
 
 
 
 
 
2c5bd8b
eb39c91
 
 
 
 
 
 
 
 
2c5bd8b
 
 
 
 
 
 
 
eb39c91
 
 
 
 
 
 
 
 
 
 
2c5bd8b
eb39c91
 
 
 
2c5bd8b
 
 
eb39c91
2c5bd8b
 
 
 
 
 
eb39c91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import Pipeline
import numpy as np
import torch

def softmax(_outputs):
  maxes = np.max(_outputs, axis=-1, keepdims=True)
  shifted_exp = np.exp(_outputs - maxes)
  return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)

class BiBert_MultiTaskPipeline(Pipeline):


  def _sanitize_parameters(self, **kwargs):

    preprocess_kwargs = {}
    if "task_id" in kwargs:
      preprocess_kwargs["task_id"] = kwargs["task_id"]
    
    forward_kwargs = {}
    if "task_id" in kwargs:
      forward_kwargs["task_id"] = kwargs["task_id"]

    postprocess_kwargs = {}
    if "top_k" in kwargs:
      postprocess_kwargs["top_k"] = kwargs["top_k"]
      postprocess_kwargs["_legacy"] = False
    return preprocess_kwargs, forward_kwargs, postprocess_kwargs

    

  def preprocess(self, inputs, task_id):
    return_tensors = self.framework
    feature = self.tokenizer(inputs, padding = True, return_tensors=return_tensors).to(self.device)
    task_ids = np.full(shape=1,fill_value=task_id, dtype=int)
    feature["task_ids"] = torch.IntTensor(task_ids)
    return feature
  
  def _forward(self, model_inputs, task_id):
    return self.model(**model_inputs)

  def postprocess(self, model_outputs, top_k=1, _legacy=True):
    outputs = model_outputs["logits"][0]
    outputs = outputs.numpy()
    scores = softmax(outputs)

    if top_k == 1 and _legacy:
      return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}

    dict_scores = [
        {"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
    ]
    if not _legacy:
      dict_scores.sort(key=lambda x: x["score"], reverse=True)
      if top_k is not None:
          dict_scores = dict_scores[:top_k]
    return dict_scores