audit_assistant / src /colpali /visualizer.py
akryldigital's picture
refactor
05c2a69 verified
"""
ColPali Visualization Module
Generates attention/saliency maps to visualize which parts of the document
are most relevant to a query.
"""
import logging
from typing import List, Dict, Any, Optional
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw, ImageFont
from matplotlib.colors import LinearSegmentedColormap
logger = logging.getLogger(__name__)
def generate_saliency_maps(
query_embedding: torch.Tensor,
image_embeddings: List[torch.Tensor],
images: List[Image.Image],
processor,
model,
top_k: int = 5,
threshold: float = 0.5
) -> List[Image.Image]:
"""
Generate saliency/attention maps showing which parts of images are most relevant.
Args:
query_embedding: Query embedding tensor [num_query_patches, embedding_dim]
image_embeddings: List of image embedding tensors, each [num_patches, embedding_dim]
images: List of PIL Images corresponding to embeddings
processor: ColPali processor for scoring
model: ColPali model
top_k: Number of top images to visualize
threshold: Threshold for highlighting (0-1)
Returns:
List of annotated images with saliency overlays
"""
logger.info(f"🎨 Generating saliency maps for {len(images)} images")
# Calculate scores for all images
scores = []
for img_emb in image_embeddings:
# Use processor's scoring method
score = processor.score_multi_vector(query_embedding.unsqueeze(0), img_emb.unsqueeze(0))
scores.append(score.item() if isinstance(score, torch.Tensor) else score)
# Get top-k images
top_indices = np.argsort(scores)[-top_k:][::-1]
annotated_images = []
for idx in top_indices:
image = images[idx]
embedding = image_embeddings[idx]
score = scores[idx]
# Create saliency map
# For ColPali, we can visualize patch-level relevance
# Each patch in the embedding corresponds to a region in the image
# Calculate patch-level scores
# Query embedding: [num_query_patches, dim]
# Image embedding: [num_image_patches, dim]
# Compute similarity for each patch pair
query_np = query_embedding.cpu().numpy()
img_np = embedding.cpu().numpy()
# Compute cosine similarity for each patch
# Normalize
query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
img_norm = img_np / (np.linalg.norm(img_np, axis=1, keepdims=True) + 1e-8)
# Compute similarity matrix: [num_query_patches, num_image_patches]
similarity_matrix = np.dot(query_norm, img_norm.T)
# Get max similarity per image patch (best match from any query patch)
patch_scores = similarity_matrix.max(axis=0) # [num_image_patches]
# Normalize scores to [0, 1]
patch_scores = (patch_scores - patch_scores.min()) / (patch_scores.max() - patch_scores.min() + 1e-8)
# Create overlay image
annotated = _create_saliency_overlay(
image,
patch_scores,
score,
threshold=threshold
)
annotated_images.append(annotated)
logger.info(f"✅ Generated {len(annotated_images)} saliency maps")
return annotated_images
def _create_saliency_overlay(
image: Image.Image,
patch_scores: np.ndarray,
overall_score: float,
threshold: float = 0.5,
patch_size: int = 16 # Approximate patch size in pixels
) -> Image.Image:
"""
Create saliency overlay on image.
Args:
image: Original PIL Image
patch_scores: Array of scores for each patch [num_patches]
overall_score: Overall relevance score
threshold: Threshold for highlighting
patch_size: Size of each patch in pixels
Returns:
Annotated PIL Image
"""
# Convert to numpy array
img_array = np.array(image)
h, w = img_array.shape[:2]
# Estimate grid dimensions
# ColPali typically uses a grid of patches
# For simplicity, assume square grid
num_patches = len(patch_scores)
grid_size = int(np.sqrt(num_patches))
if grid_size * grid_size != num_patches:
# Non-square grid, try to estimate
# Common aspect ratios
aspect_ratio = w / h
cols = int(np.sqrt(num_patches * aspect_ratio))
rows = int(num_patches / cols)
if cols * rows != num_patches:
# Fallback to square
grid_size = int(np.sqrt(num_patches))
rows = cols = grid_size
else:
rows = cols = grid_size
# Calculate patch dimensions
patch_h = h // rows
patch_w = w // cols
# Create overlay
overlay = np.zeros((h, w, 4), dtype=np.uint8) # RGBA
# Create colormap (red for high relevance)
cmap = plt.cm.Reds
patch_idx = 0
for i in range(rows):
for j in range(cols):
if patch_idx >= len(patch_scores):
break
score = patch_scores[patch_idx]
if score >= threshold:
# Calculate patch bounds
y1 = i * patch_h
y2 = min((i + 1) * patch_h, h)
x1 = j * patch_w
x2 = min((j + 1) * patch_w, w)
# Get color from colormap
color = cmap(score)[:3] # RGB
color_uint8 = (np.array(color) * 255).astype(np.uint8)
# Set overlay
overlay[y1:y2, x1:x2, :3] = color_uint8
overlay[y1:y2, x1:x2, 3] = int(score * 128) # Alpha based on score
patch_idx += 1
# Blend overlay with original image
overlay_img = Image.fromarray(overlay, 'RGBA')
annotated = Image.alpha_composite(image.convert('RGBA'), overlay_img)
# Add text annotation with score
draw = ImageDraw.Draw(annotated)
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24)
except:
font = ImageFont.load_default()
score_text = f"Relevance: {overall_score:.3f}"
draw.text((10, 10), score_text, fill=(255, 255, 255, 255), font=font, stroke_width=2, stroke_fill=(0, 0, 0, 255))
return annotated.convert('RGB')
def visualize_retrieval_results(
query: str,
retrieved_docs: List[Dict[str, Any]],
output_path: Optional[str] = None
) -> None:
"""
Visualize retrieval results with images and scores.
Args:
query: Original query text
retrieved_docs: List of retrieved documents with images and scores
output_path: Optional path to save visualization
"""
num_docs = len(retrieved_docs)
fig, axes = plt.subplots(1, num_docs, figsize=(5 * num_docs, 5))
if num_docs == 1:
axes = [axes]
for idx, (doc, ax) in enumerate(zip(retrieved_docs, axes)):
if 'image' in doc:
ax.imshow(doc['image'])
ax.set_title(f"Rank {idx+1}\nScore: {doc.get('score', 0):.3f}")
ax.axis('off')
plt.suptitle(f"Query: {query}", fontsize=14, fontweight='bold')
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"💾 Saved visualization to: {output_path}")
else:
plt.show()
plt.close()