cryogenic22 commited on
Commit
3ab44aa
·
verified ·
1 Parent(s): b74d28a

Update components/chat_interface.py

Browse files
Files changed (1) hide show
  1. components/chat_interface.py +142 -59
components/chat_interface.py CHANGED
@@ -2,98 +2,181 @@ import streamlit as st
2
  from typing import List, Dict
3
  import anthropic
4
  import os
 
5
 
6
  class ChatInterface:
7
- def __init__(self, vector_store):
8
  self.vector_store = vector_store
9
-
10
- # Initialize Anthropic client using environment variable
11
  try:
12
  api_key = os.getenv("ANTHROPIC_API_KEY")
13
  if not api_key:
14
- st.error("Please set the ANTHROPIC_API_KEY in your environment variables.")
15
  st.stop()
16
  self.client = anthropic.Anthropic(api_key=api_key)
17
  except Exception as e:
18
  st.error(f"Error initializing Anthropic client: {str(e)}")
19
  st.stop()
20
 
21
- # Initialize chat history and analyzed documents
22
  if "messages" not in st.session_state:
23
  st.session_state.messages = []
24
  if "analyzed_documents" not in st.session_state:
25
  st.session_state.analyzed_documents = []
 
 
26
 
27
  def render(self):
28
- """Render chat interface with analyzed documents and history."""
29
- st.subheader("Chat with Documents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Display analyzed documents
32
- with st.expander("Analyzed Documents", expanded=True):
33
- if st.session_state.analyzed_documents:
34
- for doc in st.session_state.analyzed_documents:
35
- st.markdown(f"- **{doc['name']}**")
36
- else:
37
- st.info("No documents analyzed yet.")
38
 
39
- # Display chat history
40
  for message in st.session_state.messages:
41
- with st.chat_message(message["role"]):
42
- st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Chat input
45
- if prompt := st.chat_input("Ask a question about your documents"):
46
- # Add user message to history
47
- st.session_state.messages.append({"role": "user", "content": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- with st.chat_message("user"):
50
- st.markdown(prompt)
51
 
52
- # Get relevant context from vector store
53
- results = self.vector_store.similarity_search(prompt, k=3)
54
- context = "\n\n".join([r["metadata"]["text"] for r in results])
 
 
 
 
 
 
 
 
55
 
56
- # Generate response with references
57
- with st.chat_message("assistant"):
58
- with st.spinner("Thinking..."):
59
- response = self.generate_response(prompt, context, results)
60
- st.markdown(response)
61
- st.session_state.messages.append({"role": "assistant", "content": response})
62
 
63
- def generate_response(self, prompt: str, context: str, results: List[Dict]) -> str:
64
- """Generate response using Claude with document references."""
65
- try:
66
- # Call the Claude API for response generation
67
  message = self.client.messages.create(
68
  model="claude-3-sonnet-20240229",
69
  max_tokens=2000,
70
  temperature=0.7,
71
- messages=[{
72
- "role": "user",
73
- "content": f"""Based on the following context from legal documents, please answer the question.
74
-
75
- Context:
76
- {context}
77
-
78
- Question: {prompt}
79
-
80
- Please provide a detailed response with references to specific parts of the documents when relevant."""
81
- }]
82
  )
83
- response_content = message.content[0].text
84
-
85
- # Append document references to the response
86
- references = [
87
- f"{idx + 1}. {res['metadata']['text'][:200]}... (Reference: {res['metadata'].get('reference', 'N/A')})"
88
- for idx, res in enumerate(results)
89
- ]
90
- references_text = "\n\nReferences:\n" + "\n".join(references)
91
- return response_content + references_text
92
  except Exception as e:
93
  st.error(f"Error generating response: {str(e)}")
94
- return "I apologize, but I encountered an error generating the response."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def add_analyzed_document(self, doc: Dict):
97
- """Add a document to the list of analyzed documents."""
 
98
  if doc not in st.session_state.analyzed_documents:
99
- st.session_state.analyzed_documents.append(doc)
 
2
  from typing import List, Dict
3
  import anthropic
4
  import os
5
+ from datetime import datetime
6
 
7
  class ChatInterface:
8
+ def __init__(self, vector_store, document_processor):
9
  self.vector_store = vector_store
10
+ self.document_processor = document_processor
11
+
12
  try:
13
  api_key = os.getenv("ANTHROPIC_API_KEY")
14
  if not api_key:
15
+ st.error("Please set the ANTHROPIC_API_KEY environment variable.")
16
  st.stop()
17
  self.client = anthropic.Anthropic(api_key=api_key)
18
  except Exception as e:
19
  st.error(f"Error initializing Anthropic client: {str(e)}")
20
  st.stop()
21
 
22
+ # Initialize session state
23
  if "messages" not in st.session_state:
24
  st.session_state.messages = []
25
  if "analyzed_documents" not in st.session_state:
26
  st.session_state.analyzed_documents = []
27
+ if "context_chunks" not in st.session_state:
28
+ st.session_state.context_chunks = []
29
 
30
  def render(self):
31
+ """Render an improved chat interface with better document context."""
32
+ st.markdown("""
33
+ <style>
34
+ .chat-message {
35
+ padding: 1.5rem;
36
+ border-radius: 0.5rem;
37
+ margin-bottom: 1rem;
38
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
+ }
40
+ .user-message {
41
+ background-color: #f0f7ff;
42
+ border-left: 4px solid #2B547E;
43
+ }
44
+ .assistant-message {
45
+ background-color: #ffffff;
46
+ border-left: 4px solid #4CAF50;
47
+ }
48
+ .reference-box {
49
+ background-color: #f5f5f5;
50
+ padding: 0.8rem;
51
+ border-radius: 0.3rem;
52
+ font-size: 0.9em;
53
+ margin-top: 0.5rem;
54
+ }
55
+ .document-chunk {
56
+ border-left: 3px solid #2196F3;
57
+ padding-left: 1rem;
58
+ margin: 0.5rem 0;
59
+ }
60
+ </style>
61
+ """, unsafe_allow_html=True)
62
 
63
+ # Display active documents and context
64
+ with st.sidebar:
65
+ st.subheader("📚 Active Documents")
66
+ for doc in st.session_state.analyzed_documents:
67
+ with st.expander(f"📄 {doc['name']}", expanded=False):
68
+ st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}")
69
+ st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}")
70
 
71
+ # Display chat history with improved styling
72
  for message in st.session_state.messages:
73
+ message_class = "user-message" if message["role"] == "user" else "assistant-message"
74
+ with st.container():
75
+ st.markdown(f"""
76
+ <div class="chat-message {message_class}">
77
+ {message["content"]}
78
+ {'<div class="reference-box">' + message.get("references", "") + '</div>' if message.get("references") else ""}
79
+ </div>
80
+ """, unsafe_allow_html=True)
81
+
82
+ # Chat input with improved context handling
83
+ if prompt := st.chat_input("Ask about your documents..."):
84
+ self._handle_chat_input(prompt)
85
 
86
+ def _handle_chat_input(self, prompt: str):
87
+ """Handle chat input with improved context management."""
88
+ # Add user message
89
+ st.session_state.messages.append({"role": "user", "content": prompt})
90
+
91
+ # Get relevant context chunks
92
+ context_chunks = self.vector_store.similarity_search(
93
+ prompt,
94
+ k=5,
95
+ filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]}
96
+ )
97
+
98
+ # Generate response
99
+ with st.spinner("Analyzing documents and generating response..."):
100
+ response_content, references = self.generate_response(prompt, context_chunks)
101
+
102
+ # Add assistant message with references
103
+ st.session_state.messages.append({
104
+ "role": "assistant",
105
+ "content": response_content,
106
+ "references": references
107
+ })
108
 
109
+ # Store context chunks for future reference
110
+ st.session_state.context_chunks = context_chunks
111
 
112
+ def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]:
113
+ """Generate response using Claude with improved context handling."""
114
+ try:
115
+ # Prepare context from chunks
116
+ context = "\n".join([
117
+ f"Document: {chunk['metadata']['title']}\n"
118
+ f"Section: {chunk['text']}\n"
119
+ f"Type: {chunk['metadata']['type']}\n"
120
+ f"Jurisdiction: {chunk['metadata']['jurisdiction']}\n"
121
+ for chunk in context_chunks
122
+ ])
123
 
124
+ # Generate system message using ontology
125
+ system_message = self._generate_system_message(prompt, context_chunks)
 
 
 
 
126
 
127
+ # Call Claude API
 
 
 
128
  message = self.client.messages.create(
129
  model="claude-3-sonnet-20240229",
130
  max_tokens=2000,
131
  temperature=0.7,
132
+ messages=[
133
+ {"role": "system", "content": system_message},
134
+ {"role": "user", "content": f"Question: {prompt}\n\nContext:\n{context}"}
135
+ ]
 
 
 
 
 
 
 
136
  )
137
+
138
+ # Format references
139
+ references_html = self._format_references(context_chunks)
140
+
141
+ return message.content[0].text, references_html
142
+
 
 
 
143
  except Exception as e:
144
  st.error(f"Error generating response: {str(e)}")
145
+ return "I apologize, but I encountered an error generating the response.", ""
146
+
147
+ def _generate_system_message(self, prompt: str, context_chunks: List[Dict]) -> str:
148
+ """Generate a system message using ontology and document context."""
149
+ # Get relevant ontology concepts
150
+ ontology_concepts = self.document_processor._link_to_ontology(prompt)
151
+
152
+ return f"""You are a legal AI assistant analyzing documents with the following context:
153
+
154
+ Document Types Present: {', '.join(set(chunk['metadata']['type'] for chunk in context_chunks))}
155
+ Jurisdictions: {', '.join(set(chunk['metadata']['jurisdiction'] for chunk in context_chunks))}
156
+ Relevant Legal Concepts: {', '.join(concept['concept'] for concept in ontology_concepts)}
157
+
158
+ Please provide detailed analysis while:
159
+ 1. Citing specific sections from the provided context
160
+ 2. Incorporating relevant legal concepts and terminology
161
+ 3. Maintaining appropriate legal language and tone
162
+ 4. Providing clear references to source documents
163
+ """
164
+
165
+ def _format_references(self, chunks: List[Dict]) -> str:
166
+ """Format reference citations in HTML."""
167
+ references = []
168
+ for i, chunk in enumerate(chunks, 1):
169
+ references.append(f"""
170
+ <div class="document-chunk">
171
+ <strong>Reference {i}:</strong> {chunk['metadata']['title']}
172
+ <br/>
173
+ <em>Section:</em> {chunk['text'][:200]}...
174
+ </div>
175
+ """)
176
+ return "\n".join(references)
177
 
178
  def add_analyzed_document(self, doc: Dict):
179
+ """Add a document with improved metadata tracking."""
180
+ doc['metadata']['added_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
181
  if doc not in st.session_state.analyzed_documents:
182
+ st.session_state.analyzed_documents.append(doc)