akryldigital's picture
add saliency ui
7f3ae81 verified
"""
Saliency Map Generation for Visual RAG
This module provides saliency map generation for visual document search results.
It implements the tile-aware ColBERT MaxSim strategy for accurate visualization
of which image regions are relevant to a query.
Key features:
1. Tile-aware architecture (understands 4Γ—3 grid of 512Γ—512 tiles)
2. Excludes global tile for cleaner saliency
3. Maps patches to resized image, then scales to original
4. Uses "hot" colormap by default for better visibility
"""
import logging
from typing import Any, Optional, Tuple
from io import BytesIO
from base64 import b64decode
import numpy as np
import requests
from PIL import Image
logger = logging.getLogger(__name__)
# Default saliency configuration
DEFAULT_ALPHA = 0.4
DEFAULT_COLORMAP = 'hot' # Better visibility than 'jet'
DEFAULT_THRESHOLD_PERCENTILE = 50
def convert_to_numpy(embedding, dtype: np.dtype = np.float32) -> np.ndarray:
"""
Convert embedding to numpy array with proper dtype.
Handles:
- Lists
- PyTorch tensors (including bfloat16)
- NumPy arrays
"""
try:
import torch
if isinstance(embedding, torch.Tensor):
if embedding.dtype == torch.bfloat16:
embedding = embedding.cpu().float()
else:
embedding = embedding.cpu()
embedding = embedding.numpy()
except ImportError:
pass
return np.array(embedding, dtype=dtype)
def validate_embeddings(
doc_embedding: np.ndarray,
query_embedding: np.ndarray
) -> Tuple[bool, str]:
"""Validate embedding shapes and types."""
if doc_embedding.ndim != 2:
return False, f"Document embedding must be 2D, got {doc_embedding.ndim}D"
if query_embedding.ndim != 2:
return False, f"Query embedding must be 2D, got {query_embedding.ndim}D"
if doc_embedding.shape[1] != query_embedding.shape[1]:
return False, f"Embedding dimensions don't match: doc={doc_embedding.shape[1]}, query={query_embedding.shape[1]}"
if np.any(np.isnan(doc_embedding)) or np.any(np.isinf(doc_embedding)):
return False, "Document embedding contains NaN or Inf values"
if np.any(np.isnan(query_embedding)) or np.any(np.isinf(query_embedding)):
return False, "Query embedding contains NaN or Inf values"
return True, ""
def compute_maxsim_scores(
doc_embedding: np.ndarray,
query_embedding: np.ndarray,
normalize: bool = True
) -> np.ndarray:
"""
Compute MaxSim scores for ColBERT-style late interaction.
MaxSim: For each document patch, find the maximum similarity
across all query patches.
"""
if normalize:
doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8)
else:
doc_norm = doc_embedding
query_norm = query_embedding
similarity_matrix = np.dot(doc_norm, query_norm.T)
patch_scores = np.max(similarity_matrix, axis=1)
return patch_scores
def normalize_scores(
score_grid: np.ndarray,
threshold_percentile: int = None
) -> np.ndarray:
"""Normalize score grid to 0-1 range with optional thresholding."""
score_min = score_grid.min()
score_max = score_grid.max()
if score_max - score_min < 1e-8:
logger.warning("All scores are identical, returning zeros")
return np.zeros_like(score_grid, dtype=np.float32)
score_grid_norm = (score_grid - score_min) / (score_max - score_min)
if threshold_percentile is not None:
score_threshold = np.percentile(score_grid, threshold_percentile)
mask = score_grid < score_threshold
score_grid_norm[mask] = 0.0
visible_count = np.sum(~mask)
total_count = score_grid.size
logger.debug(f"Threshold: {score_threshold:.3f} ({threshold_percentile}th percentile)")
logger.debug(f"Visible patches: {visible_count} / {total_count}")
return score_grid_norm
def download_image(page_url: str) -> Optional[Image.Image]:
"""Download image from URL or decode from data URI."""
try:
if page_url.startswith(("http://", "https://")):
resp = requests.get(page_url, timeout=15)
resp.raise_for_status()
image = Image.open(BytesIO(resp.content))
elif page_url.startswith("data:image"):
b64_data = page_url.split(",", 1)[1]
image = Image.open(BytesIO(b64decode(b64_data)))
else:
image = Image.open(page_url)
if image.mode != "RGB":
image = image.convert("RGB")
return image
except Exception as e:
logger.error(f"Failed to load image: {e}")
return None
def apply_colormap_and_blend(
score_grid: np.ndarray,
image: Image.Image,
alpha: float = DEFAULT_ALPHA,
colormap: str = DEFAULT_COLORMAP
) -> Image.Image:
"""Apply colormap to scores and blend with original image."""
from matplotlib import cm
img_width, img_height = image.size
# Resize heatmap to image size
heatmap_pil = Image.fromarray((score_grid * 255).astype(np.uint8), mode='L')
heatmap_resized = heatmap_pil.resize((img_width, img_height), Image.BILINEAR)
heatmap_array = np.array(heatmap_resized) / 255.0
# Apply colormap
cmap = cm.get_cmap(colormap)
heatmap_colored = cmap(heatmap_array)[:, :, :3]
heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
heatmap_img = Image.fromarray(heatmap_colored, mode='RGB')
# Blend with original image
overlay = Image.blend(image, heatmap_img, alpha=alpha)
return overlay
def generate_tile_aware_saliency(
qdrant_client: Any,
collection_name: str,
point_id: str,
query_embedding: np.ndarray,
alpha: float = DEFAULT_ALPHA,
colormap: str = DEFAULT_COLORMAP,
threshold_percentile: int = DEFAULT_THRESHOLD_PERCENTILE
) -> Optional[Image.Image]:
"""
Generate tile-aware saliency map for a document-query pair.
This is the main function to call for saliency generation.
Args:
qdrant_client: Qdrant client instance
collection_name: Name of the collection
point_id: ID of the document point
query_embedding: Query multi-vector embedding [num_query_patches, dim]
alpha: Overlay transparency (0.0-1.0)
colormap: Matplotlib colormap name (default: 'hot')
threshold_percentile: Hide patches below this percentile (default: 50)
Returns:
PIL Image with saliency overlay, or None if generation fails
"""
try:
# Step 1: Fetch full multi-vector embedding AND payload
logger.debug(f"Fetching point {point_id} with tile metadata from {collection_name}")
points = qdrant_client.retrieve(
collection_name=collection_name,
ids=[point_id],
with_vectors=["initial"],
with_payload=True
)
if not points or len(points) == 0:
logger.error(f"Point {point_id} not found in collection")
return None
point = points[0]
doc_vector = point.vector.get("initial")
payload = point.payload
if doc_vector is None:
logger.error("No 'initial' vector found for point")
return None
# Step 2: Get tile structure from payload
num_tiles = payload.get('num_tiles')
tile_rows = payload.get('tile_rows')
tile_cols = payload.get('tile_cols')
patches_per_tile = payload.get('patches_per_tile', 64)
resized_width = payload.get('resized_width')
resized_height = payload.get('resized_height')
resized_url = payload.get('resized_url') or payload.get('page')
original_width = payload.get('original_width')
original_height = payload.get('original_height')
if not all([num_tiles, tile_rows, tile_cols, resized_width, resized_height]):
logger.warning("Missing tile metadata - cannot generate saliency")
return None
logger.info(f"βœ… Tile structure: {tile_rows}Γ—{tile_cols} tiles, {patches_per_tile} patches/tile")
logger.info(f"βœ… Resized image: {resized_width}Γ—{resized_height}")
logger.info(f"βœ… Original image: {original_width}Γ—{original_height}")
# Step 3: Convert embeddings
doc_embedding = convert_to_numpy(doc_vector)
query_emb = convert_to_numpy(query_embedding)
is_valid, error_msg = validate_embeddings(doc_embedding, query_emb)
if not is_valid:
logger.error(f"Embedding validation failed: {error_msg}")
return None
logger.info(f"Document embedding: {doc_embedding.shape}")
logger.info(f"Query embedding: {query_emb.shape}")
# Step 4: Separate tile embeddings from global tile
total_patches = num_tiles * patches_per_tile
tile_patches = total_patches - patches_per_tile # Exclude global
if len(doc_embedding) < total_patches:
logger.warning(f"Embedding size mismatch: got {len(doc_embedding)}, expected {total_patches}")
tile_embeddings = doc_embedding[:tile_patches] if len(doc_embedding) > tile_patches else doc_embedding
else:
tile_embeddings = doc_embedding[:tile_patches]
logger.info(f"Using {len(tile_embeddings)} tile patches (excluding global)")
# Step 5: Compute MaxSim scores
patch_scores = compute_maxsim_scores(tile_embeddings, query_emb, normalize=True)
logger.info(f"Computed scores for {len(patch_scores)} patches")
# Step 6: Reshape patches into tile structure
patches_per_tile_side = int(np.sqrt(patches_per_tile)) # 8 for 64 patches
try:
num_actual_tiles = tile_rows * tile_cols
if len(patch_scores) != num_actual_tiles * patches_per_tile:
logger.error(f"Patch count mismatch: {len(patch_scores)} patches")
return None
tile_scores = patch_scores.reshape(num_actual_tiles, patches_per_tile)
# Reshape each tile's patches to 8Γ—8 grid (F-order)
tile_grids = []
for tile_idx in range(num_actual_tiles):
tile_patch_scores = tile_scores[tile_idx]
tile_grid = tile_patch_scores.reshape(
patches_per_tile_side, patches_per_tile_side, order='F'
)
tile_grids.append(tile_grid)
# Arrange tiles into full image grid
full_grid_rows = []
for row_idx in range(tile_rows):
row_tiles = []
for col_idx in range(tile_cols):
tile_idx = row_idx * tile_cols + col_idx
row_tiles.append(tile_grids[tile_idx])
row_grid = np.concatenate(row_tiles, axis=1)
full_grid_rows.append(row_grid)
score_grid = np.concatenate(full_grid_rows, axis=0)
logger.info(f"βœ… Reconstructed grid: {score_grid.shape} (from {tile_rows}Γ—{tile_cols} tiles)")
except ValueError as e:
logger.error(f"❌ Failed to reshape patches: {e}")
return None
# Step 7: Normalize scores
score_grid_norm = normalize_scores(score_grid, threshold_percentile=threshold_percentile)
# Step 8: Download RESIZED image
logger.info(f"Downloading resized image from: {resized_url}")
resized_image = download_image(resized_url)
if resized_image is None:
logger.error("Failed to download resized image")
return None
# Step 9: Apply heatmap to resized image
overlay_resized = apply_colormap_and_blend(
score_grid_norm, resized_image, alpha, colormap
)
# Step 10: Resize back to original dimensions
if original_width and original_height:
overlay_final = overlay_resized.resize(
(original_width, original_height), Image.BILINEAR
)
logger.info(f"βœ… Resized saliency map to original: {original_width}Γ—{original_height}")
else:
overlay_final = overlay_resized
logger.info(f"βœ… Saliency map generated successfully")
return overlay_final
except Exception as e:
logger.error(f"Saliency generation failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def can_generate_saliency(metadata: dict) -> bool:
"""
Check if saliency can be generated for a document based on its metadata.
Args:
metadata: Document metadata dictionary
Returns:
True if all required tile metadata is present
"""
required_fields = ['num_tiles', 'tile_rows', 'tile_cols', 'resized_width', 'resized_height']
return all(metadata.get(field) is not None for field in required_fields)
def get_saliency_metadata_summary(metadata: dict) -> str:
"""
Get a summary of saliency-related metadata for display.
Args:
metadata: Document metadata dictionary
Returns:
Human-readable summary string
"""
num_tiles = metadata.get('num_tiles', 'N/A')
tile_rows = metadata.get('tile_rows', 'N/A')
tile_cols = metadata.get('tile_cols', 'N/A')
patches_per_tile = metadata.get('patches_per_tile', 64)
if all(v != 'N/A' for v in [num_tiles, tile_rows, tile_cols]):
return f"{tile_rows}Γ—{tile_cols} tiles ({num_tiles} total), {patches_per_tile} patches/tile"
else:
return "Tile metadata not available"