fix: exclude prompt tokens in generate() return value
#1
by
haahha
- opened
Problem:
- In generate(), one return path slices as
final_x = x[:, :total_length][:, :eos_pos + 1], including the prompt. - Another path
return generated_answer[:, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1 ]correctly excludes the prompt. - This causes inconsistent generated token outputs between return paths.
Solution:
- Change to
final_x = x[:, :total_length][:, prompt_length : eos_pos + 1]to exclude prompt tokens. - Ensures all generate() return paths consistently match the docstring: returned tokens start after the prompt and stop at the first
eos_idorgen_length.