Spaces:
Sleeping
Sleeping
| """ | |
| ColPali Visualization Module | |
| Generates attention/saliency maps to visualize which parts of the document | |
| are most relevant to a query. | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image, ImageDraw, ImageFont | |
| from matplotlib.colors import LinearSegmentedColormap | |
| logger = logging.getLogger(__name__) | |
| def generate_saliency_maps( | |
| query_embedding: torch.Tensor, | |
| image_embeddings: List[torch.Tensor], | |
| images: List[Image.Image], | |
| processor, | |
| model, | |
| top_k: int = 5, | |
| threshold: float = 0.5 | |
| ) -> List[Image.Image]: | |
| """ | |
| Generate saliency/attention maps showing which parts of images are most relevant. | |
| Args: | |
| query_embedding: Query embedding tensor [num_query_patches, embedding_dim] | |
| image_embeddings: List of image embedding tensors, each [num_patches, embedding_dim] | |
| images: List of PIL Images corresponding to embeddings | |
| processor: ColPali processor for scoring | |
| model: ColPali model | |
| top_k: Number of top images to visualize | |
| threshold: Threshold for highlighting (0-1) | |
| Returns: | |
| List of annotated images with saliency overlays | |
| """ | |
| logger.info(f"🎨 Generating saliency maps for {len(images)} images") | |
| # Calculate scores for all images | |
| scores = [] | |
| for img_emb in image_embeddings: | |
| # Use processor's scoring method | |
| score = processor.score_multi_vector(query_embedding.unsqueeze(0), img_emb.unsqueeze(0)) | |
| scores.append(score.item() if isinstance(score, torch.Tensor) else score) | |
| # Get top-k images | |
| top_indices = np.argsort(scores)[-top_k:][::-1] | |
| annotated_images = [] | |
| for idx in top_indices: | |
| image = images[idx] | |
| embedding = image_embeddings[idx] | |
| score = scores[idx] | |
| # Create saliency map | |
| # For ColPali, we can visualize patch-level relevance | |
| # Each patch in the embedding corresponds to a region in the image | |
| # Calculate patch-level scores | |
| # Query embedding: [num_query_patches, dim] | |
| # Image embedding: [num_image_patches, dim] | |
| # Compute similarity for each patch pair | |
| query_np = query_embedding.cpu().numpy() | |
| img_np = embedding.cpu().numpy() | |
| # Compute cosine similarity for each patch | |
| # Normalize | |
| query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8) | |
| img_norm = img_np / (np.linalg.norm(img_np, axis=1, keepdims=True) + 1e-8) | |
| # Compute similarity matrix: [num_query_patches, num_image_patches] | |
| similarity_matrix = np.dot(query_norm, img_norm.T) | |
| # Get max similarity per image patch (best match from any query patch) | |
| patch_scores = similarity_matrix.max(axis=0) # [num_image_patches] | |
| # Normalize scores to [0, 1] | |
| patch_scores = (patch_scores - patch_scores.min()) / (patch_scores.max() - patch_scores.min() + 1e-8) | |
| # Create overlay image | |
| annotated = _create_saliency_overlay( | |
| image, | |
| patch_scores, | |
| score, | |
| threshold=threshold | |
| ) | |
| annotated_images.append(annotated) | |
| logger.info(f"✅ Generated {len(annotated_images)} saliency maps") | |
| return annotated_images | |
| def _create_saliency_overlay( | |
| image: Image.Image, | |
| patch_scores: np.ndarray, | |
| overall_score: float, | |
| threshold: float = 0.5, | |
| patch_size: int = 16 # Approximate patch size in pixels | |
| ) -> Image.Image: | |
| """ | |
| Create saliency overlay on image. | |
| Args: | |
| image: Original PIL Image | |
| patch_scores: Array of scores for each patch [num_patches] | |
| overall_score: Overall relevance score | |
| threshold: Threshold for highlighting | |
| patch_size: Size of each patch in pixels | |
| Returns: | |
| Annotated PIL Image | |
| """ | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| h, w = img_array.shape[:2] | |
| # Estimate grid dimensions | |
| # ColPali typically uses a grid of patches | |
| # For simplicity, assume square grid | |
| num_patches = len(patch_scores) | |
| grid_size = int(np.sqrt(num_patches)) | |
| if grid_size * grid_size != num_patches: | |
| # Non-square grid, try to estimate | |
| # Common aspect ratios | |
| aspect_ratio = w / h | |
| cols = int(np.sqrt(num_patches * aspect_ratio)) | |
| rows = int(num_patches / cols) | |
| if cols * rows != num_patches: | |
| # Fallback to square | |
| grid_size = int(np.sqrt(num_patches)) | |
| rows = cols = grid_size | |
| else: | |
| rows = cols = grid_size | |
| # Calculate patch dimensions | |
| patch_h = h // rows | |
| patch_w = w // cols | |
| # Create overlay | |
| overlay = np.zeros((h, w, 4), dtype=np.uint8) # RGBA | |
| # Create colormap (red for high relevance) | |
| cmap = plt.cm.Reds | |
| patch_idx = 0 | |
| for i in range(rows): | |
| for j in range(cols): | |
| if patch_idx >= len(patch_scores): | |
| break | |
| score = patch_scores[patch_idx] | |
| if score >= threshold: | |
| # Calculate patch bounds | |
| y1 = i * patch_h | |
| y2 = min((i + 1) * patch_h, h) | |
| x1 = j * patch_w | |
| x2 = min((j + 1) * patch_w, w) | |
| # Get color from colormap | |
| color = cmap(score)[:3] # RGB | |
| color_uint8 = (np.array(color) * 255).astype(np.uint8) | |
| # Set overlay | |
| overlay[y1:y2, x1:x2, :3] = color_uint8 | |
| overlay[y1:y2, x1:x2, 3] = int(score * 128) # Alpha based on score | |
| patch_idx += 1 | |
| # Blend overlay with original image | |
| overlay_img = Image.fromarray(overlay, 'RGBA') | |
| annotated = Image.alpha_composite(image.convert('RGBA'), overlay_img) | |
| # Add text annotation with score | |
| draw = ImageDraw.Draw(annotated) | |
| try: | |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24) | |
| except: | |
| font = ImageFont.load_default() | |
| score_text = f"Relevance: {overall_score:.3f}" | |
| draw.text((10, 10), score_text, fill=(255, 255, 255, 255), font=font, stroke_width=2, stroke_fill=(0, 0, 0, 255)) | |
| return annotated.convert('RGB') | |
| def visualize_retrieval_results( | |
| query: str, | |
| retrieved_docs: List[Dict[str, Any]], | |
| output_path: Optional[str] = None | |
| ) -> None: | |
| """ | |
| Visualize retrieval results with images and scores. | |
| Args: | |
| query: Original query text | |
| retrieved_docs: List of retrieved documents with images and scores | |
| output_path: Optional path to save visualization | |
| """ | |
| num_docs = len(retrieved_docs) | |
| fig, axes = plt.subplots(1, num_docs, figsize=(5 * num_docs, 5)) | |
| if num_docs == 1: | |
| axes = [axes] | |
| for idx, (doc, ax) in enumerate(zip(retrieved_docs, axes)): | |
| if 'image' in doc: | |
| ax.imshow(doc['image']) | |
| ax.set_title(f"Rank {idx+1}\nScore: {doc.get('score', 0):.3f}") | |
| ax.axis('off') | |
| plt.suptitle(f"Query: {query}", fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| if output_path: | |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') | |
| logger.info(f"💾 Saved visualization to: {output_path}") | |
| else: | |
| plt.show() | |
| plt.close() | |