Legal_AI_Agent / utils /legal_notebook_interface.py
cryogenic22's picture
Update utils/legal_notebook_interface.py
3ce1d9e verified
import streamlit as st
from typing import Dict, List, Optional
from datetime import datetime
from pathlib import Path
import json
import os
from utils.legal_prompt_generator import LegalPromptGenerator
from anthropic import Anthropic
class LegalNotebookInterface:
def __init__(self, case_manager, vector_store, document_processor):
"""Initialize LegalNotebookInterface with required components."""
# Store the components as instance attributes
self.case_manager = case_manager
self.vector_store = vector_store
self.document_processor = document_processor
self.prompt_generator = LegalPromptGenerator()
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "current_case" not in st.session_state:
st.session_state.current_case = None
if "selected_documents" not in st.session_state:
st.session_state.selected_documents = []
if "active_citation" not in st.session_state:
st.session_state.active_citation = None
# Add styling for notebook-like interface
st.markdown("""
<style>
.main-content {
max-width: 900px;
margin: 0 auto;
}
.citation-number {
display: inline-flex;
align-items: center;
justify-content: center;
width: 18px;
height: 18px;
background-color: #E8E8E8;
color: #666;
border-radius: 50%;
font-size: 11px;
cursor: pointer;
margin: 0 2px;
vertical-align: super;
border: 1px solid #D0D0D0;
}
.citation-number:hover {
background-color: #0066cc;
color: white;
}
.source-preview {
background: white;
border: 1px solid #E0E0E0;
border-radius: 8px;
padding: 15px;
margin-top: 10px;
margin-bottom: 10px;
}
.source-header {
font-size: 14px;
color: #666;
margin-bottom: 8px;
border-bottom: 1px solid #E0E0E0;
padding-bottom: 8px;
}
.highlighted-text {
background-color: #FFF3CD;
padding: 2px 4px;
border-radius: 3px;
}
.chat-message {
margin-bottom: 20px;
padding: 10px;
border-radius: 8px;
}
.user-message {
background-color: #F8F9FA;
}
.assistant-message {
background-color: white;
border: 1px solid #E9ECEF;
}
</style>
""", unsafe_allow_html=True)
def _render_chat_interface(self):
"""Render the chat interface with notebook-style citations."""
st.title("💬 Document Analysis")
# Create two-column layout for chat and preview
chat_col, preview_col = st.columns([2, 1])
with chat_col:
# Show context if needed
if st.session_state.current_case:
case = self.case_manager.get_case(st.session_state.current_case)
documents = self.case_manager.list_documents(case['id'])
selected_docs = [doc for doc in documents if doc['id'] in st.session_state.selected_documents]
if selected_docs:
st.markdown(
f"""
*Analyzing {len(selected_docs)} selected document(s)*
""")
else:
st.info("Please select documents from the sidebar")
# Display chat history
for msg_idx, message in enumerate(st.session_state.chat_history):
if message["role"] == "user":
st.markdown(f"""<div class="chat-message user-message">
{message['content']}</div>""", unsafe_allow_html=True)
else:
content = message["content"]
citations = message.get("citations", [])
# Add citation numbers to content
for idx, citation in enumerate(citations, 1):
# Create unique identifier for this citation
citation_id = f"citation_{msg_idx}_{idx}"
# Add clickable number that will trigger preview
content = content.replace(
citation['text'][:50],
f"{citation['text'][:50]}<span class='citation-number' onclick='show_citation(\"{citation_id}\")'>{idx}</span>"
)
st.markdown(f"""<div class="chat-message assistant-message">
{content}</div>""", unsafe_allow_html=True)
# Store citations in session state for preview
for idx, citation in enumerate(citations, 1):
citation_key = f"citation_{msg_idx}_{idx}"
st.session_state[citation_key] = citation
# Preview panel
with preview_col:
if st.session_state.active_citation:
citation = st.session_state.get(st.session_state.active_citation)
if citation:
st.markdown("""<div class="source-preview">
<div class="source-header">
📄 Source Reference
</div>
<strong>{}</strong>
<div class="highlighted-text">{}</div>
<small>Section: {}</small>
</div>""".format(
citation['document'],
citation['text'],
citation['section']
), unsafe_allow_html=True)
# Add JavaScript for handling citation clicks
st.markdown("""
<script>
function show_citation(citation_id) {
// Update active citation in session state
window.parent.postMessage({
type: 'streamlit:set_session_state',
data: { active_citation: citation_id }
}, '*');
}
</script>
""", unsafe_allow_html=True)
# Chat input
if prompt := st.chat_input(
"Ask about your documents...",
disabled=not st.session_state.selected_documents
):
self._process_query(prompt)
def render(self):
"""Main render method with sidebar controls."""
self._render_sidebar()
self._render_chat_interface()
def _render_sidebar(self):
"""Render sidebar with case and document selection."""
with st.sidebar:
st.markdown("### 📁 Case Selection")
cases = self.case_manager.get_all_cases()
if not cases:
st.info("No cases found. Please create a case first.")
return
selected_case = st.selectbox(
"Select Case",
cases,
format_func=lambda x: x['title'],
key="case_selector"
)
if selected_case:
st.session_state.current_case = selected_case['id']
documents = self.case_manager.list_documents(selected_case['id'])
if documents:
st.markdown("### 📑 Document Selection")
col1, col2 = st.columns([3, 1])
with col1:
st.markdown("#### Documents")
with col2:
if st.button("Select All"):
st.session_state.selected_documents = [doc['id'] for doc in documents]
for doc in documents:
doc_col, del_col = st.columns([4, 1])
with doc_col:
doc_selected = st.checkbox(
f"📄 {doc['title']}",
value=doc['id'] in st.session_state.selected_documents,
key=f"doc_{doc['id']}"
)
with del_col:
if st.button("🗑️", key=f"del_{doc['id']}", help="Delete document"):
try:
self.vector_store.delete_document(doc['id'])
self.case_manager.remove_document(selected_case['id'], doc['id'])
if doc['id'] in st.session_state.selected_documents:
st.session_state.selected_documents.remove(doc['id'])
st.rerun()
except Exception as e:
st.error(f"Error deleting document: {str(e)}")
if doc_selected and doc['id'] not in st.session_state.selected_documents:
st.session_state.selected_documents.append(doc['id'])
elif not doc_selected and doc['id'] in st.session_state.selected_documents:
st.session_state.selected_documents.remove(doc['id'])
else:
st.info("No documents in this case. Please add documents first.")
def _process_query(self, prompt: str):
"""Process query with legal prompt generation and structured LLM response."""
if not st.session_state.current_case:
st.error("Please select a case first.")
return
if not st.session_state.selected_documents:
st.error("Please select at least one document to analyze.")
return
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
try:
with st.status("Processing your query...", expanded=True) as status:
status.write("🔍 Finding relevant documents...")
results = self.vector_store.similarity_search(
query=prompt,
k=5,
filter_docs=st.session_state.selected_documents
)
if not results:
st.warning("No relevant information found in the selected documents.")
return
case = self.case_manager.get_case(st.session_state.current_case)
status.write("💭 Analyzing legal context...")
context_chunks = []
context_text = []
for r in results:
doc_id = r['doc_id']
doc = next((d for d in self.case_manager.list_documents(st.session_state.current_case)
if d['id'] == doc_id), None)
doc_title = doc['title'] if doc else 'Unknown Document'
context_chunks.append({
'text': r['text'],
'metadata': {
**r.get('metadata', {}),
'document_title': doc_title
}
})
context_text.append(f"From {doc_title}:\n{r['text']}\n")
system_message = self.prompt_generator.generate_system_message(
context_chunks=context_chunks,
query=prompt,
case_metadata=case
)
user_message = self.prompt_generator.generate_user_message(
query=prompt,
context="\n\n".join(context_text)
)
status.write("✍️ Generating legal analysis...")
response = self._generate_legal_response(
system_message=system_message,
user_message=user_message,
results=results
)
status.update(label="✅ Analysis complete!", state="complete")
st.session_state.chat_history.append(response)
st.rerun()
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.exception(e)
def _generate_legal_response(self, system_message: str, user_message: str, results: List[Dict]) -> Dict:
"""Generate structured legal response using Anthropic API with proper legal prompting."""
if not results:
return {
"role": "assistant",
"content": "I couldn't find any relevant information in the documents.",
"citations": []
}
try:
citations = []
for idx, result in enumerate(results):
doc_id = result['doc_id']
doc = next((d for d in self.case_manager.list_documents(st.session_state.current_case)
if d['id'] == doc_id), None)
source_doc = doc['title'] if doc else 'Unknown Document'
chunk_idx = result.get('metadata', {}).get('chunk_index', 'unknown')
citations.append({
"document": source_doc,
"section": f"Section {chunk_idx}",
"text": result['text'].strip(),
"score": result.get('score', 0)
})
try:
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
anthropic = Anthropic(api_key=api_key)
# Create a single message with the system context included
combined_prompt = f"{system_message}\n\nUser Query: {user_message}"
# Make the API call with the correct format
response = anthropic.messages.create(
model="claude-3-sonnet-20240229",
max_tokens=4096,
messages=[{
"role": "user",
"content": combined_prompt
}],
temperature=0.2
)
llm_response = response.content[0].text
return {
"role": "assistant",
"content": llm_response,
"citations": sorted(citations, key=lambda x: x['score'], reverse=True)
}
except Exception as api_error:
st.error(f"Claude API Error: {str(api_error)}")
if "ANTHROPIC_API_KEY" not in os.environ:
st.error("ANTHROPIC_API_KEY not found. Please ensure it's set in HuggingFace Secrets.")
raise api_error
except Exception as e:
st.error(f"Error in response generation: {str(e)}")
return {
"role": "assistant",
"content": f"I apologize, but I encountered an error while analyzing the documents: {str(e)}",
"citations": []
}