Spaces:
Sleeping
Sleeping
| """ | |
| Visual Embedder - Generate visual and text embeddings for document retrieval. | |
| This module provides a flexible interface that supports: | |
| - ColPali models (ColSmol, ColPali, ColQwen2) | |
| - Other vision-language models (future) | |
| - Image embedding with tile-aware processing | |
| - Query embedding with special token filtering | |
| The embedder is BACKEND-AGNOSTIC - configure which model to use via the | |
| `backend` parameter or model_name. | |
| """ | |
| import gc | |
| import logging | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| logger = logging.getLogger(__name__) | |
| class VisualEmbedder: | |
| """ | |
| Visual document embedder supporting multiple backends. | |
| Currently supports: | |
| - ColPali family (ColSmol-500M, ColPali, ColQwen2) | |
| - More backends can be added | |
| Args: | |
| model_name: HuggingFace model name (e.g., "vidore/colSmol-500M") | |
| backend: Backend type ("colpali", "auto"). "auto" detects from model_name. | |
| device: Device to use (auto, cuda, mps, cpu) | |
| torch_dtype: Data type for model weights | |
| batch_size: Batch size for image processing | |
| filter_special_tokens: Filter special tokens from query embeddings | |
| Example: | |
| >>> # Auto-detect backend from model name | |
| >>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M") | |
| >>> | |
| >>> # Embed images | |
| >>> image_embeddings = embedder.embed_images(images) | |
| >>> | |
| >>> # Embed query | |
| >>> query_embedding = embedder.embed_query("What is the budget?") | |
| >>> | |
| >>> # Get token info for saliency maps | |
| >>> embeddings, token_infos = embedder.embed_images( | |
| ... images, return_token_info=True | |
| ... ) | |
| """ | |
| # Known model families and their backends | |
| MODEL_BACKENDS = { | |
| "colsmol": "colpali", | |
| "colpali": "colpali", | |
| "colqwen": "colpali", | |
| "colidefics": "colpali", | |
| } | |
| def __init__( | |
| self, | |
| model_name: str = "vidore/colSmol-500M", | |
| backend: str = "auto", | |
| device: Optional[str] = None, | |
| torch_dtype: Optional[torch.dtype] = None, | |
| output_dtype: Optional[np.dtype] = None, | |
| batch_size: int = 4, | |
| filter_special_tokens: bool = True, | |
| processor_speed: str = "fast", | |
| ): | |
| self.model_name = model_name | |
| self.batch_size = batch_size | |
| self.filter_special_tokens = filter_special_tokens | |
| if processor_speed not in ("fast", "slow", "auto"): | |
| raise ValueError("processor_speed must be one of: fast, slow, auto") | |
| self.processor_speed = processor_speed | |
| if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"): | |
| self.filter_special_tokens = False | |
| logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS") | |
| if backend == "auto": | |
| backend = self._detect_backend(model_name) | |
| self.backend = backend | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| self.device = device | |
| if torch_dtype is None: | |
| if device == "cuda": | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| torch_dtype = torch.float32 | |
| self.torch_dtype = torch_dtype | |
| if output_dtype is None: | |
| if torch_dtype == torch.float16: | |
| output_dtype = np.float16 | |
| else: | |
| output_dtype = np.float32 | |
| self.output_dtype = output_dtype | |
| self._model = None | |
| self._processor = None | |
| self._image_token_id = None | |
| logger.info("🤖 VisualEmbedder initialized") | |
| logger.info(f" Model: {model_name}") | |
| logger.info(f" Backend: {backend}") | |
| logger.info( | |
| f" Device: {device}, torch_dtype: {torch_dtype}, output_dtype: {output_dtype}" | |
| ) | |
| def _detect_backend(self, model_name: str) -> str: | |
| """Auto-detect backend from model name.""" | |
| model_lower = model_name.lower() | |
| for key, backend in self.MODEL_BACKENDS.items(): | |
| if key in model_lower: | |
| logger.debug(f"Detected backend '{backend}' from model name") | |
| return backend | |
| # Default to colpali for unknown models | |
| logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend") | |
| return "colpali" | |
| def _load_model(self): | |
| """Lazy load the model when first needed.""" | |
| if self._model is not None: | |
| return | |
| if self.backend == "colpali": | |
| self._load_colpali_model() | |
| else: | |
| raise ValueError(f"Unknown backend: {self.backend}") | |
| def _load_colpali_model(self): | |
| """Load ColPali-family model.""" | |
| try: | |
| from colpali_engine.models import ( | |
| ColIdefics3, | |
| ColIdefics3Processor, | |
| ColPali, | |
| ColPaliProcessor, | |
| ColQwen2, | |
| ColQwen2Processor, | |
| ) | |
| except ImportError: | |
| raise ImportError( | |
| "colpali_engine not installed. Install with: " | |
| "pip install visual-rag-toolkit[embedding] or " | |
| "pip install colpali-engine" | |
| ) | |
| logger.info(f"🤖 Loading ColPali model: {self.model_name}") | |
| logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}") | |
| def _processor_kwargs(): | |
| if self.processor_speed == "auto": | |
| return {} | |
| return {"use_fast": self.processor_speed == "fast"} | |
| from transformers import AutoConfig | |
| cfg = AutoConfig.from_pretrained(self.model_name) | |
| model_type = str(getattr(cfg, "model_type", "") or "").lower() | |
| if model_type == "colpali" or "colpali" in (self.model_name or "").lower(): | |
| self._model = ColPali.from_pretrained( | |
| self.model_name, | |
| torch_dtype=self.torch_dtype, | |
| device_map=self.device, | |
| ).eval() | |
| try: | |
| self._processor = ColPaliProcessor.from_pretrained( | |
| self.model_name, **_processor_kwargs() | |
| ) | |
| except TypeError: | |
| self._processor = ColPaliProcessor.from_pretrained(self.model_name) | |
| except Exception: | |
| if self.processor_speed == "fast": | |
| self._processor = ColPaliProcessor.from_pretrained( | |
| self.model_name, use_fast=False | |
| ) | |
| else: | |
| raise | |
| self._image_token_id = self._processor.image_token_id | |
| logger.info("✅ Loaded ColPali backend") | |
| return | |
| if model_type.startswith("qwen2") or "colqwen" in (self.model_name or "").lower(): | |
| self._model = ColQwen2.from_pretrained( | |
| self.model_name, | |
| dtype=self.torch_dtype, | |
| device_map=self.device, | |
| ).eval() | |
| try: | |
| self._processor = ColQwen2Processor.from_pretrained( | |
| self.model_name, device_map=self.device, **_processor_kwargs() | |
| ) | |
| except TypeError: | |
| self._processor = ColQwen2Processor.from_pretrained( | |
| self.model_name, device_map=self.device | |
| ) | |
| except Exception: | |
| if self.processor_speed == "fast": | |
| self._processor = ColQwen2Processor.from_pretrained( | |
| self.model_name, device_map=self.device, use_fast=False | |
| ) | |
| else: | |
| raise | |
| self._image_token_id = self._processor.image_token_id | |
| logger.info("✅ Loaded ColQwen2 backend") | |
| return | |
| attn_implementation = "eager" | |
| if self.device != "cpu": | |
| try: | |
| import flash_attn # noqa | |
| attn_implementation = "flash_attention_2" | |
| logger.info(" Using FlashAttention2") | |
| except ImportError: | |
| pass | |
| self._model = ColIdefics3.from_pretrained( | |
| self.model_name, | |
| dtype=self.torch_dtype, | |
| device_map=self.device, | |
| attn_implementation=attn_implementation, | |
| ).eval() | |
| try: | |
| self._processor = ColIdefics3Processor.from_pretrained( | |
| self.model_name, **_processor_kwargs() | |
| ) | |
| except TypeError: | |
| self._processor = ColIdefics3Processor.from_pretrained(self.model_name) | |
| except Exception: | |
| if self.processor_speed == "fast": | |
| self._processor = ColIdefics3Processor.from_pretrained( | |
| self.model_name, use_fast=False | |
| ) | |
| else: | |
| raise | |
| self._image_token_id = self._processor.image_token_id | |
| logger.info("✅ Model loaded successfully") | |
| def model(self): | |
| self._load_model() | |
| return self._model | |
| def processor(self): | |
| self._load_model() | |
| return self._processor | |
| def image_token_id(self): | |
| self._load_model() | |
| return self._image_token_id | |
| def embed_query( | |
| self, | |
| query_text: str, | |
| filter_special_tokens: Optional[bool] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Generate embedding for a text query. | |
| By default, filters out special tokens (CLS, SEP, PAD) to keep only | |
| meaningful text tokens for better MaxSim matching. | |
| Args: | |
| query_text: Natural language query string | |
| filter_special_tokens: Override instance-level setting | |
| Returns: | |
| Query embedding tensor of shape [num_tokens, embedding_dim] | |
| """ | |
| should_filter = ( | |
| filter_special_tokens | |
| if filter_special_tokens is not None | |
| else self.filter_special_tokens | |
| ) | |
| with torch.no_grad(): | |
| processed = self.processor.process_queries([query_text]).to(self.model.device) | |
| embedding = self.model(**processed) | |
| # Remove batch dimension: [1, tokens, dim] -> [tokens, dim] | |
| if embedding.dim() == 3: | |
| embedding = embedding.squeeze(0) | |
| if should_filter: | |
| # Filter special tokens based on attention mask | |
| attention_mask = processed.get("attention_mask") | |
| if attention_mask is not None: | |
| # Keep only tokens with attention_mask = 1 | |
| valid_mask = attention_mask.squeeze(0).bool() | |
| embedding = embedding[valid_mask] | |
| # Additionally filter padding tokens if present | |
| input_ids = processed.get("input_ids") | |
| if input_ids is not None: | |
| input_ids = input_ids.squeeze(0)[valid_mask] | |
| # Filter common special token IDs | |
| # IDs >= 4 are usually real tokens for most tokenizers | |
| non_special_mask = input_ids >= 4 | |
| if non_special_mask.any(): | |
| embedding = embedding[non_special_mask] | |
| logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering") | |
| else: | |
| logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)") | |
| return embedding | |
| def embed_queries( | |
| self, | |
| query_texts: List[str], | |
| batch_size: Optional[int] = None, | |
| filter_special_tokens: Optional[bool] = None, | |
| show_progress: bool = True, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Generate embeddings for a list of text queries. | |
| Returns a list of tensors, each of shape [num_tokens, embedding_dim]. | |
| """ | |
| should_filter = ( | |
| filter_special_tokens | |
| if filter_special_tokens is not None | |
| else self.filter_special_tokens | |
| ) | |
| batch_size = batch_size or self.batch_size | |
| outputs: List[torch.Tensor] = [] | |
| iterator = range(0, len(query_texts), batch_size) | |
| if show_progress: | |
| iterator = tqdm(iterator, desc="📝 Embedding queries", unit="batch") | |
| for i in iterator: | |
| batch = query_texts[i : i + batch_size] | |
| with torch.no_grad(): | |
| processed = self.processor.process_queries(batch).to(self.model.device) | |
| batch_embeddings = self.model(**processed) | |
| if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3: | |
| attn = processed.get("attention_mask") if should_filter else None | |
| input_ids = processed.get("input_ids") if should_filter else None | |
| for j in range(batch_embeddings.shape[0]): | |
| emb = batch_embeddings[j] | |
| if should_filter and attn is not None: | |
| valid_mask = attn[j].bool() | |
| emb = emb[valid_mask] | |
| if input_ids is not None: | |
| ids = input_ids[j][valid_mask] | |
| non_special_mask = ids >= 4 | |
| if non_special_mask.any(): | |
| emb = emb[non_special_mask] | |
| outputs.append(emb) | |
| else: | |
| outputs.extend(batch_embeddings) | |
| del processed, batch_embeddings | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elif torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| return outputs | |
| def embed_images( | |
| self, | |
| images: List[Image.Image], | |
| batch_size: Optional[int] = None, | |
| return_token_info: bool = False, | |
| show_progress: bool = True, | |
| ) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]: | |
| """ | |
| Generate embeddings for a list of images. | |
| Args: | |
| images: List of PIL Images | |
| batch_size: Override instance batch size | |
| return_token_info: Also return token metadata (for saliency maps) | |
| show_progress: Show progress bar | |
| Returns: | |
| If return_token_info=False: | |
| List of embedding tensors [num_patches, dim] | |
| If return_token_info=True: | |
| Tuple of (embeddings, token_infos) | |
| Token info contains: | |
| - visual_token_indices: Indices of visual tokens in embedding | |
| - num_visual_tokens: Count of visual tokens | |
| - n_rows, n_cols: Tile grid dimensions | |
| - num_tiles: Total tiles (n_rows × n_cols + 1 global) | |
| """ | |
| batch_size = batch_size or self.batch_size | |
| if ( | |
| self.device == "mps" | |
| and "colpali" in (self.model_name or "").lower() | |
| and int(batch_size) > 1 | |
| ): | |
| batch_size = 1 | |
| embeddings = [] | |
| token_infos = [] if return_token_info else None | |
| iterator = range(0, len(images), batch_size) | |
| if show_progress: | |
| iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch") | |
| for i in iterator: | |
| batch = images[i : i + batch_size] | |
| with torch.no_grad(): | |
| processed = self.processor.process_images(batch).to(self.model.device) | |
| # Extract token info before model forward | |
| if return_token_info: | |
| input_ids = processed["input_ids"] | |
| batch_n_rows = processed.get("n_rows") | |
| batch_n_cols = processed.get("n_cols") | |
| for j in range(input_ids.shape[0]): | |
| # Find visual token indices | |
| image_token_mask = input_ids[j] == self.image_token_id | |
| visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist() | |
| n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None | |
| n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None | |
| token_infos.append( | |
| { | |
| "visual_token_indices": visual_indices, | |
| "num_visual_tokens": len(visual_indices), | |
| "n_rows": n_rows, | |
| "n_cols": n_cols, | |
| "num_tiles": (n_rows * n_cols + 1) if n_rows and n_cols else None, | |
| } | |
| ) | |
| # Generate embeddings | |
| batch_embeddings = self.model(**processed) | |
| # Extract per-image embeddings | |
| if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3: | |
| for j in range(batch_embeddings.shape[0]): | |
| embeddings.append(batch_embeddings[j].cpu()) | |
| else: | |
| embeddings.extend([e.cpu() for e in batch_embeddings]) | |
| # Memory cleanup | |
| del processed, batch_embeddings | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elif torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| if return_token_info: | |
| return embeddings, token_infos | |
| return embeddings | |
| def extract_visual_embedding( | |
| self, | |
| full_embedding: torch.Tensor, | |
| token_info: Dict[str, Any], | |
| ) -> np.ndarray: | |
| """ | |
| Extract only visual token embeddings from full embedding. | |
| Filters out special tokens, keeping only visual patches for MaxSim. | |
| Args: | |
| full_embedding: Full embedding [all_tokens, dim] | |
| token_info: Token info dict from embed_images | |
| Returns: | |
| Visual embedding array [num_visual_tokens, dim] | |
| """ | |
| visual_indices = token_info["visual_token_indices"] | |
| if isinstance(full_embedding, torch.Tensor): | |
| if full_embedding.dtype == torch.bfloat16: | |
| visual_emb = full_embedding[visual_indices].cpu().float().numpy() | |
| else: | |
| visual_emb = full_embedding[visual_indices].cpu().numpy() | |
| else: | |
| visual_emb = np.array(full_embedding)[visual_indices] | |
| return visual_emb.astype(self.output_dtype) | |
| def mean_pool_visual_embedding( | |
| self, | |
| visual_embedding: Union[torch.Tensor, np.ndarray], | |
| token_info: Optional[Dict[str, Any]] = None, | |
| *, | |
| target_vectors: int = 32, | |
| ) -> np.ndarray: | |
| from visual_rag.embedding.pooling import colpali_row_mean_pooling, tile_level_mean_pooling | |
| model_lower = (self.model_name or "").lower() | |
| is_colsmol = "colsmol" in model_lower | |
| if isinstance(visual_embedding, torch.Tensor): | |
| if visual_embedding.dtype == torch.bfloat16: | |
| visual_np = visual_embedding.cpu().float().numpy() | |
| else: | |
| visual_np = visual_embedding.cpu().numpy().astype(np.float32) | |
| else: | |
| visual_np = np.array(visual_embedding, dtype=np.float32) | |
| if is_colsmol: | |
| n_rows = (token_info or {}).get("n_rows") | |
| n_cols = (token_info or {}).get("n_cols") | |
| num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13 | |
| return tile_level_mean_pooling( | |
| visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype | |
| ) | |
| num_tokens = int(visual_np.shape[0]) | |
| grid = int(round(float(num_tokens) ** 0.5)) | |
| if grid * grid != num_tokens: | |
| raise ValueError( | |
| f"Cannot infer square grid from num_visual_tokens={num_tokens} for model={self.model_name}" | |
| ) | |
| if int(target_vectors) != int(grid): | |
| raise ValueError( | |
| f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}" | |
| ) | |
| return colpali_row_mean_pooling( | |
| visual_np, grid_size=int(target_vectors), output_dtype=self.output_dtype | |
| ) | |
| def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray: | |
| if mean_pool.size == 0: | |
| return np.zeros((128,), dtype=self.output_dtype) | |
| return mean_pool.mean(axis=0).astype(self.output_dtype) | |
| def experimental_pool_visual_embedding( | |
| self, | |
| visual_embedding: Union[torch.Tensor, np.ndarray], | |
| token_info: Optional[Dict[str, Any]] = None, | |
| *, | |
| target_vectors: int = 32, | |
| mean_pool: Optional[np.ndarray] = None, | |
| ) -> np.ndarray: | |
| from visual_rag.embedding.pooling import ( | |
| colpali_experimental_pooling_from_rows, | |
| colsmol_experimental_pooling, | |
| ) | |
| model_lower = (self.model_name or "").lower() | |
| is_colsmol = "colsmol" in model_lower | |
| if isinstance(visual_embedding, torch.Tensor): | |
| if visual_embedding.dtype == torch.bfloat16: | |
| visual_np = visual_embedding.cpu().float().numpy() | |
| else: | |
| visual_np = visual_embedding.cpu().numpy().astype(np.float32) | |
| else: | |
| visual_np = np.array(visual_embedding, dtype=np.float32) | |
| if is_colsmol: | |
| if ( | |
| mean_pool is not None | |
| and getattr(mean_pool, "shape", None) is not None | |
| and int(mean_pool.shape[0]) > 0 | |
| ): | |
| num_tiles = int(mean_pool.shape[0]) | |
| else: | |
| num_tiles = (token_info or {}).get("num_tiles") | |
| if num_tiles is None: | |
| num_visual_tokens = (token_info or {}).get("num_visual_tokens") | |
| if num_visual_tokens is None: | |
| num_visual_tokens = int(visual_np.shape[0]) | |
| patches_per_tile = 64 | |
| num_tiles = int(num_visual_tokens) // patches_per_tile | |
| if int(num_tiles) * patches_per_tile != int(num_visual_tokens): | |
| num_tiles = int(num_tiles) + 1 | |
| num_tiles = int(num_tiles) | |
| return colsmol_experimental_pooling( | |
| visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype | |
| ) | |
| rows = ( | |
| mean_pool | |
| if mean_pool is not None | |
| else self.mean_pool_visual_embedding( | |
| visual_np, token_info, target_vectors=target_vectors | |
| ) | |
| ) | |
| if int(rows.shape[0]) != int(target_vectors): | |
| raise ValueError( | |
| f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}" | |
| ) | |
| return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype) | |
| # Backward compatibility alias | |
| ColPaliEmbedder = VisualEmbedder | |