jammygrams commited on
Commit
bbefeb7
·
1 Parent(s): 5885b20

Upload handler.py

Browse files

handler for inference endpoint

Files changed (1) hide show
  1. handler.py +43 -0
handler.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import os
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ PRETRAINED_MODEL_NAME = "facebook/bart-large"
6
+ ADAPTER_MODEL_NAME = "jammygrams/bart-qa"
7
+ ADAPTER_NAME = "narrativeqa"
8
+
9
+ class EndpointHandler():
10
+ def __init__(
11
+ self,
12
+ path: str,
13
+ ):
14
+ # self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
15
+ self.tokenizer = AutoTokenizer.from_file(os.path.join(path, "tokenizer.json"))
16
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(path, "pytorch_model.bin"))
17
+ self.model.set_active_adapters(ADAPTER_NAME)
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> str:
20
+ """
21
+ data args:
22
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
23
+ kwargs
24
+ Return:
25
+ output :obj:`list` | `dict`: will be serialized and returned
26
+ """
27
+ inputs = data.pop("inputs", data)
28
+ # test_sample = self.tokenizer([user_input], truncation=True, max_length=1024, return_tensors='pt')
29
+ tokenized_input = self.tokenizer([inputs], return_tensors="pt")
30
+ prediction = self.model.generate(
31
+ tokenized_input.input_ids,
32
+ num_beams=5,
33
+ return_dict_in_generate=True,
34
+ output_scores=True,
35
+ max_length=50,
36
+ )
37
+ output = self.tokenizer.decode(
38
+ prediction["sequences"][0], # single prediction
39
+ skip_special_tokens=True,
40
+ clean_up_tokenization_spaces=True,
41
+ )
42
+
43
+ return [output]