akryldigital commited on
Commit
865589a
Β·
verified Β·
1 Parent(s): b0fe395

add saliency map components

Browse files
src/agents/__init__.py CHANGED
@@ -1,21 +1,46 @@
1
  """
2
- Agent modules for chatbot implementations
3
- """
4
 
5
- from .gemini_chatbot import get_gemini_chatbot
6
- from .visual_chatbot import get_visual_chatbot
7
- from .multi_agent_chatbot import get_multi_agent_chatbot
8
- from .smart_chatbot import get_chatbot as get_smart_chatbot
9
- from .visual_multi_agent_chatbot import get_visual_multi_agent_chatbot
10
 
11
- # Alias for backward compatibility
12
- get_visual_chatbot_v2 = get_visual_multi_agent_chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  __all__ = [
15
- "get_smart_chatbot",
16
- "get_multi_agent_chatbot",
17
- "get_gemini_chatbot",
18
- "get_visual_chatbot",
19
- "get_visual_multi_agent_chatbot",
20
- "get_visual_chatbot_v2"
 
 
 
 
 
 
 
 
21
  ]
 
 
 
 
1
  """
2
+ UI Components Module
 
3
 
4
+ This module contains UI-related components including styles, visualizations,
5
+ and utility functions for the Streamlit application.
6
+ """
 
 
7
 
8
+ from .styles import get_custom_css
9
+ from .components import (
10
+ display_chunk_statistics_charts,
11
+ display_chunk_statistics_table
12
+ )
13
+ from .utils import extract_chunk_statistics
14
+ from .visual_documents import (
15
+ display_visual_search_results,
16
+ display_visual_document_statistics,
17
+ display_visual_document_details
18
+ )
19
+ from .saliency import (
20
+ generate_tile_aware_saliency,
21
+ can_generate_saliency,
22
+ get_saliency_metadata_summary,
23
+ DEFAULT_ALPHA,
24
+ DEFAULT_COLORMAP,
25
+ DEFAULT_THRESHOLD_PERCENTILE
26
+ )
27
 
28
  __all__ = [
29
+ "get_custom_css",
30
+ "display_chunk_statistics_charts",
31
+ "display_chunk_statistics_table",
32
+ "extract_chunk_statistics",
33
+ "display_visual_search_results",
34
+ "display_visual_document_statistics",
35
+ "display_visual_document_details",
36
+ # Saliency functions
37
+ "generate_tile_aware_saliency",
38
+ "can_generate_saliency",
39
+ "get_saliency_metadata_summary",
40
+ "DEFAULT_ALPHA",
41
+ "DEFAULT_COLORMAP",
42
+ "DEFAULT_THRESHOLD_PERCENTILE"
43
  ]
44
+
45
+
46
+
src/agents/saliency.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Saliency Map Generation for Visual RAG
3
+
4
+ This module provides saliency map generation for visual document search results.
5
+ It implements the tile-aware ColBERT MaxSim strategy for accurate visualization
6
+ of which image regions are relevant to a query.
7
+
8
+ Key features:
9
+ 1. Tile-aware architecture (understands 4Γ—3 grid of 512Γ—512 tiles)
10
+ 2. Excludes global tile for cleaner saliency
11
+ 3. Maps patches to resized image, then scales to original
12
+ 4. Uses "hot" colormap by default for better visibility
13
+ """
14
+
15
+ import logging
16
+ from typing import Any, Optional, Tuple
17
+ from io import BytesIO
18
+ from base64 import b64decode
19
+
20
+ import numpy as np
21
+ import requests
22
+ from PIL import Image
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Default saliency configuration
27
+ DEFAULT_ALPHA = 0.4
28
+ DEFAULT_COLORMAP = 'hot' # Better visibility than 'jet'
29
+ DEFAULT_THRESHOLD_PERCENTILE = 50
30
+
31
+
32
+ def convert_to_numpy(embedding, dtype: np.dtype = np.float32) -> np.ndarray:
33
+ """
34
+ Convert embedding to numpy array with proper dtype.
35
+
36
+ Handles:
37
+ - Lists
38
+ - PyTorch tensors (including bfloat16)
39
+ - NumPy arrays
40
+ """
41
+ try:
42
+ import torch
43
+ if isinstance(embedding, torch.Tensor):
44
+ if embedding.dtype == torch.bfloat16:
45
+ embedding = embedding.cpu().float()
46
+ else:
47
+ embedding = embedding.cpu()
48
+ embedding = embedding.numpy()
49
+ except ImportError:
50
+ pass
51
+
52
+ return np.array(embedding, dtype=dtype)
53
+
54
+
55
+ def validate_embeddings(
56
+ doc_embedding: np.ndarray,
57
+ query_embedding: np.ndarray
58
+ ) -> Tuple[bool, str]:
59
+ """Validate embedding shapes and types."""
60
+ if doc_embedding.ndim != 2:
61
+ return False, f"Document embedding must be 2D, got {doc_embedding.ndim}D"
62
+
63
+ if query_embedding.ndim != 2:
64
+ return False, f"Query embedding must be 2D, got {query_embedding.ndim}D"
65
+
66
+ if doc_embedding.shape[1] != query_embedding.shape[1]:
67
+ return False, f"Embedding dimensions don't match: doc={doc_embedding.shape[1]}, query={query_embedding.shape[1]}"
68
+
69
+ if np.any(np.isnan(doc_embedding)) or np.any(np.isinf(doc_embedding)):
70
+ return False, "Document embedding contains NaN or Inf values"
71
+
72
+ if np.any(np.isnan(query_embedding)) or np.any(np.isinf(query_embedding)):
73
+ return False, "Query embedding contains NaN or Inf values"
74
+
75
+ return True, ""
76
+
77
+
78
+ def compute_maxsim_scores(
79
+ doc_embedding: np.ndarray,
80
+ query_embedding: np.ndarray,
81
+ normalize: bool = True
82
+ ) -> np.ndarray:
83
+ """
84
+ Compute MaxSim scores for ColBERT-style late interaction.
85
+
86
+ MaxSim: For each document patch, find the maximum similarity
87
+ across all query patches.
88
+ """
89
+ if normalize:
90
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
91
+ query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8)
92
+ else:
93
+ doc_norm = doc_embedding
94
+ query_norm = query_embedding
95
+
96
+ similarity_matrix = np.dot(doc_norm, query_norm.T)
97
+ patch_scores = np.max(similarity_matrix, axis=1)
98
+
99
+ return patch_scores
100
+
101
+
102
+ def normalize_scores(
103
+ score_grid: np.ndarray,
104
+ threshold_percentile: int = None
105
+ ) -> np.ndarray:
106
+ """Normalize score grid to 0-1 range with optional thresholding."""
107
+ score_min = score_grid.min()
108
+ score_max = score_grid.max()
109
+
110
+ if score_max - score_min < 1e-8:
111
+ logger.warning("All scores are identical, returning zeros")
112
+ return np.zeros_like(score_grid, dtype=np.float32)
113
+
114
+ score_grid_norm = (score_grid - score_min) / (score_max - score_min)
115
+
116
+ if threshold_percentile is not None:
117
+ score_threshold = np.percentile(score_grid, threshold_percentile)
118
+ mask = score_grid < score_threshold
119
+ score_grid_norm[mask] = 0.0
120
+
121
+ visible_count = np.sum(~mask)
122
+ total_count = score_grid.size
123
+ logger.debug(f"Threshold: {score_threshold:.3f} ({threshold_percentile}th percentile)")
124
+ logger.debug(f"Visible patches: {visible_count} / {total_count}")
125
+
126
+ return score_grid_norm
127
+
128
+
129
+ def download_image(page_url: str) -> Optional[Image.Image]:
130
+ """Download image from URL or decode from data URI."""
131
+ try:
132
+ if page_url.startswith(("http://", "https://")):
133
+ resp = requests.get(page_url, timeout=15)
134
+ resp.raise_for_status()
135
+ image = Image.open(BytesIO(resp.content))
136
+ elif page_url.startswith("data:image"):
137
+ b64_data = page_url.split(",", 1)[1]
138
+ image = Image.open(BytesIO(b64decode(b64_data)))
139
+ else:
140
+ image = Image.open(page_url)
141
+
142
+ if image.mode != "RGB":
143
+ image = image.convert("RGB")
144
+
145
+ return image
146
+
147
+ except Exception as e:
148
+ logger.error(f"Failed to load image: {e}")
149
+ return None
150
+
151
+
152
+ def apply_colormap_and_blend(
153
+ score_grid: np.ndarray,
154
+ image: Image.Image,
155
+ alpha: float = DEFAULT_ALPHA,
156
+ colormap: str = DEFAULT_COLORMAP
157
+ ) -> Image.Image:
158
+ """Apply colormap to scores and blend with original image."""
159
+ from matplotlib import cm
160
+
161
+ img_width, img_height = image.size
162
+
163
+ # Resize heatmap to image size
164
+ heatmap_pil = Image.fromarray((score_grid * 255).astype(np.uint8), mode='L')
165
+ heatmap_resized = heatmap_pil.resize((img_width, img_height), Image.BILINEAR)
166
+ heatmap_array = np.array(heatmap_resized) / 255.0
167
+
168
+ # Apply colormap
169
+ cmap = cm.get_cmap(colormap)
170
+ heatmap_colored = cmap(heatmap_array)[:, :, :3]
171
+ heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
172
+ heatmap_img = Image.fromarray(heatmap_colored, mode='RGB')
173
+
174
+ # Blend with original image
175
+ overlay = Image.blend(image, heatmap_img, alpha=alpha)
176
+
177
+ return overlay
178
+
179
+
180
+ def generate_tile_aware_saliency(
181
+ qdrant_client: Any,
182
+ collection_name: str,
183
+ point_id: str,
184
+ query_embedding: np.ndarray,
185
+ alpha: float = DEFAULT_ALPHA,
186
+ colormap: str = DEFAULT_COLORMAP,
187
+ threshold_percentile: int = DEFAULT_THRESHOLD_PERCENTILE
188
+ ) -> Optional[Image.Image]:
189
+ """
190
+ Generate tile-aware saliency map for a document-query pair.
191
+
192
+ This is the main function to call for saliency generation.
193
+
194
+ Args:
195
+ qdrant_client: Qdrant client instance
196
+ collection_name: Name of the collection
197
+ point_id: ID of the document point
198
+ query_embedding: Query multi-vector embedding [num_query_patches, dim]
199
+ alpha: Overlay transparency (0.0-1.0)
200
+ colormap: Matplotlib colormap name (default: 'hot')
201
+ threshold_percentile: Hide patches below this percentile (default: 50)
202
+
203
+ Returns:
204
+ PIL Image with saliency overlay, or None if generation fails
205
+ """
206
+ try:
207
+ # Step 1: Fetch full multi-vector embedding AND payload
208
+ logger.debug(f"Fetching point {point_id} with tile metadata from {collection_name}")
209
+ points = qdrant_client.retrieve(
210
+ collection_name=collection_name,
211
+ ids=[point_id],
212
+ with_vectors=["initial"],
213
+ with_payload=True
214
+ )
215
+
216
+ if not points or len(points) == 0:
217
+ logger.error(f"Point {point_id} not found in collection")
218
+ return None
219
+
220
+ point = points[0]
221
+ doc_vector = point.vector.get("initial")
222
+ payload = point.payload
223
+
224
+ if doc_vector is None:
225
+ logger.error("No 'initial' vector found for point")
226
+ return None
227
+
228
+ # Step 2: Get tile structure from payload
229
+ num_tiles = payload.get('num_tiles')
230
+ tile_rows = payload.get('tile_rows')
231
+ tile_cols = payload.get('tile_cols')
232
+ patches_per_tile = payload.get('patches_per_tile', 64)
233
+
234
+ resized_width = payload.get('resized_width')
235
+ resized_height = payload.get('resized_height')
236
+ resized_url = payload.get('resized_url') or payload.get('page')
237
+
238
+ original_width = payload.get('original_width')
239
+ original_height = payload.get('original_height')
240
+
241
+ if not all([num_tiles, tile_rows, tile_cols, resized_width, resized_height]):
242
+ logger.warning("Missing tile metadata - cannot generate saliency")
243
+ return None
244
+
245
+ logger.info(f"βœ… Tile structure: {tile_rows}Γ—{tile_cols} tiles, {patches_per_tile} patches/tile")
246
+ logger.info(f"βœ… Resized image: {resized_width}Γ—{resized_height}")
247
+ logger.info(f"βœ… Original image: {original_width}Γ—{original_height}")
248
+
249
+ # Step 3: Convert embeddings
250
+ doc_embedding = convert_to_numpy(doc_vector)
251
+ query_emb = convert_to_numpy(query_embedding)
252
+
253
+ is_valid, error_msg = validate_embeddings(doc_embedding, query_emb)
254
+ if not is_valid:
255
+ logger.error(f"Embedding validation failed: {error_msg}")
256
+ return None
257
+
258
+ logger.info(f"Document embedding: {doc_embedding.shape}")
259
+ logger.info(f"Query embedding: {query_emb.shape}")
260
+
261
+ # Step 4: Separate tile embeddings from global tile
262
+ total_patches = num_tiles * patches_per_tile
263
+ tile_patches = total_patches - patches_per_tile # Exclude global
264
+
265
+ if len(doc_embedding) < total_patches:
266
+ logger.warning(f"Embedding size mismatch: got {len(doc_embedding)}, expected {total_patches}")
267
+ tile_embeddings = doc_embedding[:tile_patches] if len(doc_embedding) > tile_patches else doc_embedding
268
+ else:
269
+ tile_embeddings = doc_embedding[:tile_patches]
270
+
271
+ logger.info(f"Using {len(tile_embeddings)} tile patches (excluding global)")
272
+
273
+ # Step 5: Compute MaxSim scores
274
+ patch_scores = compute_maxsim_scores(tile_embeddings, query_emb, normalize=True)
275
+ logger.info(f"Computed scores for {len(patch_scores)} patches")
276
+
277
+ # Step 6: Reshape patches into tile structure
278
+ patches_per_tile_side = int(np.sqrt(patches_per_tile)) # 8 for 64 patches
279
+
280
+ try:
281
+ num_actual_tiles = tile_rows * tile_cols
282
+
283
+ if len(patch_scores) != num_actual_tiles * patches_per_tile:
284
+ logger.error(f"Patch count mismatch: {len(patch_scores)} patches")
285
+ return None
286
+
287
+ tile_scores = patch_scores.reshape(num_actual_tiles, patches_per_tile)
288
+
289
+ # Reshape each tile's patches to 8Γ—8 grid (F-order)
290
+ tile_grids = []
291
+ for tile_idx in range(num_actual_tiles):
292
+ tile_patch_scores = tile_scores[tile_idx]
293
+ tile_grid = tile_patch_scores.reshape(
294
+ patches_per_tile_side, patches_per_tile_side, order='F'
295
+ )
296
+ tile_grids.append(tile_grid)
297
+
298
+ # Arrange tiles into full image grid
299
+ full_grid_rows = []
300
+ for row_idx in range(tile_rows):
301
+ row_tiles = []
302
+ for col_idx in range(tile_cols):
303
+ tile_idx = row_idx * tile_cols + col_idx
304
+ row_tiles.append(tile_grids[tile_idx])
305
+ row_grid = np.concatenate(row_tiles, axis=1)
306
+ full_grid_rows.append(row_grid)
307
+
308
+ score_grid = np.concatenate(full_grid_rows, axis=0)
309
+
310
+ logger.info(f"βœ… Reconstructed grid: {score_grid.shape} (from {tile_rows}Γ—{tile_cols} tiles)")
311
+
312
+ except ValueError as e:
313
+ logger.error(f"❌ Failed to reshape patches: {e}")
314
+ return None
315
+
316
+ # Step 7: Normalize scores
317
+ score_grid_norm = normalize_scores(score_grid, threshold_percentile=threshold_percentile)
318
+
319
+ # Step 8: Download RESIZED image
320
+ logger.info(f"Downloading resized image from: {resized_url}")
321
+ resized_image = download_image(resized_url)
322
+ if resized_image is None:
323
+ logger.error("Failed to download resized image")
324
+ return None
325
+
326
+ # Step 9: Apply heatmap to resized image
327
+ overlay_resized = apply_colormap_and_blend(
328
+ score_grid_norm, resized_image, alpha, colormap
329
+ )
330
+
331
+ # Step 10: Resize back to original dimensions
332
+ if original_width and original_height:
333
+ overlay_final = overlay_resized.resize(
334
+ (original_width, original_height), Image.BILINEAR
335
+ )
336
+ logger.info(f"βœ… Resized saliency map to original: {original_width}Γ—{original_height}")
337
+ else:
338
+ overlay_final = overlay_resized
339
+
340
+ logger.info(f"βœ… Saliency map generated successfully")
341
+ return overlay_final
342
+
343
+ except Exception as e:
344
+ logger.error(f"Saliency generation failed: {e}")
345
+ import traceback
346
+ logger.debug(traceback.format_exc())
347
+ return None
348
+
349
+
350
+ def can_generate_saliency(metadata: dict) -> bool:
351
+ """
352
+ Check if saliency can be generated for a document based on its metadata.
353
+
354
+ Args:
355
+ metadata: Document metadata dictionary
356
+
357
+ Returns:
358
+ True if all required tile metadata is present
359
+ """
360
+ required_fields = ['num_tiles', 'tile_rows', 'tile_cols', 'resized_width', 'resized_height']
361
+ return all(metadata.get(field) is not None for field in required_fields)
362
+
363
+
364
+ def get_saliency_metadata_summary(metadata: dict) -> str:
365
+ """
366
+ Get a summary of saliency-related metadata for display.
367
+
368
+ Args:
369
+ metadata: Document metadata dictionary
370
+
371
+ Returns:
372
+ Human-readable summary string
373
+ """
374
+ num_tiles = metadata.get('num_tiles', 'N/A')
375
+ tile_rows = metadata.get('tile_rows', 'N/A')
376
+ tile_cols = metadata.get('tile_cols', 'N/A')
377
+ patches_per_tile = metadata.get('patches_per_tile', 64)
378
+
379
+ if all(v != 'N/A' for v in [num_tiles, tile_rows, tile_cols]):
380
+ return f"{tile_rows}Γ—{tile_cols} tiles ({num_tiles} total), {patches_per_tile} patches/tile"
381
+ else:
382
+ return "Tile metadata not available"
383
+
src/agents/visual_documents.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Document Display Components
3
+
4
+ UI components for displaying visual search results with enhanced metadata.
5
+ Includes saliency map visualization for tile-aware ColPali embeddings.
6
+ """
7
+
8
+ import streamlit as st
9
+ import pandas as pd
10
+ import numpy as np
11
+ import logging
12
+ from typing import List, Any, Dict, Optional
13
+ from collections import Counter
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def display_visual_document_statistics(sources: List[Any]) -> None:
19
+ """
20
+ Display statistics for visual search results in a bordered box with tables.
21
+
22
+ Args:
23
+ sources: List of VisualSearchResult objects
24
+ """
25
+ if not sources:
26
+ return
27
+
28
+ # Extract statistics
29
+ filenames = []
30
+ years = []
31
+ sources_list = []
32
+ districts = []
33
+
34
+ for doc in sources:
35
+ metadata = getattr(doc, 'metadata', {})
36
+ filenames.append(metadata.get('filename', 'Unknown'))
37
+ year = metadata.get('year')
38
+ if year:
39
+ years.append(year)
40
+ source = metadata.get('source')
41
+ if source:
42
+ sources_list.append(source)
43
+ district = metadata.get('district')
44
+ if district and district != 'None':
45
+ districts.append(district)
46
+
47
+ # Count unique values
48
+ unique_files = len(set(filenames))
49
+ unique_years = len(set(years))
50
+ unique_sources = len(set(sources_list))
51
+
52
+ # Create bordered container
53
+ with st.container():
54
+ st.markdown("""
55
+ <style>
56
+ .stats-container {
57
+ border: 2px solid #e0e0e0;
58
+ border-radius: 10px;
59
+ padding: 20px;
60
+ margin: 10px 0;
61
+ background-color: #f9f9f9;
62
+ }
63
+ </style>
64
+ """, unsafe_allow_html=True)
65
+
66
+ st.markdown('<div class="stats-container">', unsafe_allow_html=True)
67
+ st.markdown("### πŸ“Š Retrieval Statistics")
68
+
69
+ # Metrics in columns
70
+ col1, col2, col3, col4 = st.columns(4)
71
+
72
+ with col1:
73
+ st.metric("Total Chunks", len(sources))
74
+ with col2:
75
+ st.metric("Unique Files", unique_files)
76
+ with col3:
77
+ st.metric("Unique Years", unique_years if unique_years > 0 else "N/A")
78
+ with col4:
79
+ st.metric("Unique Sources", unique_sources if unique_sources > 0 else "N/A")
80
+
81
+ st.markdown("---")
82
+
83
+ # Distribution tables in columns
84
+ col1, col2, col3, col4 = st.columns(4)
85
+
86
+ with col1:
87
+ # District distribution
88
+ if districts:
89
+ district_counts = Counter(districts)
90
+ st.markdown("**🏘️ Districts**")
91
+ district_df = pd.DataFrame([
92
+ {"District": dist, "Count": count}
93
+ for dist, count in district_counts.most_common(10)
94
+ ])
95
+ st.dataframe(district_df, hide_index=True, use_container_width=True)
96
+
97
+ with col2:
98
+ # Source distribution
99
+ if sources_list:
100
+ source_counts = Counter(sources_list)
101
+ st.markdown("**πŸ›οΈ Sources**")
102
+ source_df = pd.DataFrame([
103
+ {"Source": src, "Count": count}
104
+ for src, count in source_counts.most_common()
105
+ ])
106
+ st.dataframe(source_df, hide_index=True, use_container_width=True)
107
+
108
+ with col3:
109
+ # Year distribution
110
+ if years:
111
+ year_counts = Counter(years)
112
+ st.markdown("**πŸ“… Years**")
113
+ year_df = pd.DataFrame([
114
+ {"Year": year, "Count": count}
115
+ for year, count in sorted(year_counts.items(), reverse=True)
116
+ ])
117
+ st.dataframe(year_df, hide_index=True, use_container_width=True)
118
+
119
+ with col4:
120
+ # File distribution (top 10)
121
+ file_counts = Counter(filenames)
122
+ st.markdown("**πŸ“„ Files**")
123
+ file_df = pd.DataFrame([
124
+ {"File": filename[:30] + "..." if len(filename) > 30 else filename, "Count": count}
125
+ for filename, count in file_counts.most_common(10)
126
+ ])
127
+ st.dataframe(file_df, hide_index=True, use_container_width=True)
128
+
129
+ st.markdown('</div>', unsafe_allow_html=True)
130
+
131
+
132
+ def display_visual_document_details(
133
+ sources: List[Any],
134
+ show_images: bool = False,
135
+ show_saliency: bool = False,
136
+ qdrant_client: Any = None,
137
+ collection_name: str = None,
138
+ query_embedding: Optional[np.ndarray] = None,
139
+ saliency_alpha: float = 0.4,
140
+ saliency_colormap: str = 'hot',
141
+ saliency_threshold: int = 50
142
+ ) -> None:
143
+ """
144
+ Display detailed information for each visual search result.
145
+
146
+ Args:
147
+ sources: List of VisualSearchResult objects
148
+ show_images: Whether to display document images (from Cloudinary)
149
+ show_saliency: Whether to generate and display saliency maps
150
+ qdrant_client: Qdrant client (required for saliency)
151
+ collection_name: Qdrant collection name (required for saliency)
152
+ query_embedding: Query embedding for saliency computation
153
+ saliency_alpha: Saliency overlay transparency (0.0-1.0)
154
+ saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
155
+ saliency_threshold: Threshold percentile for saliency (default: 50)
156
+ """
157
+ st.markdown("### πŸ“„ Document Details")
158
+
159
+ # Import saliency functions if needed
160
+ if show_saliency:
161
+ from .saliency import generate_tile_aware_saliency, can_generate_saliency
162
+
163
+ for i, doc in enumerate(sources):
164
+ metadata = getattr(doc, 'metadata', {})
165
+
166
+ # Get basic metadata
167
+ filename = metadata.get('filename', 'Unknown')
168
+ page_number = metadata.get('page_number', '?')
169
+ year = metadata.get('year', 'Unknown')
170
+ source = metadata.get('source', 'Unknown')
171
+ district = metadata.get('district')
172
+ score = getattr(doc, 'score', 0.0)
173
+
174
+ # Get visual-specific metadata
175
+ num_tiles = metadata.get('num_tiles')
176
+ tile_rows = metadata.get('tile_rows')
177
+ tile_cols = metadata.get('tile_cols')
178
+ num_visual_tokens = metadata.get('num_visual_tokens')
179
+ original_width = metadata.get('original_width')
180
+ original_height = metadata.get('original_height')
181
+ resized_width = metadata.get('resized_width')
182
+ resized_height = metadata.get('resized_height')
183
+
184
+ # Get image URLs
185
+ original_url = metadata.get('original_url')
186
+ resized_url = metadata.get('resized_url')
187
+ page_url = metadata.get('page') # Fallback
188
+
189
+ # Get point_id for saliency (check doc.id first, then metadata)
190
+ point_id = getattr(doc, 'id', None) or metadata.get('point_id') or metadata.get('_id')
191
+
192
+ # Debug logging for saliency
193
+ if show_saliency:
194
+ logger.debug(f"Doc {i+1}: point_id={point_id}, has_tiles={metadata.get('num_tiles') is not None}")
195
+
196
+ # Build title
197
+ score_text = f" (Score: {score:.3f})"
198
+ title = f"πŸ“„ Document {i+1}: {filename[:50]}...{score_text}"
199
+
200
+ with st.expander(title, expanded=(i == 0)): # Expand first result
201
+ # Two-column layout: Metadata (left) and Image (right)
202
+ col_meta, col_image = st.columns([1, 2])
203
+
204
+ with col_meta:
205
+ st.markdown("### πŸ“‹ Metadata")
206
+
207
+ # Basic metadata
208
+ st.write(f"πŸ“„ **File:** {filename}")
209
+ st.write(f"πŸ›οΈ **Source:** {source}")
210
+ st.write(f"πŸ“… **Year:** {year}")
211
+ st.write(f"πŸ“– **Page:** {page_number}")
212
+
213
+ if district and district != 'None':
214
+ st.write(f"πŸ“ **District:** {district}")
215
+
216
+ # Relevance score
217
+ st.markdown("---")
218
+ st.markdown("### 🎯 Relevance")
219
+ score_color = "🟒" if score > 0.7 else "🟑" if score > 0.5 else "πŸ”΄"
220
+ st.markdown(f"**Score:** {score_color} **{score:.3f}**")
221
+
222
+ # Visual metadata (if available)
223
+ if num_tiles or num_visual_tokens:
224
+ st.markdown("---")
225
+ st.markdown("### 🎨 Visual Metadata")
226
+
227
+ if num_tiles:
228
+ st.write(f"πŸ”² **Tiles:** {num_tiles} ({tile_rows}Γ—{tile_cols})")
229
+ if num_visual_tokens:
230
+ st.write(f"πŸ”’ **Visual Tokens:** {num_visual_tokens}")
231
+ if original_width and original_height:
232
+ st.write(f"πŸ“ **Original Size:** {original_width}Γ—{original_height}")
233
+ if resized_width and resized_height:
234
+ st.write(f"πŸ“ **Resized Size:** {resized_width}Γ—{resized_height}")
235
+
236
+ processing_version = metadata.get('processing_version')
237
+ if processing_version:
238
+ st.write(f"βš™οΈ **Processing:** {processing_version}")
239
+
240
+ # Text content preview
241
+ content = getattr(doc, 'page_content', '')
242
+ if content:
243
+ st.markdown("---")
244
+ with st.expander("πŸ“ Extracted Text", expanded=True):
245
+ st.text_area(
246
+ "Content",
247
+ value=content[:500] + ("..." if len(content) > 500 else ""),
248
+ height=150,
249
+ disabled=True,
250
+ label_visibility="collapsed",
251
+ key=f"visual_doc_text_{i}"
252
+ )
253
+ else:
254
+ st.markdown("---")
255
+ st.caption("_No text extracted (image-only page)_")
256
+
257
+ # Show image URLs under text
258
+ if original_url and resized_url:
259
+ with st.expander("πŸ”— Image URLs", expanded=True):
260
+ st.markdown(f"**Original:** [{original_url}]({original_url})")
261
+ st.markdown(f"**Resized (for embeddings):** [{resized_url}]({resized_url})")
262
+
263
+ with col_image:
264
+ st.markdown("### πŸ“Έ Document Page")
265
+
266
+ # Check if we should generate saliency
267
+ saliency_generated = False
268
+
269
+ if show_saliency and show_images:
270
+ # Check if we have all requirements for saliency
271
+ has_client = qdrant_client is not None
272
+ has_collection = collection_name is not None
273
+ has_query = query_embedding is not None
274
+ has_point_id = point_id is not None
275
+ has_tile_metadata = can_generate_saliency(metadata)
276
+
277
+ can_saliency = has_client and has_collection and has_query and has_point_id and has_tile_metadata
278
+
279
+ if not can_saliency:
280
+ missing = []
281
+ if not has_client: missing.append("qdrant_client")
282
+ if not has_collection: missing.append("collection_name")
283
+ if not has_query: missing.append("query_embedding")
284
+ if not has_point_id: missing.append("point_id")
285
+ if not has_tile_metadata: missing.append("tile_metadata")
286
+ logger.warning(f"Doc {i+1}: Saliency unavailable, missing: {missing}")
287
+
288
+ if can_saliency:
289
+ try:
290
+ with st.spinner(f"πŸ”₯ Generating saliency map for Doc {i+1}..."):
291
+ # Convert query embedding if needed
292
+ query_emb = query_embedding
293
+ if hasattr(query_emb, 'cpu'):
294
+ query_emb = query_emb.cpu().float().numpy()
295
+ if query_emb.ndim == 3:
296
+ query_emb = query_emb.squeeze(0) # Remove batch dimension
297
+
298
+ logger.info(f"πŸ”₯ Generating saliency for doc {i+1}: point_id={point_id}, colormap={saliency_colormap}")
299
+
300
+ saliency_img = generate_tile_aware_saliency(
301
+ qdrant_client=qdrant_client,
302
+ collection_name=collection_name,
303
+ point_id=point_id,
304
+ query_embedding=query_emb,
305
+ alpha=saliency_alpha,
306
+ colormap=saliency_colormap,
307
+ threshold_percentile=saliency_threshold
308
+ )
309
+
310
+ if saliency_img:
311
+ # Display saliency map
312
+ st.image(saliency_img, width=700, caption=f"πŸ”₯ Saliency Map - Page {page_number}")
313
+ saliency_generated = True
314
+ logger.info(f"βœ… Saliency map displayed for doc {i+1}")
315
+ else:
316
+ logger.warning(f"Saliency generation returned None for doc {i+1}")
317
+ st.caption("_Saliency map could not be generated_")
318
+ except Exception as e:
319
+ logger.error(f"Saliency generation failed for doc {i+1}: {e}")
320
+ import traceback
321
+ logger.debug(traceback.format_exc())
322
+ st.warning(f"⚠️ Saliency generation failed: {str(e)[:100]}")
323
+ else:
324
+ if not has_tile_metadata:
325
+ st.caption("_Saliency unavailable: missing tile metadata_")
326
+ elif not has_point_id:
327
+ st.caption("_Saliency unavailable: missing point_id_")
328
+
329
+ # Display original image if saliency wasn't generated
330
+ if show_images and not saliency_generated:
331
+ # Use ORIGINAL image (not resized) for display
332
+ image_url = original_url or resized_url or page_url
333
+
334
+ if image_url and isinstance(image_url, str) and image_url.startswith('http'):
335
+ try:
336
+ # Use width parameter for medium-sized image
337
+ st.image(image_url, width=700, caption=f"Page {page_number}")
338
+ except Exception as e:
339
+ st.error(f"Failed to load image: {e}")
340
+ else:
341
+ st.info("No image URL available")
342
+ elif not show_images:
343
+ st.info("Enable image display in settings to view document pages")
344
+
345
+
346
+ def display_visual_search_results(
347
+ sources: List[Any],
348
+ show_statistics: bool = True,
349
+ show_images: bool = False,
350
+ show_saliency: bool = False,
351
+ qdrant_client: Any = None,
352
+ collection_name: str = None,
353
+ query_embedding: Optional[np.ndarray] = None,
354
+ saliency_alpha: float = 0.4,
355
+ saliency_colormap: str = 'hot',
356
+ saliency_threshold: int = 50,
357
+ max_display: int = 20
358
+ ) -> None:
359
+ """
360
+ Display visual search results with statistics and details.
361
+
362
+ Args:
363
+ sources: List of VisualSearchResult objects
364
+ show_statistics: Whether to show statistics
365
+ show_images: Whether to show document images
366
+ show_saliency: Whether to generate and display saliency maps
367
+ qdrant_client: Qdrant client (required for saliency)
368
+ collection_name: Qdrant collection name (required for saliency)
369
+ query_embedding: Query embedding for saliency computation
370
+ saliency_alpha: Saliency overlay transparency (0.0-1.0)
371
+ saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
372
+ saliency_threshold: Threshold percentile for saliency (default: 50)
373
+ max_display: Maximum number of documents to display in detail
374
+ """
375
+ if not sources:
376
+ st.info("No documents were retrieved for the last query.")
377
+ return
378
+
379
+ # Count unique filenames
380
+ unique_filenames = set()
381
+ for doc in sources:
382
+ filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
383
+ unique_filenames.add(filename)
384
+
385
+ st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents:**")
386
+
387
+ if len(unique_filenames) < len(sources):
388
+ st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
389
+
390
+ # Show saliency info if enabled
391
+ if show_saliency:
392
+ st.info(f"πŸ”₯ **Saliency Maps Enabled**: Showing which image regions are most relevant to your query (using '{saliency_colormap}' colormap)")
393
+
394
+ # Show statistics
395
+ if show_statistics:
396
+ display_visual_document_statistics(sources)
397
+ st.markdown("---")
398
+
399
+ # Show detailed results (limit to max_display)
400
+ display_sources = sources[:max_display]
401
+ if len(sources) > max_display:
402
+ st.warning(f"⚠️ Showing top {max_display} of {len(sources)} results")
403
+
404
+ display_visual_document_details(
405
+ display_sources,
406
+ show_images=show_images,
407
+ show_saliency=show_saliency,
408
+ qdrant_client=qdrant_client,
409
+ collection_name=collection_name,
410
+ query_embedding=query_embedding,
411
+ saliency_alpha=saliency_alpha,
412
+ saliency_colormap=saliency_colormap,
413
+ saliency_threshold=saliency_threshold
414
+ )
415
+
416
+ if len(sources) > max_display:
417
+ st.info(f"πŸ’‘ {len(sources) - max_display} more results not shown")
418
+