OzTianlu commited on
Commit
39bf691
·
verified ·
1 Parent(s): cd370de

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -102,9 +102,9 @@ class EndpointHandler:
102
  if attention_mask is not None:
103
  attention_mask = attention_mask.to(self.model.device)
104
 
105
- input_len = input_ids.shape[-1]
106
 
107
  gen_ids = self.model.generate(
 
108
  input_ids=input_ids,
109
  max_new_tokens=max_new_tokens,
110
  do_sample=do_sample,
@@ -118,7 +118,7 @@ class EndpointHandler:
118
  )
119
 
120
  # Only return newly generated tokens
121
- new_tokens = gen_ids[0, input_len:]
122
  text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
123
  return {"generated_text": text}
124
 
 
102
  if attention_mask is not None:
103
  attention_mask = attention_mask.to(self.model.device)
104
 
 
105
 
106
  gen_ids = self.model.generate(
107
+ attention_mask=attention_mask,
108
  input_ids=input_ids,
109
  max_new_tokens=max_new_tokens,
110
  do_sample=do_sample,
 
118
  )
119
 
120
  # Only return newly generated tokens
121
+ new_tokens = gen_ids
122
  text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
123
  return {"generated_text": text}
124