akryldigital commited on
Commit
7f3ae81
Β·
verified Β·
1 Parent(s): 215981c

add saliency ui

Browse files
src/ui_components/__init__.py CHANGED
@@ -11,11 +11,36 @@ from .components import (
11
  display_chunk_statistics_table
12
  )
13
  from .utils import extract_chunk_statistics
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  __all__ = [
16
  "get_custom_css",
17
  "display_chunk_statistics_charts",
18
  "display_chunk_statistics_table",
19
- "extract_chunk_statistics"
 
 
 
 
 
 
 
 
 
 
20
  ]
21
 
 
 
 
11
  display_chunk_statistics_table
12
  )
13
  from .utils import extract_chunk_statistics
14
+ from .visual_documents import (
15
+ display_visual_search_results,
16
+ display_visual_document_statistics,
17
+ display_visual_document_details
18
+ )
19
+ from .saliency import (
20
+ generate_tile_aware_saliency,
21
+ can_generate_saliency,
22
+ get_saliency_metadata_summary,
23
+ DEFAULT_ALPHA,
24
+ DEFAULT_COLORMAP,
25
+ DEFAULT_THRESHOLD_PERCENTILE
26
+ )
27
 
28
  __all__ = [
29
  "get_custom_css",
30
  "display_chunk_statistics_charts",
31
  "display_chunk_statistics_table",
32
+ "extract_chunk_statistics",
33
+ "display_visual_search_results",
34
+ "display_visual_document_statistics",
35
+ "display_visual_document_details",
36
+ # Saliency functions
37
+ "generate_tile_aware_saliency",
38
+ "can_generate_saliency",
39
+ "get_saliency_metadata_summary",
40
+ "DEFAULT_ALPHA",
41
+ "DEFAULT_COLORMAP",
42
+ "DEFAULT_THRESHOLD_PERCENTILE"
43
  ]
44
 
45
+
46
+
src/ui_components/saliency.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Saliency Map Generation for Visual RAG
3
+
4
+ This module provides saliency map generation for visual document search results.
5
+ It implements the tile-aware ColBERT MaxSim strategy for accurate visualization
6
+ of which image regions are relevant to a query.
7
+
8
+ Key features:
9
+ 1. Tile-aware architecture (understands 4Γ—3 grid of 512Γ—512 tiles)
10
+ 2. Excludes global tile for cleaner saliency
11
+ 3. Maps patches to resized image, then scales to original
12
+ 4. Uses "hot" colormap by default for better visibility
13
+ """
14
+
15
+ import logging
16
+ from typing import Any, Optional, Tuple
17
+ from io import BytesIO
18
+ from base64 import b64decode
19
+
20
+ import numpy as np
21
+ import requests
22
+ from PIL import Image
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Default saliency configuration
27
+ DEFAULT_ALPHA = 0.4
28
+ DEFAULT_COLORMAP = 'hot' # Better visibility than 'jet'
29
+ DEFAULT_THRESHOLD_PERCENTILE = 50
30
+
31
+
32
+ def convert_to_numpy(embedding, dtype: np.dtype = np.float32) -> np.ndarray:
33
+ """
34
+ Convert embedding to numpy array with proper dtype.
35
+
36
+ Handles:
37
+ - Lists
38
+ - PyTorch tensors (including bfloat16)
39
+ - NumPy arrays
40
+ """
41
+ try:
42
+ import torch
43
+ if isinstance(embedding, torch.Tensor):
44
+ if embedding.dtype == torch.bfloat16:
45
+ embedding = embedding.cpu().float()
46
+ else:
47
+ embedding = embedding.cpu()
48
+ embedding = embedding.numpy()
49
+ except ImportError:
50
+ pass
51
+
52
+ return np.array(embedding, dtype=dtype)
53
+
54
+
55
+ def validate_embeddings(
56
+ doc_embedding: np.ndarray,
57
+ query_embedding: np.ndarray
58
+ ) -> Tuple[bool, str]:
59
+ """Validate embedding shapes and types."""
60
+ if doc_embedding.ndim != 2:
61
+ return False, f"Document embedding must be 2D, got {doc_embedding.ndim}D"
62
+
63
+ if query_embedding.ndim != 2:
64
+ return False, f"Query embedding must be 2D, got {query_embedding.ndim}D"
65
+
66
+ if doc_embedding.shape[1] != query_embedding.shape[1]:
67
+ return False, f"Embedding dimensions don't match: doc={doc_embedding.shape[1]}, query={query_embedding.shape[1]}"
68
+
69
+ if np.any(np.isnan(doc_embedding)) or np.any(np.isinf(doc_embedding)):
70
+ return False, "Document embedding contains NaN or Inf values"
71
+
72
+ if np.any(np.isnan(query_embedding)) or np.any(np.isinf(query_embedding)):
73
+ return False, "Query embedding contains NaN or Inf values"
74
+
75
+ return True, ""
76
+
77
+
78
+ def compute_maxsim_scores(
79
+ doc_embedding: np.ndarray,
80
+ query_embedding: np.ndarray,
81
+ normalize: bool = True
82
+ ) -> np.ndarray:
83
+ """
84
+ Compute MaxSim scores for ColBERT-style late interaction.
85
+
86
+ MaxSim: For each document patch, find the maximum similarity
87
+ across all query patches.
88
+ """
89
+ if normalize:
90
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
91
+ query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8)
92
+ else:
93
+ doc_norm = doc_embedding
94
+ query_norm = query_embedding
95
+
96
+ similarity_matrix = np.dot(doc_norm, query_norm.T)
97
+ patch_scores = np.max(similarity_matrix, axis=1)
98
+
99
+ return patch_scores
100
+
101
+
102
+ def normalize_scores(
103
+ score_grid: np.ndarray,
104
+ threshold_percentile: int = None
105
+ ) -> np.ndarray:
106
+ """Normalize score grid to 0-1 range with optional thresholding."""
107
+ score_min = score_grid.min()
108
+ score_max = score_grid.max()
109
+
110
+ if score_max - score_min < 1e-8:
111
+ logger.warning("All scores are identical, returning zeros")
112
+ return np.zeros_like(score_grid, dtype=np.float32)
113
+
114
+ score_grid_norm = (score_grid - score_min) / (score_max - score_min)
115
+
116
+ if threshold_percentile is not None:
117
+ score_threshold = np.percentile(score_grid, threshold_percentile)
118
+ mask = score_grid < score_threshold
119
+ score_grid_norm[mask] = 0.0
120
+
121
+ visible_count = np.sum(~mask)
122
+ total_count = score_grid.size
123
+ logger.debug(f"Threshold: {score_threshold:.3f} ({threshold_percentile}th percentile)")
124
+ logger.debug(f"Visible patches: {visible_count} / {total_count}")
125
+
126
+ return score_grid_norm
127
+
128
+
129
+ def download_image(page_url: str) -> Optional[Image.Image]:
130
+ """Download image from URL or decode from data URI."""
131
+ try:
132
+ if page_url.startswith(("http://", "https://")):
133
+ resp = requests.get(page_url, timeout=15)
134
+ resp.raise_for_status()
135
+ image = Image.open(BytesIO(resp.content))
136
+ elif page_url.startswith("data:image"):
137
+ b64_data = page_url.split(",", 1)[1]
138
+ image = Image.open(BytesIO(b64decode(b64_data)))
139
+ else:
140
+ image = Image.open(page_url)
141
+
142
+ if image.mode != "RGB":
143
+ image = image.convert("RGB")
144
+
145
+ return image
146
+
147
+ except Exception as e:
148
+ logger.error(f"Failed to load image: {e}")
149
+ return None
150
+
151
+
152
+ def apply_colormap_and_blend(
153
+ score_grid: np.ndarray,
154
+ image: Image.Image,
155
+ alpha: float = DEFAULT_ALPHA,
156
+ colormap: str = DEFAULT_COLORMAP
157
+ ) -> Image.Image:
158
+ """Apply colormap to scores and blend with original image."""
159
+ from matplotlib import cm
160
+
161
+ img_width, img_height = image.size
162
+
163
+ # Resize heatmap to image size
164
+ heatmap_pil = Image.fromarray((score_grid * 255).astype(np.uint8), mode='L')
165
+ heatmap_resized = heatmap_pil.resize((img_width, img_height), Image.BILINEAR)
166
+ heatmap_array = np.array(heatmap_resized) / 255.0
167
+
168
+ # Apply colormap
169
+ cmap = cm.get_cmap(colormap)
170
+ heatmap_colored = cmap(heatmap_array)[:, :, :3]
171
+ heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
172
+ heatmap_img = Image.fromarray(heatmap_colored, mode='RGB')
173
+
174
+ # Blend with original image
175
+ overlay = Image.blend(image, heatmap_img, alpha=alpha)
176
+
177
+ return overlay
178
+
179
+
180
+ def generate_tile_aware_saliency(
181
+ qdrant_client: Any,
182
+ collection_name: str,
183
+ point_id: str,
184
+ query_embedding: np.ndarray,
185
+ alpha: float = DEFAULT_ALPHA,
186
+ colormap: str = DEFAULT_COLORMAP,
187
+ threshold_percentile: int = DEFAULT_THRESHOLD_PERCENTILE
188
+ ) -> Optional[Image.Image]:
189
+ """
190
+ Generate tile-aware saliency map for a document-query pair.
191
+
192
+ This is the main function to call for saliency generation.
193
+
194
+ Args:
195
+ qdrant_client: Qdrant client instance
196
+ collection_name: Name of the collection
197
+ point_id: ID of the document point
198
+ query_embedding: Query multi-vector embedding [num_query_patches, dim]
199
+ alpha: Overlay transparency (0.0-1.0)
200
+ colormap: Matplotlib colormap name (default: 'hot')
201
+ threshold_percentile: Hide patches below this percentile (default: 50)
202
+
203
+ Returns:
204
+ PIL Image with saliency overlay, or None if generation fails
205
+ """
206
+ try:
207
+ # Step 1: Fetch full multi-vector embedding AND payload
208
+ logger.debug(f"Fetching point {point_id} with tile metadata from {collection_name}")
209
+ points = qdrant_client.retrieve(
210
+ collection_name=collection_name,
211
+ ids=[point_id],
212
+ with_vectors=["initial"],
213
+ with_payload=True
214
+ )
215
+
216
+ if not points or len(points) == 0:
217
+ logger.error(f"Point {point_id} not found in collection")
218
+ return None
219
+
220
+ point = points[0]
221
+ doc_vector = point.vector.get("initial")
222
+ payload = point.payload
223
+
224
+ if doc_vector is None:
225
+ logger.error("No 'initial' vector found for point")
226
+ return None
227
+
228
+ # Step 2: Get tile structure from payload
229
+ num_tiles = payload.get('num_tiles')
230
+ tile_rows = payload.get('tile_rows')
231
+ tile_cols = payload.get('tile_cols')
232
+ patches_per_tile = payload.get('patches_per_tile', 64)
233
+
234
+ resized_width = payload.get('resized_width')
235
+ resized_height = payload.get('resized_height')
236
+ resized_url = payload.get('resized_url') or payload.get('page')
237
+
238
+ original_width = payload.get('original_width')
239
+ original_height = payload.get('original_height')
240
+
241
+ if not all([num_tiles, tile_rows, tile_cols, resized_width, resized_height]):
242
+ logger.warning("Missing tile metadata - cannot generate saliency")
243
+ return None
244
+
245
+ logger.info(f"βœ… Tile structure: {tile_rows}Γ—{tile_cols} tiles, {patches_per_tile} patches/tile")
246
+ logger.info(f"βœ… Resized image: {resized_width}Γ—{resized_height}")
247
+ logger.info(f"βœ… Original image: {original_width}Γ—{original_height}")
248
+
249
+ # Step 3: Convert embeddings
250
+ doc_embedding = convert_to_numpy(doc_vector)
251
+ query_emb = convert_to_numpy(query_embedding)
252
+
253
+ is_valid, error_msg = validate_embeddings(doc_embedding, query_emb)
254
+ if not is_valid:
255
+ logger.error(f"Embedding validation failed: {error_msg}")
256
+ return None
257
+
258
+ logger.info(f"Document embedding: {doc_embedding.shape}")
259
+ logger.info(f"Query embedding: {query_emb.shape}")
260
+
261
+ # Step 4: Separate tile embeddings from global tile
262
+ total_patches = num_tiles * patches_per_tile
263
+ tile_patches = total_patches - patches_per_tile # Exclude global
264
+
265
+ if len(doc_embedding) < total_patches:
266
+ logger.warning(f"Embedding size mismatch: got {len(doc_embedding)}, expected {total_patches}")
267
+ tile_embeddings = doc_embedding[:tile_patches] if len(doc_embedding) > tile_patches else doc_embedding
268
+ else:
269
+ tile_embeddings = doc_embedding[:tile_patches]
270
+
271
+ logger.info(f"Using {len(tile_embeddings)} tile patches (excluding global)")
272
+
273
+ # Step 5: Compute MaxSim scores
274
+ patch_scores = compute_maxsim_scores(tile_embeddings, query_emb, normalize=True)
275
+ logger.info(f"Computed scores for {len(patch_scores)} patches")
276
+
277
+ # Step 6: Reshape patches into tile structure
278
+ patches_per_tile_side = int(np.sqrt(patches_per_tile)) # 8 for 64 patches
279
+
280
+ try:
281
+ num_actual_tiles = tile_rows * tile_cols
282
+
283
+ if len(patch_scores) != num_actual_tiles * patches_per_tile:
284
+ logger.error(f"Patch count mismatch: {len(patch_scores)} patches")
285
+ return None
286
+
287
+ tile_scores = patch_scores.reshape(num_actual_tiles, patches_per_tile)
288
+
289
+ # Reshape each tile's patches to 8Γ—8 grid (F-order)
290
+ tile_grids = []
291
+ for tile_idx in range(num_actual_tiles):
292
+ tile_patch_scores = tile_scores[tile_idx]
293
+ tile_grid = tile_patch_scores.reshape(
294
+ patches_per_tile_side, patches_per_tile_side, order='F'
295
+ )
296
+ tile_grids.append(tile_grid)
297
+
298
+ # Arrange tiles into full image grid
299
+ full_grid_rows = []
300
+ for row_idx in range(tile_rows):
301
+ row_tiles = []
302
+ for col_idx in range(tile_cols):
303
+ tile_idx = row_idx * tile_cols + col_idx
304
+ row_tiles.append(tile_grids[tile_idx])
305
+ row_grid = np.concatenate(row_tiles, axis=1)
306
+ full_grid_rows.append(row_grid)
307
+
308
+ score_grid = np.concatenate(full_grid_rows, axis=0)
309
+
310
+ logger.info(f"βœ… Reconstructed grid: {score_grid.shape} (from {tile_rows}Γ—{tile_cols} tiles)")
311
+
312
+ except ValueError as e:
313
+ logger.error(f"❌ Failed to reshape patches: {e}")
314
+ return None
315
+
316
+ # Step 7: Normalize scores
317
+ score_grid_norm = normalize_scores(score_grid, threshold_percentile=threshold_percentile)
318
+
319
+ # Step 8: Download RESIZED image
320
+ logger.info(f"Downloading resized image from: {resized_url}")
321
+ resized_image = download_image(resized_url)
322
+ if resized_image is None:
323
+ logger.error("Failed to download resized image")
324
+ return None
325
+
326
+ # Step 9: Apply heatmap to resized image
327
+ overlay_resized = apply_colormap_and_blend(
328
+ score_grid_norm, resized_image, alpha, colormap
329
+ )
330
+
331
+ # Step 10: Resize back to original dimensions
332
+ if original_width and original_height:
333
+ overlay_final = overlay_resized.resize(
334
+ (original_width, original_height), Image.BILINEAR
335
+ )
336
+ logger.info(f"βœ… Resized saliency map to original: {original_width}Γ—{original_height}")
337
+ else:
338
+ overlay_final = overlay_resized
339
+
340
+ logger.info(f"βœ… Saliency map generated successfully")
341
+ return overlay_final
342
+
343
+ except Exception as e:
344
+ logger.error(f"Saliency generation failed: {e}")
345
+ import traceback
346
+ logger.debug(traceback.format_exc())
347
+ return None
348
+
349
+
350
+ def can_generate_saliency(metadata: dict) -> bool:
351
+ """
352
+ Check if saliency can be generated for a document based on its metadata.
353
+
354
+ Args:
355
+ metadata: Document metadata dictionary
356
+
357
+ Returns:
358
+ True if all required tile metadata is present
359
+ """
360
+ required_fields = ['num_tiles', 'tile_rows', 'tile_cols', 'resized_width', 'resized_height']
361
+ return all(metadata.get(field) is not None for field in required_fields)
362
+
363
+
364
+ def get_saliency_metadata_summary(metadata: dict) -> str:
365
+ """
366
+ Get a summary of saliency-related metadata for display.
367
+
368
+ Args:
369
+ metadata: Document metadata dictionary
370
+
371
+ Returns:
372
+ Human-readable summary string
373
+ """
374
+ num_tiles = metadata.get('num_tiles', 'N/A')
375
+ tile_rows = metadata.get('tile_rows', 'N/A')
376
+ tile_cols = metadata.get('tile_cols', 'N/A')
377
+ patches_per_tile = metadata.get('patches_per_tile', 64)
378
+
379
+ if all(v != 'N/A' for v in [num_tiles, tile_rows, tile_cols]):
380
+ return f"{tile_rows}Γ—{tile_cols} tiles ({num_tiles} total), {patches_per_tile} patches/tile"
381
+ else:
382
+ return "Tile metadata not available"
383
+
src/ui_components/visual_documents.py CHANGED
@@ -2,13 +2,18 @@
2
  Visual Document Display Components
3
 
4
  UI components for displaying visual search results with enhanced metadata.
 
5
  """
6
 
7
  import streamlit as st
8
  import pandas as pd
9
- from typing import List, Any, Dict
 
 
10
  from collections import Counter
11
 
 
 
12
 
13
  def display_visual_document_statistics(sources: List[Any]) -> None:
14
  """
@@ -124,16 +129,37 @@ def display_visual_document_statistics(sources: List[Any]) -> None:
124
  st.markdown('</div>', unsafe_allow_html=True)
125
 
126
 
127
- def display_visual_document_details(sources: List[Any], show_images: bool = False) -> None:
 
 
 
 
 
 
 
 
 
 
128
  """
129
  Display detailed information for each visual search result.
130
 
131
  Args:
132
  sources: List of VisualSearchResult objects
133
  show_images: Whether to display document images (from Cloudinary)
 
 
 
 
 
 
 
134
  """
135
  st.markdown("### πŸ“„ Document Details")
136
 
 
 
 
 
137
  for i, doc in enumerate(sources):
138
  metadata = getattr(doc, 'metadata', {})
139
 
@@ -160,6 +186,13 @@ def display_visual_document_details(sources: List[Any], show_images: bool = Fals
160
  resized_url = metadata.get('resized_url')
161
  page_url = metadata.get('page') # Fallback
162
 
 
 
 
 
 
 
 
163
  # Build title
164
  score_text = f" (Score: {score:.3f})"
165
  title = f"πŸ“„ Document {i+1}: {filename[:50]}...{score_text}"
@@ -228,22 +261,104 @@ def display_visual_document_details(sources: List[Any], show_images: bool = Fals
228
  st.markdown(f"**Resized (for embeddings):** [{resized_url}]({resized_url})")
229
 
230
  with col_image:
231
- st.markdown("### πŸ“„ Document Page")
232
 
233
- # Display image (if available and requested)
234
- if show_images:
235
- # Use ORIGINAL image (not resized) for display
236
- image_url = original_url or resized_url or page_url
 
 
 
 
 
 
 
 
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  if image_url and isinstance(image_url, str) and image_url.startswith('http'):
239
  try:
240
- # Use width parameter for medium-sized image
241
- st.image(image_url, width=750, caption=f"Page {page_number}")
242
  except Exception as e:
243
  st.error(f"Failed to load image: {e}")
244
  else:
245
  st.info("No image URL available")
246
- else:
247
  st.info("Enable image display in settings to view document pages")
248
 
249
 
@@ -251,6 +366,13 @@ def display_visual_search_results(
251
  sources: List[Any],
252
  show_statistics: bool = True,
253
  show_images: bool = False,
 
 
 
 
 
 
 
254
  max_display: int = 20
255
  ) -> None:
256
  """
@@ -260,6 +382,13 @@ def display_visual_search_results(
260
  sources: List of VisualSearchResult objects
261
  show_statistics: Whether to show statistics
262
  show_images: Whether to show document images
 
 
 
 
 
 
 
263
  max_display: Maximum number of documents to display in detail
264
  """
265
  if not sources:
@@ -277,6 +406,10 @@ def display_visual_search_results(
277
  if len(unique_filenames) < len(sources):
278
  st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
279
 
 
 
 
 
280
  # Show statistics
281
  if show_statistics:
282
  display_visual_document_statistics(sources)
@@ -287,7 +420,17 @@ def display_visual_search_results(
287
  if len(sources) > max_display:
288
  st.warning(f"⚠️ Showing top {max_display} of {len(sources)} results")
289
 
290
- display_visual_document_details(display_sources, show_images=show_images)
 
 
 
 
 
 
 
 
 
 
291
 
292
  if len(sources) > max_display:
293
  st.info(f"πŸ’‘ {len(sources) - max_display} more results not shown")
 
2
  Visual Document Display Components
3
 
4
  UI components for displaying visual search results with enhanced metadata.
5
+ Includes saliency map visualization for tile-aware ColPali embeddings.
6
  """
7
 
8
  import streamlit as st
9
  import pandas as pd
10
+ import numpy as np
11
+ import logging
12
+ from typing import List, Any, Dict, Optional
13
  from collections import Counter
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
 
18
  def display_visual_document_statistics(sources: List[Any]) -> None:
19
  """
 
129
  st.markdown('</div>', unsafe_allow_html=True)
130
 
131
 
132
+ def display_visual_document_details(
133
+ sources: List[Any],
134
+ show_images: bool = False,
135
+ show_saliency: bool = False,
136
+ qdrant_client: Any = None,
137
+ collection_name: str = None,
138
+ query_embedding: Optional[np.ndarray] = None,
139
+ saliency_alpha: float = 0.4,
140
+ saliency_colormap: str = 'hot',
141
+ saliency_threshold: int = 50
142
+ ) -> None:
143
  """
144
  Display detailed information for each visual search result.
145
 
146
  Args:
147
  sources: List of VisualSearchResult objects
148
  show_images: Whether to display document images (from Cloudinary)
149
+ show_saliency: Whether to generate and display saliency maps
150
+ qdrant_client: Qdrant client (required for saliency)
151
+ collection_name: Qdrant collection name (required for saliency)
152
+ query_embedding: Query embedding for saliency computation
153
+ saliency_alpha: Saliency overlay transparency (0.0-1.0)
154
+ saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
155
+ saliency_threshold: Threshold percentile for saliency (default: 50)
156
  """
157
  st.markdown("### πŸ“„ Document Details")
158
 
159
+ # Import saliency functions if needed
160
+ if show_saliency:
161
+ from .saliency import generate_tile_aware_saliency, can_generate_saliency
162
+
163
  for i, doc in enumerate(sources):
164
  metadata = getattr(doc, 'metadata', {})
165
 
 
186
  resized_url = metadata.get('resized_url')
187
  page_url = metadata.get('page') # Fallback
188
 
189
+ # Get point_id for saliency (check doc.id first, then metadata)
190
+ point_id = getattr(doc, 'id', None) or metadata.get('point_id') or metadata.get('_id')
191
+
192
+ # Debug logging for saliency
193
+ if show_saliency:
194
+ logger.debug(f"Doc {i+1}: point_id={point_id}, has_tiles={metadata.get('num_tiles') is not None}")
195
+
196
  # Build title
197
  score_text = f" (Score: {score:.3f})"
198
  title = f"πŸ“„ Document {i+1}: {filename[:50]}...{score_text}"
 
261
  st.markdown(f"**Resized (for embeddings):** [{resized_url}]({resized_url})")
262
 
263
  with col_image:
264
+ st.markdown("### πŸ“Έ Document Page")
265
 
266
+ # Get original image URL
267
+ image_url = original_url or resized_url or page_url
268
+
269
+ # Check if we should generate saliency (show BOTH original and saliency side by side)
270
+ if show_saliency and show_images:
271
+ # Check if we have all requirements for saliency
272
+ has_client = qdrant_client is not None
273
+ has_collection = collection_name is not None
274
+ has_query = query_embedding is not None
275
+ has_point_id = point_id is not None
276
+ has_tile_metadata = can_generate_saliency(metadata)
277
+
278
+ can_saliency = has_client and has_collection and has_query and has_point_id and has_tile_metadata
279
 
280
+ if not can_saliency:
281
+ missing = []
282
+ if not has_client: missing.append("qdrant_client")
283
+ if not has_collection: missing.append("collection_name")
284
+ if not has_query: missing.append("query_embedding")
285
+ if not has_point_id: missing.append("point_id")
286
+ if not has_tile_metadata: missing.append("tile_metadata")
287
+ logger.warning(f"Doc {i+1}: Saliency unavailable, missing: {missing}")
288
+
289
+ if can_saliency:
290
+ # Create two columns: Original image | Saliency map
291
+ img_col1, img_col2 = st.columns(2)
292
+
293
+ # Left column: Original image (ALWAYS show)
294
+ with img_col1:
295
+ st.markdown("**πŸ“„ Original**")
296
+ if image_url and isinstance(image_url, str) and image_url.startswith('http'):
297
+ try:
298
+ st.image(image_url, use_container_width=True, caption=f"Page {page_number}")
299
+ except Exception as e:
300
+ st.error(f"Failed to load image: {e}")
301
+ else:
302
+ st.info("No image URL available")
303
+
304
+ # Right column: Saliency map
305
+ with img_col2:
306
+ st.markdown("**πŸ”₯ Saliency Map**")
307
+ try:
308
+ with st.spinner(f"Generating..."):
309
+ # Convert query embedding if needed
310
+ query_emb = query_embedding
311
+ if hasattr(query_emb, 'cpu'):
312
+ query_emb = query_emb.cpu().float().numpy()
313
+ if query_emb.ndim == 3:
314
+ query_emb = query_emb.squeeze(0) # Remove batch dimension
315
+
316
+ logger.info(f"πŸ”₯ Generating saliency for doc {i+1}: point_id={point_id}, colormap={saliency_colormap}")
317
+
318
+ saliency_img = generate_tile_aware_saliency(
319
+ qdrant_client=qdrant_client,
320
+ collection_name=collection_name,
321
+ point_id=point_id,
322
+ query_embedding=query_emb,
323
+ alpha=saliency_alpha,
324
+ colormap=saliency_colormap,
325
+ threshold_percentile=saliency_threshold
326
+ )
327
+
328
+ if saliency_img:
329
+ st.image(saliency_img, use_container_width=True, caption=f"Relevance heatmap")
330
+ logger.info(f"βœ… Saliency map displayed for doc {i+1}")
331
+ else:
332
+ logger.warning(f"Saliency generation returned None for doc {i+1}")
333
+ st.caption("_Could not generate saliency map_")
334
+ except Exception as e:
335
+ logger.error(f"Saliency generation failed for doc {i+1}: {e}")
336
+ import traceback
337
+ logger.debug(traceback.format_exc())
338
+ st.warning(f"⚠️ Failed: {str(e)[:80]}")
339
+ else:
340
+ # Can't generate saliency - just show original image
341
+ if image_url and isinstance(image_url, str) and image_url.startswith('http'):
342
+ try:
343
+ st.image(image_url, width=700, caption=f"Page {page_number}")
344
+ except Exception as e:
345
+ st.error(f"Failed to load image: {e}")
346
+
347
+ if not has_tile_metadata:
348
+ st.caption("_Saliency unavailable: missing tile metadata_")
349
+ elif not has_point_id:
350
+ st.caption("_Saliency unavailable: missing point_id_")
351
+
352
+ # Display original image only (no saliency requested)
353
+ elif show_images:
354
  if image_url and isinstance(image_url, str) and image_url.startswith('http'):
355
  try:
356
+ st.image(image_url, width=700, caption=f"Page {page_number}")
 
357
  except Exception as e:
358
  st.error(f"Failed to load image: {e}")
359
  else:
360
  st.info("No image URL available")
361
+ elif not show_images:
362
  st.info("Enable image display in settings to view document pages")
363
 
364
 
 
366
  sources: List[Any],
367
  show_statistics: bool = True,
368
  show_images: bool = False,
369
+ show_saliency: bool = False,
370
+ qdrant_client: Any = None,
371
+ collection_name: str = None,
372
+ query_embedding: Optional[np.ndarray] = None,
373
+ saliency_alpha: float = 0.4,
374
+ saliency_colormap: str = 'hot',
375
+ saliency_threshold: int = 50,
376
  max_display: int = 20
377
  ) -> None:
378
  """
 
382
  sources: List of VisualSearchResult objects
383
  show_statistics: Whether to show statistics
384
  show_images: Whether to show document images
385
+ show_saliency: Whether to generate and display saliency maps
386
+ qdrant_client: Qdrant client (required for saliency)
387
+ collection_name: Qdrant collection name (required for saliency)
388
+ query_embedding: Query embedding for saliency computation
389
+ saliency_alpha: Saliency overlay transparency (0.0-1.0)
390
+ saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
391
+ saliency_threshold: Threshold percentile for saliency (default: 50)
392
  max_display: Maximum number of documents to display in detail
393
  """
394
  if not sources:
 
406
  if len(unique_filenames) < len(sources):
407
  st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
408
 
409
+ # Show saliency info if enabled
410
+ if show_saliency:
411
+ st.info(f"πŸ”₯ **Saliency Maps Enabled**: Showing which image regions are most relevant to your query (using '{saliency_colormap}' colormap)")
412
+
413
  # Show statistics
414
  if show_statistics:
415
  display_visual_document_statistics(sources)
 
420
  if len(sources) > max_display:
421
  st.warning(f"⚠️ Showing top {max_display} of {len(sources)} results")
422
 
423
+ display_visual_document_details(
424
+ display_sources,
425
+ show_images=show_images,
426
+ show_saliency=show_saliency,
427
+ qdrant_client=qdrant_client,
428
+ collection_name=collection_name,
429
+ query_embedding=query_embedding,
430
+ saliency_alpha=saliency_alpha,
431
+ saliency_colormap=saliency_colormap,
432
+ saliency_threshold=saliency_threshold
433
+ )
434
 
435
  if len(sources) > max_display:
436
  st.info(f"πŸ’‘ {len(sources) - max_display} more results not shown")