fix: exclude prompt tokens in generate() return value

#1
by haahha - opened

Problem:

  • In generate(), one return path slices as final_x = x[:, :total_length][:, :eos_pos + 1], including the prompt.
  • Another path return generated_answer[:, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1 ]correctly excludes the prompt.
  • This causes inconsistent generated token outputs between return paths.

Solution:

  • Change to final_x = x[:, :total_length][:, prompt_length : eos_pos + 1] to exclude prompt tokens.
  • Ensures all generate() return paths consistently match the docstring: returned tokens start after the prompt and stop at the first eos_id or gen_length.
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment