import streamlit as st from typing import List, Dict import anthropic import os from datetime import datetime from utils.legal_prompt_generator import LegalPromptGenerator class ChatInterface: def __init__(self, vector_store, document_processor): self.vector_store = vector_store self.document_processor = document_processor try: api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: st.error("Please set the ANTHROPIC_API_KEY environment variable.") st.stop() self.client = anthropic.Anthropic(api_key=api_key) except Exception as e: st.error(f"Error initializing Anthropic client: {str(e)}") st.stop() # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "analyzed_documents" not in st.session_state: st.session_state.analyzed_documents = [] if "context_chunks" not in st.session_state: st.session_state.context_chunks = [] def render(self): """Render an improved chat interface with better document context.""" st.markdown(""" """, unsafe_allow_html=True) # Display active documents and context with st.sidebar: st.subheader("📚 Active Documents") for doc in st.session_state.analyzed_documents: with st.expander(f"📄 {doc['name']}", expanded=False): st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}") st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}") # Display chat history with improved styling for message in st.session_state.messages: message_class = "user-message" if message["role"] == "user" else "assistant-message" with st.container(): st.markdown(f"""
""", unsafe_allow_html=True) # Chat input with improved context handling if prompt := st.chat_input("Ask about your documents..."): self._handle_chat_input(prompt) def _handle_chat_input(self, prompt: str): """Handle chat input with improved context management.""" # Add user message st.session_state.messages.append({"role": "user", "content": prompt}) # Get relevant context chunks context_chunks = self.vector_store.similarity_search( prompt, k=5, filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]} ) # Generate response with st.spinner("Analyzing documents and generating response..."): response_content, references = self.generate_response(prompt, context_chunks) # Add assistant message with references st.session_state.messages.append({ "role": "assistant", "content": response_content, "references": references }) # Store context chunks for future reference st.session_state.context_chunks = context_chunks def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]: """Generate response using Claude with improved context handling.""" try: # Prepare context from chunks context = "\n".join([ f"Document: {chunk['metadata']['title']}\n" f"Section: {chunk['text']}\n" f"Type: {chunk['metadata']['type']}\n" f"Jurisdiction: {chunk['metadata']['jurisdiction']}\n" for chunk in context_chunks ]) # Generate system message using ontology system_message = self._generate_system_message(prompt, context_chunks) # Call Claude API message = self.client.messages.create( model="claude-3-sonnet-20240229", max_tokens=2000, temperature=0.7, messages=[ {"role": "system", "content": system_message}, {"role": "user", "content": f"Question: {prompt}\n\nContext:\n{context}"} ] ) # Format references references_html = self._format_references(context_chunks) return message.content[0].text, references_html except Exception as e: st.error(f"Error generating response: {str(e)}") return "I apologize, but I encountered an error generating the response.", "" def __init__(self, case_manager, vector_store, document_processor): """Initialize ChatInterface with enhanced components.""" self.case_manager = case_manager self.vector_store = vector_store self.document_processor = document_processor self.prompt_generator = LegalPromptGenerator() try: api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: st.error("Please set the ANTHROPIC_API_KEY environment variable.") st.stop() self.client = anthropic.Anthropic(api_key=api_key) except Exception as e: st.error(f"Error initializing Anthropic client: {str(e)}") st.stop() # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "analyzed_documents" not in st.session_state: st.session_state.analyzed_documents = [] if "context_chunks" not in st.session_state: st.session_state.context_chunks = [] if "current_case" not in st.session_state: st.session_state.current_case = None def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]: """Generate messages for the Claude API with enhanced prompts.""" # Get case metadata if available case_metadata = None if st.session_state.current_case: case_metadata = self.case_manager.get_case(st.session_state.current_case) # Generate enhanced system message system_message = self.prompt_generator.generate_system_message( context_chunks=context_chunks, query=prompt, case_metadata=case_metadata ) # Prepare context from chunks context = "\n".join([ f"Document: {chunk['metadata']['title']}\n" f"Section: {chunk['text']}\n" f"Type: {chunk['metadata']['type']}\n" f"Jurisdiction: {chunk['metadata']['jurisdiction']}\n" for chunk in context_chunks ]) # Generate user message user_message = self.prompt_generator.generate_user_message(prompt, context) # Check if this is a follow-up question if st.session_state.messages: previous_query = next( (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "user"), None ) previous_response = next( (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "assistant"), None ) if previous_query and previous_response: user_message = self.prompt_generator.generate_follow_up_prompt( original_query=previous_query, follow_up_query=prompt, previous_response=previous_response, context_chunks=context_chunks ) return [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message} ] def _format_references(self, chunks: List[Dict]) -> str: """Format reference citations in HTML.""" references = [] for i, chunk in enumerate(chunks, 1): references.append(f"""