File size: 11,314 Bytes
c4ef1cf
 
 
 
 
 
 
 
9513cca
 
 
 
c4ef1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
9513cca
c4ef1cf
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
9513cca
63fcb87
c4ef1cf
 
9513cca
 
c4ef1cf
9513cca
63fcb87
c4ef1cf
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
9513cca
 
c4ef1cf
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
 
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
 
 
 
9513cca
c4ef1cf
 
 
 
 
 
 
 
 
9513cca
c4ef1cf
 
 
 
9513cca
c4ef1cf
 
 
 
9513cca
 
c4ef1cf
 
 
 
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
 
9513cca
c4ef1cf
 
 
9513cca
c4ef1cf
 
9513cca
c4ef1cf
 
 
 
 
9513cca
c4ef1cf
9513cca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""
Saliency Map Generation for Visual Document Retrieval.

Generates attention/saliency maps to visualize which parts of documents
are most relevant to a query.
"""

import logging
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from PIL import Image

logger = logging.getLogger(__name__)


def generate_saliency_map(
    query_embedding: np.ndarray,
    doc_embedding: np.ndarray,
    image: Image.Image,
    token_info: Optional[Dict[str, Any]] = None,
    colormap: str = "Reds",
    alpha: float = 0.5,
    threshold_percentile: float = 50.0,
) -> Tuple[Image.Image, np.ndarray]:
    """
    Generate saliency map showing which parts of the image match the query.

    Computes patch-level relevance scores and overlays them on the image.

    Args:
        query_embedding: Query embeddings [num_query_tokens, dim]
        doc_embedding: Document visual embeddings [num_visual_tokens, dim]
        image: Original PIL Image
        token_info: Optional token info with n_rows, n_cols for tile grid
        colormap: Matplotlib colormap name (Reds, viridis, jet, etc.)
        alpha: Overlay transparency (0-1)
        threshold_percentile: Only highlight patches above this percentile

    Returns:
        Tuple of (annotated_image, patch_scores)

    Example:
        >>> query = embedder.embed_query("budget allocation")
        >>> doc = visual_embedding  # From embed_images
        >>> annotated, scores = generate_saliency_map(
        ...     query_embedding=query.numpy(),
        ...     doc_embedding=doc,
        ...     image=page_image,
        ...     token_info=token_info,
        ... )
        >>> annotated.save("saliency.png")
    """
    # Ensure numpy arrays
    if hasattr(query_embedding, "numpy"):
        query_np = query_embedding.numpy()
    elif hasattr(query_embedding, "cpu"):
        query_np = query_embedding.cpu().float().numpy()  # .float() for BFloat16
    else:
        query_np = np.array(query_embedding, dtype=np.float32)

    if hasattr(doc_embedding, "numpy"):
        doc_np = doc_embedding.numpy()
    elif hasattr(doc_embedding, "cpu"):
        doc_np = doc_embedding.cpu().float().numpy()  # .float() for BFloat16
    else:
        doc_np = np.array(doc_embedding, dtype=np.float32)

    # Normalize embeddings
    query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
    doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8)

    # Compute similarity matrix: [num_query, num_doc]
    similarity_matrix = np.dot(query_norm, doc_norm.T)

    # Get max similarity per document patch (best match from any query token)
    patch_scores = similarity_matrix.max(axis=0)

    # Normalize to [0, 1]
    score_min, score_max = patch_scores.min(), patch_scores.max()
    if score_max - score_min > 1e-8:
        patch_scores_norm = (patch_scores - score_min) / (score_max - score_min)
    else:
        patch_scores_norm = np.zeros_like(patch_scores)

    # Determine grid dimensions
    if token_info and token_info.get("n_rows") and token_info.get("n_cols"):
        n_rows = token_info["n_rows"]
        n_cols = token_info["n_cols"]
        num_tiles = n_rows * n_cols + 1  # +1 for global tile
        patches_per_tile = 64  # ColSmol standard

        # Reshape to tile grid (excluding global tile)
        try:
            # Skip global tile patches at the end
            tile_patches = num_tiles * patches_per_tile
            if len(patch_scores_norm) >= tile_patches:
                grid_patches = patch_scores_norm[: n_rows * n_cols * patches_per_tile]
            else:
                grid_patches = patch_scores_norm

            # Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile]
            # Then mean per tile
            num_grid_tiles = n_rows * n_cols
            grid_patches = grid_patches[: num_grid_tiles * patches_per_tile]
            tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1)
            tile_scores = tile_scores.reshape(n_rows, n_cols)
        except Exception as e:
            logger.warning(f"Could not reshape to tile grid: {e}")
            tile_scores = None
    else:
        tile_scores = None
        n_rows = n_cols = None

    # Create overlay
    annotated = create_saliency_overlay(
        image=image,
        scores=tile_scores if tile_scores is not None else patch_scores_norm,
        colormap=colormap,
        alpha=alpha,
        threshold_percentile=threshold_percentile,
        grid_rows=n_rows,
        grid_cols=n_cols,
    )

    return annotated, patch_scores


def create_saliency_overlay(
    image: Image.Image,
    scores: np.ndarray,
    colormap: str = "Reds",
    alpha: float = 0.5,
    threshold_percentile: float = 50.0,
    grid_rows: Optional[int] = None,
    grid_cols: Optional[int] = None,
) -> Image.Image:
    """
    Create colored overlay on image based on scores.

    Args:
        image: Base PIL Image
        scores: Score array - 1D [num_patches] or 2D [rows, cols]
        colormap: Matplotlib colormap name
        alpha: Overlay transparency
        threshold_percentile: Only color patches above this percentile
        grid_rows, grid_cols: Grid dimensions (auto-detected if not provided)

    Returns:
        Annotated PIL Image
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        logger.warning("matplotlib not installed, returning original image")
        return image

    img_array = np.array(image)
    h, w = img_array.shape[:2]

    # Handle 2D scores (tile grid)
    if scores.ndim == 2:
        rows, cols = scores.shape
    elif grid_rows and grid_cols:
        rows, cols = grid_rows, grid_cols
        # Reshape if possible
        if len(scores) == rows * cols:
            scores = scores.reshape(rows, cols)
        else:
            # Fallback: estimate grid from score count
            num_patches = len(scores)
            aspect = w / h
            cols = int(np.sqrt(num_patches * aspect))
            rows = max(1, num_patches // cols)
            scores = scores[: rows * cols].reshape(rows, cols)
    else:
        # Auto-estimate grid
        num_patches = len(scores) if scores.ndim == 1 else scores.size
        aspect = w / h
        cols = max(1, int(np.sqrt(num_patches * aspect)))
        rows = max(1, num_patches // cols)

        if rows * cols > len(scores) if scores.ndim == 1 else scores.size:
            cols = max(1, cols - 1)

        if scores.ndim == 1:
            scores = scores[: rows * cols].reshape(rows, cols)

    # Get colormap
    cmap = plt.cm.get_cmap(colormap)

    # Calculate threshold
    threshold = np.percentile(scores, threshold_percentile)

    # Calculate cell dimensions
    cell_h = h // rows
    cell_w = w // cols

    # Create RGBA overlay
    overlay = np.zeros((h, w, 4), dtype=np.uint8)

    for i in range(rows):
        for j in range(cols):
            score = scores[i, j]

            if score >= threshold:
                y1 = i * cell_h
                y2 = min((i + 1) * cell_h, h)
                x1 = j * cell_w
                x2 = min((j + 1) * cell_w, w)

                # Normalize score for coloring (above threshold)
                norm_score = (score - threshold) / (1.0 - threshold + 1e-8)
                norm_score = min(1.0, max(0.0, norm_score))

                # Get color
                color = cmap(norm_score)[:3]
                color_uint8 = (np.array(color) * 255).astype(np.uint8)

                overlay[y1:y2, x1:x2, :3] = color_uint8
                overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score)

    # Blend with original
    overlay_img = Image.fromarray(overlay, "RGBA")
    result = Image.alpha_composite(image.convert("RGBA"), overlay_img)

    return result.convert("RGB")


def visualize_search_results(
    query: str,
    results: List[Dict[str, Any]],
    query_embedding: Optional[np.ndarray] = None,
    embeddings: Optional[List[np.ndarray]] = None,
    output_path: Optional[str] = None,
    max_results: int = 5,
    show_saliency: bool = False,
) -> Optional[Image.Image]:
    """
    Visualize search results as a grid of images with scores.

    Args:
        query: Original query text
        results: List of search results with 'payload' containing 'page' (image URL/base64)
        query_embedding: Query embedding for saliency (optional)
        embeddings: Document embeddings for saliency (optional)
        output_path: Path to save visualization (optional)
        max_results: Maximum results to show
        show_saliency: Generate saliency overlays (requires query_embedding & embeddings)

    Returns:
        Combined visualization image if successful
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        logger.error("matplotlib required for visualization")
        return None

    results = results[:max_results]
    n = len(results)

    if n == 0:
        logger.warning("No results to visualize")
        return None

    fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
    if n == 1:
        axes = [axes]

    for idx, (result, ax) in enumerate(zip(results, axes)):
        payload = result.get("payload", {})
        score = result.get("score_final", result.get("score_stage1", 0))

        # Try to load image from payload
        page_data = payload.get("page", "")
        image = None

        if page_data.startswith("data:image"):
            # Base64 encoded
            try:
                import base64
                from io import BytesIO

                b64_data = page_data.split(",")[1]
                image = Image.open(BytesIO(base64.b64decode(b64_data)))
            except Exception as e:
                logger.debug(f"Could not decode base64 image: {e}")
        elif page_data.startswith("http"):
            # URL - try to fetch
            try:
                import urllib.request
                from io import BytesIO

                with urllib.request.urlopen(page_data, timeout=5) as response:
                    image = Image.open(BytesIO(response.read()))
            except Exception as e:
                logger.debug(f"Could not fetch image URL: {e}")

        if image:
            ax.imshow(image)
        else:
            # Show placeholder
            ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=12, color="gray")

        # Add title
        title = f"Rank {idx + 1}\nScore: {score:.3f}"
        if payload.get("filename"):
            title += f"\n{payload['filename'][:30]}"
        if payload.get("page_number") is not None:
            title += f" p.{payload['page_number'] + 1}"

        ax.set_title(title, fontsize=9)
        ax.axis("off")

    # Add query as suptitle
    query_display = query[:80] + "..." if len(query) > 80 else query
    plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold")
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches="tight")
        logger.info(f"💾 Saved visualization to: {output_path}")

    # Convert to PIL Image for return
    from io import BytesIO

    buf = BytesIO()
    plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
    buf.seek(0)
    result_image = Image.open(buf)

    plt.close()

    return result_image