cryogenic22's picture
Update components/chat.py
4534bd6 verified
# components/chat.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from utils.database import verify_vector_store
from threading import Lock
from typing import Optional
import traceback
from utils.response_formatter import EnhancedResponseFormatter, display_enhanced_response
# Create a lock for QA system access
qa_lock = Lock()
def display_chat_interface():
"""Display chat interface with proper formatting."""
if not st.session_state.get('messages'):
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.write(message.content)
elif isinstance(message, AIMessage):
with st.chat_message("assistant"):
display_enhanced_response(message.content, message.additional_kwargs.get('sources'))
# Chat input
if prompt := st.chat_input("Ask about your documents..."):
with st.chat_message("user"):
st.write(prompt)
st.session_state.messages.append(HumanMessage(content=prompt))
with st.chat_message("assistant"):
with st.spinner("Analyzing documents..."):
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": st.session_state.messages
})
if response:
ai_message = AIMessage(
content=str(response),
additional_kwargs={'sources': response.metadata.get('sources', [])}
)
st.session_state.messages.append(ai_message)
display_enhanced_response(
str(response),
response.metadata.get('sources', [])
)
def _verify_chat_ready() -> bool:
"""Check if the chat system is properly initialized."""
if 'qa_system' not in st.session_state or st.session_state.qa_system is None:
st.warning("Please upload documents first to initialize the chat system.")
return False
return True
def _process_user_message(prompt: str):
"""Process a new user message and generate AI response."""
try:
with st.spinner("Analyzing..."):
# Create and display user message
human_message = HumanMessage(content=prompt)
st.session_state.messages.append(human_message)
with st.chat_message("user"):
st.write(prompt)
# Generate AI response with source tracking
with qa_lock:
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": st.session_state.messages
})
if response:
# Extract sources from response metadata
sources = response.metadata.get('sources', []) if hasattr(response, 'metadata') else []
# Create and display AI message with enhanced formatting
ai_message = AIMessage(
content=str(response),
additional_kwargs={'sources': sources}
)
st.session_state.messages.append(ai_message)
with st.chat_message("assistant"):
display_enhanced_response(
str(response),
sources=sources
)
st.rerun()
else:
st.error("No response received. Please try again.")
except Exception as e:
st.error(f"An error occurred while processing your message: {str(e)}")
if st.session_state.get('debug_mode'):
st.error(traceback.format_exc())
def _display_chat_history():
"""Display all messages in the chat history."""
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.write(message.content)
elif isinstance(message, AIMessage):
with st.chat_message("assistant"):
st.write(message.content)