Spaces:
Sleeping
Sleeping
| """ | |
| Visual Document Display Components | |
| UI components for displaying visual search results with enhanced metadata. | |
| Includes saliency map visualization for tile-aware ColPali embeddings. | |
| """ | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import logging | |
| from typing import List, Any, Dict, Optional | |
| from collections import Counter | |
| logger = logging.getLogger(__name__) | |
| def display_visual_document_statistics(sources: List[Any]) -> None: | |
| """ | |
| Display statistics for visual search results in a bordered box with tables. | |
| Args: | |
| sources: List of VisualSearchResult objects | |
| """ | |
| if not sources: | |
| return | |
| # Extract statistics | |
| filenames = [] | |
| years = [] | |
| sources_list = [] | |
| districts = [] | |
| for doc in sources: | |
| metadata = getattr(doc, 'metadata', {}) | |
| filenames.append(metadata.get('filename', 'Unknown')) | |
| year = metadata.get('year') | |
| if year: | |
| years.append(year) | |
| source = metadata.get('source') | |
| if source: | |
| sources_list.append(source) | |
| district = metadata.get('district') | |
| if district and district != 'None': | |
| districts.append(district) | |
| # Count unique values | |
| unique_files = len(set(filenames)) | |
| unique_years = len(set(years)) | |
| unique_sources = len(set(sources_list)) | |
| # Create bordered container | |
| with st.container(): | |
| st.markdown(""" | |
| <style> | |
| .stats-container { | |
| border: 2px solid #e0e0e0; | |
| border-radius: 10px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| background-color: #f9f9f9; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown('<div class="stats-container">', unsafe_allow_html=True) | |
| st.markdown("### π Retrieval Statistics") | |
| # Metrics in columns | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Total Chunks", len(sources)) | |
| with col2: | |
| st.metric("Unique Files", unique_files) | |
| with col3: | |
| st.metric("Unique Years", unique_years if unique_years > 0 else "N/A") | |
| with col4: | |
| st.metric("Unique Sources", unique_sources if unique_sources > 0 else "N/A") | |
| st.markdown("---") | |
| # Distribution tables in columns | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| # District distribution | |
| if districts: | |
| district_counts = Counter(districts) | |
| st.markdown("**ποΈ Districts**") | |
| district_df = pd.DataFrame([ | |
| {"District": dist, "Count": count} | |
| for dist, count in district_counts.most_common(10) | |
| ]) | |
| st.dataframe(district_df, hide_index=True, use_container_width=True) | |
| with col2: | |
| # Source distribution | |
| if sources_list: | |
| source_counts = Counter(sources_list) | |
| st.markdown("**ποΈ Sources**") | |
| source_df = pd.DataFrame([ | |
| {"Source": src, "Count": count} | |
| for src, count in source_counts.most_common() | |
| ]) | |
| st.dataframe(source_df, hide_index=True, use_container_width=True) | |
| with col3: | |
| # Year distribution | |
| if years: | |
| year_counts = Counter(years) | |
| st.markdown("**π Years**") | |
| year_df = pd.DataFrame([ | |
| {"Year": year, "Count": count} | |
| for year, count in sorted(year_counts.items(), reverse=True) | |
| ]) | |
| st.dataframe(year_df, hide_index=True, use_container_width=True) | |
| with col4: | |
| # File distribution (top 10) | |
| file_counts = Counter(filenames) | |
| st.markdown("**π Files**") | |
| file_df = pd.DataFrame([ | |
| {"File": filename[:30] + "..." if len(filename) > 30 else filename, "Count": count} | |
| for filename, count in file_counts.most_common(10) | |
| ]) | |
| st.dataframe(file_df, hide_index=True, use_container_width=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| def display_visual_document_details( | |
| sources: List[Any], | |
| show_images: bool = False, | |
| show_saliency: bool = False, | |
| qdrant_client: Any = None, | |
| collection_name: str = None, | |
| query_embedding: Optional[np.ndarray] = None, | |
| saliency_alpha: float = 0.4, | |
| saliency_colormap: str = 'hot', | |
| saliency_threshold: int = 50 | |
| ) -> None: | |
| """ | |
| Display detailed information for each visual search result. | |
| Args: | |
| sources: List of VisualSearchResult objects | |
| show_images: Whether to display document images (from Cloudinary) | |
| show_saliency: Whether to generate and display saliency maps | |
| qdrant_client: Qdrant client (required for saliency) | |
| collection_name: Qdrant collection name (required for saliency) | |
| query_embedding: Query embedding for saliency computation | |
| saliency_alpha: Saliency overlay transparency (0.0-1.0) | |
| saliency_colormap: Matplotlib colormap for saliency (default: 'hot') | |
| saliency_threshold: Threshold percentile for saliency (default: 50) | |
| """ | |
| st.markdown("### π Document Details") | |
| # Import saliency functions if needed | |
| if show_saliency: | |
| from .saliency import generate_tile_aware_saliency, can_generate_saliency | |
| for i, doc in enumerate(sources): | |
| metadata = getattr(doc, 'metadata', {}) | |
| # Get basic metadata | |
| filename = metadata.get('filename', 'Unknown') | |
| page_number = metadata.get('page_number', '?') | |
| year = metadata.get('year', 'Unknown') | |
| source = metadata.get('source', 'Unknown') | |
| district = metadata.get('district') | |
| score = getattr(doc, 'score', 0.0) | |
| # Get visual-specific metadata | |
| num_tiles = metadata.get('num_tiles') | |
| tile_rows = metadata.get('tile_rows') | |
| tile_cols = metadata.get('tile_cols') | |
| num_visual_tokens = metadata.get('num_visual_tokens') | |
| original_width = metadata.get('original_width') | |
| original_height = metadata.get('original_height') | |
| resized_width = metadata.get('resized_width') | |
| resized_height = metadata.get('resized_height') | |
| # Get image URLs | |
| original_url = metadata.get('original_url') | |
| resized_url = metadata.get('resized_url') | |
| page_url = metadata.get('page') # Fallback | |
| # Get point_id for saliency (check doc.id first, then metadata) | |
| point_id = getattr(doc, 'id', None) or metadata.get('point_id') or metadata.get('_id') | |
| # Debug logging for saliency | |
| if show_saliency: | |
| logger.debug(f"Doc {i+1}: point_id={point_id}, has_tiles={metadata.get('num_tiles') is not None}") | |
| # Build title | |
| score_text = f" (Score: {score:.3f})" | |
| title = f"π Document {i+1}: {filename[:50]}...{score_text}" | |
| with st.expander(title, expanded=(i == 0)): # Expand first result | |
| # Two-column layout: Metadata (left) and Image (right) | |
| col_meta, col_image = st.columns([1, 2]) | |
| with col_meta: | |
| st.markdown("### π Metadata") | |
| # Basic metadata | |
| st.write(f"π **File:** {filename}") | |
| st.write(f"ποΈ **Source:** {source}") | |
| st.write(f"π **Year:** {year}") | |
| st.write(f"π **Page:** {page_number}") | |
| if district and district != 'None': | |
| st.write(f"π **District:** {district}") | |
| # Relevance score | |
| st.markdown("---") | |
| st.markdown("### π― Relevance") | |
| score_color = "π’" if score > 0.7 else "π‘" if score > 0.5 else "π΄" | |
| st.markdown(f"**Score:** {score_color} **{score:.3f}**") | |
| # Visual metadata (if available) | |
| if num_tiles or num_visual_tokens: | |
| st.markdown("---") | |
| st.markdown("### π¨ Visual Metadata") | |
| if num_tiles: | |
| st.write(f"π² **Tiles:** {num_tiles} ({tile_rows}Γ{tile_cols})") | |
| if num_visual_tokens: | |
| st.write(f"π’ **Visual Tokens:** {num_visual_tokens}") | |
| if original_width and original_height: | |
| st.write(f"π **Original Size:** {original_width}Γ{original_height}") | |
| if resized_width and resized_height: | |
| st.write(f"π **Resized Size:** {resized_width}Γ{resized_height}") | |
| processing_version = metadata.get('processing_version') | |
| if processing_version: | |
| st.write(f"βοΈ **Processing:** {processing_version}") | |
| # Text content preview | |
| content = getattr(doc, 'page_content', '') | |
| if content: | |
| st.markdown("---") | |
| with st.expander("π Extracted Text", expanded=True): | |
| st.text_area( | |
| "Content", | |
| value=content[:500] + ("..." if len(content) > 500 else ""), | |
| height=150, | |
| disabled=True, | |
| label_visibility="collapsed", | |
| key=f"visual_doc_text_{i}" | |
| ) | |
| else: | |
| st.markdown("---") | |
| st.caption("_No text extracted (image-only page)_") | |
| # Show image URLs under text | |
| if original_url and resized_url: | |
| with st.expander("π Image URLs", expanded=True): | |
| st.markdown(f"**Original:** [{original_url}]({original_url})") | |
| st.markdown(f"**Resized (for embeddings):** [{resized_url}]({resized_url})") | |
| with col_image: | |
| st.markdown("### πΈ Document Page") | |
| # Get original image URL | |
| image_url = original_url or resized_url or page_url | |
| # Check if we should generate saliency (show BOTH original and saliency side by side) | |
| if show_saliency and show_images: | |
| # Check if we have all requirements for saliency | |
| has_client = qdrant_client is not None | |
| has_collection = collection_name is not None | |
| has_query = query_embedding is not None | |
| has_point_id = point_id is not None | |
| has_tile_metadata = can_generate_saliency(metadata) | |
| can_saliency = has_client and has_collection and has_query and has_point_id and has_tile_metadata | |
| if not can_saliency: | |
| missing = [] | |
| if not has_client: missing.append("qdrant_client") | |
| if not has_collection: missing.append("collection_name") | |
| if not has_query: missing.append("query_embedding") | |
| if not has_point_id: missing.append("point_id") | |
| if not has_tile_metadata: missing.append("tile_metadata") | |
| logger.warning(f"Doc {i+1}: Saliency unavailable, missing: {missing}") | |
| if can_saliency: | |
| # Create two columns: Original image | Saliency map | |
| img_col1, img_col2 = st.columns(2) | |
| # Left column: Original image (ALWAYS show) | |
| with img_col1: | |
| st.markdown("**π Original**") | |
| if image_url and isinstance(image_url, str) and image_url.startswith('http'): | |
| try: | |
| st.image(image_url, use_container_width=True, caption=f"Page {page_number}") | |
| except Exception as e: | |
| st.error(f"Failed to load image: {e}") | |
| else: | |
| st.info("No image URL available") | |
| # Right column: Saliency map | |
| with img_col2: | |
| st.markdown("**π₯ Saliency Map**") | |
| try: | |
| with st.spinner(f"Generating..."): | |
| # Convert query embedding if needed | |
| query_emb = query_embedding | |
| if hasattr(query_emb, 'cpu'): | |
| query_emb = query_emb.cpu().float().numpy() | |
| if query_emb.ndim == 3: | |
| query_emb = query_emb.squeeze(0) # Remove batch dimension | |
| logger.info(f"π₯ Generating saliency for doc {i+1}: point_id={point_id}, colormap={saliency_colormap}") | |
| saliency_img = generate_tile_aware_saliency( | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| point_id=point_id, | |
| query_embedding=query_emb, | |
| alpha=saliency_alpha, | |
| colormap=saliency_colormap, | |
| threshold_percentile=saliency_threshold | |
| ) | |
| if saliency_img: | |
| st.image(saliency_img, use_container_width=True, caption=f"Relevance heatmap") | |
| logger.info(f"β Saliency map displayed for doc {i+1}") | |
| else: | |
| logger.warning(f"Saliency generation returned None for doc {i+1}") | |
| st.caption("_Could not generate saliency map_") | |
| except Exception as e: | |
| logger.error(f"Saliency generation failed for doc {i+1}: {e}") | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| st.warning(f"β οΈ Failed: {str(e)[:80]}") | |
| else: | |
| # Can't generate saliency - just show original image | |
| if image_url and isinstance(image_url, str) and image_url.startswith('http'): | |
| try: | |
| st.image(image_url, width=700, caption=f"Page {page_number}") | |
| except Exception as e: | |
| st.error(f"Failed to load image: {e}") | |
| if not has_tile_metadata: | |
| st.caption("_Saliency unavailable: missing tile metadata_") | |
| elif not has_point_id: | |
| st.caption("_Saliency unavailable: missing point_id_") | |
| # Display original image only (no saliency requested) | |
| elif show_images: | |
| if image_url and isinstance(image_url, str) and image_url.startswith('http'): | |
| try: | |
| st.image(image_url, width=700, caption=f"Page {page_number}") | |
| except Exception as e: | |
| st.error(f"Failed to load image: {e}") | |
| else: | |
| st.info("No image URL available") | |
| elif not show_images: | |
| st.info("Enable image display in settings to view document pages") | |
| def display_visual_search_results( | |
| sources: List[Any], | |
| show_statistics: bool = True, | |
| show_images: bool = False, | |
| show_saliency: bool = False, | |
| qdrant_client: Any = None, | |
| collection_name: str = None, | |
| query_embedding: Optional[np.ndarray] = None, | |
| saliency_alpha: float = 0.4, | |
| saliency_colormap: str = 'hot', | |
| saliency_threshold: int = 50, | |
| max_display: int = 20 | |
| ) -> None: | |
| """ | |
| Display visual search results with statistics and details. | |
| Args: | |
| sources: List of VisualSearchResult objects | |
| show_statistics: Whether to show statistics | |
| show_images: Whether to show document images | |
| show_saliency: Whether to generate and display saliency maps | |
| qdrant_client: Qdrant client (required for saliency) | |
| collection_name: Qdrant collection name (required for saliency) | |
| query_embedding: Query embedding for saliency computation | |
| saliency_alpha: Saliency overlay transparency (0.0-1.0) | |
| saliency_colormap: Matplotlib colormap for saliency (default: 'hot') | |
| saliency_threshold: Threshold percentile for saliency (default: 50) | |
| max_display: Maximum number of documents to display in detail | |
| """ | |
| if not sources: | |
| st.info("No documents were retrieved for the last query.") | |
| return | |
| # Count unique filenames | |
| unique_filenames = set() | |
| for doc in sources: | |
| filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown') | |
| unique_filenames.add(filename) | |
| st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents:**") | |
| if len(unique_filenames) < len(sources): | |
| st.info(f"π‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.") | |
| # Show saliency info if enabled | |
| if show_saliency: | |
| st.info(f"π₯ **Saliency Maps Enabled**: Showing which image regions are most relevant to your query (using '{saliency_colormap}' colormap)") | |
| # Show statistics | |
| if show_statistics: | |
| display_visual_document_statistics(sources) | |
| st.markdown("---") | |
| # Show detailed results (limit to max_display) | |
| display_sources = sources[:max_display] | |
| if len(sources) > max_display: | |
| st.warning(f"β οΈ Showing top {max_display} of {len(sources)} results") | |
| display_visual_document_details( | |
| display_sources, | |
| show_images=show_images, | |
| show_saliency=show_saliency, | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| query_embedding=query_embedding, | |
| saliency_alpha=saliency_alpha, | |
| saliency_colormap=saliency_colormap, | |
| saliency_threshold=saliency_threshold | |
| ) | |
| if len(sources) > max_display: | |
| st.info(f"π‘ {len(sources) - max_display} more results not shown") | |