Legal_AI_Agent / components /chat_interface.py
cryogenic22's picture
Update components/chat_interface.py
c5654cc verified
import streamlit as st
from typing import List, Dict
import anthropic
import os
from datetime import datetime
from utils.legal_prompt_generator import LegalPromptGenerator
class ChatInterface:
def __init__(self, case_manager, vector_store, document_processor):
"""Initialize ChatInterface with all required components."""
self.case_manager = case_manager
self.vector_store = vector_store
self.document_processor = document_processor
self.prompt_generator = LegalPromptGenerator()
try:
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
st.error("Please set the ANTHROPIC_API_KEY environment variable.")
st.stop()
self.client = anthropic.Anthropic(api_key=api_key)
except Exception as e:
st.error(f"Error initializing Anthropic client: {str(e)}")
st.stop()
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "analyzed_documents" not in st.session_state:
st.session_state.analyzed_documents = []
if "context_chunks" not in st.session_state:
st.session_state.context_chunks = []
if "current_case" not in st.session_state:
st.session_state.current_case = None
def render(self):
"""Render the chat interface with document and context management."""
st.markdown("""
<style>
.chat-message {
padding: 1.5rem;
border-radius: 0.5rem;
margin-bottom: 1rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.user-message {
background-color: #f0f7ff;
border-left: 4px solid #2B547E;
}
.assistant-message {
background-color: #ffffff;
border-left: 4px solid #4CAF50;
}
.reference-box {
background-color: #f5f5f5;
padding: 0.8rem;
border-radius: 0.3rem;
font-size: 0.9em;
margin-top: 0.5rem;
}
.document-chunk {
border-left: 3px solid #2196F3;
padding-left: 1rem;
margin: 0.5rem 0;
}
</style>
""", unsafe_allow_html=True)
# Display active documents in the sidebar
with st.sidebar:
st.subheader("📚 Active Documents")
for doc in st.session_state.analyzed_documents:
with st.expander(f"📄 {doc['name']}", expanded=False):
st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}")
st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}")
# Display chat history
for message in st.session_state.messages:
message_class = "user-message" if message["role"] == "user" else "assistant-message"
with st.container():
st.markdown(f"""
<div class="chat-message {message_class}">
{message["content"]}
{'<div class="reference-box">' + message.get("references", "") + '</div>' if message.get("references") else ""}
</div>
""", unsafe_allow_html=True)
# Chat input
if prompt := st.chat_input("Ask about your documents..."):
self._handle_chat_input(prompt)
def _handle_chat_input(self, prompt: str):
"""Process user input and generate a response."""
st.session_state.messages.append({"role": "user", "content": prompt})
with st.spinner("Analyzing documents and generating a response..."):
try:
# Retrieve relevant document chunks
context_chunks = self.vector_store.similarity_search(
query=prompt,
k=5,
filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]}
)
# Generate the response
response, references = self.generate_response(prompt, context_chunks)
# Add assistant response
st.session_state.messages.append({
"role": "assistant",
"content": response,
"references": references
})
# Update context for future queries
st.session_state.context_chunks = context_chunks
except Exception as e:
st.error(f"Error generating response: {str(e)}")
def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]:
"""Generate a response using the LLM and LegalPromptGenerator."""
try:
# Generate structured messages
messages = self._generate_messages(prompt, context_chunks)
# Call the LLM
response = self.client.messages.create(
model="claude-3",
max_tokens=2000,
temperature=0.7,
messages=messages
)
# Format references
references_html = self._format_references(context_chunks)
return response.content[0].text, references_html
except Exception as e:
st.error(f"Error generating response: {str(e)}")
return "An error occurred while processing your query.", ""
def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]:
"""Generate structured messages for LLM input."""
# Get case metadata if available
case_metadata = None
if st.session_state.current_case:
case_metadata = self.case_manager.get_case(st.session_state.current_case)
# Generate system message
system_message = self.prompt_generator.generate_system_message(
context_chunks=context_chunks,
query=prompt,
case_metadata=case_metadata
)
# Generate user message
context = "\n".join([
f"Document: {chunk['metadata'].get('title', 'Untitled')}\n"
f"Section: {chunk['text']}\n"
for chunk in context_chunks
])
user_message = self.prompt_generator.generate_user_message(prompt, context)
# Handle follow-up questions
if st.session_state.messages:
previous_query = next(
(m["content"] for m in reversed(st.session_state.messages) if m["role"] == "user"),
None
)
previous_response = next(
(m["content"] for m in reversed(st.session_state.messages) if m["role"] == "assistant"),
None
)
if previous_query and previous_response:
user_message = self.prompt_generator.generate_follow_up_prompt(
original_query=previous_query,
follow_up_query=prompt,
previous_response=previous_response,
context_chunks=context_chunks
)
return [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
def _format_references(self, chunks: List[Dict]) -> str:
"""Format references as HTML for display."""
references = []
for i, chunk in enumerate(chunks, 1):
references.append(f"""
<div class="document-chunk">
<strong>Reference {i}:</strong> {chunk['metadata'].get('title', 'Untitled')}
<br/>
<em>Section:</em> {chunk['text'][:200]}...
</div>
""")
return "\n".join(references)
def add_analyzed_document(self, doc: Dict):
"""Add a document to session state with metadata tracking."""
doc['metadata']['added_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if doc not in st.session_state.analyzed_documents:
st.session_state.analyzed_documents.append(doc)