haahha commited on
Commit
87620ad
·
verified ·
1 Parent(s): d699e90

fix: exclude prompt tokens in generate() return value

Browse files

Problem:
- In generate(), one return path slices as `final_x = x[:, :total_length][:, :eos_pos + 1]`, including the prompt.
- Another path `return generated_answer[:, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1 ]`correctly excludes the prompt.
- This causes inconsistent generated token outputs between return paths.

Solution:
- Change to `final_x = x[:, :total_length][:, prompt_length : eos_pos + 1]` to exclude prompt tokens.
- Ensures all generate() return paths consistently match the docstring: returned tokens start after the prompt and stop at the first `eos_id` or `gen_length`.

Files changed (1) hide show
  1. modeling_llada2_moe.py +1 -1
modeling_llada2_moe.py CHANGED
@@ -1409,7 +1409,7 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
1409
  if len(eos_pos_in_x[0]) > 0:
1410
  eos_pos = eos_pos_in_x[0][0].item()
1411
  if (cur_x[0, prompt_length:eos_pos] != mask_id).all():
1412
- final_x = x[:, :total_length][:, : eos_pos + 1]
1413
  return final_x
1414
 
1415
  x[:, :current_window_end] = cur_x
 
1409
  if len(eos_pos_in_x[0]) > 0:
1410
  eos_pos = eos_pos_in_x[0][0].item()
1411
  if (cur_x[0, prompt_length:eos_pos] != mask_id).all():
1412
+ final_x = x[:, :total_length][:, prompt_length : eos_pos + 1]
1413
  return final_x
1414
 
1415
  x[:, :current_window_end] = cur_x