Add eos_token_id to fix infinite generation loop

#3
by chbae624 - opened

TranslateGemma models generate <end_of_turn> token (ID: 106) after completing translation, but generation_config.json doesn't include this token in eos_token_id. This causes the model to continue generating until max_new_tokens is reached, resulting in:

  1. Extremely slow inference - generates thousands of unnecessary tokens
  2. Repeated <end_of_turn> tokens in output

Reproduction

from transformers import AutoModelForImageTextToText, AutoProcessor
import torch

model_id = "google/translategemma-12b-it"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

messages = [{"role": "user", "content": [{"type": "text", "source_lang_code": "en", "target_lang_code": "ko", "text": "Hello"}]}]
inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)

with torch.inference_mode():
    output = model.generate(**inputs, max_new_tokens=50, do_sample=False)

generated = output[0][len(inputs['input_ids'][0]):]
print(processor.decode(generated, skip_special_tokens=False))

Current output:
안녕하세요... (repeats until max_new_tokens)
Expected output:
안녕하세요

Solution

Add "eos_token_id": [1, 106] to generation_config.json:

1 = (default end-of-stream token)
106 = (instruction-tuned model turn terminator)

Thank you very much!
Setting eos_token_id as you suggested worked for me.

Google org

Thanks!

RyanMullins changed pull request status to merged

Sign up or log in to comment