Spaces:
Sleeping
Sleeping
| """ | |
| Pooling strategies for multi-vector embeddings. | |
| Provides: | |
| - Tile-level mean pooling: Preserves spatial structure (num_tiles × dim) | |
| - Global mean pooling: Single vector (1 × dim) | |
| - MaxSim scoring for ColBERT-style late interaction | |
| """ | |
| import logging | |
| from typing import Optional, Union | |
| import numpy as np | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| def _infer_output_dtype( | |
| embedding: Union[torch.Tensor, np.ndarray], | |
| output_dtype: Optional[np.dtype] = None, | |
| ) -> np.dtype: | |
| """Infer output dtype: use provided, else match input (fp16→fp16, bf16→fp32, fp32→fp32).""" | |
| if output_dtype is not None: | |
| return output_dtype | |
| if isinstance(embedding, torch.Tensor): | |
| if embedding.dtype == torch.float16: | |
| return np.float16 | |
| return np.float32 | |
| if isinstance(embedding, np.ndarray) and embedding.dtype == np.float16: | |
| return np.float16 | |
| return np.float32 | |
| def tile_level_mean_pooling( | |
| embedding: Union[torch.Tensor, np.ndarray], | |
| num_tiles: int, | |
| patches_per_tile: int = 64, | |
| output_dtype: Optional[np.dtype] = None, | |
| ) -> np.ndarray: | |
| """ | |
| Compute tile-level mean pooling for multi-vector embeddings. | |
| Instead of collapsing to 1×dim (global pooling), this preserves spatial | |
| structure by computing mean per tile → num_tiles × dim. | |
| This is our NOVEL contribution for scalable visual retrieval: | |
| - Faster than full MaxSim (fewer vectors to compare) | |
| - More accurate than global pooling (preserves spatial info) | |
| - Ideal for two-stage retrieval (prefetch with pooled, rerank with full) | |
| Args: | |
| embedding: Visual token embeddings [num_visual_tokens, dim] | |
| num_tiles: Number of tiles (including global tile) | |
| patches_per_tile: Patches per tile (64 for ColSmol) | |
| output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32) | |
| Returns: | |
| Tile-level pooled embeddings [num_tiles, dim] | |
| Example: | |
| >>> # Image with 4×3 tiles + 1 global = 13 tiles | |
| >>> # Each tile has 64 patches → 832 visual tokens | |
| >>> pooled = tile_level_mean_pooling(embedding, num_tiles=13) | |
| >>> print(pooled.shape) # (13, 128) | |
| """ | |
| out_dtype = _infer_output_dtype(embedding, output_dtype) | |
| if isinstance(embedding, torch.Tensor): | |
| if embedding.dtype == torch.bfloat16: | |
| emb_np = embedding.cpu().float().numpy() | |
| else: | |
| emb_np = embedding.cpu().numpy().astype(np.float32) | |
| else: | |
| emb_np = np.array(embedding, dtype=np.float32) | |
| num_visual_tokens = emb_np.shape[0] | |
| expected_tokens = num_tiles * patches_per_tile | |
| if num_visual_tokens != expected_tokens: | |
| logger.debug(f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}") | |
| actual_tiles = num_visual_tokens // patches_per_tile | |
| if actual_tiles * patches_per_tile != num_visual_tokens: | |
| actual_tiles += 1 | |
| num_tiles = actual_tiles | |
| tile_embeddings = [] | |
| for tile_idx in range(num_tiles): | |
| start_idx = tile_idx * patches_per_tile | |
| end_idx = min(start_idx + patches_per_tile, num_visual_tokens) | |
| if start_idx >= num_visual_tokens: | |
| break | |
| tile_patches = emb_np[start_idx:end_idx] | |
| tile_mean = tile_patches.mean(axis=0) | |
| tile_embeddings.append(tile_mean) | |
| return np.array(tile_embeddings, dtype=out_dtype) | |
| def colpali_row_mean_pooling( | |
| embedding: Union[torch.Tensor, np.ndarray], | |
| grid_size: int = 32, | |
| output_dtype: Optional[np.dtype] = None, | |
| ) -> np.ndarray: | |
| out_dtype = _infer_output_dtype(embedding, output_dtype) | |
| if isinstance(embedding, torch.Tensor): | |
| if embedding.dtype == torch.bfloat16: | |
| emb_np = embedding.cpu().float().numpy() | |
| else: | |
| emb_np = embedding.cpu().numpy().astype(np.float32) | |
| else: | |
| emb_np = np.array(embedding, dtype=np.float32) | |
| num_tokens, dim = emb_np.shape | |
| expected = int(grid_size) * int(grid_size) | |
| if num_tokens != expected: | |
| raise ValueError( | |
| f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}" | |
| ) | |
| grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim)) | |
| pooled = grid.mean(axis=1) | |
| return pooled.astype(out_dtype) | |
| def colsmol_experimental_pooling( | |
| embedding: Union[torch.Tensor, np.ndarray], | |
| num_tiles: int, | |
| patches_per_tile: int = 64, | |
| output_dtype: Optional[np.dtype] = None, | |
| ) -> np.ndarray: | |
| out_dtype = _infer_output_dtype(embedding, output_dtype) | |
| if isinstance(embedding, torch.Tensor): | |
| if embedding.dtype == torch.bfloat16: | |
| emb_np = embedding.cpu().float().numpy() | |
| else: | |
| emb_np = embedding.cpu().numpy().astype(np.float32) | |
| else: | |
| emb_np = np.array(embedding, dtype=np.float32) | |
| num_visual_tokens, dim = emb_np.shape | |
| if num_tiles <= 0: | |
| raise ValueError("num_tiles must be > 0") | |
| if patches_per_tile <= 0: | |
| raise ValueError("patches_per_tile must be > 0") | |
| last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile) | |
| if last_tile_start >= num_visual_tokens: | |
| actual_tiles = int(num_visual_tokens) // int(patches_per_tile) | |
| if actual_tiles * int(patches_per_tile) != int(num_visual_tokens): | |
| actual_tiles += 1 | |
| if actual_tiles <= 0: | |
| raise ValueError( | |
| f"Not enough tokens for num_tiles={num_tiles}, patches_per_tile={patches_per_tile}: got {num_visual_tokens}" | |
| ) | |
| num_tiles = actual_tiles | |
| last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile) | |
| prefix = emb_np[:last_tile_start] | |
| last_tile = emb_np[ | |
| last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens) | |
| ] | |
| if prefix.size: | |
| prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim)) | |
| prefix_means = prefix_tiles.mean(axis=1) | |
| else: | |
| prefix_means = np.zeros((0, int(dim)), dtype=out_dtype) | |
| return np.concatenate([prefix_means.astype(out_dtype), last_tile.astype(out_dtype)], axis=0) | |
| def colpali_experimental_pooling_from_rows( | |
| row_vectors: Union[torch.Tensor, np.ndarray], | |
| output_dtype: Optional[np.dtype] = None, | |
| ) -> np.ndarray: | |
| """ | |
| Experimental "convolution-style" pooling with window size 3. | |
| For N input rows, produces N + 2 output vectors: | |
| - Position 0: row[0] alone (1 row) | |
| - Position 1: mean(rows[0:2]) (2 rows) | |
| - Position 2: mean(rows[0:3]) (3 rows) | |
| - Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1]) | |
| - Position N: mean(rows[N-2:N]) (last 2 rows) | |
| - Position N+1: row[N-1] alone (last row) | |
| For N=32 rows: produces 34 vectors. | |
| """ | |
| out_dtype = _infer_output_dtype(row_vectors, output_dtype) | |
| if isinstance(row_vectors, torch.Tensor): | |
| if row_vectors.dtype == torch.bfloat16: | |
| rows = row_vectors.cpu().float().numpy() | |
| else: | |
| rows = row_vectors.cpu().numpy().astype(np.float32) | |
| else: | |
| rows = np.array(row_vectors, dtype=np.float32) | |
| n, dim = rows.shape | |
| if n < 1: | |
| raise ValueError("row_vectors must be non-empty") | |
| if n == 1: | |
| return rows.astype(out_dtype) | |
| if n == 2: | |
| return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype) | |
| if n == 3: | |
| return np.stack( | |
| [ | |
| rows[0], | |
| rows[:2].mean(axis=0), | |
| rows[:3].mean(axis=0), | |
| rows[1:3].mean(axis=0), | |
| rows[2], | |
| ], | |
| axis=0, | |
| ).astype(out_dtype) | |
| out = np.zeros((n + 2, dim), dtype=np.float32) | |
| out[0] = rows[0] | |
| out[1] = rows[:2].mean(axis=0) | |
| out[2] = rows[:3].mean(axis=0) | |
| for i in range(3, n): | |
| out[i] = rows[i - 2 : i + 1].mean(axis=0) | |
| out[n] = rows[n - 2 : n].mean(axis=0) | |
| out[n + 1] = rows[n - 1] | |
| return out.astype(out_dtype) | |
| def global_mean_pooling( | |
| embedding: Union[torch.Tensor, np.ndarray], | |
| output_dtype: Optional[np.dtype] = None, | |
| ) -> np.ndarray: | |
| """ | |
| Compute global mean pooling → single vector. | |
| This is the simplest pooling but loses all spatial information. | |
| Use for fastest retrieval when accuracy can be sacrificed. | |
| Args: | |
| embedding: Multi-vector embeddings [num_tokens, dim] | |
| output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32) | |
| Returns: | |
| Pooled vector [dim] | |
| """ | |
| out_dtype = _infer_output_dtype(embedding, output_dtype) | |
| if isinstance(embedding, torch.Tensor): | |
| if embedding.dtype == torch.bfloat16: | |
| emb_np = embedding.cpu().float().numpy() | |
| else: | |
| emb_np = embedding.cpu().numpy() | |
| else: | |
| emb_np = np.array(embedding) | |
| return emb_np.mean(axis=0).astype(out_dtype) | |
| def compute_maxsim_score( | |
| query_embedding: np.ndarray, | |
| doc_embedding: np.ndarray, | |
| normalize: bool = True, | |
| ) -> float: | |
| """ | |
| Compute ColBERT-style MaxSim late interaction score. | |
| For each query token, finds max similarity with any document token, | |
| then sums across query tokens. | |
| This is the standard scoring for ColBERT/ColPali: | |
| score = Σ_q max_d (sim(q, d)) | |
| Args: | |
| query_embedding: Query embeddings [num_query_tokens, dim] | |
| doc_embedding: Document embeddings [num_doc_tokens, dim] | |
| normalize: L2 normalize embeddings before scoring (recommended) | |
| Returns: | |
| MaxSim score (higher is better) | |
| Example: | |
| >>> query = embedder.embed_query("budget allocation") | |
| >>> doc = embeddings[0] # From embed_images | |
| >>> score = compute_maxsim_score(query, doc) | |
| """ | |
| if normalize: | |
| # L2 normalize | |
| query_norm = query_embedding / ( | |
| np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8 | |
| ) | |
| doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8) | |
| else: | |
| query_norm = query_embedding | |
| doc_norm = doc_embedding | |
| # Compute similarity matrix: [num_query, num_doc] | |
| similarity_matrix = np.dot(query_norm, doc_norm.T) | |
| # MaxSim: For each query token, take max similarity with any doc token | |
| max_similarities = similarity_matrix.max(axis=1) | |
| # Sum across query tokens | |
| score = float(max_similarities.sum()) | |
| return score | |
| def compute_maxsim_batch( | |
| query_embedding: np.ndarray, | |
| doc_embeddings: list, | |
| normalize: bool = True, | |
| ) -> list: | |
| """ | |
| Compute MaxSim scores for multiple documents efficiently. | |
| Args: | |
| query_embedding: Query embeddings [num_query_tokens, dim] | |
| doc_embeddings: List of document embeddings | |
| normalize: L2 normalize embeddings | |
| Returns: | |
| List of MaxSim scores | |
| """ | |
| # Pre-normalize query once | |
| if normalize: | |
| query_norm = query_embedding / ( | |
| np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8 | |
| ) | |
| else: | |
| query_norm = query_embedding | |
| scores = [] | |
| for doc_emb in doc_embeddings: | |
| if normalize: | |
| doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8) | |
| else: | |
| doc_norm = doc_emb | |
| sim_matrix = np.dot(query_norm, doc_norm.T) | |
| max_sims = sim_matrix.max(axis=1) | |
| scores.append(float(max_sims.sum())) | |
| return scores | |