cryogenic22's picture
Update components/chat.py
9e17cc9 verified
# components/chat.py
import streamlit as st
from datetime import datetime
from langchain_core.messages import HumanMessage, AIMessage
from backend import get_embeddings_model, initialize_qa_system
from utils.persistence import PersistenceManager
def ensure_embeddings_initialized():
"""Ensure embeddings model is initialized in session state."""
if 'embeddings' not in st.session_state:
try:
st.session_state.embeddings = get_embeddings_model()
except Exception as e:
st.error(f"Error initializing embeddings model: {e}")
return False
return True
def format_session_date(session_id: str) -> str:
"""Format session ID into readable date."""
try:
# Handle different possible date formats
if '_' in session_id:
date_part = session_id.split('_')[1]
else:
date_part = session_id
if len(date_part) == 8: # Format: YYYYMMDD
return datetime.strptime(date_part, '%Y%m%d').strftime('%B %d, %Y')
elif len(date_part) == 14: # Format: YYYYMMDD_HHMMSS
return datetime.strptime(date_part, '%Y%m%d%H%M%S').strftime('%B %d, %Y %I:%M %p')
else:
return date_part # Return original if format unknown
except Exception as e:
return f"Session: {session_id}" # Fallback display
def clean_ai_response(content):
"""Clean up AI response content and remove technical artifacts."""
content = str(content)
# Remove common technical artifacts
if "content='" in content:
content = content.split("content='")[1]
if "additional_kwargs" in content:
content = content.split("additional_kwargs")[0]
# Clean up any remaining artifacts
content = content.strip("'")
content = content.replace('\\n', '\n')
content = content.replace('\\t', '\t')
return content
def format_assistant_response(content):
"""Format the assistant's response into a structured layout."""
try:
# Clean the content first
content = clean_ai_response(content)
# Identify sections and structure
lines = [line.strip() for line in content.split('\n') if line.strip()]
formatted_sections = []
current_section = []
for line in lines:
# Handle bullet points and lists
if line.startswith(('•', '-', '*')):
line = f"<li>{line.lstrip('•-* ')}</li>"
if not current_section:
current_section.append("<ul>")
current_section.append(line)
# Handle section headers
elif line.startswith('**') and line.endswith('**'):
if current_section:
if current_section[0] == "<ul>":
current_section.append("</ul>")
formatted_sections.extend(current_section)
current_section = []
header = line.strip('**')
formatted_sections.append(f'<h4 class="section-header">{header}</h4>')
# Regular text
else:
if current_section and current_section[0] == "<ul>":
current_section.append("</ul>")
formatted_sections.extend(current_section)
current_section = []
formatted_sections.append(f'<p>{line}</p>')
# Add any remaining content
if current_section:
if current_section[0] == "<ul>":
current_section.append("</ul>")
formatted_sections.extend(current_section)
# Build the final HTML
formatted_html = f"""
<div class="response-content">
{''.join(formatted_sections)}
</div>
"""
return formatted_html
except Exception as e:
st.error(f"Error formatting response: {str(e)}")
return str(content)
def display_chat_interface():
"""Display modern chat interface with clean formatting."""
# Initialize persistence if needed
if 'persistence' not in st.session_state:
st.session_state.persistence = PersistenceManager()
# Add custom CSS for modern chat styling
st.markdown("""
<style>
/* Chat container */
.chat-container {
max-width: 800px;
margin: auto;
}
/* User message */
.user-message {
background-color: #f0f2f6;
padding: 1rem;
border-radius: 10px;
margin: 1rem 0;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
/* Assistant message */
.assistant-message {
background-color: #ffffff;
border: 1px solid #e0e0e0;
padding: 1.5rem;
border-radius: 10px;
margin: 1rem 0;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
/* Response content */
.response-content {
line-height: 1.6;
}
/* Section headers */
.section-header {
color: #0f52ba;
margin: 1.5rem 0 1rem 0;
font-size: 1.1rem;
font-weight: 600;
border-bottom: 2px solid #e0e0e0;
padding-bottom: 0.5rem;
}
/* Lists */
.response-content ul {
margin: 1rem 0;
padding-left: 1.5rem;
list-style-type: none;
}
.response-content li {
margin: 0.5rem 0;
position: relative;
padding-left: 1rem;
}
.response-content li:before {
content: "•";
position: absolute;
left: -1rem;
color: #0f52ba;
}
/* Paragraphs */
.response-content p {
margin: 1rem 0;
color: #2c3e50;
}
/* Source citations */
.source-citation {
font-style: italic;
color: #666;
border-top: 1px solid #e0e0e0;
margin-top: 1rem;
padding-top: 0.5rem;
font-size: 0.9rem;
}
/* Session selector */
.session-selector {
padding: 1rem;
background-color: #f8f9fa;
border-radius: 10px;
margin-bottom: 1rem;
border: 1px solid #e0e0e0;
}
/* Current session info */
.session-info {
margin-bottom: 1rem;
padding: 0.5rem 1rem;
background-color: #e3f2fd;
border-radius: 5px;
font-size: 0.9rem;
color: #1565c0;
}
</style>
""", unsafe_allow_html=True)
try:
# Check if QA system is initialized
if 'qa_system' not in st.session_state or st.session_state.qa_system is None:
# Show available sessions if any
sessions = st.session_state.persistence.list_available_sessions()
if sessions:
st.markdown('<div class="session-selector">', unsafe_allow_html=True)
st.subheader("💾 Load Previous Chat")
selected_session = st.selectbox(
"Select a previous conversation:",
options=[s['session_id'] for s in sessions],
format_func=lambda x: f"Chat from {format_session_date(x)}"
)
if selected_session and st.button("Load Selected Chat"):
with st.spinner("Loading previous chat..."):
messages = st.session_state.persistence.load_chat_history(selected_session)
if messages:
st.session_state.messages = messages
# Load vector store
vector_store = st.session_state.persistence.load_vector_store(selected_session)
if vector_store:
st.session_state.vector_store = vector_store
st.session_state.qa_system = initialize_qa_system(vector_store)
st.session_state.current_session_id = selected_session
st.rerun()
st.markdown('</div>', unsafe_allow_html=True)
st.warning("Please upload documents or select a previous chat to begin.")
return
# Initialize chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Display current session info if available
if 'current_session_id' in st.session_state:
session_date = format_session_date(st.session_state.current_session_id)
st.markdown(f"""
<div class="session-info">
📅 Current Session: {session_date}
</div>
""", unsafe_allow_html=True)
# Display chat history
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
st.markdown(f"""
<div class="user-message">
🧑‍💼 <strong>You:</strong><br>{message.content}
</div>
""", unsafe_allow_html=True)
elif isinstance(message, AIMessage):
st.markdown(f"""
<div class="assistant-message">
🤖 <strong>Assistant:</strong>
{format_assistant_response(message.content)}
</div>
""", unsafe_allow_html=True)
# Chat input
if prompt := st.chat_input("Ask about your documents..."):
with st.spinner("Analyzing..."):
# Validate input
if not prompt.strip():
st.warning("Please enter a valid question.")
return
try:
# Create and append human message
human_message = HumanMessage(content=prompt)
st.session_state.messages.append(human_message)
# Get response from QA system
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": []
})
if response:
ai_message = AIMessage(content=str(response))
st.session_state.messages.append(ai_message)
# Save chat history
if 'current_session_id' in st.session_state:
st.session_state.persistence.save_chat_history(
st.session_state.messages,
st.session_state.current_session_id,
metadata={
'last_question': prompt,
'last_updated': datetime.now().isoformat()
}
)
st.rerun()
else:
st.error("No valid response received. Please try again.")
except Exception as e:
st.error(f"Error processing response: {str(e)}")
import traceback
st.error(traceback.format_exc())
# Remove the failed message attempt
if st.session_state.messages:
st.session_state.messages.pop()
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
import traceback
st.error(traceback.format_exc())