akryldigital commited on
Commit
600f4a4
·
verified ·
1 Parent(s): 865589a

delete bug

Browse files
Files changed (1) hide show
  1. src/agents/saliency.py +0 -383
src/agents/saliency.py DELETED
@@ -1,383 +0,0 @@
1
- """
2
- Saliency Map Generation for Visual RAG
3
-
4
- This module provides saliency map generation for visual document search results.
5
- It implements the tile-aware ColBERT MaxSim strategy for accurate visualization
6
- of which image regions are relevant to a query.
7
-
8
- Key features:
9
- 1. Tile-aware architecture (understands 4×3 grid of 512×512 tiles)
10
- 2. Excludes global tile for cleaner saliency
11
- 3. Maps patches to resized image, then scales to original
12
- 4. Uses "hot" colormap by default for better visibility
13
- """
14
-
15
- import logging
16
- from typing import Any, Optional, Tuple
17
- from io import BytesIO
18
- from base64 import b64decode
19
-
20
- import numpy as np
21
- import requests
22
- from PIL import Image
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
- # Default saliency configuration
27
- DEFAULT_ALPHA = 0.4
28
- DEFAULT_COLORMAP = 'hot' # Better visibility than 'jet'
29
- DEFAULT_THRESHOLD_PERCENTILE = 50
30
-
31
-
32
- def convert_to_numpy(embedding, dtype: np.dtype = np.float32) -> np.ndarray:
33
- """
34
- Convert embedding to numpy array with proper dtype.
35
-
36
- Handles:
37
- - Lists
38
- - PyTorch tensors (including bfloat16)
39
- - NumPy arrays
40
- """
41
- try:
42
- import torch
43
- if isinstance(embedding, torch.Tensor):
44
- if embedding.dtype == torch.bfloat16:
45
- embedding = embedding.cpu().float()
46
- else:
47
- embedding = embedding.cpu()
48
- embedding = embedding.numpy()
49
- except ImportError:
50
- pass
51
-
52
- return np.array(embedding, dtype=dtype)
53
-
54
-
55
- def validate_embeddings(
56
- doc_embedding: np.ndarray,
57
- query_embedding: np.ndarray
58
- ) -> Tuple[bool, str]:
59
- """Validate embedding shapes and types."""
60
- if doc_embedding.ndim != 2:
61
- return False, f"Document embedding must be 2D, got {doc_embedding.ndim}D"
62
-
63
- if query_embedding.ndim != 2:
64
- return False, f"Query embedding must be 2D, got {query_embedding.ndim}D"
65
-
66
- if doc_embedding.shape[1] != query_embedding.shape[1]:
67
- return False, f"Embedding dimensions don't match: doc={doc_embedding.shape[1]}, query={query_embedding.shape[1]}"
68
-
69
- if np.any(np.isnan(doc_embedding)) or np.any(np.isinf(doc_embedding)):
70
- return False, "Document embedding contains NaN or Inf values"
71
-
72
- if np.any(np.isnan(query_embedding)) or np.any(np.isinf(query_embedding)):
73
- return False, "Query embedding contains NaN or Inf values"
74
-
75
- return True, ""
76
-
77
-
78
- def compute_maxsim_scores(
79
- doc_embedding: np.ndarray,
80
- query_embedding: np.ndarray,
81
- normalize: bool = True
82
- ) -> np.ndarray:
83
- """
84
- Compute MaxSim scores for ColBERT-style late interaction.
85
-
86
- MaxSim: For each document patch, find the maximum similarity
87
- across all query patches.
88
- """
89
- if normalize:
90
- doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
91
- query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8)
92
- else:
93
- doc_norm = doc_embedding
94
- query_norm = query_embedding
95
-
96
- similarity_matrix = np.dot(doc_norm, query_norm.T)
97
- patch_scores = np.max(similarity_matrix, axis=1)
98
-
99
- return patch_scores
100
-
101
-
102
- def normalize_scores(
103
- score_grid: np.ndarray,
104
- threshold_percentile: int = None
105
- ) -> np.ndarray:
106
- """Normalize score grid to 0-1 range with optional thresholding."""
107
- score_min = score_grid.min()
108
- score_max = score_grid.max()
109
-
110
- if score_max - score_min < 1e-8:
111
- logger.warning("All scores are identical, returning zeros")
112
- return np.zeros_like(score_grid, dtype=np.float32)
113
-
114
- score_grid_norm = (score_grid - score_min) / (score_max - score_min)
115
-
116
- if threshold_percentile is not None:
117
- score_threshold = np.percentile(score_grid, threshold_percentile)
118
- mask = score_grid < score_threshold
119
- score_grid_norm[mask] = 0.0
120
-
121
- visible_count = np.sum(~mask)
122
- total_count = score_grid.size
123
- logger.debug(f"Threshold: {score_threshold:.3f} ({threshold_percentile}th percentile)")
124
- logger.debug(f"Visible patches: {visible_count} / {total_count}")
125
-
126
- return score_grid_norm
127
-
128
-
129
- def download_image(page_url: str) -> Optional[Image.Image]:
130
- """Download image from URL or decode from data URI."""
131
- try:
132
- if page_url.startswith(("http://", "https://")):
133
- resp = requests.get(page_url, timeout=15)
134
- resp.raise_for_status()
135
- image = Image.open(BytesIO(resp.content))
136
- elif page_url.startswith("data:image"):
137
- b64_data = page_url.split(",", 1)[1]
138
- image = Image.open(BytesIO(b64decode(b64_data)))
139
- else:
140
- image = Image.open(page_url)
141
-
142
- if image.mode != "RGB":
143
- image = image.convert("RGB")
144
-
145
- return image
146
-
147
- except Exception as e:
148
- logger.error(f"Failed to load image: {e}")
149
- return None
150
-
151
-
152
- def apply_colormap_and_blend(
153
- score_grid: np.ndarray,
154
- image: Image.Image,
155
- alpha: float = DEFAULT_ALPHA,
156
- colormap: str = DEFAULT_COLORMAP
157
- ) -> Image.Image:
158
- """Apply colormap to scores and blend with original image."""
159
- from matplotlib import cm
160
-
161
- img_width, img_height = image.size
162
-
163
- # Resize heatmap to image size
164
- heatmap_pil = Image.fromarray((score_grid * 255).astype(np.uint8), mode='L')
165
- heatmap_resized = heatmap_pil.resize((img_width, img_height), Image.BILINEAR)
166
- heatmap_array = np.array(heatmap_resized) / 255.0
167
-
168
- # Apply colormap
169
- cmap = cm.get_cmap(colormap)
170
- heatmap_colored = cmap(heatmap_array)[:, :, :3]
171
- heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
172
- heatmap_img = Image.fromarray(heatmap_colored, mode='RGB')
173
-
174
- # Blend with original image
175
- overlay = Image.blend(image, heatmap_img, alpha=alpha)
176
-
177
- return overlay
178
-
179
-
180
- def generate_tile_aware_saliency(
181
- qdrant_client: Any,
182
- collection_name: str,
183
- point_id: str,
184
- query_embedding: np.ndarray,
185
- alpha: float = DEFAULT_ALPHA,
186
- colormap: str = DEFAULT_COLORMAP,
187
- threshold_percentile: int = DEFAULT_THRESHOLD_PERCENTILE
188
- ) -> Optional[Image.Image]:
189
- """
190
- Generate tile-aware saliency map for a document-query pair.
191
-
192
- This is the main function to call for saliency generation.
193
-
194
- Args:
195
- qdrant_client: Qdrant client instance
196
- collection_name: Name of the collection
197
- point_id: ID of the document point
198
- query_embedding: Query multi-vector embedding [num_query_patches, dim]
199
- alpha: Overlay transparency (0.0-1.0)
200
- colormap: Matplotlib colormap name (default: 'hot')
201
- threshold_percentile: Hide patches below this percentile (default: 50)
202
-
203
- Returns:
204
- PIL Image with saliency overlay, or None if generation fails
205
- """
206
- try:
207
- # Step 1: Fetch full multi-vector embedding AND payload
208
- logger.debug(f"Fetching point {point_id} with tile metadata from {collection_name}")
209
- points = qdrant_client.retrieve(
210
- collection_name=collection_name,
211
- ids=[point_id],
212
- with_vectors=["initial"],
213
- with_payload=True
214
- )
215
-
216
- if not points or len(points) == 0:
217
- logger.error(f"Point {point_id} not found in collection")
218
- return None
219
-
220
- point = points[0]
221
- doc_vector = point.vector.get("initial")
222
- payload = point.payload
223
-
224
- if doc_vector is None:
225
- logger.error("No 'initial' vector found for point")
226
- return None
227
-
228
- # Step 2: Get tile structure from payload
229
- num_tiles = payload.get('num_tiles')
230
- tile_rows = payload.get('tile_rows')
231
- tile_cols = payload.get('tile_cols')
232
- patches_per_tile = payload.get('patches_per_tile', 64)
233
-
234
- resized_width = payload.get('resized_width')
235
- resized_height = payload.get('resized_height')
236
- resized_url = payload.get('resized_url') or payload.get('page')
237
-
238
- original_width = payload.get('original_width')
239
- original_height = payload.get('original_height')
240
-
241
- if not all([num_tiles, tile_rows, tile_cols, resized_width, resized_height]):
242
- logger.warning("Missing tile metadata - cannot generate saliency")
243
- return None
244
-
245
- logger.info(f"✅ Tile structure: {tile_rows}×{tile_cols} tiles, {patches_per_tile} patches/tile")
246
- logger.info(f"✅ Resized image: {resized_width}×{resized_height}")
247
- logger.info(f"✅ Original image: {original_width}×{original_height}")
248
-
249
- # Step 3: Convert embeddings
250
- doc_embedding = convert_to_numpy(doc_vector)
251
- query_emb = convert_to_numpy(query_embedding)
252
-
253
- is_valid, error_msg = validate_embeddings(doc_embedding, query_emb)
254
- if not is_valid:
255
- logger.error(f"Embedding validation failed: {error_msg}")
256
- return None
257
-
258
- logger.info(f"Document embedding: {doc_embedding.shape}")
259
- logger.info(f"Query embedding: {query_emb.shape}")
260
-
261
- # Step 4: Separate tile embeddings from global tile
262
- total_patches = num_tiles * patches_per_tile
263
- tile_patches = total_patches - patches_per_tile # Exclude global
264
-
265
- if len(doc_embedding) < total_patches:
266
- logger.warning(f"Embedding size mismatch: got {len(doc_embedding)}, expected {total_patches}")
267
- tile_embeddings = doc_embedding[:tile_patches] if len(doc_embedding) > tile_patches else doc_embedding
268
- else:
269
- tile_embeddings = doc_embedding[:tile_patches]
270
-
271
- logger.info(f"Using {len(tile_embeddings)} tile patches (excluding global)")
272
-
273
- # Step 5: Compute MaxSim scores
274
- patch_scores = compute_maxsim_scores(tile_embeddings, query_emb, normalize=True)
275
- logger.info(f"Computed scores for {len(patch_scores)} patches")
276
-
277
- # Step 6: Reshape patches into tile structure
278
- patches_per_tile_side = int(np.sqrt(patches_per_tile)) # 8 for 64 patches
279
-
280
- try:
281
- num_actual_tiles = tile_rows * tile_cols
282
-
283
- if len(patch_scores) != num_actual_tiles * patches_per_tile:
284
- logger.error(f"Patch count mismatch: {len(patch_scores)} patches")
285
- return None
286
-
287
- tile_scores = patch_scores.reshape(num_actual_tiles, patches_per_tile)
288
-
289
- # Reshape each tile's patches to 8×8 grid (F-order)
290
- tile_grids = []
291
- for tile_idx in range(num_actual_tiles):
292
- tile_patch_scores = tile_scores[tile_idx]
293
- tile_grid = tile_patch_scores.reshape(
294
- patches_per_tile_side, patches_per_tile_side, order='F'
295
- )
296
- tile_grids.append(tile_grid)
297
-
298
- # Arrange tiles into full image grid
299
- full_grid_rows = []
300
- for row_idx in range(tile_rows):
301
- row_tiles = []
302
- for col_idx in range(tile_cols):
303
- tile_idx = row_idx * tile_cols + col_idx
304
- row_tiles.append(tile_grids[tile_idx])
305
- row_grid = np.concatenate(row_tiles, axis=1)
306
- full_grid_rows.append(row_grid)
307
-
308
- score_grid = np.concatenate(full_grid_rows, axis=0)
309
-
310
- logger.info(f"✅ Reconstructed grid: {score_grid.shape} (from {tile_rows}×{tile_cols} tiles)")
311
-
312
- except ValueError as e:
313
- logger.error(f"❌ Failed to reshape patches: {e}")
314
- return None
315
-
316
- # Step 7: Normalize scores
317
- score_grid_norm = normalize_scores(score_grid, threshold_percentile=threshold_percentile)
318
-
319
- # Step 8: Download RESIZED image
320
- logger.info(f"Downloading resized image from: {resized_url}")
321
- resized_image = download_image(resized_url)
322
- if resized_image is None:
323
- logger.error("Failed to download resized image")
324
- return None
325
-
326
- # Step 9: Apply heatmap to resized image
327
- overlay_resized = apply_colormap_and_blend(
328
- score_grid_norm, resized_image, alpha, colormap
329
- )
330
-
331
- # Step 10: Resize back to original dimensions
332
- if original_width and original_height:
333
- overlay_final = overlay_resized.resize(
334
- (original_width, original_height), Image.BILINEAR
335
- )
336
- logger.info(f"✅ Resized saliency map to original: {original_width}×{original_height}")
337
- else:
338
- overlay_final = overlay_resized
339
-
340
- logger.info(f"✅ Saliency map generated successfully")
341
- return overlay_final
342
-
343
- except Exception as e:
344
- logger.error(f"Saliency generation failed: {e}")
345
- import traceback
346
- logger.debug(traceback.format_exc())
347
- return None
348
-
349
-
350
- def can_generate_saliency(metadata: dict) -> bool:
351
- """
352
- Check if saliency can be generated for a document based on its metadata.
353
-
354
- Args:
355
- metadata: Document metadata dictionary
356
-
357
- Returns:
358
- True if all required tile metadata is present
359
- """
360
- required_fields = ['num_tiles', 'tile_rows', 'tile_cols', 'resized_width', 'resized_height']
361
- return all(metadata.get(field) is not None for field in required_fields)
362
-
363
-
364
- def get_saliency_metadata_summary(metadata: dict) -> str:
365
- """
366
- Get a summary of saliency-related metadata for display.
367
-
368
- Args:
369
- metadata: Document metadata dictionary
370
-
371
- Returns:
372
- Human-readable summary string
373
- """
374
- num_tiles = metadata.get('num_tiles', 'N/A')
375
- tile_rows = metadata.get('tile_rows', 'N/A')
376
- tile_cols = metadata.get('tile_cols', 'N/A')
377
- patches_per_tile = metadata.get('patches_per_tile', 64)
378
-
379
- if all(v != 'N/A' for v in [num_tiles, tile_rows, tile_cols]):
380
- return f"{tile_rows}×{tile_cols} tiles ({num_tiles} total), {patches_per_tile} patches/tile"
381
- else:
382
- return "Tile metadata not available"
383
-