Melissa Roemmele commited on
Commit
cd1659c
·
1 Parent(s): ac83b16

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -14
handler.py CHANGED
@@ -7,30 +7,32 @@ class EndpointHandler:
7
  def __init__(self, path=""):
8
  # load model and tokenizer from path
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.model = AutoModelForCausalLM.from_pretrained(
11
- path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
12
- )
 
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
16
  # process input
17
  inputs = data.pop("inputs", data)
18
- parameters = data.pop("parameters", None)
 
19
 
20
  # preprocess
21
- inputs = self.tokenizer(inputs, return_tensors="pt")
22
- inputs.pop("token_type_ids")
23
- inputs.pop("attention_mask")
24
  inputs = inputs.to(self.device)
 
25
 
26
- # pass inputs with all kwargs in data
27
- if parameters is not None:
28
- outputs = self.model.generate(**inputs, **parameters)
29
- else:
30
- outputs = self.model.generate(**inputs)
31
 
32
  # postprocess the prediction
33
- prediction = self.tokenizer.decode(
34
- outputs[0], skip_special_tokens=True)
35
 
36
  return [{"generated_text": prediction}]
 
7
  def __init__(self, path=""):
8
  # load model and tokenizer from path
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForCausalLM.from_pretrained(path,
11
+ device_map="auto",
12
+ torch_dtype=torch.float16,
13
+ trust_remote_code=True)
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
17
  # process input
18
  inputs = data.pop("inputs", data)
19
+ parameters = data.pop("parameters", {})
20
+ return_full_text = parameters.pop("return_full_text", True)
21
 
22
  # preprocess
23
+ inputs = self.tokenizer(inputs,
24
+ return_tensors="pt",
25
+ return_token_type_ids=False)
26
  inputs = inputs.to(self.device)
27
+ input_len = len(inputs[0])
28
 
29
+ outputs = self.model.generate(**inputs, **parameters)[0]
30
+
31
+ if not return_full_text:
32
+ outputs = outputs[input_len:]
 
33
 
34
  # postprocess the prediction
35
+ prediction = self.tokenizer.decode(outputs,
36
+ skip_special_tokens=True)
37
 
38
  return [{"generated_text": prediction}]