add possibility multiples predictions in one time
Browse files
__pycache__/bert_for_sequence_classification.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/bert_for_sequence_classification.cpython-37.pyc and b/__pycache__/bert_for_sequence_classification.cpython-37.pyc differ
|
|
|
__pycache__/bibert_multitask_classification.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/bibert_multitask_classification.cpython-37.pyc and b/__pycache__/bibert_multitask_classification.cpython-37.pyc differ
|
|
|
__pycache__/handler.cpython-37.pyc
CHANGED
|
Binary files a/__pycache__/handler.cpython-37.pyc and b/__pycache__/handler.cpython-37.pyc differ
|
|
|
handler.py
CHANGED
|
@@ -20,10 +20,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 20 |
|
| 21 |
@dataclass
|
| 22 |
class Task:
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
tasks = [
|
| 29 |
Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
|
|
@@ -34,50 +34,50 @@ tasks = [
|
|
| 34 |
idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
|
| 35 |
|
| 36 |
class EndpointHandler():
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
inputs = data.pop("inputs", data)
|
| 61 |
-
#lang = data.pop("lang", None)
|
| 62 |
-
#logger.info("The language of Verbatim is %s.", lang)
|
| 63 |
-
if isinstance(inputs, str):
|
| 64 |
-
inputs = [inputs]
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
@dataclass
|
| 22 |
class Task:
|
| 23 |
+
id: int
|
| 24 |
+
name: str
|
| 25 |
+
type: str
|
| 26 |
+
num_labels: int
|
| 27 |
|
| 28 |
tasks = [
|
| 29 |
Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
|
|
|
|
| 34 |
idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
|
| 35 |
|
| 36 |
class EndpointHandler():
|
| 37 |
+
def __init__(self, path=""):
|
| 38 |
+
# Preload all the elements you are going to need at inference.
|
| 39 |
+
logger.info("The device is %s.", device)
|
| 40 |
+
|
| 41 |
+
t0 = perf_counter()
|
| 42 |
+
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
| 44 |
+
model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
|
| 45 |
+
self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
|
| 46 |
+
self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
|
| 47 |
+
elapsed = 1000 * (perf_counter() - t0)
|
| 48 |
+
logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 52 |
+
"""
|
| 53 |
+
data args:
|
| 54 |
+
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
| 55 |
+
kwargs
|
| 56 |
+
Return:
|
| 57 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
| 58 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
inputs = data.pop("inputs", data)
|
| 61 |
+
#lang = data.pop("lang", None)
|
| 62 |
+
#logger.info("The language of Verbatim is %s.", lang)
|
| 63 |
+
if isinstance(inputs, str):
|
| 64 |
+
inputs = [inputs]
|
| 65 |
|
| 66 |
+
t0 = perf_counter()
|
| 67 |
+
prediction_res = []
|
| 68 |
+
classifier_pol = self.classifier_p(inputs)
|
| 69 |
+
classifier_subj = self.classifier_s(inputs)
|
| 70 |
+
logger.info("Prediction polarity %s", classifier_pol)
|
| 71 |
+
logger.info("Prediction subjective %s", classifier_subj)
|
| 72 |
+
|
| 73 |
+
for idx, x in enumerate(inputs):
|
| 74 |
+
label = classifier_pol[idx]['label']
|
| 75 |
+
score = classifier_pol[idx]['score']
|
| 76 |
+
|
| 77 |
+
if label == '0' and score >= 0.75:
|
| 78 |
+
prediction_res.append({"label":"Neutral"})
|
| 79 |
+
else:
|
| 80 |
+
prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label'])})
|
| 81 |
+
elapsed = 1000 * (perf_counter() - t0)
|
| 82 |
+
logger.info("Model prediction time: %d ms.", elapsed)
|
| 83 |
+
return prediction_res
|