CausalPali / README.md
Z1zs's picture
Update README.md
ec31906 verified
metadata
license: apache-2.0
base_model:
  - google/paligemma-3b-mix-448
pipeline_tag: visual-document-retrieval
library_name: transformers
tags:
  - transformers
  - multimodal_embedding
  - embedding
  - colpali
  - multilingual-embedding
  - causalembed
datasets:
  - vidore/colpali_train_set

Z1zs/CausalPali

CausalPali is an auto-regressive multimodal embedding model based on the PaliGemma-3B architecture, proposed by researchers from HKUST(GZ) and Alibaba Cloud. It redefines visual document retrieval by treating embedding creation as a sequential generation process, mapping text queries and visual document pages into a compact and expressive latent multi-vector space.

The model employs a novel CausalEmbed training paradigm, which fine-tunes the MLLM to synthesize latent representations token-by-token. This approach effectively distills dense visual information into a small set of highly semantic vectors, overcoming the storage bottlenecks of traditional multi-vector models while maintaining superior retrieval accuracy. On benchmarks like Vidore, CausalPali achieves significant performance gains over the original ColPali baseline while using a fraction of the storage.

Key Features

Model size: 3 billion parameters.
Auto-regressive Generation: Unlike parallel patch-level encoding, it generates multi-vector embeddings sequentially in a latent space, enabling better modeling of dependencies between visual elements.
High Efficiency: Reduces the number of visual tokens by 30-155x compared to standard ColPali (using only ~64 tokens vs. thousands), significantly lowering storage and memory overhead.
Test-time Scaling: Supports dynamic adjustment of the number of generated embedding vectors during inference, allowing users to trade off between retrieval speed and precision without retraining.
Superior Performance: Achieves a 14.6% performance uplift on retrieval tasks compared to the full-resource ColPali baseline, demonstrating that compact, auto-regressively generated embeddings can capture richer semantics.
Context-aware Compression: Incorporates iterative margin loss and progressive refinement during training to ensure each generated vector adds maximal marginal information.

Usage

Requirements See in CausalEmbed Repo.

Basic Usage

pip install git+https://github.com/Z1zs/Causal-Embed
import torch
from PIL import Image
from colpali_engine.models import ColPaliProcessor, CausalPali

dtoken_num, qtoken_num = 64, 16
images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")]
queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"]
processor = ColPaliProcessor.from_pretrained(
            "Z1zs/CausalPali", 
            max_num_visual_tokens=768, 
            fix_mistral_regex=True
        )
processor.tokenizer.pad_token = processor.tokenizer.eos_token
processor.tokenizer.add_tokens("<|latent|>")
latent_id = processor.tokenizer.convert_tokens_to_ids("<|latent|>")

model = CausalPali.from_pretrained(
            "Z1zs/CausalPali",
            torch_dtype=torch.bfloat16,
            doc_token_num=dtoken_num, # expected doc token num
            query_token_num=qtoken_num, # expected query token num
            attn_implementation="flash_attention_2",
        )
model.latent_token_id = latent_id
model.to(device).eval()

batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)

# Forward pass
with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)

Model Performance

Vidore v1 + v2 (NDCG@5)

Model Token Count Vidore V1 Vidore V2
CausalQwen 64 81.1 51.6
CausalPali 64 75.0 45.4

Vidore v3 (NDCG@5)

Model Token Count PUB AVG
CausalQwen 64 42.6
CausalPali 64 33.6

Citation

If you use this model in your work, please cite:

@misc{huo2026causalembed,
      title={CausalEmbed: Auto-Regressive Multi-Vector Generation in Latent Space for Visual Document Embedding}, 
      author={Jiahao Huo and Yu Huang and Yibo Yan and Ye Pan and Yi Cao and Mingdong Ou and Philip S. Yu and Xuming Hu},
      year={2026},
      eprint={2601.21262},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2601.21262}, 
}