fix: exclude prompt tokens in generate() return value
#1
by
haahha
- opened
- modeling_llada2_moe.py +1 -1
modeling_llada2_moe.py
CHANGED
|
@@ -1409,7 +1409,7 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
|
|
| 1409 |
if len(eos_pos_in_x[0]) > 0:
|
| 1410 |
eos_pos = eos_pos_in_x[0][0].item()
|
| 1411 |
if (cur_x[0, prompt_length:eos_pos] != mask_id).all():
|
| 1412 |
-
final_x = x[:, :total_length][:, : eos_pos + 1]
|
| 1413 |
return final_x
|
| 1414 |
|
| 1415 |
x[:, :current_window_end] = cur_x
|
|
|
|
| 1409 |
if len(eos_pos_in_x[0]) > 0:
|
| 1410 |
eos_pos = eos_pos_in_x[0][0].item()
|
| 1411 |
if (cur_x[0, prompt_length:eos_pos] != mask_id).all():
|
| 1412 |
+
final_x = x[:, :total_length][:, prompt_length : eos_pos + 1]
|
| 1413 |
return final_x
|
| 1414 |
|
| 1415 |
x[:, :current_window_end] = cur_x
|