Spaces:
Sleeping
Sleeping
| """ | |
| Saliency Map Generation for Visual RAG | |
| This module provides saliency map generation for visual document search results. | |
| It implements the tile-aware ColBERT MaxSim strategy for accurate visualization | |
| of which image regions are relevant to a query. | |
| Key features: | |
| 1. Tile-aware architecture (understands 4Γ3 grid of 512Γ512 tiles) | |
| 2. Excludes global tile for cleaner saliency | |
| 3. Maps patches to resized image, then scales to original | |
| 4. Uses "hot" colormap by default for better visibility | |
| """ | |
| import logging | |
| from typing import Any, Optional, Tuple | |
| from io import BytesIO | |
| from base64 import b64decode | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| # Default saliency configuration | |
| DEFAULT_ALPHA = 0.4 | |
| DEFAULT_COLORMAP = 'hot' # Better visibility than 'jet' | |
| DEFAULT_THRESHOLD_PERCENTILE = 50 | |
| def convert_to_numpy(embedding, dtype: np.dtype = np.float32) -> np.ndarray: | |
| """ | |
| Convert embedding to numpy array with proper dtype. | |
| Handles: | |
| - Lists | |
| - PyTorch tensors (including bfloat16) | |
| - NumPy arrays | |
| """ | |
| try: | |
| import torch | |
| if isinstance(embedding, torch.Tensor): | |
| if embedding.dtype == torch.bfloat16: | |
| embedding = embedding.cpu().float() | |
| else: | |
| embedding = embedding.cpu() | |
| embedding = embedding.numpy() | |
| except ImportError: | |
| pass | |
| return np.array(embedding, dtype=dtype) | |
| def validate_embeddings( | |
| doc_embedding: np.ndarray, | |
| query_embedding: np.ndarray | |
| ) -> Tuple[bool, str]: | |
| """Validate embedding shapes and types.""" | |
| if doc_embedding.ndim != 2: | |
| return False, f"Document embedding must be 2D, got {doc_embedding.ndim}D" | |
| if query_embedding.ndim != 2: | |
| return False, f"Query embedding must be 2D, got {query_embedding.ndim}D" | |
| if doc_embedding.shape[1] != query_embedding.shape[1]: | |
| return False, f"Embedding dimensions don't match: doc={doc_embedding.shape[1]}, query={query_embedding.shape[1]}" | |
| if np.any(np.isnan(doc_embedding)) or np.any(np.isinf(doc_embedding)): | |
| return False, "Document embedding contains NaN or Inf values" | |
| if np.any(np.isnan(query_embedding)) or np.any(np.isinf(query_embedding)): | |
| return False, "Query embedding contains NaN or Inf values" | |
| return True, "" | |
| def compute_maxsim_scores( | |
| doc_embedding: np.ndarray, | |
| query_embedding: np.ndarray, | |
| normalize: bool = True | |
| ) -> np.ndarray: | |
| """ | |
| Compute MaxSim scores for ColBERT-style late interaction. | |
| MaxSim: For each document patch, find the maximum similarity | |
| across all query patches. | |
| """ | |
| if normalize: | |
| doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8) | |
| query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8) | |
| else: | |
| doc_norm = doc_embedding | |
| query_norm = query_embedding | |
| similarity_matrix = np.dot(doc_norm, query_norm.T) | |
| patch_scores = np.max(similarity_matrix, axis=1) | |
| return patch_scores | |
| def normalize_scores( | |
| score_grid: np.ndarray, | |
| threshold_percentile: int = None | |
| ) -> np.ndarray: | |
| """Normalize score grid to 0-1 range with optional thresholding.""" | |
| score_min = score_grid.min() | |
| score_max = score_grid.max() | |
| if score_max - score_min < 1e-8: | |
| logger.warning("All scores are identical, returning zeros") | |
| return np.zeros_like(score_grid, dtype=np.float32) | |
| score_grid_norm = (score_grid - score_min) / (score_max - score_min) | |
| if threshold_percentile is not None: | |
| score_threshold = np.percentile(score_grid, threshold_percentile) | |
| mask = score_grid < score_threshold | |
| score_grid_norm[mask] = 0.0 | |
| visible_count = np.sum(~mask) | |
| total_count = score_grid.size | |
| logger.debug(f"Threshold: {score_threshold:.3f} ({threshold_percentile}th percentile)") | |
| logger.debug(f"Visible patches: {visible_count} / {total_count}") | |
| return score_grid_norm | |
| def download_image(page_url: str) -> Optional[Image.Image]: | |
| """Download image from URL or decode from data URI.""" | |
| try: | |
| if page_url.startswith(("http://", "https://")): | |
| resp = requests.get(page_url, timeout=15) | |
| resp.raise_for_status() | |
| image = Image.open(BytesIO(resp.content)) | |
| elif page_url.startswith("data:image"): | |
| b64_data = page_url.split(",", 1)[1] | |
| image = Image.open(BytesIO(b64decode(b64_data))) | |
| else: | |
| image = Image.open(page_url) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return image | |
| except Exception as e: | |
| logger.error(f"Failed to load image: {e}") | |
| return None | |
| def apply_colormap_and_blend( | |
| score_grid: np.ndarray, | |
| image: Image.Image, | |
| alpha: float = DEFAULT_ALPHA, | |
| colormap: str = DEFAULT_COLORMAP | |
| ) -> Image.Image: | |
| """Apply colormap to scores and blend with original image.""" | |
| from matplotlib import cm | |
| img_width, img_height = image.size | |
| # Resize heatmap to image size | |
| heatmap_pil = Image.fromarray((score_grid * 255).astype(np.uint8), mode='L') | |
| heatmap_resized = heatmap_pil.resize((img_width, img_height), Image.BILINEAR) | |
| heatmap_array = np.array(heatmap_resized) / 255.0 | |
| # Apply colormap | |
| cmap = cm.get_cmap(colormap) | |
| heatmap_colored = cmap(heatmap_array)[:, :, :3] | |
| heatmap_colored = (heatmap_colored * 255).astype(np.uint8) | |
| heatmap_img = Image.fromarray(heatmap_colored, mode='RGB') | |
| # Blend with original image | |
| overlay = Image.blend(image, heatmap_img, alpha=alpha) | |
| return overlay | |
| def generate_tile_aware_saliency( | |
| qdrant_client: Any, | |
| collection_name: str, | |
| point_id: str, | |
| query_embedding: np.ndarray, | |
| alpha: float = DEFAULT_ALPHA, | |
| colormap: str = DEFAULT_COLORMAP, | |
| threshold_percentile: int = DEFAULT_THRESHOLD_PERCENTILE | |
| ) -> Optional[Image.Image]: | |
| """ | |
| Generate tile-aware saliency map for a document-query pair. | |
| This is the main function to call for saliency generation. | |
| Args: | |
| qdrant_client: Qdrant client instance | |
| collection_name: Name of the collection | |
| point_id: ID of the document point | |
| query_embedding: Query multi-vector embedding [num_query_patches, dim] | |
| alpha: Overlay transparency (0.0-1.0) | |
| colormap: Matplotlib colormap name (default: 'hot') | |
| threshold_percentile: Hide patches below this percentile (default: 50) | |
| Returns: | |
| PIL Image with saliency overlay, or None if generation fails | |
| """ | |
| try: | |
| # Step 1: Fetch full multi-vector embedding AND payload | |
| logger.debug(f"Fetching point {point_id} with tile metadata from {collection_name}") | |
| points = qdrant_client.retrieve( | |
| collection_name=collection_name, | |
| ids=[point_id], | |
| with_vectors=["initial"], | |
| with_payload=True | |
| ) | |
| if not points or len(points) == 0: | |
| logger.error(f"Point {point_id} not found in collection") | |
| return None | |
| point = points[0] | |
| doc_vector = point.vector.get("initial") | |
| payload = point.payload | |
| if doc_vector is None: | |
| logger.error("No 'initial' vector found for point") | |
| return None | |
| # Step 2: Get tile structure from payload | |
| num_tiles = payload.get('num_tiles') | |
| tile_rows = payload.get('tile_rows') | |
| tile_cols = payload.get('tile_cols') | |
| patches_per_tile = payload.get('patches_per_tile', 64) | |
| resized_width = payload.get('resized_width') | |
| resized_height = payload.get('resized_height') | |
| resized_url = payload.get('resized_url') or payload.get('page') | |
| original_width = payload.get('original_width') | |
| original_height = payload.get('original_height') | |
| if not all([num_tiles, tile_rows, tile_cols, resized_width, resized_height]): | |
| logger.warning("Missing tile metadata - cannot generate saliency") | |
| return None | |
| logger.info(f"β Tile structure: {tile_rows}Γ{tile_cols} tiles, {patches_per_tile} patches/tile") | |
| logger.info(f"β Resized image: {resized_width}Γ{resized_height}") | |
| logger.info(f"β Original image: {original_width}Γ{original_height}") | |
| # Step 3: Convert embeddings | |
| doc_embedding = convert_to_numpy(doc_vector) | |
| query_emb = convert_to_numpy(query_embedding) | |
| is_valid, error_msg = validate_embeddings(doc_embedding, query_emb) | |
| if not is_valid: | |
| logger.error(f"Embedding validation failed: {error_msg}") | |
| return None | |
| logger.info(f"Document embedding: {doc_embedding.shape}") | |
| logger.info(f"Query embedding: {query_emb.shape}") | |
| # Step 4: Separate tile embeddings from global tile | |
| total_patches = num_tiles * patches_per_tile | |
| tile_patches = total_patches - patches_per_tile # Exclude global | |
| if len(doc_embedding) < total_patches: | |
| logger.warning(f"Embedding size mismatch: got {len(doc_embedding)}, expected {total_patches}") | |
| tile_embeddings = doc_embedding[:tile_patches] if len(doc_embedding) > tile_patches else doc_embedding | |
| else: | |
| tile_embeddings = doc_embedding[:tile_patches] | |
| logger.info(f"Using {len(tile_embeddings)} tile patches (excluding global)") | |
| # Step 5: Compute MaxSim scores | |
| patch_scores = compute_maxsim_scores(tile_embeddings, query_emb, normalize=True) | |
| logger.info(f"Computed scores for {len(patch_scores)} patches") | |
| # Step 6: Reshape patches into tile structure | |
| patches_per_tile_side = int(np.sqrt(patches_per_tile)) # 8 for 64 patches | |
| try: | |
| num_actual_tiles = tile_rows * tile_cols | |
| if len(patch_scores) != num_actual_tiles * patches_per_tile: | |
| logger.error(f"Patch count mismatch: {len(patch_scores)} patches") | |
| return None | |
| tile_scores = patch_scores.reshape(num_actual_tiles, patches_per_tile) | |
| # Reshape each tile's patches to 8Γ8 grid (F-order) | |
| tile_grids = [] | |
| for tile_idx in range(num_actual_tiles): | |
| tile_patch_scores = tile_scores[tile_idx] | |
| tile_grid = tile_patch_scores.reshape( | |
| patches_per_tile_side, patches_per_tile_side, order='F' | |
| ) | |
| tile_grids.append(tile_grid) | |
| # Arrange tiles into full image grid | |
| full_grid_rows = [] | |
| for row_idx in range(tile_rows): | |
| row_tiles = [] | |
| for col_idx in range(tile_cols): | |
| tile_idx = row_idx * tile_cols + col_idx | |
| row_tiles.append(tile_grids[tile_idx]) | |
| row_grid = np.concatenate(row_tiles, axis=1) | |
| full_grid_rows.append(row_grid) | |
| score_grid = np.concatenate(full_grid_rows, axis=0) | |
| logger.info(f"β Reconstructed grid: {score_grid.shape} (from {tile_rows}Γ{tile_cols} tiles)") | |
| except ValueError as e: | |
| logger.error(f"β Failed to reshape patches: {e}") | |
| return None | |
| # Step 7: Normalize scores | |
| score_grid_norm = normalize_scores(score_grid, threshold_percentile=threshold_percentile) | |
| # Step 8: Download RESIZED image | |
| logger.info(f"Downloading resized image from: {resized_url}") | |
| resized_image = download_image(resized_url) | |
| if resized_image is None: | |
| logger.error("Failed to download resized image") | |
| return None | |
| # Step 9: Apply heatmap to resized image | |
| overlay_resized = apply_colormap_and_blend( | |
| score_grid_norm, resized_image, alpha, colormap | |
| ) | |
| # Step 10: Resize back to original dimensions | |
| if original_width and original_height: | |
| overlay_final = overlay_resized.resize( | |
| (original_width, original_height), Image.BILINEAR | |
| ) | |
| logger.info(f"β Resized saliency map to original: {original_width}Γ{original_height}") | |
| else: | |
| overlay_final = overlay_resized | |
| logger.info(f"β Saliency map generated successfully") | |
| return overlay_final | |
| except Exception as e: | |
| logger.error(f"Saliency generation failed: {e}") | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| return None | |
| def can_generate_saliency(metadata: dict) -> bool: | |
| """ | |
| Check if saliency can be generated for a document based on its metadata. | |
| Args: | |
| metadata: Document metadata dictionary | |
| Returns: | |
| True if all required tile metadata is present | |
| """ | |
| required_fields = ['num_tiles', 'tile_rows', 'tile_cols', 'resized_width', 'resized_height'] | |
| return all(metadata.get(field) is not None for field in required_fields) | |
| def get_saliency_metadata_summary(metadata: dict) -> str: | |
| """ | |
| Get a summary of saliency-related metadata for display. | |
| Args: | |
| metadata: Document metadata dictionary | |
| Returns: | |
| Human-readable summary string | |
| """ | |
| num_tiles = metadata.get('num_tiles', 'N/A') | |
| tile_rows = metadata.get('tile_rows', 'N/A') | |
| tile_cols = metadata.get('tile_cols', 'N/A') | |
| patches_per_tile = metadata.get('patches_per_tile', 64) | |
| if all(v != 'N/A' for v in [num_tiles, tile_rows, tile_cols]): | |
| return f"{tile_rows}Γ{tile_cols} tiles ({num_tiles} total), {patches_per_tile} patches/tile" | |
| else: | |
| return "Tile metadata not available" | |