""" Saliency Map Generation for Visual Document Retrieval. Generates attention/saliency maps to visualize which parts of documents are most relevant to a query. """ import logging from typing import Any, Dict, List, Optional, Tuple import numpy as np from PIL import Image logger = logging.getLogger(__name__) def generate_saliency_map( query_embedding: np.ndarray, doc_embedding: np.ndarray, image: Image.Image, token_info: Optional[Dict[str, Any]] = None, colormap: str = "Reds", alpha: float = 0.5, threshold_percentile: float = 50.0, ) -> Tuple[Image.Image, np.ndarray]: """ Generate saliency map showing which parts of the image match the query. Computes patch-level relevance scores and overlays them on the image. Args: query_embedding: Query embeddings [num_query_tokens, dim] doc_embedding: Document visual embeddings [num_visual_tokens, dim] image: Original PIL Image token_info: Optional token info with n_rows, n_cols for tile grid colormap: Matplotlib colormap name (Reds, viridis, jet, etc.) alpha: Overlay transparency (0-1) threshold_percentile: Only highlight patches above this percentile Returns: Tuple of (annotated_image, patch_scores) Example: >>> query = embedder.embed_query("budget allocation") >>> doc = visual_embedding # From embed_images >>> annotated, scores = generate_saliency_map( ... query_embedding=query.numpy(), ... doc_embedding=doc, ... image=page_image, ... token_info=token_info, ... ) >>> annotated.save("saliency.png") """ # Ensure numpy arrays if hasattr(query_embedding, "numpy"): query_np = query_embedding.numpy() elif hasattr(query_embedding, "cpu"): query_np = query_embedding.cpu().float().numpy() # .float() for BFloat16 else: query_np = np.array(query_embedding, dtype=np.float32) if hasattr(doc_embedding, "numpy"): doc_np = doc_embedding.numpy() elif hasattr(doc_embedding, "cpu"): doc_np = doc_embedding.cpu().float().numpy() # .float() for BFloat16 else: doc_np = np.array(doc_embedding, dtype=np.float32) # Normalize embeddings query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8) doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8) # Compute similarity matrix: [num_query, num_doc] similarity_matrix = np.dot(query_norm, doc_norm.T) # Get max similarity per document patch (best match from any query token) patch_scores = similarity_matrix.max(axis=0) # Normalize to [0, 1] score_min, score_max = patch_scores.min(), patch_scores.max() if score_max - score_min > 1e-8: patch_scores_norm = (patch_scores - score_min) / (score_max - score_min) else: patch_scores_norm = np.zeros_like(patch_scores) # Determine grid dimensions if token_info and token_info.get("n_rows") and token_info.get("n_cols"): n_rows = token_info["n_rows"] n_cols = token_info["n_cols"] num_tiles = n_rows * n_cols + 1 # +1 for global tile patches_per_tile = 64 # ColSmol standard # Reshape to tile grid (excluding global tile) try: # Skip global tile patches at the end tile_patches = num_tiles * patches_per_tile if len(patch_scores_norm) >= tile_patches: grid_patches = patch_scores_norm[: n_rows * n_cols * patches_per_tile] else: grid_patches = patch_scores_norm # Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile] # Then mean per tile num_grid_tiles = n_rows * n_cols grid_patches = grid_patches[: num_grid_tiles * patches_per_tile] tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1) tile_scores = tile_scores.reshape(n_rows, n_cols) except Exception as e: logger.warning(f"Could not reshape to tile grid: {e}") tile_scores = None else: tile_scores = None n_rows = n_cols = None # Create overlay annotated = create_saliency_overlay( image=image, scores=tile_scores if tile_scores is not None else patch_scores_norm, colormap=colormap, alpha=alpha, threshold_percentile=threshold_percentile, grid_rows=n_rows, grid_cols=n_cols, ) return annotated, patch_scores def create_saliency_overlay( image: Image.Image, scores: np.ndarray, colormap: str = "Reds", alpha: float = 0.5, threshold_percentile: float = 50.0, grid_rows: Optional[int] = None, grid_cols: Optional[int] = None, ) -> Image.Image: """ Create colored overlay on image based on scores. Args: image: Base PIL Image scores: Score array - 1D [num_patches] or 2D [rows, cols] colormap: Matplotlib colormap name alpha: Overlay transparency threshold_percentile: Only color patches above this percentile grid_rows, grid_cols: Grid dimensions (auto-detected if not provided) Returns: Annotated PIL Image """ try: import matplotlib.pyplot as plt except ImportError: logger.warning("matplotlib not installed, returning original image") return image img_array = np.array(image) h, w = img_array.shape[:2] # Handle 2D scores (tile grid) if scores.ndim == 2: rows, cols = scores.shape elif grid_rows and grid_cols: rows, cols = grid_rows, grid_cols # Reshape if possible if len(scores) == rows * cols: scores = scores.reshape(rows, cols) else: # Fallback: estimate grid from score count num_patches = len(scores) aspect = w / h cols = int(np.sqrt(num_patches * aspect)) rows = max(1, num_patches // cols) scores = scores[: rows * cols].reshape(rows, cols) else: # Auto-estimate grid num_patches = len(scores) if scores.ndim == 1 else scores.size aspect = w / h cols = max(1, int(np.sqrt(num_patches * aspect))) rows = max(1, num_patches // cols) if rows * cols > len(scores) if scores.ndim == 1 else scores.size: cols = max(1, cols - 1) if scores.ndim == 1: scores = scores[: rows * cols].reshape(rows, cols) # Get colormap cmap = plt.cm.get_cmap(colormap) # Calculate threshold threshold = np.percentile(scores, threshold_percentile) # Calculate cell dimensions cell_h = h // rows cell_w = w // cols # Create RGBA overlay overlay = np.zeros((h, w, 4), dtype=np.uint8) for i in range(rows): for j in range(cols): score = scores[i, j] if score >= threshold: y1 = i * cell_h y2 = min((i + 1) * cell_h, h) x1 = j * cell_w x2 = min((j + 1) * cell_w, w) # Normalize score for coloring (above threshold) norm_score = (score - threshold) / (1.0 - threshold + 1e-8) norm_score = min(1.0, max(0.0, norm_score)) # Get color color = cmap(norm_score)[:3] color_uint8 = (np.array(color) * 255).astype(np.uint8) overlay[y1:y2, x1:x2, :3] = color_uint8 overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score) # Blend with original overlay_img = Image.fromarray(overlay, "RGBA") result = Image.alpha_composite(image.convert("RGBA"), overlay_img) return result.convert("RGB") def visualize_search_results( query: str, results: List[Dict[str, Any]], query_embedding: Optional[np.ndarray] = None, embeddings: Optional[List[np.ndarray]] = None, output_path: Optional[str] = None, max_results: int = 5, show_saliency: bool = False, ) -> Optional[Image.Image]: """ Visualize search results as a grid of images with scores. Args: query: Original query text results: List of search results with 'payload' containing 'page' (image URL/base64) query_embedding: Query embedding for saliency (optional) embeddings: Document embeddings for saliency (optional) output_path: Path to save visualization (optional) max_results: Maximum results to show show_saliency: Generate saliency overlays (requires query_embedding & embeddings) Returns: Combined visualization image if successful """ try: import matplotlib.pyplot as plt except ImportError: logger.error("matplotlib required for visualization") return None results = results[:max_results] n = len(results) if n == 0: logger.warning("No results to visualize") return None fig, axes = plt.subplots(1, n, figsize=(4 * n, 4)) if n == 1: axes = [axes] for idx, (result, ax) in enumerate(zip(results, axes)): payload = result.get("payload", {}) score = result.get("score_final", result.get("score_stage1", 0)) # Try to load image from payload page_data = payload.get("page", "") image = None if page_data.startswith("data:image"): # Base64 encoded try: import base64 from io import BytesIO b64_data = page_data.split(",")[1] image = Image.open(BytesIO(base64.b64decode(b64_data))) except Exception as e: logger.debug(f"Could not decode base64 image: {e}") elif page_data.startswith("http"): # URL - try to fetch try: import urllib.request from io import BytesIO with urllib.request.urlopen(page_data, timeout=5) as response: image = Image.open(BytesIO(response.read())) except Exception as e: logger.debug(f"Could not fetch image URL: {e}") if image: ax.imshow(image) else: # Show placeholder ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=12, color="gray") # Add title title = f"Rank {idx + 1}\nScore: {score:.3f}" if payload.get("filename"): title += f"\n{payload['filename'][:30]}" if payload.get("page_number") is not None: title += f" p.{payload['page_number'] + 1}" ax.set_title(title, fontsize=9) ax.axis("off") # Add query as suptitle query_display = query[:80] + "..." if len(query) > 80 else query plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold") plt.tight_layout() if output_path: plt.savefig(output_path, dpi=150, bbox_inches="tight") logger.info(f"💾 Saved visualization to: {output_path}") # Convert to PIL Image for return from io import BytesIO buf = BytesIO() plt.savefig(buf, format="png", dpi=100, bbox_inches="tight") buf.seek(0) result_image = Image.open(buf) plt.close() return result_image