Unexpected Warning When Loading `google/shieldgemma-2-4b-it` & Low Accuracy on Custom Dataset

#8
by Haulyn5 - opened

Hello,

I recently attempted to utilize the google/shieldgemma-2-4b-it model, strictly following the example code provided in the model’s README (with only token param added). However, I encountered an unexpected warning during the loading process:

Some weights of ShieldGemma2ForImageClassification were not initialized from the model checkpoint at google/shieldgemma-2-4b-it and are newly initialized: ['model.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Additionally, when evaluating the model on my custom dataset, its accuracy only reached approximately **50%**—a result far below my expectations. Given the warning above, I suspect this performance discrepancy may be linked to the uninitialized weights mentioned.

I would greatly appreciate clarification on the following questions:

  1. Is this warning expected behavior when loading google/shieldgemma-2-4b-it?
  2. Could this uninitialized weight warning be the root cause of the unexpectedly low accuracy on my dataset?
  3. What steps are recommended to debug or verify whether the model weights have been correctly initialized?

Environment Details:

  • Python version: 3.11.2
  • PyTorch version: 2.8.0+cu128
  • Transformers version: 4.56.2
  • OS: Linux 5.4.143-amd64
  • GPU: Tesla V100-SXM2-32GB

Thanks in advance for your time and assistance!

CC: @merve , @BalakrishnaCh , @Renu11 , @RyanMullins

Google org

Hi @Haulyn5 ,

Thanks for reaching out to us, yes, this warning is generally expected behavior when loading a language model checkpoint into a task-specific head, but it is critical to understand which weights are affected. ShieldGemma2ForImageClassification.from_pretrained(), the Hugging Face library is loading the original pre-trained model weights but then attempting to map them into a model class that includes a specific Image Classification Head. The weight flagged, ['model.lm_head.weight'], is the Language Model (LM) head from the original base Gemma architecture.

The ShieldGemma models are instruction-tuned for a specific safety evaluation task, which involves giving it an image and a policy text and having it output a "Yes" or "No" token.
To correctly use the model for Image Classification on your custom dataset, you need to perform Fine-tuning. The model is giving you that exact advice, you should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Recommended Action :
The definitive fix for this problem is to fine-tune the model on your custom dataset. This process will train the randomly initialized lm_head (or the actual classification head being used) to learn the mapping from the model's internal features to your specific output labels, leveraging the powerful pre-trained Gemma backbone.

Thanks.

Sign up or log in to comment