ColGemma-4

ColGemma4-E2B-IT-Base

See also: ColGemma4-E4B-IT-Base (larger 4.5B-effective variant)

ColGemma4 is a visual document retrieval model built on Google's Gemma 4 E2B (5.1B params). It generates ColBERT-style multi-vector representations of document images and text queries for late-interaction retrieval.

This is the base version - a single-seed LoRA adapter trained with ColbertLoss, no hard negatives, no model merging. It establishes the ColGemma4 baseline on the ViDoRe benchmark.

Built following the ColPali architecture pattern, adapted for Gemma 4's multimodal architecture.

Model Description

Property Value
Base model google/gemma-4-E2B-it (5.1B total, 2.3B effective)
Architecture ColBERT late-interaction over Gemma 4 VLM
Embedding dim 128
Visual tokens 1120 (max soft tokens)
Fine-tuning LoRA (r=32, alpha=32, dropout=0.1)
Trainable params 50.7M (1.04% of total)
Projection Random-init linear (hidden_size → 128), not trained
Training loss ColbertLoss (temperature=0.02, in-batch negatives only)
Precision BF16

Benchmark Results

All scores are nDCG@5 on the ViDoRe benchmark.

ViDoRe V1

Task nDCG@5 nDCG@10
ArxivQA 78.54 80.12
DocVQA 52.43 55.86
InfoVQA 87.11 87.88
ShiftProject 81.34 82.29
SyntheticDocQA - AI 98.16 98.16
SyntheticDocQA - Energy 92.04 92.40
SyntheticDocQA - Government 94.37 94.37
SyntheticDocQA - Healthcare 93.06 93.42
Tabfquad 84.78 86.10
Tatdqa 69.49 72.15
Average 83.13 84.28

ViDoRe V2

Task nDCG@5 nDCG@10
BioMedical Lectures 53.06 56.30
ESG Reports - HL 54.96 58.39
ESG Reports 34.19 38.38
Economics Reports 37.19 38.86
Average 44.85 47.98

ViDoRe V3

Task nDCG@5 nDCG@10
Computer Science 55.80 59.61
Energy 56.03 59.09
Finance En 39.61 42.12
Finance Fr 35.82 39.03
HR 38.62 42.01
Industrial 31.98 33.73
Pharmaceuticals 49.03 50.68
Physics 40.82 43.89
Average 43.46 46.27

Usage

Installation

pip install colpali-engine transformers torch peft

Loading the Model

import torch
from colgemma4 import ColGemma4, ColGemma4Processor

# Load base model + LoRA adapter
model = ColGemma4.from_pretrained(
    "athrael-soju/ColGemma4-E2B-IT-Base",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
    ignore_mismatched_sizes=True,  # needed for custom_text_proj
)

processor = ColGemma4Processor.from_pretrained(
    "athrael-soju/ColGemma4-E2B-IT-Base",
    max_num_visual_tokens=1120,
)

Encoding Documents (Images)

from PIL import Image

images = [Image.open("page1.png"), Image.open("page2.png")]
batch_doc = processor.process_images(images)
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}

with torch.no_grad():
    doc_embeddings = model(**batch_doc)  # (batch, seq_len, 128)

Encoding Queries

queries = ["What is the revenue for Q3 2024?"]
batch_query = processor.process_queries(queries)
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}

with torch.no_grad():
    query_embeddings = model(**batch_query)  # (batch, seq_len, 128)

Scoring (MaxSim)

scores = processor.score(query_embeddings, doc_embeddings)
# scores[i][j] = relevance of query i to document j

Training Configuration

Base model: google/gemma-4-E2B-it
Loss: ColbertLoss (temperature=0.02)
Hard negatives: none
Batch size per GPU: 64
GPUs: 7
Gradient accumulation: 1
Effective batch size: 448 (64 x 7)
In-batch negatives: 448
LoRA:
  r: 32
  alpha: 32
  dropout: 0.1
  target_modules: "language_model.*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj)"
  # custom_text_proj is NOT LoRA-targeted (random init, untrained)
Learning rate: 2e-4 (cosine schedule, 8% warmup)
Weight decay: 0.02
Epochs: 1
Steps: 1,729
Visual tokens: 1120
Attention: Bidirectional (all layers patched)
Gradient checkpointing: enabled
Precision: BF16

Training Data

Trained on ~774K query-document pairs from publicly available datasets:

  • vidore/colpali_train_set
  • openbmb/VisRAG-Ret-Train-Synthetic-data
  • openbmb/VisRAG-Ret-Train-In-domain-data
  • llamaindex/vdr-multilingual-train (en/de/es/fr/it subsets)
  • vidore/tatdqa_train

Troubleshooting & Fine-tuning Guide

If you're building on this model or training your own ColGemma4, here's what we learned along the way.

Gemma 4 Architecture Gotchas

  • Position embedding memory blow-up - The vision encoder uses F.one_hot(positions, num_classes=10240) which allocates ~314 GB at batch=32 with 1120 tokens. Replacing with F.embedding is mathematically identical and saves ~106 GB/GPU. Required for batch sizes above 8.
    # In Gemma4VisionEncoder._position_embeddings:
    # Replace: one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
    # With:    return F.embedding(clamped_positions, self.position_embedding_table)
    
  • KV-sharing and use_cache - Gemma 4 E2B has 20 of 35 text layers that reuse K/V from earlier layers via the cache. During training, always set use_cache=False to ensure every layer computes its own K/V and all LoRA weights are active. At inference time, set use_cache=True so the KV-sharing architecture works as designed.
  • Flash Attention is incompatible - Gemma 4 has head_dim 256 (sliding) and 512 (global). FA v2 caps at 256, FA v4 at 128. Use attn_implementation="sdpa" instead.
  • Gradient checkpointing is mandatory - At 1120 visual tokens, activations from 35 layers consume ~175 GB. Even with the position patch, disabling gradient checkpointing will OOM.

Training Tips

  • Leave custom_text_proj alone - The 128-dim projection is randomly initialized and works best untrained. Both LoRA-targeting and modules_to_save caused regressions in our experiments. The random projection provides a consistent mapping without overfitting.
  • Keep grad_accum=1 with contrastive losses - all_gather only collects the current micro-batch, so accumulation steps halve your in-batch negatives. Training loss looks deceptively good but eval regresses. Use the largest batch that fits with grad_accum=1.
  • Avoid torch.compile - It adds _orig_mod. to weight keys, breaking PEFT adapter loading at eval time. Scores drop to near-zero despite healthy training loss.
  • Watch for silent weight randomization - ignore_mismatched_sizes=True will initialize mismatched weights to random without any error. Sanity check: loss at step 0 should be near log(batch_size) (~6.1 for batch 448), and grad norms should be 5-20. If grad norms are in the thousands, weights didn't load correctly.

HydraGemma4 (Dual-Head Variant)

This repo also includes the HydraGemma4 architecture - a dual-head model that supports both retrieval (ColBERT embeddings) and generation (text output) from the same base model. See hydra_gemma4.py for the implementation.

Limitations

  • Single-seed, no model merging - variance across seeds is not averaged out
  • No hard negative mining - relies entirely on in-batch negatives
  • English-centric training data (some multilingual from VDR)
  • Visual token budget fixed at 1120 - train and eval must match

Citation

@misc{colgemma4,
  title={ColGemma4: Visual Document Retrieval with Gemma 4},
  author={Athrael Soju},
  year={2026},
  url={https://huggingface.co/athrael-soju/ColGemma4-E2B-IT-Base}
}

Acknowledgements

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for athrael-soju/ColGemma4-E2B-IT-Base

Adapter
(51)
this model

Datasets used to train athrael-soju/ColGemma4-E2B-IT-Base

Collection including athrael-soju/ColGemma4-E2B-IT-Base

Papers for athrael-soju/ColGemma4-E2B-IT-Base