"""
Visual Document Display Components
UI components for displaying visual search results with enhanced metadata.
Includes saliency map visualization for tile-aware ColPali embeddings.
"""
import streamlit as st
import pandas as pd
import numpy as np
import logging
from typing import List, Any, Dict, Optional
from collections import Counter
logger = logging.getLogger(__name__)
def display_visual_document_statistics(sources: List[Any]) -> None:
"""
Display statistics for visual search results in a bordered box with tables.
Args:
sources: List of VisualSearchResult objects
"""
if not sources:
return
# Extract statistics
filenames = []
years = []
sources_list = []
districts = []
for doc in sources:
metadata = getattr(doc, 'metadata', {})
filenames.append(metadata.get('filename', 'Unknown'))
year = metadata.get('year')
if year:
years.append(year)
source = metadata.get('source')
if source:
sources_list.append(source)
district = metadata.get('district')
if district and district != 'None':
districts.append(district)
# Count unique values
unique_files = len(set(filenames))
unique_years = len(set(years))
unique_sources = len(set(sources_list))
# Create bordered container
with st.container():
st.markdown("""
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
st.markdown("### 📊 Retrieval Statistics")
# Metrics in columns
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Chunks", len(sources))
with col2:
st.metric("Unique Files", unique_files)
with col3:
st.metric("Unique Years", unique_years if unique_years > 0 else "N/A")
with col4:
st.metric("Unique Sources", unique_sources if unique_sources > 0 else "N/A")
st.markdown("---")
# Distribution tables in columns
col1, col2, col3, col4 = st.columns(4)
with col1:
# District distribution
if districts:
district_counts = Counter(districts)
st.markdown("**🏘️ Districts**")
district_df = pd.DataFrame([
{"District": dist, "Count": count}
for dist, count in district_counts.most_common(10)
])
st.dataframe(district_df, hide_index=True, use_container_width=True)
with col2:
# Source distribution
if sources_list:
source_counts = Counter(sources_list)
st.markdown("**🏛️ Sources**")
source_df = pd.DataFrame([
{"Source": src, "Count": count}
for src, count in source_counts.most_common()
])
st.dataframe(source_df, hide_index=True, use_container_width=True)
with col3:
# Year distribution
if years:
year_counts = Counter(years)
st.markdown("**📅 Years**")
year_df = pd.DataFrame([
{"Year": year, "Count": count}
for year, count in sorted(year_counts.items(), reverse=True)
])
st.dataframe(year_df, hide_index=True, use_container_width=True)
with col4:
# File distribution (top 10)
file_counts = Counter(filenames)
st.markdown("**📄 Files**")
file_df = pd.DataFrame([
{"File": filename[:30] + "..." if len(filename) > 30 else filename, "Count": count}
for filename, count in file_counts.most_common(10)
])
st.dataframe(file_df, hide_index=True, use_container_width=True)
st.markdown('
', unsafe_allow_html=True)
def display_visual_document_details(
sources: List[Any],
show_images: bool = False,
show_saliency: bool = False,
qdrant_client: Any = None,
collection_name: str = None,
query_embedding: Optional[np.ndarray] = None,
saliency_alpha: float = 0.4,
saliency_colormap: str = 'hot',
saliency_threshold: int = 50
) -> None:
"""
Display detailed information for each visual search result.
Args:
sources: List of VisualSearchResult objects
show_images: Whether to display document images (from Cloudinary)
show_saliency: Whether to generate and display saliency maps
qdrant_client: Qdrant client (required for saliency)
collection_name: Qdrant collection name (required for saliency)
query_embedding: Query embedding for saliency computation
saliency_alpha: Saliency overlay transparency (0.0-1.0)
saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
saliency_threshold: Threshold percentile for saliency (default: 50)
"""
st.markdown("### 📄 Document Details")
# Import saliency functions if needed
if show_saliency:
from .saliency import generate_tile_aware_saliency, can_generate_saliency
for i, doc in enumerate(sources):
metadata = getattr(doc, 'metadata', {})
# Get basic metadata
filename = metadata.get('filename', 'Unknown')
page_number = metadata.get('page_number', '?')
year = metadata.get('year', 'Unknown')
source = metadata.get('source', 'Unknown')
district = metadata.get('district')
score = getattr(doc, 'score', 0.0)
# Get visual-specific metadata
num_tiles = metadata.get('num_tiles')
tile_rows = metadata.get('tile_rows')
tile_cols = metadata.get('tile_cols')
num_visual_tokens = metadata.get('num_visual_tokens')
original_width = metadata.get('original_width')
original_height = metadata.get('original_height')
resized_width = metadata.get('resized_width')
resized_height = metadata.get('resized_height')
# Get image URLs
original_url = metadata.get('original_url')
resized_url = metadata.get('resized_url')
page_url = metadata.get('page') # Fallback
# Get point_id for saliency (check doc.id first, then metadata)
point_id = getattr(doc, 'id', None) or metadata.get('point_id') or metadata.get('_id')
# Debug logging for saliency
if show_saliency:
logger.debug(f"Doc {i+1}: point_id={point_id}, has_tiles={metadata.get('num_tiles') is not None}")
# Build title
score_text = f" (Score: {score:.3f})"
title = f"📄 Document {i+1}: {filename[:50]}...{score_text}"
with st.expander(title, expanded=(i == 0)): # Expand first result
# Two-column layout: Metadata (left) and Image (right)
col_meta, col_image = st.columns([1, 2])
with col_meta:
st.markdown("### 📋 Metadata")
# Basic metadata
st.write(f"📄 **File:** {filename}")
st.write(f"🏛️ **Source:** {source}")
st.write(f"📅 **Year:** {year}")
st.write(f"📖 **Page:** {page_number}")
if district and district != 'None':
st.write(f"📍 **District:** {district}")
# Relevance score
st.markdown("---")
st.markdown("### 🎯 Relevance")
score_color = "🟢" if score > 0.7 else "🟡" if score > 0.5 else "🔴"
st.markdown(f"**Score:** {score_color} **{score:.3f}**")
# Visual metadata (if available)
if num_tiles or num_visual_tokens:
st.markdown("---")
st.markdown("### 🎨 Visual Metadata")
if num_tiles:
st.write(f"🔲 **Tiles:** {num_tiles} ({tile_rows}×{tile_cols})")
if num_visual_tokens:
st.write(f"🔢 **Visual Tokens:** {num_visual_tokens}")
if original_width and original_height:
st.write(f"📐 **Original Size:** {original_width}×{original_height}")
if resized_width and resized_height:
st.write(f"📐 **Resized Size:** {resized_width}×{resized_height}")
processing_version = metadata.get('processing_version')
if processing_version:
st.write(f"⚙️ **Processing:** {processing_version}")
# Text content preview
content = getattr(doc, 'page_content', '')
if content:
st.markdown("---")
with st.expander("📝 Extracted Text", expanded=True):
st.text_area(
"Content",
value=content[:500] + ("..." if len(content) > 500 else ""),
height=150,
disabled=True,
label_visibility="collapsed",
key=f"visual_doc_text_{i}"
)
else:
st.markdown("---")
st.caption("_No text extracted (image-only page)_")
# Show image URLs under text
if original_url and resized_url:
with st.expander("🔗 Image URLs", expanded=True):
st.markdown(f"**Original:** [{original_url}]({original_url})")
st.markdown(f"**Resized (for embeddings):** [{resized_url}]({resized_url})")
with col_image:
st.markdown("### 📸 Document Page")
# Get original image URL
image_url = original_url or resized_url or page_url
# Check if we should generate saliency (show BOTH original and saliency side by side)
if show_saliency and show_images:
# Check if we have all requirements for saliency
has_client = qdrant_client is not None
has_collection = collection_name is not None
has_query = query_embedding is not None
has_point_id = point_id is not None
has_tile_metadata = can_generate_saliency(metadata)
can_saliency = has_client and has_collection and has_query and has_point_id and has_tile_metadata
if not can_saliency:
missing = []
if not has_client: missing.append("qdrant_client")
if not has_collection: missing.append("collection_name")
if not has_query: missing.append("query_embedding")
if not has_point_id: missing.append("point_id")
if not has_tile_metadata: missing.append("tile_metadata")
logger.warning(f"Doc {i+1}: Saliency unavailable, missing: {missing}")
if can_saliency:
# Create two columns: Original image | Saliency map
img_col1, img_col2 = st.columns(2)
# Left column: Original image (ALWAYS show)
with img_col1:
st.markdown("**📄 Original**")
if image_url and isinstance(image_url, str) and image_url.startswith('http'):
try:
st.image(image_url, use_container_width=True, caption=f"Page {page_number}")
except Exception as e:
st.error(f"Failed to load image: {e}")
else:
st.info("No image URL available")
# Right column: Saliency map
with img_col2:
st.markdown("**🔥 Saliency Map**")
try:
with st.spinner(f"Generating..."):
# Convert query embedding if needed
query_emb = query_embedding
if hasattr(query_emb, 'cpu'):
query_emb = query_emb.cpu().float().numpy()
if query_emb.ndim == 3:
query_emb = query_emb.squeeze(0) # Remove batch dimension
logger.info(f"🔥 Generating saliency for doc {i+1}: point_id={point_id}, colormap={saliency_colormap}")
saliency_img = generate_tile_aware_saliency(
qdrant_client=qdrant_client,
collection_name=collection_name,
point_id=point_id,
query_embedding=query_emb,
alpha=saliency_alpha,
colormap=saliency_colormap,
threshold_percentile=saliency_threshold
)
if saliency_img:
st.image(saliency_img, use_container_width=True, caption=f"Relevance heatmap")
logger.info(f"✅ Saliency map displayed for doc {i+1}")
else:
logger.warning(f"Saliency generation returned None for doc {i+1}")
st.caption("_Could not generate saliency map_")
except Exception as e:
logger.error(f"Saliency generation failed for doc {i+1}: {e}")
import traceback
logger.debug(traceback.format_exc())
st.warning(f"⚠️ Failed: {str(e)[:80]}")
else:
# Can't generate saliency - just show original image
if image_url and isinstance(image_url, str) and image_url.startswith('http'):
try:
st.image(image_url, width=700, caption=f"Page {page_number}")
except Exception as e:
st.error(f"Failed to load image: {e}")
if not has_tile_metadata:
st.caption("_Saliency unavailable: missing tile metadata_")
elif not has_point_id:
st.caption("_Saliency unavailable: missing point_id_")
# Display original image only (no saliency requested)
elif show_images:
if image_url and isinstance(image_url, str) and image_url.startswith('http'):
try:
st.image(image_url, width=700, caption=f"Page {page_number}")
except Exception as e:
st.error(f"Failed to load image: {e}")
else:
st.info("No image URL available")
elif not show_images:
st.info("Enable image display in settings to view document pages")
def display_visual_search_results(
sources: List[Any],
show_statistics: bool = True,
show_images: bool = False,
show_saliency: bool = False,
qdrant_client: Any = None,
collection_name: str = None,
query_embedding: Optional[np.ndarray] = None,
saliency_alpha: float = 0.4,
saliency_colormap: str = 'hot',
saliency_threshold: int = 50,
max_display: int = 20
) -> None:
"""
Display visual search results with statistics and details.
Args:
sources: List of VisualSearchResult objects
show_statistics: Whether to show statistics
show_images: Whether to show document images
show_saliency: Whether to generate and display saliency maps
qdrant_client: Qdrant client (required for saliency)
collection_name: Qdrant collection name (required for saliency)
query_embedding: Query embedding for saliency computation
saliency_alpha: Saliency overlay transparency (0.0-1.0)
saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
saliency_threshold: Threshold percentile for saliency (default: 50)
max_display: Maximum number of documents to display in detail
"""
if not sources:
st.info("No documents were retrieved for the last query.")
return
# Count unique filenames
unique_filenames = set()
for doc in sources:
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
unique_filenames.add(filename)
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents:**")
if len(unique_filenames) < len(sources):
st.info(f"💡 **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
# Show saliency info if enabled
if show_saliency:
st.info(f"🔥 **Saliency Maps Enabled**: Showing which image regions are most relevant to your query (using '{saliency_colormap}' colormap)")
# Show statistics
if show_statistics:
display_visual_document_statistics(sources)
st.markdown("---")
# Show detailed results (limit to max_display)
display_sources = sources[:max_display]
if len(sources) > max_display:
st.warning(f"⚠️ Showing top {max_display} of {len(sources)} results")
display_visual_document_details(
display_sources,
show_images=show_images,
show_saliency=show_saliency,
qdrant_client=qdrant_client,
collection_name=collection_name,
query_embedding=query_embedding,
saliency_alpha=saliency_alpha,
saliency_colormap=saliency_colormap,
saliency_threshold=saliency_threshold
)
if len(sources) > max_display:
st.info(f"💡 {len(sources) - max_display} more results not shown")