|
|
|
|
|
|
|
|
import streamlit as st |
|
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
from utils.database import verify_vector_store |
|
|
from threading import Lock |
|
|
from typing import Optional |
|
|
import traceback |
|
|
from utils.response_formatter import EnhancedResponseFormatter, display_enhanced_response |
|
|
|
|
|
|
|
|
|
|
|
qa_lock = Lock() |
|
|
|
|
|
def display_chat_interface(): |
|
|
"""Display chat interface with proper formatting.""" |
|
|
if not st.session_state.get('messages'): |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
if isinstance(message, HumanMessage): |
|
|
with st.chat_message("user"): |
|
|
st.write(message.content) |
|
|
elif isinstance(message, AIMessage): |
|
|
with st.chat_message("assistant"): |
|
|
display_enhanced_response(message.content, message.additional_kwargs.get('sources')) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask about your documents..."): |
|
|
with st.chat_message("user"): |
|
|
st.write(prompt) |
|
|
st.session_state.messages.append(HumanMessage(content=prompt)) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Analyzing documents..."): |
|
|
response = st.session_state.qa_system.invoke({ |
|
|
"input": prompt, |
|
|
"chat_history": st.session_state.messages |
|
|
}) |
|
|
|
|
|
if response: |
|
|
ai_message = AIMessage( |
|
|
content=str(response), |
|
|
additional_kwargs={'sources': response.metadata.get('sources', [])} |
|
|
) |
|
|
st.session_state.messages.append(ai_message) |
|
|
display_enhanced_response( |
|
|
str(response), |
|
|
response.metadata.get('sources', []) |
|
|
) |
|
|
|
|
|
def _verify_chat_ready() -> bool: |
|
|
"""Check if the chat system is properly initialized.""" |
|
|
if 'qa_system' not in st.session_state or st.session_state.qa_system is None: |
|
|
st.warning("Please upload documents first to initialize the chat system.") |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def _process_user_message(prompt: str): |
|
|
"""Process a new user message and generate AI response.""" |
|
|
try: |
|
|
with st.spinner("Analyzing..."): |
|
|
|
|
|
human_message = HumanMessage(content=prompt) |
|
|
st.session_state.messages.append(human_message) |
|
|
with st.chat_message("user"): |
|
|
st.write(prompt) |
|
|
|
|
|
|
|
|
with qa_lock: |
|
|
response = st.session_state.qa_system.invoke({ |
|
|
"input": prompt, |
|
|
"chat_history": st.session_state.messages |
|
|
}) |
|
|
|
|
|
if response: |
|
|
|
|
|
sources = response.metadata.get('sources', []) if hasattr(response, 'metadata') else [] |
|
|
|
|
|
|
|
|
ai_message = AIMessage( |
|
|
content=str(response), |
|
|
additional_kwargs={'sources': sources} |
|
|
) |
|
|
st.session_state.messages.append(ai_message) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
display_enhanced_response( |
|
|
str(response), |
|
|
sources=sources |
|
|
) |
|
|
|
|
|
st.rerun() |
|
|
else: |
|
|
st.error("No response received. Please try again.") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"An error occurred while processing your message: {str(e)}") |
|
|
if st.session_state.get('debug_mode'): |
|
|
st.error(traceback.format_exc()) |
|
|
|
|
|
def _display_chat_history(): |
|
|
"""Display all messages in the chat history.""" |
|
|
for message in st.session_state.messages: |
|
|
if isinstance(message, HumanMessage): |
|
|
with st.chat_message("user"): |
|
|
st.write(message.content) |
|
|
elif isinstance(message, AIMessage): |
|
|
with st.chat_message("assistant"): |
|
|
st.write(message.content) |