""" ColPali Visualization Module Generates attention/saliency maps to visualize which parts of the document are most relevant to a query. """ import logging from typing import List, Dict, Any, Optional import torch import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image, ImageDraw, ImageFont from matplotlib.colors import LinearSegmentedColormap logger = logging.getLogger(__name__) def generate_saliency_maps( query_embedding: torch.Tensor, image_embeddings: List[torch.Tensor], images: List[Image.Image], processor, model, top_k: int = 5, threshold: float = 0.5 ) -> List[Image.Image]: """ Generate saliency/attention maps showing which parts of images are most relevant. Args: query_embedding: Query embedding tensor [num_query_patches, embedding_dim] image_embeddings: List of image embedding tensors, each [num_patches, embedding_dim] images: List of PIL Images corresponding to embeddings processor: ColPali processor for scoring model: ColPali model top_k: Number of top images to visualize threshold: Threshold for highlighting (0-1) Returns: List of annotated images with saliency overlays """ logger.info(f"🎨 Generating saliency maps for {len(images)} images") # Calculate scores for all images scores = [] for img_emb in image_embeddings: # Use processor's scoring method score = processor.score_multi_vector(query_embedding.unsqueeze(0), img_emb.unsqueeze(0)) scores.append(score.item() if isinstance(score, torch.Tensor) else score) # Get top-k images top_indices = np.argsort(scores)[-top_k:][::-1] annotated_images = [] for idx in top_indices: image = images[idx] embedding = image_embeddings[idx] score = scores[idx] # Create saliency map # For ColPali, we can visualize patch-level relevance # Each patch in the embedding corresponds to a region in the image # Calculate patch-level scores # Query embedding: [num_query_patches, dim] # Image embedding: [num_image_patches, dim] # Compute similarity for each patch pair query_np = query_embedding.cpu().numpy() img_np = embedding.cpu().numpy() # Compute cosine similarity for each patch # Normalize query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8) img_norm = img_np / (np.linalg.norm(img_np, axis=1, keepdims=True) + 1e-8) # Compute similarity matrix: [num_query_patches, num_image_patches] similarity_matrix = np.dot(query_norm, img_norm.T) # Get max similarity per image patch (best match from any query patch) patch_scores = similarity_matrix.max(axis=0) # [num_image_patches] # Normalize scores to [0, 1] patch_scores = (patch_scores - patch_scores.min()) / (patch_scores.max() - patch_scores.min() + 1e-8) # Create overlay image annotated = _create_saliency_overlay( image, patch_scores, score, threshold=threshold ) annotated_images.append(annotated) logger.info(f"✅ Generated {len(annotated_images)} saliency maps") return annotated_images def _create_saliency_overlay( image: Image.Image, patch_scores: np.ndarray, overall_score: float, threshold: float = 0.5, patch_size: int = 16 # Approximate patch size in pixels ) -> Image.Image: """ Create saliency overlay on image. Args: image: Original PIL Image patch_scores: Array of scores for each patch [num_patches] overall_score: Overall relevance score threshold: Threshold for highlighting patch_size: Size of each patch in pixels Returns: Annotated PIL Image """ # Convert to numpy array img_array = np.array(image) h, w = img_array.shape[:2] # Estimate grid dimensions # ColPali typically uses a grid of patches # For simplicity, assume square grid num_patches = len(patch_scores) grid_size = int(np.sqrt(num_patches)) if grid_size * grid_size != num_patches: # Non-square grid, try to estimate # Common aspect ratios aspect_ratio = w / h cols = int(np.sqrt(num_patches * aspect_ratio)) rows = int(num_patches / cols) if cols * rows != num_patches: # Fallback to square grid_size = int(np.sqrt(num_patches)) rows = cols = grid_size else: rows = cols = grid_size # Calculate patch dimensions patch_h = h // rows patch_w = w // cols # Create overlay overlay = np.zeros((h, w, 4), dtype=np.uint8) # RGBA # Create colormap (red for high relevance) cmap = plt.cm.Reds patch_idx = 0 for i in range(rows): for j in range(cols): if patch_idx >= len(patch_scores): break score = patch_scores[patch_idx] if score >= threshold: # Calculate patch bounds y1 = i * patch_h y2 = min((i + 1) * patch_h, h) x1 = j * patch_w x2 = min((j + 1) * patch_w, w) # Get color from colormap color = cmap(score)[:3] # RGB color_uint8 = (np.array(color) * 255).astype(np.uint8) # Set overlay overlay[y1:y2, x1:x2, :3] = color_uint8 overlay[y1:y2, x1:x2, 3] = int(score * 128) # Alpha based on score patch_idx += 1 # Blend overlay with original image overlay_img = Image.fromarray(overlay, 'RGBA') annotated = Image.alpha_composite(image.convert('RGBA'), overlay_img) # Add text annotation with score draw = ImageDraw.Draw(annotated) try: font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24) except: font = ImageFont.load_default() score_text = f"Relevance: {overall_score:.3f}" draw.text((10, 10), score_text, fill=(255, 255, 255, 255), font=font, stroke_width=2, stroke_fill=(0, 0, 0, 255)) return annotated.convert('RGB') def visualize_retrieval_results( query: str, retrieved_docs: List[Dict[str, Any]], output_path: Optional[str] = None ) -> None: """ Visualize retrieval results with images and scores. Args: query: Original query text retrieved_docs: List of retrieved documents with images and scores output_path: Optional path to save visualization """ num_docs = len(retrieved_docs) fig, axes = plt.subplots(1, num_docs, figsize=(5 * num_docs, 5)) if num_docs == 1: axes = [axes] for idx, (doc, ax) in enumerate(zip(retrieved_docs, axes)): if 'image' in doc: ax.imshow(doc['image']) ax.set_title(f"Rank {idx+1}\nScore: {doc.get('score', 0):.3f}") ax.axis('off') plt.suptitle(f"Query: {query}", fontsize=14, fontweight='bold') plt.tight_layout() if output_path: plt.savefig(output_path, dpi=150, bbox_inches='tight') logger.info(f"💾 Saved visualization to: {output_path}") else: plt.show() plt.close()