audit_assistant / src /ui_components /visual_documents.py
akryldigital's picture
add saliency ui
7f3ae81 verified
raw
history blame
19.4 kB
"""
Visual Document Display Components
UI components for displaying visual search results with enhanced metadata.
Includes saliency map visualization for tile-aware ColPali embeddings.
"""
import streamlit as st
import pandas as pd
import numpy as np
import logging
from typing import List, Any, Dict, Optional
from collections import Counter
logger = logging.getLogger(__name__)
def display_visual_document_statistics(sources: List[Any]) -> None:
"""
Display statistics for visual search results in a bordered box with tables.
Args:
sources: List of VisualSearchResult objects
"""
if not sources:
return
# Extract statistics
filenames = []
years = []
sources_list = []
districts = []
for doc in sources:
metadata = getattr(doc, 'metadata', {})
filenames.append(metadata.get('filename', 'Unknown'))
year = metadata.get('year')
if year:
years.append(year)
source = metadata.get('source')
if source:
sources_list.append(source)
district = metadata.get('district')
if district and district != 'None':
districts.append(district)
# Count unique values
unique_files = len(set(filenames))
unique_years = len(set(years))
unique_sources = len(set(sources_list))
# Create bordered container
with st.container():
st.markdown("""
<style>
.stats-container {
border: 2px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
background-color: #f9f9f9;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="stats-container">', unsafe_allow_html=True)
st.markdown("### πŸ“Š Retrieval Statistics")
# Metrics in columns
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Chunks", len(sources))
with col2:
st.metric("Unique Files", unique_files)
with col3:
st.metric("Unique Years", unique_years if unique_years > 0 else "N/A")
with col4:
st.metric("Unique Sources", unique_sources if unique_sources > 0 else "N/A")
st.markdown("---")
# Distribution tables in columns
col1, col2, col3, col4 = st.columns(4)
with col1:
# District distribution
if districts:
district_counts = Counter(districts)
st.markdown("**🏘️ Districts**")
district_df = pd.DataFrame([
{"District": dist, "Count": count}
for dist, count in district_counts.most_common(10)
])
st.dataframe(district_df, hide_index=True, use_container_width=True)
with col2:
# Source distribution
if sources_list:
source_counts = Counter(sources_list)
st.markdown("**πŸ›οΈ Sources**")
source_df = pd.DataFrame([
{"Source": src, "Count": count}
for src, count in source_counts.most_common()
])
st.dataframe(source_df, hide_index=True, use_container_width=True)
with col3:
# Year distribution
if years:
year_counts = Counter(years)
st.markdown("**πŸ“… Years**")
year_df = pd.DataFrame([
{"Year": year, "Count": count}
for year, count in sorted(year_counts.items(), reverse=True)
])
st.dataframe(year_df, hide_index=True, use_container_width=True)
with col4:
# File distribution (top 10)
file_counts = Counter(filenames)
st.markdown("**πŸ“„ Files**")
file_df = pd.DataFrame([
{"File": filename[:30] + "..." if len(filename) > 30 else filename, "Count": count}
for filename, count in file_counts.most_common(10)
])
st.dataframe(file_df, hide_index=True, use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
def display_visual_document_details(
sources: List[Any],
show_images: bool = False,
show_saliency: bool = False,
qdrant_client: Any = None,
collection_name: str = None,
query_embedding: Optional[np.ndarray] = None,
saliency_alpha: float = 0.4,
saliency_colormap: str = 'hot',
saliency_threshold: int = 50
) -> None:
"""
Display detailed information for each visual search result.
Args:
sources: List of VisualSearchResult objects
show_images: Whether to display document images (from Cloudinary)
show_saliency: Whether to generate and display saliency maps
qdrant_client: Qdrant client (required for saliency)
collection_name: Qdrant collection name (required for saliency)
query_embedding: Query embedding for saliency computation
saliency_alpha: Saliency overlay transparency (0.0-1.0)
saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
saliency_threshold: Threshold percentile for saliency (default: 50)
"""
st.markdown("### πŸ“„ Document Details")
# Import saliency functions if needed
if show_saliency:
from .saliency import generate_tile_aware_saliency, can_generate_saliency
for i, doc in enumerate(sources):
metadata = getattr(doc, 'metadata', {})
# Get basic metadata
filename = metadata.get('filename', 'Unknown')
page_number = metadata.get('page_number', '?')
year = metadata.get('year', 'Unknown')
source = metadata.get('source', 'Unknown')
district = metadata.get('district')
score = getattr(doc, 'score', 0.0)
# Get visual-specific metadata
num_tiles = metadata.get('num_tiles')
tile_rows = metadata.get('tile_rows')
tile_cols = metadata.get('tile_cols')
num_visual_tokens = metadata.get('num_visual_tokens')
original_width = metadata.get('original_width')
original_height = metadata.get('original_height')
resized_width = metadata.get('resized_width')
resized_height = metadata.get('resized_height')
# Get image URLs
original_url = metadata.get('original_url')
resized_url = metadata.get('resized_url')
page_url = metadata.get('page') # Fallback
# Get point_id for saliency (check doc.id first, then metadata)
point_id = getattr(doc, 'id', None) or metadata.get('point_id') or metadata.get('_id')
# Debug logging for saliency
if show_saliency:
logger.debug(f"Doc {i+1}: point_id={point_id}, has_tiles={metadata.get('num_tiles') is not None}")
# Build title
score_text = f" (Score: {score:.3f})"
title = f"πŸ“„ Document {i+1}: {filename[:50]}...{score_text}"
with st.expander(title, expanded=(i == 0)): # Expand first result
# Two-column layout: Metadata (left) and Image (right)
col_meta, col_image = st.columns([1, 2])
with col_meta:
st.markdown("### πŸ“‹ Metadata")
# Basic metadata
st.write(f"πŸ“„ **File:** {filename}")
st.write(f"πŸ›οΈ **Source:** {source}")
st.write(f"πŸ“… **Year:** {year}")
st.write(f"πŸ“– **Page:** {page_number}")
if district and district != 'None':
st.write(f"πŸ“ **District:** {district}")
# Relevance score
st.markdown("---")
st.markdown("### 🎯 Relevance")
score_color = "🟒" if score > 0.7 else "🟑" if score > 0.5 else "πŸ”΄"
st.markdown(f"**Score:** {score_color} **{score:.3f}**")
# Visual metadata (if available)
if num_tiles or num_visual_tokens:
st.markdown("---")
st.markdown("### 🎨 Visual Metadata")
if num_tiles:
st.write(f"πŸ”² **Tiles:** {num_tiles} ({tile_rows}Γ—{tile_cols})")
if num_visual_tokens:
st.write(f"πŸ”’ **Visual Tokens:** {num_visual_tokens}")
if original_width and original_height:
st.write(f"πŸ“ **Original Size:** {original_width}Γ—{original_height}")
if resized_width and resized_height:
st.write(f"πŸ“ **Resized Size:** {resized_width}Γ—{resized_height}")
processing_version = metadata.get('processing_version')
if processing_version:
st.write(f"βš™οΈ **Processing:** {processing_version}")
# Text content preview
content = getattr(doc, 'page_content', '')
if content:
st.markdown("---")
with st.expander("πŸ“ Extracted Text", expanded=True):
st.text_area(
"Content",
value=content[:500] + ("..." if len(content) > 500 else ""),
height=150,
disabled=True,
label_visibility="collapsed",
key=f"visual_doc_text_{i}"
)
else:
st.markdown("---")
st.caption("_No text extracted (image-only page)_")
# Show image URLs under text
if original_url and resized_url:
with st.expander("πŸ”— Image URLs", expanded=True):
st.markdown(f"**Original:** [{original_url}]({original_url})")
st.markdown(f"**Resized (for embeddings):** [{resized_url}]({resized_url})")
with col_image:
st.markdown("### πŸ“Έ Document Page")
# Get original image URL
image_url = original_url or resized_url or page_url
# Check if we should generate saliency (show BOTH original and saliency side by side)
if show_saliency and show_images:
# Check if we have all requirements for saliency
has_client = qdrant_client is not None
has_collection = collection_name is not None
has_query = query_embedding is not None
has_point_id = point_id is not None
has_tile_metadata = can_generate_saliency(metadata)
can_saliency = has_client and has_collection and has_query and has_point_id and has_tile_metadata
if not can_saliency:
missing = []
if not has_client: missing.append("qdrant_client")
if not has_collection: missing.append("collection_name")
if not has_query: missing.append("query_embedding")
if not has_point_id: missing.append("point_id")
if not has_tile_metadata: missing.append("tile_metadata")
logger.warning(f"Doc {i+1}: Saliency unavailable, missing: {missing}")
if can_saliency:
# Create two columns: Original image | Saliency map
img_col1, img_col2 = st.columns(2)
# Left column: Original image (ALWAYS show)
with img_col1:
st.markdown("**πŸ“„ Original**")
if image_url and isinstance(image_url, str) and image_url.startswith('http'):
try:
st.image(image_url, use_container_width=True, caption=f"Page {page_number}")
except Exception as e:
st.error(f"Failed to load image: {e}")
else:
st.info("No image URL available")
# Right column: Saliency map
with img_col2:
st.markdown("**πŸ”₯ Saliency Map**")
try:
with st.spinner(f"Generating..."):
# Convert query embedding if needed
query_emb = query_embedding
if hasattr(query_emb, 'cpu'):
query_emb = query_emb.cpu().float().numpy()
if query_emb.ndim == 3:
query_emb = query_emb.squeeze(0) # Remove batch dimension
logger.info(f"πŸ”₯ Generating saliency for doc {i+1}: point_id={point_id}, colormap={saliency_colormap}")
saliency_img = generate_tile_aware_saliency(
qdrant_client=qdrant_client,
collection_name=collection_name,
point_id=point_id,
query_embedding=query_emb,
alpha=saliency_alpha,
colormap=saliency_colormap,
threshold_percentile=saliency_threshold
)
if saliency_img:
st.image(saliency_img, use_container_width=True, caption=f"Relevance heatmap")
logger.info(f"βœ… Saliency map displayed for doc {i+1}")
else:
logger.warning(f"Saliency generation returned None for doc {i+1}")
st.caption("_Could not generate saliency map_")
except Exception as e:
logger.error(f"Saliency generation failed for doc {i+1}: {e}")
import traceback
logger.debug(traceback.format_exc())
st.warning(f"⚠️ Failed: {str(e)[:80]}")
else:
# Can't generate saliency - just show original image
if image_url and isinstance(image_url, str) and image_url.startswith('http'):
try:
st.image(image_url, width=700, caption=f"Page {page_number}")
except Exception as e:
st.error(f"Failed to load image: {e}")
if not has_tile_metadata:
st.caption("_Saliency unavailable: missing tile metadata_")
elif not has_point_id:
st.caption("_Saliency unavailable: missing point_id_")
# Display original image only (no saliency requested)
elif show_images:
if image_url and isinstance(image_url, str) and image_url.startswith('http'):
try:
st.image(image_url, width=700, caption=f"Page {page_number}")
except Exception as e:
st.error(f"Failed to load image: {e}")
else:
st.info("No image URL available")
elif not show_images:
st.info("Enable image display in settings to view document pages")
def display_visual_search_results(
sources: List[Any],
show_statistics: bool = True,
show_images: bool = False,
show_saliency: bool = False,
qdrant_client: Any = None,
collection_name: str = None,
query_embedding: Optional[np.ndarray] = None,
saliency_alpha: float = 0.4,
saliency_colormap: str = 'hot',
saliency_threshold: int = 50,
max_display: int = 20
) -> None:
"""
Display visual search results with statistics and details.
Args:
sources: List of VisualSearchResult objects
show_statistics: Whether to show statistics
show_images: Whether to show document images
show_saliency: Whether to generate and display saliency maps
qdrant_client: Qdrant client (required for saliency)
collection_name: Qdrant collection name (required for saliency)
query_embedding: Query embedding for saliency computation
saliency_alpha: Saliency overlay transparency (0.0-1.0)
saliency_colormap: Matplotlib colormap for saliency (default: 'hot')
saliency_threshold: Threshold percentile for saliency (default: 50)
max_display: Maximum number of documents to display in detail
"""
if not sources:
st.info("No documents were retrieved for the last query.")
return
# Count unique filenames
unique_filenames = set()
for doc in sources:
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
unique_filenames.add(filename)
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents:**")
if len(unique_filenames) < len(sources):
st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
# Show saliency info if enabled
if show_saliency:
st.info(f"πŸ”₯ **Saliency Maps Enabled**: Showing which image regions are most relevant to your query (using '{saliency_colormap}' colormap)")
# Show statistics
if show_statistics:
display_visual_document_statistics(sources)
st.markdown("---")
# Show detailed results (limit to max_display)
display_sources = sources[:max_display]
if len(sources) > max_display:
st.warning(f"⚠️ Showing top {max_display} of {len(sources)} results")
display_visual_document_details(
display_sources,
show_images=show_images,
show_saliency=show_saliency,
qdrant_client=qdrant_client,
collection_name=collection_name,
query_embedding=query_embedding,
saliency_alpha=saliency_alpha,
saliency_colormap=saliency_colormap,
saliency_threshold=saliency_threshold
)
if len(sources) > max_display:
st.info(f"πŸ’‘ {len(sources) - max_display} more results not shown")