ColGemma4-E2B-IT-Base
See also: ColGemma4-E4B-IT-Base (larger 4.5B-effective variant)
ColGemma4 is a visual document retrieval model built on Google's Gemma 4 E2B (5.1B params). It generates ColBERT-style multi-vector representations of document images and text queries for late-interaction retrieval.
This is the base version - a single-seed LoRA adapter trained with ColbertLoss, no hard negatives, no model merging. It establishes the ColGemma4 baseline on the ViDoRe benchmark.
Built following the ColPali architecture pattern, adapted for Gemma 4's multimodal architecture.
Model Description
| Property | Value |
|---|---|
| Base model | google/gemma-4-E2B-it (5.1B total, 2.3B effective) |
| Architecture | ColBERT late-interaction over Gemma 4 VLM |
| Embedding dim | 128 |
| Visual tokens | 1120 (max soft tokens) |
| Fine-tuning | LoRA (r=32, alpha=32, dropout=0.1) |
| Trainable params | 50.7M (1.04% of total) |
| Projection | Random-init linear (hidden_size → 128), not trained |
| Training loss | ColbertLoss (temperature=0.02, in-batch negatives only) |
| Precision | BF16 |
Benchmark Results
All scores are nDCG@5 on the ViDoRe benchmark.
ViDoRe V1
| Task | nDCG@5 | nDCG@10 |
|---|---|---|
| ArxivQA | 78.54 | 80.12 |
| DocVQA | 52.43 | 55.86 |
| InfoVQA | 87.11 | 87.88 |
| ShiftProject | 81.34 | 82.29 |
| SyntheticDocQA - AI | 98.16 | 98.16 |
| SyntheticDocQA - Energy | 92.04 | 92.40 |
| SyntheticDocQA - Government | 94.37 | 94.37 |
| SyntheticDocQA - Healthcare | 93.06 | 93.42 |
| Tabfquad | 84.78 | 86.10 |
| Tatdqa | 69.49 | 72.15 |
| Average | 83.13 | 84.28 |
ViDoRe V2
| Task | nDCG@5 | nDCG@10 |
|---|---|---|
| BioMedical Lectures | 53.06 | 56.30 |
| ESG Reports - HL | 54.96 | 58.39 |
| ESG Reports | 34.19 | 38.38 |
| Economics Reports | 37.19 | 38.86 |
| Average | 44.85 | 47.98 |
ViDoRe V3
| Task | nDCG@5 | nDCG@10 |
|---|---|---|
| Computer Science | 55.80 | 59.61 |
| Energy | 56.03 | 59.09 |
| Finance En | 39.61 | 42.12 |
| Finance Fr | 35.82 | 39.03 |
| HR | 38.62 | 42.01 |
| Industrial | 31.98 | 33.73 |
| Pharmaceuticals | 49.03 | 50.68 |
| Physics | 40.82 | 43.89 |
| Average | 43.46 | 46.27 |
Usage
Installation
pip install colpali-engine transformers torch peft
Loading the Model
import torch
from colgemma4 import ColGemma4, ColGemma4Processor
# Load base model + LoRA adapter
model = ColGemma4.from_pretrained(
"athrael-soju/ColGemma4-E2B-IT-Base",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
ignore_mismatched_sizes=True, # needed for custom_text_proj
)
processor = ColGemma4Processor.from_pretrained(
"athrael-soju/ColGemma4-E2B-IT-Base",
max_num_visual_tokens=1120,
)
Encoding Documents (Images)
from PIL import Image
images = [Image.open("page1.png"), Image.open("page2.png")]
batch_doc = processor.process_images(images)
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
with torch.no_grad():
doc_embeddings = model(**batch_doc) # (batch, seq_len, 128)
Encoding Queries
queries = ["What is the revenue for Q3 2024?"]
batch_query = processor.process_queries(queries)
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
with torch.no_grad():
query_embeddings = model(**batch_query) # (batch, seq_len, 128)
Scoring (MaxSim)
scores = processor.score(query_embeddings, doc_embeddings)
# scores[i][j] = relevance of query i to document j
Training Configuration
Base model: google/gemma-4-E2B-it
Loss: ColbertLoss (temperature=0.02)
Hard negatives: none
Batch size per GPU: 64
GPUs: 7
Gradient accumulation: 1
Effective batch size: 448 (64 x 7)
In-batch negatives: 448
LoRA:
r: 32
alpha: 32
dropout: 0.1
target_modules: "language_model.*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj)"
# custom_text_proj is NOT LoRA-targeted (random init, untrained)
Learning rate: 2e-4 (cosine schedule, 8% warmup)
Weight decay: 0.02
Epochs: 1
Steps: 1,729
Visual tokens: 1120
Attention: Bidirectional (all layers patched)
Gradient checkpointing: enabled
Precision: BF16
Training Data
Trained on ~774K query-document pairs from publicly available datasets:
vidore/colpali_train_setopenbmb/VisRAG-Ret-Train-Synthetic-dataopenbmb/VisRAG-Ret-Train-In-domain-datallamaindex/vdr-multilingual-train(en/de/es/fr/it subsets)vidore/tatdqa_train
Troubleshooting & Fine-tuning Guide
If you're building on this model or training your own ColGemma4, here's what we learned along the way.
Gemma 4 Architecture Gotchas
- Position embedding memory blow-up - The vision encoder uses
F.one_hot(positions, num_classes=10240)which allocates ~314 GB at batch=32 with 1120 tokens. Replacing withF.embeddingis mathematically identical and saves ~106 GB/GPU. Required for batch sizes above 8.# In Gemma4VisionEncoder._position_embeddings: # Replace: one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) # With: return F.embedding(clamped_positions, self.position_embedding_table) - KV-sharing and
use_cache- Gemma 4 E2B has 20 of 35 text layers that reuse K/V from earlier layers via the cache. During training, always setuse_cache=Falseto ensure every layer computes its own K/V and all LoRA weights are active. At inference time, setuse_cache=Trueso the KV-sharing architecture works as designed. - Flash Attention is incompatible - Gemma 4 has head_dim 256 (sliding) and 512 (global). FA v2 caps at 256, FA v4 at 128. Use
attn_implementation="sdpa"instead. - Gradient checkpointing is mandatory - At 1120 visual tokens, activations from 35 layers consume ~175 GB. Even with the position patch, disabling gradient checkpointing will OOM.
Training Tips
- Leave
custom_text_projalone - The 128-dim projection is randomly initialized and works best untrained. Both LoRA-targeting andmodules_to_savecaused regressions in our experiments. The random projection provides a consistent mapping without overfitting. - Keep
grad_accum=1with contrastive losses -all_gatheronly collects the current micro-batch, so accumulation steps halve your in-batch negatives. Training loss looks deceptively good but eval regresses. Use the largest batch that fits withgrad_accum=1. - Avoid
torch.compile- It adds_orig_mod.to weight keys, breaking PEFT adapter loading at eval time. Scores drop to near-zero despite healthy training loss. - Watch for silent weight randomization -
ignore_mismatched_sizes=Truewill initialize mismatched weights to random without any error. Sanity check: loss at step 0 should be nearlog(batch_size)(~6.1 for batch 448), and grad norms should be 5-20. If grad norms are in the thousands, weights didn't load correctly.
HydraGemma4 (Dual-Head Variant)
This repo also includes the HydraGemma4 architecture - a dual-head model that supports both retrieval (ColBERT embeddings) and generation (text output) from the same base model. See hydra_gemma4.py for the implementation.
Limitations
- Single-seed, no model merging - variance across seeds is not averaged out
- No hard negative mining - relies entirely on in-batch negatives
- English-centric training data (some multilingual from VDR)
- Visual token budget fixed at 1120 - train and eval must match
Citation
@misc{colgemma4,
title={ColGemma4: Visual Document Retrieval with Gemma 4},
author={Athrael Soju},
year={2026},
url={https://huggingface.co/athrael-soju/ColGemma4-E2B-IT-Base}
}
Acknowledgements
- ColPali by Faysse et al. for the ColBERT-over-VLM architecture
- Google Gemma 4 for the base model
- colpali-engine for the training framework
- ViDoRe benchmark for evaluation
Model tree for athrael-soju/ColGemma4-E2B-IT-Base
Base model
google/gemma-4-E2B-it