audit_assistant / src /colpali /processor.py
akryldigital's picture
add colpali scripts
150fb2f verified
"""
ColPali Query Embedding Processor
Handles query embedding generation using ColSmol-500M model.
This is a standalone implementation for inference only (no PDF processing).
"""
import logging
from typing import Optional
import torch
logger = logging.getLogger(__name__)
# Check if colpali_engine is available
try:
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
COLPALI_AVAILABLE = True
except ImportError:
COLPALI_AVAILABLE = False
logger.warning("colpali_engine not installed. Install with: pip install colpali-engine")
class ColPaliProcessor:
"""
Processes queries using ColPali for visual document retrieval.
This is a lightweight processor focused on query embedding generation.
"""
def __init__(
self,
model_name: str = "vidore/colSmol-500M",
device: str = "cpu",
torch_dtype: torch.dtype = torch.float32,
batch_size: int = 4
):
"""
Initialize ColPali processor.
Args:
model_name: HuggingFace model name for ColPali
device: Device to use ("cuda", "cpu", "mps")
torch_dtype: Data type for model weights
batch_size: Batch size for processing
"""
if not COLPALI_AVAILABLE:
raise ImportError(
"colpali_engine not installed. Install with: "
"pip install colpali-engine"
)
# Validate model name (must include organization prefix)
if '/' not in model_name:
logger.warning(f"⚠️ Model name '{model_name}' missing organization prefix, adding 'vidore/'")
model_name = f"vidore/{model_name}"
self.model_name = model_name
self.device = device
self.torch_dtype = torch_dtype
self.batch_size = batch_size
logger.info(f"πŸ€– Loading ColPali model: {model_name}")
logger.info(f" Device: {device}, dtype: {torch_dtype}")
# Load model and processor
try:
# Determine attention implementation
attn_implementation = "eager" # Default for compatibility
if device != "cpu":
try:
import flash_attn
attn_implementation = "flash_attention_2"
logger.info(" Using FlashAttention2 for faster inference")
except ImportError:
logger.info(" FlashAttention2 not available, using eager attention")
self.model = ColIdefics3.from_pretrained(
model_name,
dtype=torch_dtype,
device_map=device,
attn_implementation=attn_implementation
).eval()
self.processor = ColIdefics3Processor.from_pretrained(model_name)
logger.info(f"βœ… ColPali model loaded successfully")
logger.info(f" Attention implementation: {attn_implementation}")
except Exception as e:
logger.error(f"❌ Failed to load ColPali model: {e}")
raise
def embed_query(self, query_text: str) -> torch.Tensor:
"""
Generate embedding for a text query.
Args:
query_text: Natural language query string
Returns:
Query embedding tensor of shape [num_patches, embedding_dim]
"""
with torch.no_grad():
# Process query using ColPali's query processing
processed_query = self.processor.process_queries([query_text]).to(self.model.device)
query_embedding = self.model(**processed_query)
return query_embedding
@property
def embedding_dim(self) -> int:
"""Get the embedding dimension of the model."""
return self.model.config.hidden_size
@property
def image_token_id(self) -> int:
"""Get the image token ID from the processor."""
return self.processor.image_token_id