Fix the example onnx inference code.
#27
by kkkkkk98 - opened
README.md
CHANGED
|
@@ -196,7 +196,7 @@ for i in range(max_new_tokens):
|
|
| 196 |
|
| 197 |
## Update values for next generation loop
|
| 198 |
input_ids = logits[:, -1].argmax(-1, keepdims=True)
|
| 199 |
-
attention_mask = np.ones_like(input_ids)
|
| 200 |
position_ids = position_ids[:, -1:] + 1
|
| 201 |
for j, key in enumerate(past_key_values):
|
| 202 |
past_key_values[key] = present_key_values[j]
|
|
|
|
| 196 |
|
| 197 |
## Update values for next generation loop
|
| 198 |
input_ids = logits[:, -1].argmax(-1, keepdims=True)
|
| 199 |
+
attention_mask = np.ones_like(np.concatenate((attention_mask, input_ids), axis=-1))
|
| 200 |
position_ids = position_ids[:, -1:] + 1
|
| 201 |
for j, key in enumerate(past_key_values):
|
| 202 |
past_key_values[key] = present_key_values[j]
|