add custom handler and modify label to return
Browse files
__pycache__/bert_for_sequence_classification.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/bert_for_sequence_classification.cpython-37.pyc and b/__pycache__/bert_for_sequence_classification.cpython-37.pyc differ
|
|
|
__pycache__/bibert_multitask_classification.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/bibert_multitask_classification.cpython-37.pyc and b/__pycache__/bibert_multitask_classification.cpython-37.pyc differ
|
|
|
__pycache__/handler.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/handler.cpython-37.pyc and b/__pycache__/handler.cpython-37.pyc differ
|
|
|
handler.py
CHANGED
|
@@ -30,6 +30,9 @@ tasks = [
|
|
| 30 |
Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
|
| 31 |
]
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
class EndpointHandler():
|
| 34 |
def __init__(self, path=""):
|
| 35 |
# Preload all the elements you are going to need at inference.
|
|
@@ -62,15 +65,21 @@ class EndpointHandler():
|
|
| 62 |
|
| 63 |
t0 = perf_counter()
|
| 64 |
prediction_res = []
|
| 65 |
-
prediction_p = self.classifier_p(inputs)
|
|
|
|
| 66 |
label = prediction_p[0]['label']
|
| 67 |
score = prediction_p[0]['score']
|
| 68 |
|
| 69 |
if label == '0' and score >= 0.75:
|
| 70 |
-
|
| 71 |
-
prediction_res = [{"label":
|
| 72 |
else:
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
elapsed = 1000 * (perf_counter() - t0)
|
| 75 |
logger.info("Model prediction time: %d ms.", elapsed)
|
| 76 |
return prediction_res
|
|
|
|
| 30 |
Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
|
| 31 |
]
|
| 32 |
|
| 33 |
+
|
| 34 |
+
idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
|
| 35 |
+
|
| 36 |
class EndpointHandler():
|
| 37 |
def __init__(self, path=""):
|
| 38 |
# Preload all the elements you are going to need at inference.
|
|
|
|
| 65 |
|
| 66 |
t0 = perf_counter()
|
| 67 |
prediction_res = []
|
| 68 |
+
prediction_p = self.classifier_p(inputs)
|
| 69 |
+
logger.info(prediction_p)
|
| 70 |
label = prediction_p[0]['label']
|
| 71 |
score = prediction_p[0]['score']
|
| 72 |
|
| 73 |
if label == '0' and score >= 0.75:
|
| 74 |
+
|
| 75 |
+
prediction_res = [{"label":"Neutral"}]
|
| 76 |
else:
|
| 77 |
+
classifier_res = self.classifier_s(inputs)
|
| 78 |
+
logger.info("Prediction %s", classifier_res)
|
| 79 |
+
label = classifier_res[0]['label']
|
| 80 |
+
for key in idtolabel.keys():
|
| 81 |
+
label = label.replace(key, idtolabel[key])
|
| 82 |
+
prediction_res = [{"label":label}]
|
| 83 |
elapsed = 1000 * (perf_counter() - t0)
|
| 84 |
logger.info("Model prediction time: %d ms.", elapsed)
|
| 85 |
return prediction_res
|