|
|
--- |
|
|
license: apache-2.0 |
|
|
base_model: |
|
|
- google/paligemma-3b-mix-448 |
|
|
pipeline_tag: visual-document-retrieval |
|
|
library_name: transformers |
|
|
tags: |
|
|
- transformers |
|
|
- multimodal_embedding |
|
|
- embedding |
|
|
- colpali |
|
|
- multilingual-embedding |
|
|
- causalembed |
|
|
datasets: |
|
|
- vidore/colpali_train_set |
|
|
--- |
|
|
# Z1zs/CausalPali |
|
|
|
|
|
**CausalPali** is an auto-regressive multimodal embedding model based on the **PaliGemma-3B** architecture, proposed by researchers from HKUST(GZ) and Alibaba Cloud. It redefines visual document retrieval by treating embedding creation as a **sequential generation process**, mapping text queries and visual document pages into a compact and expressive latent multi-vector space. |
|
|
|
|
|
The model employs a novel **CausalEmbed** training paradigm, which fine-tunes the MLLM to synthesize latent representations token-by-token. This approach effectively distills dense visual information into a small set of highly semantic vectors, overcoming the storage bottlenecks of traditional multi-vector models while maintaining superior retrieval accuracy. On benchmarks like Vidore, **CausalPali** achieves significant performance gains over the original ColPali baseline while using a fraction of the storage. |
|
|
|
|
|
## Key Features |
|
|
|
|
|
**Model size**: 3 billion parameters. |
|
|
**Auto-regressive Generation**: Unlike parallel patch-level encoding, it generates multi-vector embeddings sequentially in a latent space, enabling better modeling of dependencies between visual elements. |
|
|
**High Efficiency**: Reduces the number of visual tokens by **30-155x** compared to standard ColPali (using only ~64 tokens vs. thousands), significantly lowering storage and memory overhead. |
|
|
**Test-time Scaling**: Supports dynamic adjustment of the number of generated embedding vectors during inference, allowing users to trade off between retrieval speed and precision without retraining. |
|
|
**Superior Performance**: Achieves a **14.6% performance uplift** on retrieval tasks compared to the full-resource ColPali baseline, demonstrating that compact, auto-regressively generated embeddings can capture richer semantics. |
|
|
**Context-aware Compression**: Incorporates iterative margin loss and progressive refinement during training to ensure each generated vector adds maximal marginal information. |
|
|
|
|
|
|
|
|
## Usage |
|
|
|
|
|
**Requirements** |
|
|
See in [CausalEmbed Repo](https://github.com/Z1zs/Causal-Embed). |
|
|
|
|
|
**Basic Usage** |
|
|
```bash |
|
|
pip install git+https://github.com/Z1zs/Causal-Embed |
|
|
``` |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from PIL import Image |
|
|
from colpali_engine.models import ColPaliProcessor, CausalPali |
|
|
|
|
|
dtoken_num, qtoken_num = 64, 16 |
|
|
images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")] |
|
|
queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"] |
|
|
processor = ColPaliProcessor.from_pretrained( |
|
|
"Z1zs/CausalPali", |
|
|
max_num_visual_tokens=768, |
|
|
fix_mistral_regex=True |
|
|
) |
|
|
processor.tokenizer.pad_token = processor.tokenizer.eos_token |
|
|
processor.tokenizer.add_tokens("<|latent|>") |
|
|
latent_id = processor.tokenizer.convert_tokens_to_ids("<|latent|>") |
|
|
|
|
|
model = CausalPali.from_pretrained( |
|
|
"Z1zs/CausalPali", |
|
|
torch_dtype=torch.bfloat16, |
|
|
doc_token_num=dtoken_num, # expected doc token num |
|
|
query_token_num=qtoken_num, # expected query token num |
|
|
attn_implementation="flash_attention_2", |
|
|
) |
|
|
model.latent_token_id = latent_id |
|
|
model.to(device).eval() |
|
|
|
|
|
batch_images = processor.process_images(images).to(model.device) |
|
|
batch_queries = processor.process_queries(queries).to(model.device) |
|
|
|
|
|
# Forward pass |
|
|
with torch.no_grad(): |
|
|
image_embeddings = model(**batch_images) |
|
|
query_embeddings = model(**batch_queries) |
|
|
|
|
|
scores = processor.score_multi_vector(query_embeddings, image_embeddings) |
|
|
``` |
|
|
|
|
|
## Model Performance |
|
|
|
|
|
### Vidore v1 + v2 (NDCG@5) |
|
|
|
|
|
| Model | Token Count | Vidore V1 | Vidore V2 | |
|
|
|--------------------------------------------|------|---------|---------| |
|
|
| **CausalQwen** | 64 | 81.1 | 51.6 | |
|
|
| **CausalPali** | 64 | 75.0 | 45.4 | |
|
|
|
|
|
|
|
|
### Vidore v3 (NDCG@5) |
|
|
|
|
|
| Model | Token Count | PUB AVG | |
|
|
|--------------------------------------------|------|---------| |
|
|
| **CausalQwen** | 64 | 42.6 | |
|
|
| **CausalPali** | 64 | 33.6 | |
|
|
|
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model in your work, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{huo2026causalembed, |
|
|
title={CausalEmbed: Auto-Regressive Multi-Vector Generation in Latent Space for Visual Document Embedding}, |
|
|
author={Jiahao Huo and Yu Huang and Yibo Yan and Ye Pan and Yi Cao and Mingdong Ou and Philip S. Yu and Xuming Hu}, |
|
|
year={2026}, |
|
|
eprint={2601.21262}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CL}, |
|
|
url={https://arxiv.org/abs/2601.21262}, |
|
|
} |
|
|
``` |