akryldigital commited on
Commit
c0655b8
Β·
verified Β·
1 Parent(s): 7f3ae81

integrate saliency maps

Browse files
Files changed (1) hide show
  1. app.py +95 -2
app.py CHANGED
@@ -440,6 +440,53 @@ def main():
440
  'districts': selected_districts if not filename_mode else [],
441
  'filenames': selected_filenames
442
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  # Main content area with tabs
445
  tab1, tab2 = st.tabs(["πŸ’¬ Chat", "πŸ“„ Retrieved Documents"])
@@ -646,11 +693,57 @@ def main():
646
  # Use visual display for visual search results
647
  if is_visual_search and st.session_state.chatbot_version == "visual":
648
  st.markdown("### 🎨 Visual Search Results")
649
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  display_visual_search_results(
651
  sources=sources,
652
  show_statistics=True,
653
  show_images=True, # Show Cloudinary images
 
 
 
 
 
 
 
654
  max_display=20
655
  )
656
  else:
@@ -958,7 +1051,7 @@ def main():
958
  st.markdown("---")
959
  st.markdown("#### πŸ“Š Retrieval History")
960
 
961
- with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=True):
962
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
963
  st.markdown(f"### **Retrieval #{idx}**")
964
 
 
440
  'districts': selected_districts if not filename_mode else [],
441
  'filenames': selected_filenames
442
  }
443
+
444
+ # Saliency settings (only for visual mode)
445
+ if st.session_state.chatbot_version == "visual":
446
+ with st.expander("πŸ”₯ Saliency Maps", expanded=False):
447
+ st.caption("Visualize which image regions are relevant to your query")
448
+
449
+ show_saliency = st.checkbox(
450
+ "Enable Saliency Maps",
451
+ value=st.session_state.get('show_saliency', False),
452
+ key="saliency_toggle",
453
+ help="Generate heatmaps showing which parts of each document are most relevant"
454
+ )
455
+ st.session_state.show_saliency = show_saliency
456
+
457
+ if show_saliency:
458
+ # Colormap selection (hot is default)
459
+ colormap_options = ["hot", "jet", "viridis", "plasma", "coolwarm", "RdYlGn"]
460
+ saliency_colormap = st.selectbox(
461
+ "Colormap",
462
+ options=colormap_options,
463
+ index=colormap_options.index(st.session_state.get('saliency_colormap', 'hot')),
464
+ key="saliency_colormap_select",
465
+ help="Color scheme for the heatmap. 'hot' recommended for visibility."
466
+ )
467
+ st.session_state.saliency_colormap = saliency_colormap
468
+
469
+ saliency_alpha = st.slider(
470
+ "Overlay Transparency",
471
+ min_value=0.1,
472
+ max_value=0.8,
473
+ value=st.session_state.get('saliency_alpha', 0.4),
474
+ step=0.1,
475
+ key="saliency_alpha_slider",
476
+ help="0.1 = subtle, 0.8 = intense"
477
+ )
478
+ st.session_state.saliency_alpha = saliency_alpha
479
+
480
+ saliency_threshold = st.slider(
481
+ "Threshold (%)",
482
+ min_value=0,
483
+ max_value=80,
484
+ value=st.session_state.get('saliency_threshold', 50),
485
+ step=10,
486
+ key="saliency_threshold_slider",
487
+ help="Hide patches below this percentile"
488
+ )
489
+ st.session_state.saliency_threshold = saliency_threshold
490
 
491
  # Main content area with tabs
492
  tab1, tab2 = st.tabs(["πŸ’¬ Chat", "πŸ“„ Retrieved Documents"])
 
693
  # Use visual display for visual search results
694
  if is_visual_search and st.session_state.chatbot_version == "visual":
695
  st.markdown("### 🎨 Visual Search Results")
696
+
697
+ # Get saliency settings from session state
698
+ show_saliency = st.session_state.get('show_saliency', False)
699
+ saliency_alpha = st.session_state.get('saliency_alpha', 0.4)
700
+ saliency_threshold = st.session_state.get('saliency_threshold', 50)
701
+ saliency_colormap = st.session_state.get('saliency_colormap', 'hot')
702
+
703
+ # Get Qdrant client and query embedding for saliency
704
+ qdrant_client = None
705
+ collection_name = None
706
+ query_embedding = None
707
+
708
+ if show_saliency:
709
+ try:
710
+ # Access the visual search adapter from the chatbot
711
+ chatbot = st.session_state.get('chatbot')
712
+ if chatbot and hasattr(chatbot, 'visual_search'):
713
+ visual_search = chatbot.visual_search
714
+ qdrant_client = visual_search.client
715
+ collection_name = visual_search.collection_name
716
+ query_embedding = visual_search.last_query_embedding
717
+
718
+ if query_embedding is None:
719
+ st.warning("⚠️ Query embedding not available for saliency")
720
+ show_saliency = False
721
+ else:
722
+ logger.info(f"βœ… Saliency enabled: colormap={saliency_colormap}, alpha={saliency_alpha}, threshold={saliency_threshold}")
723
+ except Exception as e:
724
+ logger.error(f"Failed to get saliency requirements: {e}")
725
+ st.warning(f"⚠️ Saliency unavailable: {str(e)[:50]}")
726
+ show_saliency = False
727
+
728
+ # Extract statistics for charts (same as v1)
729
+ stats = extract_chunk_statistics(sources)
730
+
731
+ # Show charts for visual RAG too (like v1)
732
+ if len(sources) >= 5:
733
+ display_chunk_statistics_charts(stats, "Retrieval Statistics")
734
+ st.markdown("---")
735
+
736
  display_visual_search_results(
737
  sources=sources,
738
  show_statistics=True,
739
  show_images=True, # Show Cloudinary images
740
+ show_saliency=show_saliency,
741
+ qdrant_client=qdrant_client,
742
+ collection_name=collection_name,
743
+ query_embedding=query_embedding,
744
+ saliency_alpha=saliency_alpha,
745
+ saliency_colormap=saliency_colormap, # Use selected colormap
746
+ saliency_threshold=saliency_threshold,
747
  max_display=20
748
  )
749
  else:
 
1051
  st.markdown("---")
1052
  st.markdown("#### πŸ“Š Retrieval History")
1053
 
1054
+ with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
1055
  for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
1056
  st.markdown(f"### **Retrieval #{idx}**")
1057