Causal Masking used but Embedding Gemma is supposed to be using bidirectional attention
Hi just to check, I read that Embedding Gemma uses bidirectional attention.
But from what I can see from the transformers code, it seems that a causal mask is used.
This would produce different results from an actual bidirectional attention.
Is this intended/correct?
Hi there 👋 A bidirectional mask is indeed necessary for this model (see https://huggingface.co/google/embeddinggemma-300m/blob/main/config.json#L57).
Make sure you're using the latest version of transformers, where this is indeed taken into account: https://github.com/huggingface/transformers/blob/a7f29523361b2cc12e51c1f5133d95f122f6f45c/src/transformers/models/gemma3/modeling_gemma3.py#L565
Hi
@tltl123
You’ve correctly identified that the model inherits its structural DNA from a causal decoder. By default, this architecture is "wired" to prevent tokens from looking ahead. To adapt this for embeddings, we utilize a specific configuration flag called bidirectional_mask: true. This parameter acts as a global override that bypasses the standard triangular causal mask enabling an "all-to-all" attention pattern where every token can attend to every other token in the sequence .
Additionally, as mentioned by Xenova please ensure your transformers environment is on version 4.46 or higher, as these later versions contain the specific logic required to handle the bidirectional switching.
Thanks
Got it thanks!
Another question:
From what I understand about sentence embedders, it needs to consume the whole input at once. Normally for LLMs, if the input is too long, it would be chunked into smaller lengths, which means you would need to use KV caching. But for the case of sentence embedders, chunking up the input means that some tokens in chunk 0 cannot attend to some tokens in chunk 1 and vice versa. So, since we cannot chunk up the inputs, KV caching would not be needed at all?
How then should you handle the cases where your input exceeds the max pos embedding length?
You are absolutely correct that Sentence embedders work very differently from autoregressive LLMs. They’re encoder-style models here you pass in the entire sequence at once, it does a single forward pass with full self-attention, and you get an embedding out. There’s no token-by-token generation step. And for the same reason, KV caching doesn’t really matter here. KV caching only helps when you’re generating tokens incrementally, which embedding models just don’t do.
Coming to part of handling cases where input exceeds max length?
Since models have a hard length limit, and if you exceed it you have to make an tradeoff. Few techniques which you can use which are better than the hard truncation of chunks are
1.Sliding window + pooling - Split the text into overlapping chunks, embed each chunk independently then pool the chunk embeddings (mean/max/weighted pooling). This defines a document-level embedding, not an approximation of a single forward pass.
2. Hierarchical approaches - Embed chunks first, then aggregate them with another model or learned pooling mechanism (used in long-document retrieval).
3. Use or train a long-context embedder - Models with extended context (e.g. Longformer-style or RoPE-scaled) must be trained or fine-tuned for embeddings; simply extending positional embeddings at inference usually degrades embedding quality.
All these are tradeoffs which you can decide based on your usecase , computational resources and whether preserving all the information is critical .
Thanks