Spaces:
Build error
Build error
| import streamlit as st | |
| from typing import List, Dict | |
| import anthropic | |
| import os | |
| from datetime import datetime | |
| from utils.legal_prompt_generator import LegalPromptGenerator | |
| class ChatInterface: | |
| def __init__(self, case_manager, vector_store, document_processor): | |
| """Initialize ChatInterface with all required components.""" | |
| self.case_manager = case_manager | |
| self.vector_store = vector_store | |
| self.document_processor = document_processor | |
| self.prompt_generator = LegalPromptGenerator() | |
| try: | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| st.error("Please set the ANTHROPIC_API_KEY environment variable.") | |
| st.stop() | |
| self.client = anthropic.Anthropic(api_key=api_key) | |
| except Exception as e: | |
| st.error(f"Error initializing Anthropic client: {str(e)}") | |
| st.stop() | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "analyzed_documents" not in st.session_state: | |
| st.session_state.analyzed_documents = [] | |
| if "context_chunks" not in st.session_state: | |
| st.session_state.context_chunks = [] | |
| if "current_case" not in st.session_state: | |
| st.session_state.current_case = None | |
| def render(self): | |
| """Render the chat interface with document and context management.""" | |
| st.markdown(""" | |
| <style> | |
| .chat-message { | |
| padding: 1.5rem; | |
| border-radius: 0.5rem; | |
| margin-bottom: 1rem; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .user-message { | |
| background-color: #f0f7ff; | |
| border-left: 4px solid #2B547E; | |
| } | |
| .assistant-message { | |
| background-color: #ffffff; | |
| border-left: 4px solid #4CAF50; | |
| } | |
| .reference-box { | |
| background-color: #f5f5f5; | |
| padding: 0.8rem; | |
| border-radius: 0.3rem; | |
| font-size: 0.9em; | |
| margin-top: 0.5rem; | |
| } | |
| .document-chunk { | |
| border-left: 3px solid #2196F3; | |
| padding-left: 1rem; | |
| margin: 0.5rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Display active documents in the sidebar | |
| with st.sidebar: | |
| st.subheader("📚 Active Documents") | |
| for doc in st.session_state.analyzed_documents: | |
| with st.expander(f"📄 {doc['name']}", expanded=False): | |
| st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}") | |
| st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}") | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| message_class = "user-message" if message["role"] == "user" else "assistant-message" | |
| with st.container(): | |
| st.markdown(f""" | |
| <div class="chat-message {message_class}"> | |
| {message["content"]} | |
| {'<div class="reference-box">' + message.get("references", "") + '</div>' if message.get("references") else ""} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Chat input | |
| if prompt := st.chat_input("Ask about your documents..."): | |
| self._handle_chat_input(prompt) | |
| def _handle_chat_input(self, prompt: str): | |
| """Process user input and generate a response.""" | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.spinner("Analyzing documents and generating a response..."): | |
| try: | |
| # Retrieve relevant document chunks | |
| context_chunks = self.vector_store.similarity_search( | |
| query=prompt, | |
| k=5, | |
| filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]} | |
| ) | |
| # Generate the response | |
| response, references = self.generate_response(prompt, context_chunks) | |
| # Add assistant response | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": response, | |
| "references": references | |
| }) | |
| # Update context for future queries | |
| st.session_state.context_chunks = context_chunks | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]: | |
| """Generate a response using the LLM and LegalPromptGenerator.""" | |
| try: | |
| # Generate structured messages | |
| messages = self._generate_messages(prompt, context_chunks) | |
| # Call the LLM | |
| response = self.client.messages.create( | |
| model="claude-3", | |
| max_tokens=2000, | |
| temperature=0.7, | |
| messages=messages | |
| ) | |
| # Format references | |
| references_html = self._format_references(context_chunks) | |
| return response.content[0].text, references_html | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| return "An error occurred while processing your query.", "" | |
| def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]: | |
| """Generate structured messages for LLM input.""" | |
| # Get case metadata if available | |
| case_metadata = None | |
| if st.session_state.current_case: | |
| case_metadata = self.case_manager.get_case(st.session_state.current_case) | |
| # Generate system message | |
| system_message = self.prompt_generator.generate_system_message( | |
| context_chunks=context_chunks, | |
| query=prompt, | |
| case_metadata=case_metadata | |
| ) | |
| # Generate user message | |
| context = "\n".join([ | |
| f"Document: {chunk['metadata'].get('title', 'Untitled')}\n" | |
| f"Section: {chunk['text']}\n" | |
| for chunk in context_chunks | |
| ]) | |
| user_message = self.prompt_generator.generate_user_message(prompt, context) | |
| # Handle follow-up questions | |
| if st.session_state.messages: | |
| previous_query = next( | |
| (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "user"), | |
| None | |
| ) | |
| previous_response = next( | |
| (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "assistant"), | |
| None | |
| ) | |
| if previous_query and previous_response: | |
| user_message = self.prompt_generator.generate_follow_up_prompt( | |
| original_query=previous_query, | |
| follow_up_query=prompt, | |
| previous_response=previous_response, | |
| context_chunks=context_chunks | |
| ) | |
| return [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| def _format_references(self, chunks: List[Dict]) -> str: | |
| """Format references as HTML for display.""" | |
| references = [] | |
| for i, chunk in enumerate(chunks, 1): | |
| references.append(f""" | |
| <div class="document-chunk"> | |
| <strong>Reference {i}:</strong> {chunk['metadata'].get('title', 'Untitled')} | |
| <br/> | |
| <em>Section:</em> {chunk['text'][:200]}... | |
| </div> | |
| """) | |
| return "\n".join(references) | |
| def add_analyzed_document(self, doc: Dict): | |
| """Add a document to session state with metadata tracking.""" | |
| doc['metadata']['added_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| if doc not in st.session_state.analyzed_documents: | |
| st.session_state.analyzed_documents.append(doc) | |