add custom handler and modify pipeline
Browse files- __pycache__/handler.cpython-37.pyc +0 -0
- handler.py +17 -10
__pycache__/handler.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/handler.cpython-37.pyc and b/__pycache__/handler.cpython-37.pyc differ
|
|
|
handler.py
CHANGED
|
@@ -7,6 +7,7 @@ from transformers.pipelines import PIPELINE_REGISTRY
|
|
| 7 |
from bibert_multitask_classification import BiBert_MultiTaskPipeline
|
| 8 |
from bert_for_sequence_classification import BertForSequenceClassification
|
| 9 |
from transformers.utils import logging
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
|
|
@@ -32,12 +33,17 @@ tasks = [
|
|
| 32 |
class EndpointHandler():
|
| 33 |
def __init__(self, path=""):
|
| 34 |
# Preload all the elements you are going to need at inference.
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
|
|
|
|
| 37 |
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
|
| 38 |
-
|
| 39 |
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
|
| 40 |
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 43 |
"""
|
|
@@ -50,20 +56,21 @@ class EndpointHandler():
|
|
| 50 |
|
| 51 |
inputs = data.pop("inputs", data)
|
| 52 |
lang = data.pop("lang", None)
|
| 53 |
-
logger.info(
|
| 54 |
-
logger.info(lang)
|
| 55 |
if isinstance(inputs, str):
|
| 56 |
inputs = [inputs]
|
| 57 |
|
|
|
|
|
|
|
| 58 |
prediction_p = self.classifier_p(inputs)
|
| 59 |
label = prediction_p[0]['label']
|
| 60 |
score = prediction_p[0]['score']
|
| 61 |
-
|
| 62 |
if label == '0' and score >= 0.75:
|
| 63 |
label = 2
|
| 64 |
-
|
| 65 |
else:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
return
|
|
|
|
| 7 |
from bibert_multitask_classification import BiBert_MultiTaskPipeline
|
| 8 |
from bert_for_sequence_classification import BertForSequenceClassification
|
| 9 |
from transformers.utils import logging
|
| 10 |
+
from time import perf_counter
|
| 11 |
|
| 12 |
|
| 13 |
PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
|
|
|
|
| 33 |
class EndpointHandler():
|
| 34 |
def __init__(self, path=""):
|
| 35 |
# Preload all the elements you are going to need at inference.
|
| 36 |
+
logger.info("The device is %s.", device)
|
| 37 |
+
|
| 38 |
+
t0 = perf_counter()
|
| 39 |
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
| 41 |
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
|
|
|
|
| 42 |
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
|
| 43 |
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
|
| 44 |
+
elapsed = 1000 * (perf_counter() - t0)
|
| 45 |
+
logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
|
| 46 |
+
|
| 47 |
|
| 48 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 49 |
"""
|
|
|
|
| 56 |
|
| 57 |
inputs = data.pop("inputs", data)
|
| 58 |
lang = data.pop("lang", None)
|
| 59 |
+
logger.info("The language of Verbatim is %s.", lang)
|
|
|
|
| 60 |
if isinstance(inputs, str):
|
| 61 |
inputs = [inputs]
|
| 62 |
|
| 63 |
+
t0 = perf_counter()
|
| 64 |
+
prediction_res = []
|
| 65 |
prediction_p = self.classifier_p(inputs)
|
| 66 |
label = prediction_p[0]['label']
|
| 67 |
score = prediction_p[0]['score']
|
| 68 |
+
|
| 69 |
if label == '0' and score >= 0.75:
|
| 70 |
label = 2
|
| 71 |
+
prediction_res = [{"label":label, "score": score}]
|
| 72 |
else:
|
| 73 |
+
prediction_res = self.classifier_s(inputs)
|
| 74 |
+
elapsed = 1000 * (perf_counter() - t0)
|
| 75 |
+
logger.info("Model prediction time: %d ms.", elapsed)
|
| 76 |
+
return prediction_res
|