CausalPali / README.md
Z1zs's picture
Update README.md
ec31906 verified
---
license: apache-2.0
base_model:
- google/paligemma-3b-mix-448
pipeline_tag: visual-document-retrieval
library_name: transformers
tags:
- transformers
- multimodal_embedding
- embedding
- colpali
- multilingual-embedding
- causalembed
datasets:
- vidore/colpali_train_set
---
# Z1zs/CausalPali
**CausalPali** is an auto-regressive multimodal embedding model based on the **PaliGemma-3B** architecture, proposed by researchers from HKUST(GZ) and Alibaba Cloud. It redefines visual document retrieval by treating embedding creation as a **sequential generation process**, mapping text queries and visual document pages into a compact and expressive latent multi-vector space.
The model employs a novel **CausalEmbed** training paradigm, which fine-tunes the MLLM to synthesize latent representations token-by-token. This approach effectively distills dense visual information into a small set of highly semantic vectors, overcoming the storage bottlenecks of traditional multi-vector models while maintaining superior retrieval accuracy. On benchmarks like Vidore, **CausalPali** achieves significant performance gains over the original ColPali baseline while using a fraction of the storage.
## Key Features
**Model size**: 3 billion parameters.
**Auto-regressive Generation**: Unlike parallel patch-level encoding, it generates multi-vector embeddings sequentially in a latent space, enabling better modeling of dependencies between visual elements.
**High Efficiency**: Reduces the number of visual tokens by **30-155x** compared to standard ColPali (using only ~64 tokens vs. thousands), significantly lowering storage and memory overhead.
**Test-time Scaling**: Supports dynamic adjustment of the number of generated embedding vectors during inference, allowing users to trade off between retrieval speed and precision without retraining.
**Superior Performance**: Achieves a **14.6% performance uplift** on retrieval tasks compared to the full-resource ColPali baseline, demonstrating that compact, auto-regressively generated embeddings can capture richer semantics.
**Context-aware Compression**: Incorporates iterative margin loss and progressive refinement during training to ensure each generated vector adds maximal marginal information.
## Usage
**Requirements**
See in [CausalEmbed Repo](https://github.com/Z1zs/Causal-Embed).
**Basic Usage**
```bash
pip install git+https://github.com/Z1zs/Causal-Embed
```
```python
import torch
from PIL import Image
from colpali_engine.models import ColPaliProcessor, CausalPali
dtoken_num, qtoken_num = 64, 16
images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
processor = ColPaliProcessor.from_pretrained(
"Z1zs/CausalPali",
max_num_visual_tokens=768,
fix_mistral_regex=True
)
processor.tokenizer.pad_token = processor.tokenizer.eos_token
processor.tokenizer.add_tokens("<|latent|>")
latent_id = processor.tokenizer.convert_tokens_to_ids("<|latent|>")
model = CausalPali.from_pretrained(
"Z1zs/CausalPali",
torch_dtype=torch.bfloat16,
doc_token_num=dtoken_num, # expected doc token num
query_token_num=qtoken_num, # expected query token num
attn_implementation="flash_attention_2",
)
model.latent_token_id = latent_id
model.to(device).eval()
batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**batch_images)
query_embeddings = model(**batch_queries)
scores = processor.score_multi_vector(query_embeddings, image_embeddings)
```
## Model Performance
### Vidore v1 + v2 (NDCG@5)
| Model | Token Count | Vidore V1 | Vidore V2 |
|--------------------------------------------|------|---------|---------|
| **CausalQwen** | 64 | 81.1 | 51.6 |
| **CausalPali** | 64 | 75.0 | 45.4 |
### Vidore v3 (NDCG@5)
| Model | Token Count | PUB AVG |
|--------------------------------------------|------|---------|
| **CausalQwen** | 64 | 42.6 |
| **CausalPali** | 64 | 33.6 |
## Citation
If you use this model in your work, please cite:
```bibtex
@misc{huo2026causalembed,
title={CausalEmbed: Auto-Regressive Multi-Vector Generation in Latent Space for Visual Document Embedding},
author={Jiahao Huo and Yu Huang and Yibo Yan and Ye Pan and Yi Cao and Mingdong Ou and Philip S. Yu and Xuming Hu},
year={2026},
eprint={2601.21262},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2601.21262},
}
```