# components/chat.py import streamlit as st from langchain_core.messages import HumanMessage, AIMessage from utils.database import verify_vector_store from threading import Lock from typing import Optional import traceback from utils.response_formatter import EnhancedResponseFormatter, display_enhanced_response # Create a lock for QA system access qa_lock = Lock() def display_chat_interface(): """Display chat interface with proper formatting.""" if not st.session_state.get('messages'): st.session_state.messages = [] # Display chat messages for message in st.session_state.messages: if isinstance(message, HumanMessage): with st.chat_message("user"): st.write(message.content) elif isinstance(message, AIMessage): with st.chat_message("assistant"): display_enhanced_response(message.content, message.additional_kwargs.get('sources')) # Chat input if prompt := st.chat_input("Ask about your documents..."): with st.chat_message("user"): st.write(prompt) st.session_state.messages.append(HumanMessage(content=prompt)) with st.chat_message("assistant"): with st.spinner("Analyzing documents..."): response = st.session_state.qa_system.invoke({ "input": prompt, "chat_history": st.session_state.messages }) if response: ai_message = AIMessage( content=str(response), additional_kwargs={'sources': response.metadata.get('sources', [])} ) st.session_state.messages.append(ai_message) display_enhanced_response( str(response), response.metadata.get('sources', []) ) def _verify_chat_ready() -> bool: """Check if the chat system is properly initialized.""" if 'qa_system' not in st.session_state or st.session_state.qa_system is None: st.warning("Please upload documents first to initialize the chat system.") return False return True def _process_user_message(prompt: str): """Process a new user message and generate AI response.""" try: with st.spinner("Analyzing..."): # Create and display user message human_message = HumanMessage(content=prompt) st.session_state.messages.append(human_message) with st.chat_message("user"): st.write(prompt) # Generate AI response with source tracking with qa_lock: response = st.session_state.qa_system.invoke({ "input": prompt, "chat_history": st.session_state.messages }) if response: # Extract sources from response metadata sources = response.metadata.get('sources', []) if hasattr(response, 'metadata') else [] # Create and display AI message with enhanced formatting ai_message = AIMessage( content=str(response), additional_kwargs={'sources': sources} ) st.session_state.messages.append(ai_message) with st.chat_message("assistant"): display_enhanced_response( str(response), sources=sources ) st.rerun() else: st.error("No response received. Please try again.") except Exception as e: st.error(f"An error occurred while processing your message: {str(e)}") if st.session_state.get('debug_mode'): st.error(traceback.format_exc()) def _display_chat_history(): """Display all messages in the chat history.""" for message in st.session_state.messages: if isinstance(message, HumanMessage): with st.chat_message("user"): st.write(message.content) elif isinstance(message, AIMessage): with st.chat_message("assistant"): st.write(message.content)