hf gemma 3 pt generate bug

#15
by dglasscortex - opened

One can reproduce it by running the following code:

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM

ckpt = "google/gemma-3-1b-pt"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = Gemma3ForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

prompt = "Eiffel tower is located in"
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=50, do_sample=False)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Expected: text without unusual spacing around periods and without repetitions.
Actual:  " the heart of Paris, France.The Eiffel Tower is a symbol of Paris and France.The Eiffel Tower is a symbol of Paris and France.The Eiffel Tower is a symbol of Paris and France.The Eiffel Tower is a symbol"

Notice the unusual spacing between "France." and "The Eiffel" that occurs multiple times within 50 tokens. Also notice the repetitions of "The Eiffel Tower is a symbol of Paris and France".

Notes:

  1. This repros for both gemma-3-1b-pt and gemma-3-27b-pt
  2. I think it repros on CPU and on accelerators with slightly different text, but similar problems.
  3. gemma-2-9b (also a pt model) output for the same prompt and also using greedy decoding looks free of the above issues: " Paris, France. It is the most visited monument in the world. It is 324 meters tall and was built in 1889. It is made of iron and has 1,665 steps. It is a". Here's the snippet for gemma2-9b.
Google org

Hi @dglasscortex
Thanks for reaching out. The repetition and spacing issues you're seeing are a result of using a small 1B model with greedy decoding and a base pretraining checkpoint. To resolve this, I recommend setting do_sample=True and adding a repetition_penalty . If you want even better results, we should probably swap to an instruction-tuned version.
I’ve updated the code below to reflect these changes:
generation = model.generate(**model_inputs, max_new_tokens=50, do_sample=True, repetition_penalty=1.1, no_repeat_ngram_size=4)

You can also fix the spacing issue by adding this at the end : "decoded = re.sub(r".(\S)", r". \1", decoded)"

This is the output i got
"Paris, and you need around 10-20 minutes to drive there. It will take one hour if you want to walk from Eiffel Tower station directly to the site. I think it depends on your purpose of visit. If it is"

Please implement these changes and let me know if they resolve the issue.
Thanks

Hi @pannaga10 ,

Thanks for the quick reply,

  1. I believe these issues are not caused by the model being small given these same symptoms repro for gemma 3 27B pt.
  2. I also believe these issues are not caused by the sampling being greedy given:
    (a) producing a token without a space "France.The" should be very unlikely. Under greedy sampling this would require for example that P(no-space-The|"Eiffel tower is located in the heart of Paris, France.") is the highest likelihood next token, which is unexpected. Digging a bit deeper I can see that adding a tokenizer.batch_decode(generation) at the end of the snippet I provided, yields the following:
    [' the', ' heart', ' of', ' Paris', ',', ' France', '.', '<start_of_image>', 'The', ' Eiffel', ' Tower', ' is' (...)]

This means the most likely next tokens is unexpectedly <start_of_image> (which I assume prints to empty). I verified the same is happening with gemma 3 27b pt.

(b) these symptoms do not repro using greedy sampling with (b) gemma 2 2b.

I agree that the greedy sampling and the small model size can lead to repetitions, but the current symptoms exceed my expectations. The sampling parameters you suggested may help mask the issue, but I think it's important to fix the underlying problem to make sure the probability distribution model itself is bug free. I think one key initial question is probably: why is (P(x|"Eiffel tower is located in the heart of Paris, France.") highest for x=<start_of_image>. You can verify this is currently the case with the following or so:

(...)
>>> model_inputs = tokenizer("Eiffel tower is located in the heart of Paris, France.", return_tensors="pt").to(
    model.device
)
>>> logits_BLV = model(**model_inputs).logits
>>> tokenizer.decode(torch.argmax(logits_BLV[0, -1, :]))
'<start_of_image>'

Thanks for your help so far!

Google org
This comment has been hidden (marked as Off-Topic)
Google org

Hi @dglasscortex
You are right that this behavior can’t be explained by greedy decoding or model size alone. The earlier suggestion to enable sampling and repetition penalties can improve surface-level output quality, but it does not address the underlying cause you mentioned. In Gemma 3 PT checkpoints, is a valid vocab token and is not currently masked for text-only generation. I agree this could be addressed in the model level . Will escalate this for further review.
Thanks

Thank you, @pannaga10 ! Looking forward to your team's investigation. My rough understanding is that Gemma3 models in general take image(s) as input, but do not generate them, so I would assume that during training there is never prediction loss on as a target token, which would mean the token should never have high probability, even though it is in the vocabulary.

Sign up or log in to comment