Update handler.py
Browse files- handler.py +2 -2
handler.py
CHANGED
|
@@ -102,9 +102,9 @@ class EndpointHandler:
|
|
| 102 |
if attention_mask is not None:
|
| 103 |
attention_mask = attention_mask.to(self.model.device)
|
| 104 |
|
| 105 |
-
input_len = input_ids.shape[-1]
|
| 106 |
|
| 107 |
gen_ids = self.model.generate(
|
|
|
|
| 108 |
input_ids=input_ids,
|
| 109 |
max_new_tokens=max_new_tokens,
|
| 110 |
do_sample=do_sample,
|
|
@@ -118,7 +118,7 @@ class EndpointHandler:
|
|
| 118 |
)
|
| 119 |
|
| 120 |
# Only return newly generated tokens
|
| 121 |
-
new_tokens = gen_ids
|
| 122 |
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 123 |
return {"generated_text": text}
|
| 124 |
|
|
|
|
| 102 |
if attention_mask is not None:
|
| 103 |
attention_mask = attention_mask.to(self.model.device)
|
| 104 |
|
|
|
|
| 105 |
|
| 106 |
gen_ids = self.model.generate(
|
| 107 |
+
attention_mask=attention_mask,
|
| 108 |
input_ids=input_ids,
|
| 109 |
max_new_tokens=max_new_tokens,
|
| 110 |
do_sample=do_sample,
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
# Only return newly generated tokens
|
| 121 |
+
new_tokens = gen_ids
|
| 122 |
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 123 |
return {"generated_text": text}
|
| 124 |
|