File size: 4,316 Bytes
9817125 d6b2b59 a8872f2 e9459fa 9e14c06 d6b2b59 9817125 067ac7d 80c5436 a8872f2 d6b2b59 502f9a1 a8872f2 4534bd6 d6b2b59 9817125 4534bd6 9817125 4534bd6 9817125 4534bd6 9817125 80c5436 2260e72 80c5436 2260e72 80c5436 2260e72 80c5436 2260e72 80c5436 2260e72 80c5436 2260e72 80c5436 2260e72 80c5436 2260e72 11ec342 9817125 d6b2b59 9817125 d6b2b59 9817125 11ec342 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# components/chat.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from utils.database import verify_vector_store
from threading import Lock
from typing import Optional
import traceback
from utils.response_formatter import EnhancedResponseFormatter, display_enhanced_response
# Create a lock for QA system access
qa_lock = Lock()
def display_chat_interface():
"""Display chat interface with proper formatting."""
if not st.session_state.get('messages'):
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.write(message.content)
elif isinstance(message, AIMessage):
with st.chat_message("assistant"):
display_enhanced_response(message.content, message.additional_kwargs.get('sources'))
# Chat input
if prompt := st.chat_input("Ask about your documents..."):
with st.chat_message("user"):
st.write(prompt)
st.session_state.messages.append(HumanMessage(content=prompt))
with st.chat_message("assistant"):
with st.spinner("Analyzing documents..."):
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": st.session_state.messages
})
if response:
ai_message = AIMessage(
content=str(response),
additional_kwargs={'sources': response.metadata.get('sources', [])}
)
st.session_state.messages.append(ai_message)
display_enhanced_response(
str(response),
response.metadata.get('sources', [])
)
def _verify_chat_ready() -> bool:
"""Check if the chat system is properly initialized."""
if 'qa_system' not in st.session_state or st.session_state.qa_system is None:
st.warning("Please upload documents first to initialize the chat system.")
return False
return True
def _process_user_message(prompt: str):
"""Process a new user message and generate AI response."""
try:
with st.spinner("Analyzing..."):
# Create and display user message
human_message = HumanMessage(content=prompt)
st.session_state.messages.append(human_message)
with st.chat_message("user"):
st.write(prompt)
# Generate AI response with source tracking
with qa_lock:
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": st.session_state.messages
})
if response:
# Extract sources from response metadata
sources = response.metadata.get('sources', []) if hasattr(response, 'metadata') else []
# Create and display AI message with enhanced formatting
ai_message = AIMessage(
content=str(response),
additional_kwargs={'sources': sources}
)
st.session_state.messages.append(ai_message)
with st.chat_message("assistant"):
display_enhanced_response(
str(response),
sources=sources
)
st.rerun()
else:
st.error("No response received. Please try again.")
except Exception as e:
st.error(f"An error occurred while processing your message: {str(e)}")
if st.session_state.get('debug_mode'):
st.error(traceback.format_exc())
def _display_chat_history():
"""Display all messages in the chat history."""
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.write(message.content)
elif isinstance(message, AIMessage):
with st.chat_message("assistant"):
st.write(message.content) |