Kalaoke commited on
Commit
1c4ad1e
·
1 Parent(s): a79dc31

add custom handler and modify pipeline

Browse files
Files changed (2) hide show
  1. __pycache__/handler.cpython-37.pyc +0 -0
  2. handler.py +17 -10
__pycache__/handler.cpython-37.pyc CHANGED
Binary files a/__pycache__/handler.cpython-37.pyc and b/__pycache__/handler.cpython-37.pyc differ
 
handler.py CHANGED
@@ -7,6 +7,7 @@ from transformers.pipelines import PIPELINE_REGISTRY
7
  from bibert_multitask_classification import BiBert_MultiTaskPipeline
8
  from bert_for_sequence_classification import BertForSequenceClassification
9
  from transformers.utils import logging
 
10
 
11
 
12
  PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
@@ -32,12 +33,17 @@ tasks = [
32
  class EndpointHandler():
33
  def __init__(self, path=""):
34
  # Preload all the elements you are going to need at inference.
35
- tokenizer = AutoTokenizer.from_pretrained(path)
 
 
36
 
 
37
  model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
38
-
39
  self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
40
  self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
 
 
 
41
 
42
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
43
  """
@@ -50,20 +56,21 @@ class EndpointHandler():
50
 
51
  inputs = data.pop("inputs", data)
52
  lang = data.pop("lang", None)
53
- logger.info(inputs)
54
- logger.info(lang)
55
  if isinstance(inputs, str):
56
  inputs = [inputs]
57
 
 
 
58
  prediction_p = self.classifier_p(inputs)
59
  label = prediction_p[0]['label']
60
  score = prediction_p[0]['score']
61
-
62
  if label == '0' and score >= 0.75:
63
  label = 2
64
- return [{"label":label, "score": score}]
65
  else:
66
- prediction_s = self.classifier_s(inputs)
67
- label = prediction_s[0]['label']
68
- score = prediction_s[0]['score']
69
- return prediction_s
 
7
  from bibert_multitask_classification import BiBert_MultiTaskPipeline
8
  from bert_for_sequence_classification import BertForSequenceClassification
9
  from transformers.utils import logging
10
+ from time import perf_counter
11
 
12
 
13
  PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
 
33
  class EndpointHandler():
34
  def __init__(self, path=""):
35
  # Preload all the elements you are going to need at inference.
36
+ logger.info("The device is %s.", device)
37
+
38
+ t0 = perf_counter()
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained(path)
41
  model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
 
42
  self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
43
  self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
44
+ elapsed = 1000 * (perf_counter() - t0)
45
+ logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
46
+
47
 
48
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
49
  """
 
56
 
57
  inputs = data.pop("inputs", data)
58
  lang = data.pop("lang", None)
59
+ logger.info("The language of Verbatim is %s.", lang)
 
60
  if isinstance(inputs, str):
61
  inputs = [inputs]
62
 
63
+ t0 = perf_counter()
64
+ prediction_res = []
65
  prediction_p = self.classifier_p(inputs)
66
  label = prediction_p[0]['label']
67
  score = prediction_p[0]['score']
68
+
69
  if label == '0' and score >= 0.75:
70
  label = 2
71
+ prediction_res = [{"label":label, "score": score}]
72
  else:
73
+ prediction_res = self.classifier_s(inputs)
74
+ elapsed = 1000 * (perf_counter() - t0)
75
+ logger.info("Model prediction time: %d ms.", elapsed)
76
+ return prediction_res