"""
Intelligent Audit Report Chatbot UI
"""
import os
import sys
import time
import json
import uuid
import logging
from pathlib import Path
import argparse
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from multi_agent_chatbot import get_multi_agent_chatbot
from smart_chatbot import get_chatbot as get_smart_chatbot
from src.reporting.feedback_schema import create_feedback_from_dict
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Page config
st.set_page_config(
layout="wide",
page_icon="🤖",
initial_sidebar_state="expanded",
page_title="Intelligent Audit Report Chatbot"
)
# Custom CSS
st.markdown("""
""", unsafe_allow_html=True)
def get_system_type():
"""Get the current system type"""
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
if system == 'smart':
return "Smart Chatbot System"
else:
return "Multi-Agent System"
def get_chatbot():
"""Initialize and return the chatbot based on system type"""
# Check environment variable for system type
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
if system == 'smart':
return get_smart_chatbot()
else:
return get_multi_agent_chatbot()
def serialize_messages(messages):
"""Serialize LangChain messages to dictionaries"""
serialized = []
for msg in messages:
if hasattr(msg, 'content'):
serialized.append({
"type": type(msg).__name__,
"content": str(msg.content)
})
return serialized
def serialize_documents(sources):
"""Serialize document objects to dictionaries with deduplication"""
serialized = []
seen_content = set()
for doc in sources:
content = getattr(doc, 'page_content', getattr(doc, 'content', ''))
# Skip if we've seen this exact content before
if content in seen_content:
continue
seen_content.add(content)
doc_dict = {
"content": content,
"metadata": getattr(doc, 'metadata', {}),
"score": getattr(doc, 'metadata', {}).get('reranked_score', getattr(doc, 'metadata', {}).get('original_score', 0.0)),
"id": getattr(doc, 'metadata', {}).get('_id', 'unknown'),
"source": getattr(doc, 'metadata', {}).get('source', 'unknown'),
"year": getattr(doc, 'metadata', {}).get('year', 'unknown'),
"district": getattr(doc, 'metadata', {}).get('district', 'unknown'),
"page": getattr(doc, 'metadata', {}).get('page', 'unknown'),
"chunk_id": getattr(doc, 'metadata', {}).get('chunk_id', 'unknown'),
"page_label": getattr(doc, 'metadata', {}).get('page_label', 'unknown'),
"original_score": getattr(doc, 'metadata', {}).get('original_score', 0.0),
"reranked_score": getattr(doc, 'metadata', {}).get('reranked_score', None)
}
serialized.append(doc_dict)
return serialized
@st.cache_data
def load_filter_options():
try:
with open("filter_options.json", "r") as f:
return json.load(f)
except FileNotFoundError:
st.info([x for x in os.listdir() if x.endswith('.json')])
st.error("filter_options.json not found. Please run the metadata analysis script.")
return {"sources": [], "years": [], "districts": [], 'filenames': []}
def main():
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'conversation_id' not in st.session_state:
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
if 'session_start_time' not in st.session_state:
st.session_state.session_start_time = time.time()
if 'active_filters' not in st.session_state:
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
# Track RAG retrieval history for feedback
if 'rag_retrieval_history' not in st.session_state:
st.session_state.rag_retrieval_history = []
# Initialize chatbot only once per app session (cached)
if 'chatbot' not in st.session_state:
with st.spinner("🔄 Loading AI models and connecting to database..."):
st.session_state.chatbot = get_chatbot()
st.success("✅ AI system ready!")
# Reset conversation history if needed (but keep chatbot cached)
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
st.session_state.messages = []
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
st.session_state.session_start_time = time.time()
st.session_state.rag_retrieval_history = []
st.session_state.feedback_submitted = False
st.session_state.reset_conversation = False
st.rerun()
# Header with system indicator
col1, col2 = st.columns([3, 1])
with col1:
st.markdown('
🤖 Intelligent Audit Report Chatbot
', unsafe_allow_html=True)
with col2:
system_type = get_system_type()
if "Multi-Agent" in system_type:
st.success(f"🔧 {system_type}")
else:
st.info(f"🔧 {system_type}")
st.markdown('Ask questions about audit reports. Use the sidebar filters to narrow down your search!
', unsafe_allow_html=True)
# Session info
duration = int(time.time() - st.session_state.session_start_time)
duration_str = f"{duration // 60}m {duration % 60}s"
st.markdown(f'''
Session Info: Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id}
''', unsafe_allow_html=True)
# Load filter options
filter_options = load_filter_options()
# Sidebar for filters
with st.sidebar:
st.markdown("### 🔍 Search Filters")
st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
st.markdown('', unsafe_allow_html=True)
st.markdown('
📄 Specific Reports (Filename Filter)
', unsafe_allow_html=True)
st.markdown('
⚠️ Selecting specific reports will ignore all other filters
', unsafe_allow_html=True)
selected_filenames = st.multiselect(
"Select specific reports:",
options=filter_options.get('filenames', []),
default=st.session_state.active_filters.get('filenames', []),
key="filenames_filter",
help="Choose specific reports to search. When enabled, all other filters are ignored."
)
st.markdown('
', unsafe_allow_html=True)
# Determine if filename filter is active
filename_mode = len(selected_filenames) > 0
# Sources filter
st.markdown('', unsafe_allow_html=True)
st.markdown('
📊 Sources
', unsafe_allow_html=True)
selected_sources = st.multiselect(
"Select sources:",
options=filter_options['sources'],
default=st.session_state.active_filters['sources'],
disabled = filename_mode,
key="sources_filter",
help="Choose which types of reports to search"
)
st.markdown('
', unsafe_allow_html=True)
# Years filter
st.markdown('', unsafe_allow_html=True)
st.markdown('
📅 Years
', unsafe_allow_html=True)
selected_years = st.multiselect(
"Select years:",
options=filter_options['years'],
default=st.session_state.active_filters['years'],
disabled = filename_mode,
key="years_filter",
help="Choose which years to search"
)
st.markdown('
', unsafe_allow_html=True)
# Districts filter
st.markdown('', unsafe_allow_html=True)
st.markdown('
🏘️ Districts
', unsafe_allow_html=True)
selected_districts = st.multiselect(
"Select districts:",
options=filter_options['districts'],
default=st.session_state.active_filters['districts'],
disabled = filename_mode,
key="districts_filter",
help="Choose which districts to search"
)
st.markdown('
', unsafe_allow_html=True)
# Update active filters
st.session_state.active_filters = {
'sources': selected_sources if not filename_mode else [],
'years': selected_years if not filename_mode else [],
'districts': selected_districts if not filename_mode else [],
'filenames': selected_filenames
}
# Clear filters button
if st.button("🗑️ Clear All Filters", key="clear_filters_button"):
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
st.rerun()
# Main content area with tabs
tab1, tab2 = st.tabs(["💬 Chat", "📄 Retrieved Documents"])
with tab1:
# Chat container
chat_container = st.container()
with chat_container:
# Display conversation history
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
st.markdown(f'{message.content}
', unsafe_allow_html=True)
elif isinstance(message, AIMessage):
st.markdown(f'{message.content}
', unsafe_allow_html=True)
# Input area
st.markdown("
", unsafe_allow_html=True)
# Create two columns for input and button
col1, col2 = st.columns([4, 1])
with col1:
# Use a counter to force input clearing
if 'input_counter' not in st.session_state:
st.session_state.input_counter = 0
user_input = st.text_input(
"Type your message here...",
placeholder="Ask about budget allocations, expenditures, or audit findings...",
key=f"user_input_{st.session_state.input_counter}",
label_visibility="collapsed"
)
with col2:
send_button = st.button("Send", key="send_button", use_container_width=True)
# Clear chat button
if st.button("🗑️ Clear Chat", key="clear_chat_button"):
st.session_state.reset_conversation = True
# Clear all conversation files
import os
conversations_dir = "conversations"
if os.path.exists(conversations_dir):
for file in os.listdir(conversations_dir):
if file.endswith('.json'):
os.remove(os.path.join(conversations_dir, file))
st.rerun()
# Handle user input
if send_button and user_input:
# Construct filter context string
filter_context_str = ""
if selected_filenames:
filter_context_str += "FILTER CONTEXT:\n"
filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n"
filter_context_str += "USER QUERY:\n"
elif selected_sources or selected_years or selected_districts:
filter_context_str += "FILTER CONTEXT:\n"
if selected_sources:
filter_context_str += f"Sources: {', '.join(selected_sources)}\n"
if selected_years:
filter_context_str += f"Years: {', '.join(selected_years)}\n"
if selected_districts:
filter_context_str += f"Districts: {', '.join(selected_districts)}\n"
filter_context_str += "USER QUERY:\n"
full_query = filter_context_str + user_input
# Add user message to history
st.session_state.messages.append(HumanMessage(content=user_input))
# Get chatbot response
with st.spinner("🤔 Thinking..."):
try:
# Pass the full query with filter context
chat_result = st.session_state.chatbot.chat(full_query, st.session_state.conversation_id)
# Handle both old format (string) and new format (dict)
if isinstance(chat_result, dict):
response = chat_result['response']
rag_result = chat_result.get('rag_result')
st.session_state.last_rag_result = rag_result
# Track RAG retrieval for feedback
if rag_result:
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
# Get the actual RAG query
actual_rag_query = chat_result.get('actual_rag_query', '')
if actual_rag_query:
# Format it like the log message
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
formatted_query = f"{timestamp} - INFO - 🔍 ACTUAL RAG QUERY: '{actual_rag_query}'"
else:
formatted_query = "No RAG query available"
retrieval_entry = {
"conversation_up_to": serialize_messages(st.session_state.messages),
"rag_query_expansion": formatted_query,
"docs_retrieved": serialize_documents(sources)
}
st.session_state.rag_retrieval_history.append(retrieval_entry)
else:
response = chat_result
st.session_state.last_rag_result = None
# Add bot response to history
st.session_state.messages.append(AIMessage(content=response))
except Exception as e:
error_msg = f"Sorry, I encountered an error: {str(e)}"
st.session_state.messages.append(AIMessage(content=error_msg))
# Clear input and rerun
st.session_state.input_counter += 1 # This will clear the input
st.rerun()
with tab2:
# Document retrieval panel
if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result:
rag_result = st.session_state.last_rag_result
# Handle both PipelineResult object and dictionary formats
sources = None
if hasattr(rag_result, 'sources'):
# PipelineResult object format
sources = rag_result.sources
elif isinstance(rag_result, dict) and 'sources' in rag_result:
# Dictionary format from multi-agent system
sources = rag_result['sources']
if sources and len(sources) > 0:
# Count unique filenames
unique_filenames = set()
for doc in sources:
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
unique_filenames.add(filename)
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 10):**")
if len(unique_filenames) < len(sources):
st.info(f"💡 **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
for i, doc in enumerate(sources[:10]): # Show top 10
# Get relevance score and ID if available
metadata = getattr(doc, 'metadata', {})
score = metadata.get('reranked_score', metadata.get('original_score', None))
chunk_id = metadata.get('_id', 'Unknown')
score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
with st.expander(f"📄 Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
# Display document metadata with emojis
metadata = getattr(doc, 'metadata', {})
col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1])
with col1:
st.write(f"📄 **File:** {metadata.get('filename', 'Unknown')}")
with col2:
st.write(f"🏛️ **Source:** {metadata.get('source', 'Unknown')}")
with col3:
st.write(f"📅 **Year:** {metadata.get('year', 'Unknown')}")
with col4:
# Display page number and chunk ID
page = metadata.get('page_label', metadata.get('page', 'Unknown'))
chunk_id = metadata.get('_id', 'Unknown')
st.write(f"📖 **Page:** {page}")
st.write(f"🆔 **ID:** {chunk_id}")
# Display full content (no truncation)
content = getattr(doc, 'page_content', 'No content available')
st.write(f"**Full Content:**")
st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}")
else:
st.info("No documents were retrieved for the last query.")
else:
st.info("No documents have been retrieved yet. Start a conversation to see retrieved documents here.")
# Feedback Dashboard Section
st.markdown("---")
st.markdown("### 💬 Feedback Dashboard")
# Check if there's any conversation to provide feedback on
has_conversation = len(st.session_state.messages) > 0
has_retrievals = len(st.session_state.rag_retrieval_history) > 0
if not has_conversation:
st.info("💡 Start a conversation to provide feedback!")
st.markdown("The feedback dashboard will be enabled once you begin chatting.")
else:
st.markdown("Help us improve by providing feedback on this conversation.")
# Initialize feedback state if not exists
if 'feedback_submitted' not in st.session_state:
st.session_state.feedback_submitted = False
# Feedback form
with st.form("feedback_form", clear_on_submit=False):
col1, col2 = st.columns([1, 1])
with col1:
feedback_score = st.slider(
"Rate this conversation (1-5)",
min_value=1,
max_value=5,
help="How satisfied are you with the conversation?"
)
with col2:
is_feedback_about_last_retrieval = st.checkbox(
"Feedback about last retrieval only",
value=True,
help="If checked, feedback applies to the most recent document retrieval"
)
open_ended_feedback = st.text_area(
"Your feedback (optional)",
placeholder="Tell us what went well or what could be improved...",
height=100
)
# Disable submit if no score selected
submit_disabled = feedback_score is None
submitted = st.form_submit_button(
"📤 Submit Feedback",
use_container_width=True,
disabled=submit_disabled
)
if submitted and not st.session_state.feedback_submitted:
# Log the feedback data being submitted
print("=" * 80)
print("🔄 FEEDBACK SUBMISSION: Starting...")
print("=" * 80)
st.write("🔍 **Debug: Feedback Data Being Submitted:**")
# Create feedback data dictionary
feedback_dict = {
"open_ended_feedback": open_ended_feedback,
"score": feedback_score,
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
"retrieved_data": st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
"conversation_id": st.session_state.conversation_id,
"timestamp": time.time(),
"message_count": len(st.session_state.messages),
"has_retrievals": has_retrievals,
"retrieval_count": len(st.session_state.rag_retrieval_history)
}
print(f"📝 FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
# Create UserFeedback dataclass instance
feedback_obj = None # Initialize outside try block
try:
feedback_obj = create_feedback_from_dict(feedback_dict)
print(f"✅ FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
st.write(f"✅ **Feedback Object Created**")
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
st.write(f"- Score: {feedback_obj.score}/5")
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
# Convert back to dict for JSON serialization
feedback_data = feedback_obj.to_dict()
except Exception as e:
print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
st.error(f"Failed to create feedback object: {e}")
feedback_data = feedback_dict
# Display the data being submitted
st.json(feedback_data)
# Save feedback to file
feedback_dir = Path("feedback")
feedback_dir.mkdir(exist_ok=True)
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
try:
# Save to local file
print(f"💾 FEEDBACK SAVE: Saving to local file: {feedback_file}")
with open(feedback_file, 'w') as f:
json.dump(feedback_data, f, indent=2, default=str)
print(f"✅ FEEDBACK SAVE: Local file saved successfully")
st.success("✅ Thank you for your feedback! It has been saved locally.")
st.balloons()
# Save to Snowflake if enabled and credentials available
logger.info("🔄 FEEDBACK SAVE: Starting Snowflake save process...")
logger.info(f"📊 FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
try:
import os
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
logger.info(f"🔍 SNOWFLAKE CHECK: enabled={snowflake_enabled}")
if snowflake_enabled:
if feedback_obj:
try:
from auditqa.reporting.snowflake_connector import save_to_snowflake
logger.info("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
print("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...") # Also print to terminal
if save_to_snowflake(feedback_obj):
logger.info("✅ SNOWFLAKE UI: Successfully saved to Snowflake")
print("✅ SNOWFLAKE UI: Successfully saved to Snowflake") # Also print to terminal
st.success("✅ Feedback also saved to Snowflake!")
else:
logger.warning("⚠️ SNOWFLAKE UI: Save failed")
print("⚠️ SNOWFLAKE UI: Save failed") # Also print to terminal
st.warning("⚠️ Snowflake save failed, but local save succeeded")
except Exception as e:
logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
print(f"❌ SNOWFLAKE UI ERROR: {e}") # Also print to terminal
import traceback
traceback.print_exc() # Print full traceback to terminal
st.warning(f"⚠️ Could not save to Snowflake: {e}")
else:
logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") # Also print to terminal
st.warning("⚠️ Skipping Snowflake save (feedback object not created)")
else:
logger.info("💡 SNOWFLAKE UI: Integration disabled")
print("💡 SNOWFLAKE UI: Integration disabled") # Also print to terminal
st.info("💡 Snowflake integration disabled (set SNOWFLAKE_ENABLED=true to enable)")
except NameError as e:
import traceback
traceback.print_exc()
logger.error(f"❌ NameError in Snowflake save: {e}")
print(f"❌ NameError in Snowflake save: {e}") # Also print to terminal
st.warning(f"⚠️ Snowflake save error: {e}")
except Exception as e:
logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") # Also print to terminal
st.warning(f"⚠️ Snowflake save error: {e}")
# Mark feedback as submitted to prevent resubmission
st.session_state.feedback_submitted = True
print("=" * 80)
print(f"✅ FEEDBACK SUBMISSION: Completed successfully")
print("=" * 80)
# Log file location
st.info(f"📁 Feedback saved to: {feedback_file}")
except Exception as e:
print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
import traceback
traceback.print_exc()
st.error(f"❌ Error saving feedback: {e}")
st.write(f"Debug error: {str(e)}")
elif st.session_state.feedback_submitted:
st.success("✅ Feedback already submitted for this conversation!")
if st.button("🔄 Submit New Feedback", key="new_feedback_button"):
st.session_state.feedback_submitted = False
st.rerun()
# Display retrieval history stats
if st.session_state.rag_retrieval_history:
st.markdown("---")
st.markdown("#### 📊 Retrieval History")
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
st.markdown(f"**Retrieval #{idx}**")
# Display the actual RAG query
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
st.code(rag_query_expansion, language="text")
# Display summary stats
st.json({
"conversation_length": len(entry.get("conversation_up_to", [])),
"documents_retrieved": len(entry.get("docs_retrieved", []))
})
st.markdown("---")
# Auto-scroll to bottom
st.markdown("""
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()