Kalaoke commited on
Commit
286854a
·
1 Parent(s): 753da8f

add possibility multiples predictions in one time

Browse files
__pycache__/bert_for_sequence_classification.cpython-37.pyc CHANGED
Binary files a/__pycache__/bert_for_sequence_classification.cpython-37.pyc and b/__pycache__/bert_for_sequence_classification.cpython-37.pyc differ
 
__pycache__/bibert_multitask_classification.cpython-37.pyc CHANGED
Binary files a/__pycache__/bibert_multitask_classification.cpython-37.pyc and b/__pycache__/bibert_multitask_classification.cpython-37.pyc differ
 
__pycache__/handler.cpython-37.pyc CHANGED
Binary files a/__pycache__/handler.cpython-37.pyc and b/__pycache__/handler.cpython-37.pyc differ
 
handler.py CHANGED
@@ -20,10 +20,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  @dataclass
22
  class Task:
23
- id: int
24
- name: str
25
- type: str
26
- num_labels: int
27
 
28
  tasks = [
29
  Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
@@ -34,50 +34,50 @@ tasks = [
34
  idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
35
 
36
  class EndpointHandler():
37
- def __init__(self, path=""):
38
- # Preload all the elements you are going to need at inference.
39
- logger.info("The device is %s.", device)
40
-
41
- t0 = perf_counter()
42
-
43
- tokenizer = AutoTokenizer.from_pretrained(path)
44
- model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
45
- self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
46
- self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
47
- elapsed = 1000 * (perf_counter() - t0)
48
- logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
49
-
50
-
51
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
- """
53
- data args:
54
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
55
- kwargs
56
- Return:
57
- A :obj:`list` | `dict`: will be serialized and returned
58
- """
59
-
60
- inputs = data.pop("inputs", data)
61
- #lang = data.pop("lang", None)
62
- #logger.info("The language of Verbatim is %s.", lang)
63
- if isinstance(inputs, str):
64
- inputs = [inputs]
65
 
66
- t0 = perf_counter()
67
- prediction_res = []
68
- classifier_pol = self.classifier_p(inputs)
 
 
69
 
70
- label = classifier_pol[0]['label']
71
- score = classifier_pol[0]['score']
72
-
73
- if label == '0' and score >= 0.75:
74
- logger.info("Prediction polarity %s", classifier_pol)
75
- prediction_res = [{"label":"Neutral"}]
76
- else:
77
- classifier_subj = self.classifier_s(inputs)
78
- logger.info("Prediction subjective %s", classifier_subj)
79
- label = idtolabel.get(classifier_subj[0]['label'])
80
- prediction_res = [{"label":label}]
81
- elapsed = 1000 * (perf_counter() - t0)
82
- logger.info("Model prediction time: %d ms.", elapsed)
83
- return prediction_res
 
 
 
 
 
20
 
21
  @dataclass
22
  class Task:
23
+ id: int
24
+ name: str
25
+ type: str
26
+ num_labels: int
27
 
28
  tasks = [
29
  Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
 
34
  idtolabel = {"0":"Negative", "1":"Negative", "2": "Neutral", "3":"Positive", "4": "Positive" }
35
 
36
  class EndpointHandler():
37
+ def __init__(self, path=""):
38
+ # Preload all the elements you are going to need at inference.
39
+ logger.info("The device is %s.", device)
40
+
41
+ t0 = perf_counter()
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(path)
44
+ model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
45
+ self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
46
+ self.classifier_p = pipeline("bibert-multitask-classification", model = model, task_id="1", tokenizer=tokenizer, device = device)
47
+ elapsed = 1000 * (perf_counter() - t0)
48
+ logger.info("Models and tokenizer Polarity loaded in %d ms.", elapsed)
49
+
50
+
51
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
+ """
53
+ data args:
54
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
55
+ kwargs
56
+ Return:
57
+ A :obj:`list` | `dict`: will be serialized and returned
58
+ """
 
 
 
 
 
 
59
 
60
+ inputs = data.pop("inputs", data)
61
+ #lang = data.pop("lang", None)
62
+ #logger.info("The language of Verbatim is %s.", lang)
63
+ if isinstance(inputs, str):
64
+ inputs = [inputs]
65
 
66
+ t0 = perf_counter()
67
+ prediction_res = []
68
+ classifier_pol = self.classifier_p(inputs)
69
+ classifier_subj = self.classifier_s(inputs)
70
+ logger.info("Prediction polarity %s", classifier_pol)
71
+ logger.info("Prediction subjective %s", classifier_subj)
72
+
73
+ for idx, x in enumerate(inputs):
74
+ label = classifier_pol[idx]['label']
75
+ score = classifier_pol[idx]['score']
76
+
77
+ if label == '0' and score >= 0.75:
78
+ prediction_res.append({"label":"Neutral"})
79
+ else:
80
+ prediction_res.append({"label":idtolabel.get(classifier_subj[idx]['label'])})
81
+ elapsed = 1000 * (perf_counter() - t0)
82
+ logger.info("Model prediction time: %d ms.", elapsed)
83
+ return prediction_res