import streamlit as st from typing import List, Dict import anthropic import os from datetime import datetime from utils.legal_prompt_generator import LegalPromptGenerator class ChatInterface: def __init__(self, case_manager, vector_store, document_processor): """Initialize ChatInterface with all required components.""" self.case_manager = case_manager self.vector_store = vector_store self.document_processor = document_processor self.prompt_generator = LegalPromptGenerator() try: api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: st.error("Please set the ANTHROPIC_API_KEY environment variable.") st.stop() self.client = anthropic.Anthropic(api_key=api_key) except Exception as e: st.error(f"Error initializing Anthropic client: {str(e)}") st.stop() # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "analyzed_documents" not in st.session_state: st.session_state.analyzed_documents = [] if "context_chunks" not in st.session_state: st.session_state.context_chunks = [] if "current_case" not in st.session_state: st.session_state.current_case = None def render(self): """Render the chat interface with document and context management.""" st.markdown(""" """, unsafe_allow_html=True) # Display active documents in the sidebar with st.sidebar: st.subheader("📚 Active Documents") for doc in st.session_state.analyzed_documents: with st.expander(f"📄 {doc['name']}", expanded=False): st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}") st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}") # Display chat history for message in st.session_state.messages: message_class = "user-message" if message["role"] == "user" else "assistant-message" with st.container(): st.markdown(f"""
""", unsafe_allow_html=True) # Chat input if prompt := st.chat_input("Ask about your documents..."): self._handle_chat_input(prompt) def _handle_chat_input(self, prompt: str): """Process user input and generate a response.""" st.session_state.messages.append({"role": "user", "content": prompt}) with st.spinner("Analyzing documents and generating a response..."): try: # Retrieve relevant document chunks context_chunks = self.vector_store.similarity_search( query=prompt, k=5, filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]} ) # Generate the response response, references = self.generate_response(prompt, context_chunks) # Add assistant response st.session_state.messages.append({ "role": "assistant", "content": response, "references": references }) # Update context for future queries st.session_state.context_chunks = context_chunks except Exception as e: st.error(f"Error generating response: {str(e)}") def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]: """Generate a response using the LLM and LegalPromptGenerator.""" try: # Generate structured messages messages = self._generate_messages(prompt, context_chunks) # Call the LLM response = self.client.messages.create( model="claude-3", max_tokens=2000, temperature=0.7, messages=messages ) # Format references references_html = self._format_references(context_chunks) return response.content[0].text, references_html except Exception as e: st.error(f"Error generating response: {str(e)}") return "An error occurred while processing your query.", "" def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]: """Generate structured messages for LLM input.""" # Get case metadata if available case_metadata = None if st.session_state.current_case: case_metadata = self.case_manager.get_case(st.session_state.current_case) # Generate system message system_message = self.prompt_generator.generate_system_message( context_chunks=context_chunks, query=prompt, case_metadata=case_metadata ) # Generate user message context = "\n".join([ f"Document: {chunk['metadata'].get('title', 'Untitled')}\n" f"Section: {chunk['text']}\n" for chunk in context_chunks ]) user_message = self.prompt_generator.generate_user_message(prompt, context) # Handle follow-up questions if st.session_state.messages: previous_query = next( (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "user"), None ) previous_response = next( (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "assistant"), None ) if previous_query and previous_response: user_message = self.prompt_generator.generate_follow_up_prompt( original_query=previous_query, follow_up_query=prompt, previous_response=previous_response, context_chunks=context_chunks ) return [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message} ] def _format_references(self, chunks: List[Dict]) -> str: """Format references as HTML for display.""" references = [] for i, chunk in enumerate(chunks, 1): references.append(f"""