# components/chat.py
import streamlit as st
from datetime import datetime
from langchain_core.messages import HumanMessage, AIMessage
from backend import get_embeddings_model, initialize_qa_system
from utils.persistence import PersistenceManager
def ensure_embeddings_initialized():
"""Ensure embeddings model is initialized in session state."""
if 'embeddings' not in st.session_state:
try:
st.session_state.embeddings = get_embeddings_model()
except Exception as e:
st.error(f"Error initializing embeddings model: {e}")
return False
return True
def format_session_date(session_id: str) -> str:
"""Format session ID into readable date."""
try:
# Handle different possible date formats
if '_' in session_id:
date_part = session_id.split('_')[1]
else:
date_part = session_id
if len(date_part) == 8: # Format: YYYYMMDD
return datetime.strptime(date_part, '%Y%m%d').strftime('%B %d, %Y')
elif len(date_part) == 14: # Format: YYYYMMDD_HHMMSS
return datetime.strptime(date_part, '%Y%m%d%H%M%S').strftime('%B %d, %Y %I:%M %p')
else:
return date_part # Return original if format unknown
except Exception as e:
return f"Session: {session_id}" # Fallback display
def clean_ai_response(content):
"""Clean up AI response content and remove technical artifacts."""
content = str(content)
# Remove common technical artifacts
if "content='" in content:
content = content.split("content='")[1]
if "additional_kwargs" in content:
content = content.split("additional_kwargs")[0]
# Clean up any remaining artifacts
content = content.strip("'")
content = content.replace('\\n', '\n')
content = content.replace('\\t', '\t')
return content
def format_assistant_response(content):
"""Format the assistant's response into a structured layout."""
try:
# Clean the content first
content = clean_ai_response(content)
# Identify sections and structure
lines = [line.strip() for line in content.split('\n') if line.strip()]
formatted_sections = []
current_section = []
for line in lines:
# Handle bullet points and lists
if line.startswith(('•', '-', '*')):
line = f"
{line.lstrip('•-* ')}"
if not current_section:
current_section.append("")
current_section.append(line)
# Handle section headers
elif line.startswith('**') and line.endswith('**'):
if current_section:
if current_section[0] == "":
current_section.append("
")
formatted_sections.extend(current_section)
current_section = []
header = line.strip('**')
formatted_sections.append(f'')
# Regular text
else:
if current_section and current_section[0] == "":
current_section.append("
")
formatted_sections.extend(current_section)
current_section = []
formatted_sections.append(f'{line}
')
# Add any remaining content
if current_section:
if current_section[0] == "":
current_section.append("
")
formatted_sections.extend(current_section)
# Build the final HTML
formatted_html = f"""
{''.join(formatted_sections)}
"""
return formatted_html
except Exception as e:
st.error(f"Error formatting response: {str(e)}")
return str(content)
def display_chat_interface():
"""Display modern chat interface with clean formatting."""
# Initialize persistence if needed
if 'persistence' not in st.session_state:
st.session_state.persistence = PersistenceManager()
# Add custom CSS for modern chat styling
st.markdown("""
""", unsafe_allow_html=True)
try:
# Check if QA system is initialized
if 'qa_system' not in st.session_state or st.session_state.qa_system is None:
# Show available sessions if any
sessions = st.session_state.persistence.list_available_sessions()
if sessions:
st.markdown('', unsafe_allow_html=True)
st.subheader("💾 Load Previous Chat")
selected_session = st.selectbox(
"Select a previous conversation:",
options=[s['session_id'] for s in sessions],
format_func=lambda x: f"Chat from {format_session_date(x)}"
)
if selected_session and st.button("Load Selected Chat"):
with st.spinner("Loading previous chat..."):
messages = st.session_state.persistence.load_chat_history(selected_session)
if messages:
st.session_state.messages = messages
# Load vector store
vector_store = st.session_state.persistence.load_vector_store(selected_session)
if vector_store:
st.session_state.vector_store = vector_store
st.session_state.qa_system = initialize_qa_system(vector_store)
st.session_state.current_session_id = selected_session
st.rerun()
st.markdown('
', unsafe_allow_html=True)
st.warning("Please upload documents or select a previous chat to begin.")
return
# Initialize chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Display current session info if available
if 'current_session_id' in st.session_state:
session_date = format_session_date(st.session_state.current_session_id)
st.markdown(f"""
📅 Current Session: {session_date}
""", unsafe_allow_html=True)
# Display chat history
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
st.markdown(f"""
🧑💼 You:
{message.content}
""", unsafe_allow_html=True)
elif isinstance(message, AIMessage):
st.markdown(f"""
🤖 Assistant:
{format_assistant_response(message.content)}
""", unsafe_allow_html=True)
# Chat input
if prompt := st.chat_input("Ask about your documents..."):
with st.spinner("Analyzing..."):
# Validate input
if not prompt.strip():
st.warning("Please enter a valid question.")
return
try:
# Create and append human message
human_message = HumanMessage(content=prompt)
st.session_state.messages.append(human_message)
# Get response from QA system
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": []
})
if response:
ai_message = AIMessage(content=str(response))
st.session_state.messages.append(ai_message)
# Save chat history
if 'current_session_id' in st.session_state:
st.session_state.persistence.save_chat_history(
st.session_state.messages,
st.session_state.current_session_id,
metadata={
'last_question': prompt,
'last_updated': datetime.now().isoformat()
}
)
st.rerun()
else:
st.error("No valid response received. Please try again.")
except Exception as e:
st.error(f"Error processing response: {str(e)}")
import traceback
st.error(traceback.format_exc())
# Remove the failed message attempt
if st.session_state.messages:
st.session_state.messages.pop()
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
import traceback
st.error(traceback.format_exc())