dariadaria commited on
Commit
1aec583
·
1 Parent(s): c0d4e65

batched handler

Browse files
Files changed (1) hide show
  1. handler.py +29 -13
handler.py CHANGED
@@ -1,16 +1,17 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
5
 
6
  class EndpointHandler:
7
- def __init__(self, path=""):
8
  self.tokenizer = AutoTokenizer.from_pretrained(path)
9
- self.model = AutoModelForSequenceClassification.from_pretrained(path, num_labels=3)
10
- def tokenize(text, topic):
11
  return self.tokenizer(
12
- topic,
13
- text,
14
  max_length=384, #512
15
  truncation="only_second",
16
  return_offsets_mapping=False,
@@ -22,16 +23,31 @@ class EndpointHandler:
22
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
23
  """
24
  data args:
25
- topic (:obj: `str`)
26
- text (:obj: `str`)
27
  Return:
28
  A :obj:`list` | `dict`: will be serialized and returned
29
  """
30
- topic = data.pop("topic", data)
31
- text = data.pop("text", data)
32
- tokenized_inputs = self.tokenize(text, topic)
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
34
  output = self.model(**tokenized_inputs)
35
- prediction = torch.argmax(output.logits, dim=-1).item()
36
- label = self.model.config.id2label[prediction]
37
- return label
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ from datasets import Dataset
5
 
6
 
7
  class EndpointHandler:
8
+ def __init__(self, path=""):
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
11
+ def tokenize(batch):
12
  return self.tokenizer(
13
+ batch['topic'],
14
+ batch['text'],
15
  max_length=384, #512
16
  truncation="only_second",
17
  return_offsets_mapping=False,
 
23
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
24
  """
25
  data args:
26
+ topics List[str]
27
+ texts List[Dict[str, str]]: keys shouls be id and text
28
  Return:
29
  A :obj:`list` | `dict`: will be serialized and returned
30
  """
31
+ topics = data.pop("topics", data)
32
+ texts = data.pop("texts", data)
33
+ batch_dict = {
34
+ 'id': [],
35
+ 'text': [],
36
+ 'topic': []
37
+ }
38
+
39
+ for topic in topics:
40
+ for text in texts:
41
+ batch_dict['id'].append(text['id'])
42
+ batch_dict['text'].append(text['text'])
43
+ batch_dict['topic'].append(topic)
44
+
45
+ batch = Dataset.from_dict(batch_dict)
46
 
47
+ tokenized_inputs = self.tokenize(batch)
48
+
49
+ # run normal prediction
50
  output = self.model(**tokenized_inputs)
51
+ batch = batch.add_column('predictions', torch.argmax(output.logits, dim=-1).numpy(force=True))
52
+ batch = batch.map(lambda b: {'label': [self.model.config.id2label[p] for p in b['predictions']]}, batched=True, remove_columns=['text', 'predictions'])
53
+ return batch.to_dict()