Yeroyan's picture
sync v0.1.3
9513cca verified
"""
Pooling strategies for multi-vector embeddings.
Provides:
- Tile-level mean pooling: Preserves spatial structure (num_tiles × dim)
- Global mean pooling: Single vector (1 × dim)
- MaxSim scoring for ColBERT-style late interaction
"""
import logging
from typing import Optional, Union
import numpy as np
import torch
logger = logging.getLogger(__name__)
def _infer_output_dtype(
embedding: Union[torch.Tensor, np.ndarray],
output_dtype: Optional[np.dtype] = None,
) -> np.dtype:
"""Infer output dtype: use provided, else match input (fp16→fp16, bf16→fp32, fp32→fp32)."""
if output_dtype is not None:
return output_dtype
if isinstance(embedding, torch.Tensor):
if embedding.dtype == torch.float16:
return np.float16
return np.float32
if isinstance(embedding, np.ndarray) and embedding.dtype == np.float16:
return np.float16
return np.float32
def tile_level_mean_pooling(
embedding: Union[torch.Tensor, np.ndarray],
num_tiles: int,
patches_per_tile: int = 64,
output_dtype: Optional[np.dtype] = None,
) -> np.ndarray:
"""
Compute tile-level mean pooling for multi-vector embeddings.
Instead of collapsing to 1×dim (global pooling), this preserves spatial
structure by computing mean per tile → num_tiles × dim.
This is our NOVEL contribution for scalable visual retrieval:
- Faster than full MaxSim (fewer vectors to compare)
- More accurate than global pooling (preserves spatial info)
- Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
Args:
embedding: Visual token embeddings [num_visual_tokens, dim]
num_tiles: Number of tiles (including global tile)
patches_per_tile: Patches per tile (64 for ColSmol)
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
Returns:
Tile-level pooled embeddings [num_tiles, dim]
Example:
>>> # Image with 4×3 tiles + 1 global = 13 tiles
>>> # Each tile has 64 patches → 832 visual tokens
>>> pooled = tile_level_mean_pooling(embedding, num_tiles=13)
>>> print(pooled.shape) # (13, 128)
"""
out_dtype = _infer_output_dtype(embedding, output_dtype)
if isinstance(embedding, torch.Tensor):
if embedding.dtype == torch.bfloat16:
emb_np = embedding.cpu().float().numpy()
else:
emb_np = embedding.cpu().numpy().astype(np.float32)
else:
emb_np = np.array(embedding, dtype=np.float32)
num_visual_tokens = emb_np.shape[0]
expected_tokens = num_tiles * patches_per_tile
if num_visual_tokens != expected_tokens:
logger.debug(f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}")
actual_tiles = num_visual_tokens // patches_per_tile
if actual_tiles * patches_per_tile != num_visual_tokens:
actual_tiles += 1
num_tiles = actual_tiles
tile_embeddings = []
for tile_idx in range(num_tiles):
start_idx = tile_idx * patches_per_tile
end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
if start_idx >= num_visual_tokens:
break
tile_patches = emb_np[start_idx:end_idx]
tile_mean = tile_patches.mean(axis=0)
tile_embeddings.append(tile_mean)
return np.array(tile_embeddings, dtype=out_dtype)
def colpali_row_mean_pooling(
embedding: Union[torch.Tensor, np.ndarray],
grid_size: int = 32,
output_dtype: Optional[np.dtype] = None,
) -> np.ndarray:
out_dtype = _infer_output_dtype(embedding, output_dtype)
if isinstance(embedding, torch.Tensor):
if embedding.dtype == torch.bfloat16:
emb_np = embedding.cpu().float().numpy()
else:
emb_np = embedding.cpu().numpy().astype(np.float32)
else:
emb_np = np.array(embedding, dtype=np.float32)
num_tokens, dim = emb_np.shape
expected = int(grid_size) * int(grid_size)
if num_tokens != expected:
raise ValueError(
f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}"
)
grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
pooled = grid.mean(axis=1)
return pooled.astype(out_dtype)
def colsmol_experimental_pooling(
embedding: Union[torch.Tensor, np.ndarray],
num_tiles: int,
patches_per_tile: int = 64,
output_dtype: Optional[np.dtype] = None,
) -> np.ndarray:
out_dtype = _infer_output_dtype(embedding, output_dtype)
if isinstance(embedding, torch.Tensor):
if embedding.dtype == torch.bfloat16:
emb_np = embedding.cpu().float().numpy()
else:
emb_np = embedding.cpu().numpy().astype(np.float32)
else:
emb_np = np.array(embedding, dtype=np.float32)
num_visual_tokens, dim = emb_np.shape
if num_tiles <= 0:
raise ValueError("num_tiles must be > 0")
if patches_per_tile <= 0:
raise ValueError("patches_per_tile must be > 0")
last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
if last_tile_start >= num_visual_tokens:
actual_tiles = int(num_visual_tokens) // int(patches_per_tile)
if actual_tiles * int(patches_per_tile) != int(num_visual_tokens):
actual_tiles += 1
if actual_tiles <= 0:
raise ValueError(
f"Not enough tokens for num_tiles={num_tiles}, patches_per_tile={patches_per_tile}: got {num_visual_tokens}"
)
num_tiles = actual_tiles
last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
prefix = emb_np[:last_tile_start]
last_tile = emb_np[
last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens)
]
if prefix.size:
prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
prefix_means = prefix_tiles.mean(axis=1)
else:
prefix_means = np.zeros((0, int(dim)), dtype=out_dtype)
return np.concatenate([prefix_means.astype(out_dtype), last_tile.astype(out_dtype)], axis=0)
def colpali_experimental_pooling_from_rows(
row_vectors: Union[torch.Tensor, np.ndarray],
output_dtype: Optional[np.dtype] = None,
) -> np.ndarray:
"""
Experimental "convolution-style" pooling with window size 3.
For N input rows, produces N + 2 output vectors:
- Position 0: row[0] alone (1 row)
- Position 1: mean(rows[0:2]) (2 rows)
- Position 2: mean(rows[0:3]) (3 rows)
- Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
- Position N: mean(rows[N-2:N]) (last 2 rows)
- Position N+1: row[N-1] alone (last row)
For N=32 rows: produces 34 vectors.
"""
out_dtype = _infer_output_dtype(row_vectors, output_dtype)
if isinstance(row_vectors, torch.Tensor):
if row_vectors.dtype == torch.bfloat16:
rows = row_vectors.cpu().float().numpy()
else:
rows = row_vectors.cpu().numpy().astype(np.float32)
else:
rows = np.array(row_vectors, dtype=np.float32)
n, dim = rows.shape
if n < 1:
raise ValueError("row_vectors must be non-empty")
if n == 1:
return rows.astype(out_dtype)
if n == 2:
return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
if n == 3:
return np.stack(
[
rows[0],
rows[:2].mean(axis=0),
rows[:3].mean(axis=0),
rows[1:3].mean(axis=0),
rows[2],
],
axis=0,
).astype(out_dtype)
out = np.zeros((n + 2, dim), dtype=np.float32)
out[0] = rows[0]
out[1] = rows[:2].mean(axis=0)
out[2] = rows[:3].mean(axis=0)
for i in range(3, n):
out[i] = rows[i - 2 : i + 1].mean(axis=0)
out[n] = rows[n - 2 : n].mean(axis=0)
out[n + 1] = rows[n - 1]
return out.astype(out_dtype)
def global_mean_pooling(
embedding: Union[torch.Tensor, np.ndarray],
output_dtype: Optional[np.dtype] = None,
) -> np.ndarray:
"""
Compute global mean pooling → single vector.
This is the simplest pooling but loses all spatial information.
Use for fastest retrieval when accuracy can be sacrificed.
Args:
embedding: Multi-vector embeddings [num_tokens, dim]
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
Returns:
Pooled vector [dim]
"""
out_dtype = _infer_output_dtype(embedding, output_dtype)
if isinstance(embedding, torch.Tensor):
if embedding.dtype == torch.bfloat16:
emb_np = embedding.cpu().float().numpy()
else:
emb_np = embedding.cpu().numpy()
else:
emb_np = np.array(embedding)
return emb_np.mean(axis=0).astype(out_dtype)
def compute_maxsim_score(
query_embedding: np.ndarray,
doc_embedding: np.ndarray,
normalize: bool = True,
) -> float:
"""
Compute ColBERT-style MaxSim late interaction score.
For each query token, finds max similarity with any document token,
then sums across query tokens.
This is the standard scoring for ColBERT/ColPali:
score = Σ_q max_d (sim(q, d))
Args:
query_embedding: Query embeddings [num_query_tokens, dim]
doc_embedding: Document embeddings [num_doc_tokens, dim]
normalize: L2 normalize embeddings before scoring (recommended)
Returns:
MaxSim score (higher is better)
Example:
>>> query = embedder.embed_query("budget allocation")
>>> doc = embeddings[0] # From embed_images
>>> score = compute_maxsim_score(query, doc)
"""
if normalize:
# L2 normalize
query_norm = query_embedding / (
np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
)
doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
else:
query_norm = query_embedding
doc_norm = doc_embedding
# Compute similarity matrix: [num_query, num_doc]
similarity_matrix = np.dot(query_norm, doc_norm.T)
# MaxSim: For each query token, take max similarity with any doc token
max_similarities = similarity_matrix.max(axis=1)
# Sum across query tokens
score = float(max_similarities.sum())
return score
def compute_maxsim_batch(
query_embedding: np.ndarray,
doc_embeddings: list,
normalize: bool = True,
) -> list:
"""
Compute MaxSim scores for multiple documents efficiently.
Args:
query_embedding: Query embeddings [num_query_tokens, dim]
doc_embeddings: List of document embeddings
normalize: L2 normalize embeddings
Returns:
List of MaxSim scores
"""
# Pre-normalize query once
if normalize:
query_norm = query_embedding / (
np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
)
else:
query_norm = query_embedding
scores = []
for doc_emb in doc_embeddings:
if normalize:
doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
else:
doc_norm = doc_emb
sim_matrix = np.dot(query_norm, doc_norm.T)
max_sims = sim_matrix.max(axis=1)
scores.append(float(max_sims.sum()))
return scores