Yeroyan's picture
Fix BFloat16 numpy conversion
63fcb87 verified
"""
Saliency Map Generation for Visual Document Retrieval.
Generates attention/saliency maps to visualize which parts of documents
are most relevant to a query.
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
def generate_saliency_map(
query_embedding: np.ndarray,
doc_embedding: np.ndarray,
image: Image.Image,
token_info: Optional[Dict[str, Any]] = None,
colormap: str = "Reds",
alpha: float = 0.5,
threshold_percentile: float = 50.0,
) -> Tuple[Image.Image, np.ndarray]:
"""
Generate saliency map showing which parts of the image match the query.
Computes patch-level relevance scores and overlays them on the image.
Args:
query_embedding: Query embeddings [num_query_tokens, dim]
doc_embedding: Document visual embeddings [num_visual_tokens, dim]
image: Original PIL Image
token_info: Optional token info with n_rows, n_cols for tile grid
colormap: Matplotlib colormap name (Reds, viridis, jet, etc.)
alpha: Overlay transparency (0-1)
threshold_percentile: Only highlight patches above this percentile
Returns:
Tuple of (annotated_image, patch_scores)
Example:
>>> query = embedder.embed_query("budget allocation")
>>> doc = visual_embedding # From embed_images
>>> annotated, scores = generate_saliency_map(
... query_embedding=query.numpy(),
... doc_embedding=doc,
... image=page_image,
... token_info=token_info,
... )
>>> annotated.save("saliency.png")
"""
# Ensure numpy arrays
if hasattr(query_embedding, "numpy"):
query_np = query_embedding.numpy()
elif hasattr(query_embedding, "cpu"):
query_np = query_embedding.cpu().float().numpy() # .float() for BFloat16
else:
query_np = np.array(query_embedding, dtype=np.float32)
if hasattr(doc_embedding, "numpy"):
doc_np = doc_embedding.numpy()
elif hasattr(doc_embedding, "cpu"):
doc_np = doc_embedding.cpu().float().numpy() # .float() for BFloat16
else:
doc_np = np.array(doc_embedding, dtype=np.float32)
# Normalize embeddings
query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8)
# Compute similarity matrix: [num_query, num_doc]
similarity_matrix = np.dot(query_norm, doc_norm.T)
# Get max similarity per document patch (best match from any query token)
patch_scores = similarity_matrix.max(axis=0)
# Normalize to [0, 1]
score_min, score_max = patch_scores.min(), patch_scores.max()
if score_max - score_min > 1e-8:
patch_scores_norm = (patch_scores - score_min) / (score_max - score_min)
else:
patch_scores_norm = np.zeros_like(patch_scores)
# Determine grid dimensions
if token_info and token_info.get("n_rows") and token_info.get("n_cols"):
n_rows = token_info["n_rows"]
n_cols = token_info["n_cols"]
num_tiles = n_rows * n_cols + 1 # +1 for global tile
patches_per_tile = 64 # ColSmol standard
# Reshape to tile grid (excluding global tile)
try:
# Skip global tile patches at the end
tile_patches = num_tiles * patches_per_tile
if len(patch_scores_norm) >= tile_patches:
grid_patches = patch_scores_norm[: n_rows * n_cols * patches_per_tile]
else:
grid_patches = patch_scores_norm
# Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile]
# Then mean per tile
num_grid_tiles = n_rows * n_cols
grid_patches = grid_patches[: num_grid_tiles * patches_per_tile]
tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1)
tile_scores = tile_scores.reshape(n_rows, n_cols)
except Exception as e:
logger.warning(f"Could not reshape to tile grid: {e}")
tile_scores = None
else:
tile_scores = None
n_rows = n_cols = None
# Create overlay
annotated = create_saliency_overlay(
image=image,
scores=tile_scores if tile_scores is not None else patch_scores_norm,
colormap=colormap,
alpha=alpha,
threshold_percentile=threshold_percentile,
grid_rows=n_rows,
grid_cols=n_cols,
)
return annotated, patch_scores
def create_saliency_overlay(
image: Image.Image,
scores: np.ndarray,
colormap: str = "Reds",
alpha: float = 0.5,
threshold_percentile: float = 50.0,
grid_rows: Optional[int] = None,
grid_cols: Optional[int] = None,
) -> Image.Image:
"""
Create colored overlay on image based on scores.
Args:
image: Base PIL Image
scores: Score array - 1D [num_patches] or 2D [rows, cols]
colormap: Matplotlib colormap name
alpha: Overlay transparency
threshold_percentile: Only color patches above this percentile
grid_rows, grid_cols: Grid dimensions (auto-detected if not provided)
Returns:
Annotated PIL Image
"""
try:
import matplotlib.pyplot as plt
except ImportError:
logger.warning("matplotlib not installed, returning original image")
return image
img_array = np.array(image)
h, w = img_array.shape[:2]
# Handle 2D scores (tile grid)
if scores.ndim == 2:
rows, cols = scores.shape
elif grid_rows and grid_cols:
rows, cols = grid_rows, grid_cols
# Reshape if possible
if len(scores) == rows * cols:
scores = scores.reshape(rows, cols)
else:
# Fallback: estimate grid from score count
num_patches = len(scores)
aspect = w / h
cols = int(np.sqrt(num_patches * aspect))
rows = max(1, num_patches // cols)
scores = scores[: rows * cols].reshape(rows, cols)
else:
# Auto-estimate grid
num_patches = len(scores) if scores.ndim == 1 else scores.size
aspect = w / h
cols = max(1, int(np.sqrt(num_patches * aspect)))
rows = max(1, num_patches // cols)
if rows * cols > len(scores) if scores.ndim == 1 else scores.size:
cols = max(1, cols - 1)
if scores.ndim == 1:
scores = scores[: rows * cols].reshape(rows, cols)
# Get colormap
cmap = plt.cm.get_cmap(colormap)
# Calculate threshold
threshold = np.percentile(scores, threshold_percentile)
# Calculate cell dimensions
cell_h = h // rows
cell_w = w // cols
# Create RGBA overlay
overlay = np.zeros((h, w, 4), dtype=np.uint8)
for i in range(rows):
for j in range(cols):
score = scores[i, j]
if score >= threshold:
y1 = i * cell_h
y2 = min((i + 1) * cell_h, h)
x1 = j * cell_w
x2 = min((j + 1) * cell_w, w)
# Normalize score for coloring (above threshold)
norm_score = (score - threshold) / (1.0 - threshold + 1e-8)
norm_score = min(1.0, max(0.0, norm_score))
# Get color
color = cmap(norm_score)[:3]
color_uint8 = (np.array(color) * 255).astype(np.uint8)
overlay[y1:y2, x1:x2, :3] = color_uint8
overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score)
# Blend with original
overlay_img = Image.fromarray(overlay, "RGBA")
result = Image.alpha_composite(image.convert("RGBA"), overlay_img)
return result.convert("RGB")
def visualize_search_results(
query: str,
results: List[Dict[str, Any]],
query_embedding: Optional[np.ndarray] = None,
embeddings: Optional[List[np.ndarray]] = None,
output_path: Optional[str] = None,
max_results: int = 5,
show_saliency: bool = False,
) -> Optional[Image.Image]:
"""
Visualize search results as a grid of images with scores.
Args:
query: Original query text
results: List of search results with 'payload' containing 'page' (image URL/base64)
query_embedding: Query embedding for saliency (optional)
embeddings: Document embeddings for saliency (optional)
output_path: Path to save visualization (optional)
max_results: Maximum results to show
show_saliency: Generate saliency overlays (requires query_embedding & embeddings)
Returns:
Combined visualization image if successful
"""
try:
import matplotlib.pyplot as plt
except ImportError:
logger.error("matplotlib required for visualization")
return None
results = results[:max_results]
n = len(results)
if n == 0:
logger.warning("No results to visualize")
return None
fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
if n == 1:
axes = [axes]
for idx, (result, ax) in enumerate(zip(results, axes)):
payload = result.get("payload", {})
score = result.get("score_final", result.get("score_stage1", 0))
# Try to load image from payload
page_data = payload.get("page", "")
image = None
if page_data.startswith("data:image"):
# Base64 encoded
try:
import base64
from io import BytesIO
b64_data = page_data.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(b64_data)))
except Exception as e:
logger.debug(f"Could not decode base64 image: {e}")
elif page_data.startswith("http"):
# URL - try to fetch
try:
import urllib.request
from io import BytesIO
with urllib.request.urlopen(page_data, timeout=5) as response:
image = Image.open(BytesIO(response.read()))
except Exception as e:
logger.debug(f"Could not fetch image URL: {e}")
if image:
ax.imshow(image)
else:
# Show placeholder
ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=12, color="gray")
# Add title
title = f"Rank {idx + 1}\nScore: {score:.3f}"
if payload.get("filename"):
title += f"\n{payload['filename'][:30]}"
if payload.get("page_number") is not None:
title += f" p.{payload['page_number'] + 1}"
ax.set_title(title, fontsize=9)
ax.axis("off")
# Add query as suptitle
query_display = query[:80] + "..." if len(query) > 80 else query
plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold")
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches="tight")
logger.info(f"💾 Saved visualization to: {output_path}")
# Convert to PIL Image for return
from io import BytesIO
buf = BytesIO()
plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
buf.seek(0)
result_image = Image.open(buf)
plt.close()
return result_image