Spaces:
Sleeping
Sleeping
integrate saliency maps
Browse files
app.py
CHANGED
|
@@ -440,6 +440,53 @@ def main():
|
|
| 440 |
'districts': selected_districts if not filename_mode else [],
|
| 441 |
'filenames': selected_filenames
|
| 442 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
# Main content area with tabs
|
| 445 |
tab1, tab2 = st.tabs(["π¬ Chat", "π Retrieved Documents"])
|
|
@@ -646,11 +693,57 @@ def main():
|
|
| 646 |
# Use visual display for visual search results
|
| 647 |
if is_visual_search and st.session_state.chatbot_version == "visual":
|
| 648 |
st.markdown("### π¨ Visual Search Results")
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
display_visual_search_results(
|
| 651 |
sources=sources,
|
| 652 |
show_statistics=True,
|
| 653 |
show_images=True, # Show Cloudinary images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
max_display=20
|
| 655 |
)
|
| 656 |
else:
|
|
@@ -958,7 +1051,7 @@ def main():
|
|
| 958 |
st.markdown("---")
|
| 959 |
st.markdown("#### π Retrieval History")
|
| 960 |
|
| 961 |
-
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=
|
| 962 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 963 |
st.markdown(f"### **Retrieval #{idx}**")
|
| 964 |
|
|
|
|
| 440 |
'districts': selected_districts if not filename_mode else [],
|
| 441 |
'filenames': selected_filenames
|
| 442 |
}
|
| 443 |
+
|
| 444 |
+
# Saliency settings (only for visual mode)
|
| 445 |
+
if st.session_state.chatbot_version == "visual":
|
| 446 |
+
with st.expander("π₯ Saliency Maps", expanded=False):
|
| 447 |
+
st.caption("Visualize which image regions are relevant to your query")
|
| 448 |
+
|
| 449 |
+
show_saliency = st.checkbox(
|
| 450 |
+
"Enable Saliency Maps",
|
| 451 |
+
value=st.session_state.get('show_saliency', False),
|
| 452 |
+
key="saliency_toggle",
|
| 453 |
+
help="Generate heatmaps showing which parts of each document are most relevant"
|
| 454 |
+
)
|
| 455 |
+
st.session_state.show_saliency = show_saliency
|
| 456 |
+
|
| 457 |
+
if show_saliency:
|
| 458 |
+
# Colormap selection (hot is default)
|
| 459 |
+
colormap_options = ["hot", "jet", "viridis", "plasma", "coolwarm", "RdYlGn"]
|
| 460 |
+
saliency_colormap = st.selectbox(
|
| 461 |
+
"Colormap",
|
| 462 |
+
options=colormap_options,
|
| 463 |
+
index=colormap_options.index(st.session_state.get('saliency_colormap', 'hot')),
|
| 464 |
+
key="saliency_colormap_select",
|
| 465 |
+
help="Color scheme for the heatmap. 'hot' recommended for visibility."
|
| 466 |
+
)
|
| 467 |
+
st.session_state.saliency_colormap = saliency_colormap
|
| 468 |
+
|
| 469 |
+
saliency_alpha = st.slider(
|
| 470 |
+
"Overlay Transparency",
|
| 471 |
+
min_value=0.1,
|
| 472 |
+
max_value=0.8,
|
| 473 |
+
value=st.session_state.get('saliency_alpha', 0.4),
|
| 474 |
+
step=0.1,
|
| 475 |
+
key="saliency_alpha_slider",
|
| 476 |
+
help="0.1 = subtle, 0.8 = intense"
|
| 477 |
+
)
|
| 478 |
+
st.session_state.saliency_alpha = saliency_alpha
|
| 479 |
+
|
| 480 |
+
saliency_threshold = st.slider(
|
| 481 |
+
"Threshold (%)",
|
| 482 |
+
min_value=0,
|
| 483 |
+
max_value=80,
|
| 484 |
+
value=st.session_state.get('saliency_threshold', 50),
|
| 485 |
+
step=10,
|
| 486 |
+
key="saliency_threshold_slider",
|
| 487 |
+
help="Hide patches below this percentile"
|
| 488 |
+
)
|
| 489 |
+
st.session_state.saliency_threshold = saliency_threshold
|
| 490 |
|
| 491 |
# Main content area with tabs
|
| 492 |
tab1, tab2 = st.tabs(["π¬ Chat", "π Retrieved Documents"])
|
|
|
|
| 693 |
# Use visual display for visual search results
|
| 694 |
if is_visual_search and st.session_state.chatbot_version == "visual":
|
| 695 |
st.markdown("### π¨ Visual Search Results")
|
| 696 |
+
|
| 697 |
+
# Get saliency settings from session state
|
| 698 |
+
show_saliency = st.session_state.get('show_saliency', False)
|
| 699 |
+
saliency_alpha = st.session_state.get('saliency_alpha', 0.4)
|
| 700 |
+
saliency_threshold = st.session_state.get('saliency_threshold', 50)
|
| 701 |
+
saliency_colormap = st.session_state.get('saliency_colormap', 'hot')
|
| 702 |
+
|
| 703 |
+
# Get Qdrant client and query embedding for saliency
|
| 704 |
+
qdrant_client = None
|
| 705 |
+
collection_name = None
|
| 706 |
+
query_embedding = None
|
| 707 |
+
|
| 708 |
+
if show_saliency:
|
| 709 |
+
try:
|
| 710 |
+
# Access the visual search adapter from the chatbot
|
| 711 |
+
chatbot = st.session_state.get('chatbot')
|
| 712 |
+
if chatbot and hasattr(chatbot, 'visual_search'):
|
| 713 |
+
visual_search = chatbot.visual_search
|
| 714 |
+
qdrant_client = visual_search.client
|
| 715 |
+
collection_name = visual_search.collection_name
|
| 716 |
+
query_embedding = visual_search.last_query_embedding
|
| 717 |
+
|
| 718 |
+
if query_embedding is None:
|
| 719 |
+
st.warning("β οΈ Query embedding not available for saliency")
|
| 720 |
+
show_saliency = False
|
| 721 |
+
else:
|
| 722 |
+
logger.info(f"β
Saliency enabled: colormap={saliency_colormap}, alpha={saliency_alpha}, threshold={saliency_threshold}")
|
| 723 |
+
except Exception as e:
|
| 724 |
+
logger.error(f"Failed to get saliency requirements: {e}")
|
| 725 |
+
st.warning(f"β οΈ Saliency unavailable: {str(e)[:50]}")
|
| 726 |
+
show_saliency = False
|
| 727 |
+
|
| 728 |
+
# Extract statistics for charts (same as v1)
|
| 729 |
+
stats = extract_chunk_statistics(sources)
|
| 730 |
+
|
| 731 |
+
# Show charts for visual RAG too (like v1)
|
| 732 |
+
if len(sources) >= 5:
|
| 733 |
+
display_chunk_statistics_charts(stats, "Retrieval Statistics")
|
| 734 |
+
st.markdown("---")
|
| 735 |
+
|
| 736 |
display_visual_search_results(
|
| 737 |
sources=sources,
|
| 738 |
show_statistics=True,
|
| 739 |
show_images=True, # Show Cloudinary images
|
| 740 |
+
show_saliency=show_saliency,
|
| 741 |
+
qdrant_client=qdrant_client,
|
| 742 |
+
collection_name=collection_name,
|
| 743 |
+
query_embedding=query_embedding,
|
| 744 |
+
saliency_alpha=saliency_alpha,
|
| 745 |
+
saliency_colormap=saliency_colormap, # Use selected colormap
|
| 746 |
+
saliency_threshold=saliency_threshold,
|
| 747 |
max_display=20
|
| 748 |
)
|
| 749 |
else:
|
|
|
|
| 1051 |
st.markdown("---")
|
| 1052 |
st.markdown("#### π Retrieval History")
|
| 1053 |
|
| 1054 |
+
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
|
| 1055 |
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 1056 |
st.markdown(f"### **Retrieval #{idx}**")
|
| 1057 |
|