File size: 4,316 Bytes
9817125
d6b2b59
a8872f2
e9459fa
9e14c06
d6b2b59
9817125
 
067ac7d
80c5436
a8872f2
d6b2b59
 
502f9a1
a8872f2
4534bd6
 
d6b2b59
9817125
4534bd6
 
 
 
 
 
 
 
9817125
4534bd6
9817125
4534bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9817125
 
 
 
 
 
 
 
80c5436
2260e72
 
 
 
80c5436
2260e72
 
 
 
 
80c5436
2260e72
 
 
 
 
 
 
80c5436
 
 
 
 
 
 
 
2260e72
80c5436
2260e72
80c5436
2260e72
80c5436
2260e72
80c5436
 
 
 
2260e72
11ec342
 
 
 
 
9817125
 
d6b2b59
 
9817125
 
d6b2b59
9817125
11ec342
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# components/chat.py

import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from utils.database import verify_vector_store
from threading import Lock
from typing import Optional
import traceback
from utils.response_formatter import EnhancedResponseFormatter, display_enhanced_response


# Create a lock for QA system access
qa_lock = Lock()

def display_chat_interface():
    """Display chat interface with proper formatting."""
    if not st.session_state.get('messages'):
        st.session_state.messages = []

    # Display chat messages
    for message in st.session_state.messages:
        if isinstance(message, HumanMessage):
            with st.chat_message("user"):
                st.write(message.content)
        elif isinstance(message, AIMessage):
            with st.chat_message("assistant"):
                display_enhanced_response(message.content, message.additional_kwargs.get('sources'))

    # Chat input
    if prompt := st.chat_input("Ask about your documents..."):
        with st.chat_message("user"):
            st.write(prompt)
            st.session_state.messages.append(HumanMessage(content=prompt))

        with st.chat_message("assistant"):
            with st.spinner("Analyzing documents..."):
                response = st.session_state.qa_system.invoke({
                    "input": prompt,
                    "chat_history": st.session_state.messages
                })
                
                if response:
                    ai_message = AIMessage(
                        content=str(response),
                        additional_kwargs={'sources': response.metadata.get('sources', [])}
                    )
                    st.session_state.messages.append(ai_message)
                    display_enhanced_response(
                        str(response),
                        response.metadata.get('sources', [])
                    )

def _verify_chat_ready() -> bool:
    """Check if the chat system is properly initialized."""
    if 'qa_system' not in st.session_state or st.session_state.qa_system is None:
        st.warning("Please upload documents first to initialize the chat system.")
        return False
    return True


def _process_user_message(prompt: str):
    """Process a new user message and generate AI response."""
    try:
        with st.spinner("Analyzing..."):
            # Create and display user message
            human_message = HumanMessage(content=prompt)
            st.session_state.messages.append(human_message)
            with st.chat_message("user"):
                st.write(prompt)

            # Generate AI response with source tracking
            with qa_lock:
                response = st.session_state.qa_system.invoke({
                    "input": prompt,
                    "chat_history": st.session_state.messages
                })

            if response:
                # Extract sources from response metadata
                sources = response.metadata.get('sources', []) if hasattr(response, 'metadata') else []
                
                # Create and display AI message with enhanced formatting
                ai_message = AIMessage(
                    content=str(response),
                    additional_kwargs={'sources': sources}
                )
                st.session_state.messages.append(ai_message)
                
                with st.chat_message("assistant"):
                    display_enhanced_response(
                        str(response),
                        sources=sources
                    )
                
                st.rerun()
            else:
                st.error("No response received. Please try again.")

    except Exception as e:
        st.error(f"An error occurred while processing your message: {str(e)}")
        if st.session_state.get('debug_mode'):
            st.error(traceback.format_exc())

def _display_chat_history():
    """Display all messages in the chat history."""
    for message in st.session_state.messages:
        if isinstance(message, HumanMessage):
            with st.chat_message("user"):
                st.write(message.content)
        elif isinstance(message, AIMessage):
            with st.chat_message("assistant"):
                st.write(message.content)