""" Saliency Map Generation for Visual RAG This module provides saliency map generation for visual document search results. It implements the tile-aware ColBERT MaxSim strategy for accurate visualization of which image regions are relevant to a query. Key features: 1. Tile-aware architecture (understands 4×3 grid of 512×512 tiles) 2. Excludes global tile for cleaner saliency 3. Maps patches to resized image, then scales to original 4. Uses "hot" colormap by default for better visibility """ import logging from typing import Any, Optional, Tuple from io import BytesIO from base64 import b64decode import numpy as np import requests from PIL import Image logger = logging.getLogger(__name__) # Default saliency configuration DEFAULT_ALPHA = 0.4 DEFAULT_COLORMAP = 'hot' # Better visibility than 'jet' DEFAULT_THRESHOLD_PERCENTILE = 50 def convert_to_numpy(embedding, dtype: np.dtype = np.float32) -> np.ndarray: """ Convert embedding to numpy array with proper dtype. Handles: - Lists - PyTorch tensors (including bfloat16) - NumPy arrays """ try: import torch if isinstance(embedding, torch.Tensor): if embedding.dtype == torch.bfloat16: embedding = embedding.cpu().float() else: embedding = embedding.cpu() embedding = embedding.numpy() except ImportError: pass return np.array(embedding, dtype=dtype) def validate_embeddings( doc_embedding: np.ndarray, query_embedding: np.ndarray ) -> Tuple[bool, str]: """Validate embedding shapes and types.""" if doc_embedding.ndim != 2: return False, f"Document embedding must be 2D, got {doc_embedding.ndim}D" if query_embedding.ndim != 2: return False, f"Query embedding must be 2D, got {query_embedding.ndim}D" if doc_embedding.shape[1] != query_embedding.shape[1]: return False, f"Embedding dimensions don't match: doc={doc_embedding.shape[1]}, query={query_embedding.shape[1]}" if np.any(np.isnan(doc_embedding)) or np.any(np.isinf(doc_embedding)): return False, "Document embedding contains NaN or Inf values" if np.any(np.isnan(query_embedding)) or np.any(np.isinf(query_embedding)): return False, "Query embedding contains NaN or Inf values" return True, "" def compute_maxsim_scores( doc_embedding: np.ndarray, query_embedding: np.ndarray, normalize: bool = True ) -> np.ndarray: """ Compute MaxSim scores for ColBERT-style late interaction. MaxSim: For each document patch, find the maximum similarity across all query patches. """ if normalize: doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8) query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8) else: doc_norm = doc_embedding query_norm = query_embedding similarity_matrix = np.dot(doc_norm, query_norm.T) patch_scores = np.max(similarity_matrix, axis=1) return patch_scores def normalize_scores( score_grid: np.ndarray, threshold_percentile: int = None ) -> np.ndarray: """Normalize score grid to 0-1 range with optional thresholding.""" score_min = score_grid.min() score_max = score_grid.max() if score_max - score_min < 1e-8: logger.warning("All scores are identical, returning zeros") return np.zeros_like(score_grid, dtype=np.float32) score_grid_norm = (score_grid - score_min) / (score_max - score_min) if threshold_percentile is not None: score_threshold = np.percentile(score_grid, threshold_percentile) mask = score_grid < score_threshold score_grid_norm[mask] = 0.0 visible_count = np.sum(~mask) total_count = score_grid.size logger.debug(f"Threshold: {score_threshold:.3f} ({threshold_percentile}th percentile)") logger.debug(f"Visible patches: {visible_count} / {total_count}") return score_grid_norm def download_image(page_url: str) -> Optional[Image.Image]: """Download image from URL or decode from data URI.""" try: if page_url.startswith(("http://", "https://")): resp = requests.get(page_url, timeout=15) resp.raise_for_status() image = Image.open(BytesIO(resp.content)) elif page_url.startswith("data:image"): b64_data = page_url.split(",", 1)[1] image = Image.open(BytesIO(b64decode(b64_data))) else: image = Image.open(page_url) if image.mode != "RGB": image = image.convert("RGB") return image except Exception as e: logger.error(f"Failed to load image: {e}") return None def apply_colormap_and_blend( score_grid: np.ndarray, image: Image.Image, alpha: float = DEFAULT_ALPHA, colormap: str = DEFAULT_COLORMAP ) -> Image.Image: """Apply colormap to scores and blend with original image.""" from matplotlib import cm img_width, img_height = image.size # Resize heatmap to image size heatmap_pil = Image.fromarray((score_grid * 255).astype(np.uint8), mode='L') heatmap_resized = heatmap_pil.resize((img_width, img_height), Image.BILINEAR) heatmap_array = np.array(heatmap_resized) / 255.0 # Apply colormap cmap = cm.get_cmap(colormap) heatmap_colored = cmap(heatmap_array)[:, :, :3] heatmap_colored = (heatmap_colored * 255).astype(np.uint8) heatmap_img = Image.fromarray(heatmap_colored, mode='RGB') # Blend with original image overlay = Image.blend(image, heatmap_img, alpha=alpha) return overlay def generate_tile_aware_saliency( qdrant_client: Any, collection_name: str, point_id: str, query_embedding: np.ndarray, alpha: float = DEFAULT_ALPHA, colormap: str = DEFAULT_COLORMAP, threshold_percentile: int = DEFAULT_THRESHOLD_PERCENTILE ) -> Optional[Image.Image]: """ Generate tile-aware saliency map for a document-query pair. This is the main function to call for saliency generation. Args: qdrant_client: Qdrant client instance collection_name: Name of the collection point_id: ID of the document point query_embedding: Query multi-vector embedding [num_query_patches, dim] alpha: Overlay transparency (0.0-1.0) colormap: Matplotlib colormap name (default: 'hot') threshold_percentile: Hide patches below this percentile (default: 50) Returns: PIL Image with saliency overlay, or None if generation fails """ try: # Step 1: Fetch full multi-vector embedding AND payload logger.debug(f"Fetching point {point_id} with tile metadata from {collection_name}") points = qdrant_client.retrieve( collection_name=collection_name, ids=[point_id], with_vectors=["initial"], with_payload=True ) if not points or len(points) == 0: logger.error(f"Point {point_id} not found in collection") return None point = points[0] doc_vector = point.vector.get("initial") payload = point.payload if doc_vector is None: logger.error("No 'initial' vector found for point") return None # Step 2: Get tile structure from payload num_tiles = payload.get('num_tiles') tile_rows = payload.get('tile_rows') tile_cols = payload.get('tile_cols') patches_per_tile = payload.get('patches_per_tile', 64) resized_width = payload.get('resized_width') resized_height = payload.get('resized_height') resized_url = payload.get('resized_url') or payload.get('page') original_width = payload.get('original_width') original_height = payload.get('original_height') if not all([num_tiles, tile_rows, tile_cols, resized_width, resized_height]): logger.warning("Missing tile metadata - cannot generate saliency") return None logger.info(f"✅ Tile structure: {tile_rows}×{tile_cols} tiles, {patches_per_tile} patches/tile") logger.info(f"✅ Resized image: {resized_width}×{resized_height}") logger.info(f"✅ Original image: {original_width}×{original_height}") # Step 3: Convert embeddings doc_embedding = convert_to_numpy(doc_vector) query_emb = convert_to_numpy(query_embedding) is_valid, error_msg = validate_embeddings(doc_embedding, query_emb) if not is_valid: logger.error(f"Embedding validation failed: {error_msg}") return None logger.info(f"Document embedding: {doc_embedding.shape}") logger.info(f"Query embedding: {query_emb.shape}") # Step 4: Separate tile embeddings from global tile total_patches = num_tiles * patches_per_tile tile_patches = total_patches - patches_per_tile # Exclude global if len(doc_embedding) < total_patches: logger.warning(f"Embedding size mismatch: got {len(doc_embedding)}, expected {total_patches}") tile_embeddings = doc_embedding[:tile_patches] if len(doc_embedding) > tile_patches else doc_embedding else: tile_embeddings = doc_embedding[:tile_patches] logger.info(f"Using {len(tile_embeddings)} tile patches (excluding global)") # Step 5: Compute MaxSim scores patch_scores = compute_maxsim_scores(tile_embeddings, query_emb, normalize=True) logger.info(f"Computed scores for {len(patch_scores)} patches") # Step 6: Reshape patches into tile structure patches_per_tile_side = int(np.sqrt(patches_per_tile)) # 8 for 64 patches try: num_actual_tiles = tile_rows * tile_cols if len(patch_scores) != num_actual_tiles * patches_per_tile: logger.error(f"Patch count mismatch: {len(patch_scores)} patches") return None tile_scores = patch_scores.reshape(num_actual_tiles, patches_per_tile) # Reshape each tile's patches to 8×8 grid (F-order) tile_grids = [] for tile_idx in range(num_actual_tiles): tile_patch_scores = tile_scores[tile_idx] tile_grid = tile_patch_scores.reshape( patches_per_tile_side, patches_per_tile_side, order='F' ) tile_grids.append(tile_grid) # Arrange tiles into full image grid full_grid_rows = [] for row_idx in range(tile_rows): row_tiles = [] for col_idx in range(tile_cols): tile_idx = row_idx * tile_cols + col_idx row_tiles.append(tile_grids[tile_idx]) row_grid = np.concatenate(row_tiles, axis=1) full_grid_rows.append(row_grid) score_grid = np.concatenate(full_grid_rows, axis=0) logger.info(f"✅ Reconstructed grid: {score_grid.shape} (from {tile_rows}×{tile_cols} tiles)") except ValueError as e: logger.error(f"❌ Failed to reshape patches: {e}") return None # Step 7: Normalize scores score_grid_norm = normalize_scores(score_grid, threshold_percentile=threshold_percentile) # Step 8: Download RESIZED image logger.info(f"Downloading resized image from: {resized_url}") resized_image = download_image(resized_url) if resized_image is None: logger.error("Failed to download resized image") return None # Step 9: Apply heatmap to resized image overlay_resized = apply_colormap_and_blend( score_grid_norm, resized_image, alpha, colormap ) # Step 10: Resize back to original dimensions if original_width and original_height: overlay_final = overlay_resized.resize( (original_width, original_height), Image.BILINEAR ) logger.info(f"✅ Resized saliency map to original: {original_width}×{original_height}") else: overlay_final = overlay_resized logger.info(f"✅ Saliency map generated successfully") return overlay_final except Exception as e: logger.error(f"Saliency generation failed: {e}") import traceback logger.debug(traceback.format_exc()) return None def can_generate_saliency(metadata: dict) -> bool: """ Check if saliency can be generated for a document based on its metadata. Args: metadata: Document metadata dictionary Returns: True if all required tile metadata is present """ required_fields = ['num_tiles', 'tile_rows', 'tile_cols', 'resized_width', 'resized_height'] return all(metadata.get(field) is not None for field in required_fields) def get_saliency_metadata_summary(metadata: dict) -> str: """ Get a summary of saliency-related metadata for display. Args: metadata: Document metadata dictionary Returns: Human-readable summary string """ num_tiles = metadata.get('num_tiles', 'N/A') tile_rows = metadata.get('tile_rows', 'N/A') tile_cols = metadata.get('tile_cols', 'N/A') patches_per_tile = metadata.get('patches_per_tile', 64) if all(v != 'N/A' for v in [num_tiles, tile_rows, tile_cols]): return f"{tile_rows}×{tile_cols} tiles ({num_tiles} total), {patches_per_tile} patches/tile" else: return "Tile metadata not available"