cryogenic22 commited on
Commit
c5654cc
·
verified ·
1 Parent(s): 14f8d88

Update components/chat_interface.py

Browse files
Files changed (1) hide show
  1. components/chat_interface.py +64 -102
components/chat_interface.py CHANGED
@@ -5,11 +5,15 @@ import os
5
  from datetime import datetime
6
  from utils.legal_prompt_generator import LegalPromptGenerator
7
 
 
8
  class ChatInterface:
9
- def __init__(self, vector_store, document_processor):
 
 
10
  self.vector_store = vector_store
11
  self.document_processor = document_processor
12
-
 
13
  try:
14
  api_key = os.getenv("ANTHROPIC_API_KEY")
15
  if not api_key:
@@ -27,9 +31,11 @@ class ChatInterface:
27
  st.session_state.analyzed_documents = []
28
  if "context_chunks" not in st.session_state:
29
  st.session_state.context_chunks = []
 
 
30
 
31
  def render(self):
32
- """Render an improved chat interface with better document context."""
33
  st.markdown("""
34
  <style>
35
  .chat-message {
@@ -61,7 +67,7 @@ class ChatInterface:
61
  </style>
62
  """, unsafe_allow_html=True)
63
 
64
- # Display active documents and context
65
  with st.sidebar:
66
  st.subheader("📚 Active Documents")
67
  for doc in st.session_state.analyzed_documents:
@@ -69,7 +75,7 @@ class ChatInterface:
69
  st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}")
70
  st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}")
71
 
72
- # Display chat history with improved styling
73
  for message in st.session_state.messages:
74
  message_class = "user-message" if message["role"] == "user" else "assistant-message"
75
  with st.container():
@@ -80,134 +86,90 @@ class ChatInterface:
80
  </div>
81
  """, unsafe_allow_html=True)
82
 
83
- # Chat input with improved context handling
84
  if prompt := st.chat_input("Ask about your documents..."):
85
  self._handle_chat_input(prompt)
86
 
87
  def _handle_chat_input(self, prompt: str):
88
- """Handle chat input with improved context management."""
89
- # Add user message
90
  st.session_state.messages.append({"role": "user", "content": prompt})
91
-
92
- # Get relevant context chunks
93
- context_chunks = self.vector_store.similarity_search(
94
- prompt,
95
- k=5,
96
- filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]}
97
- )
98
-
99
- # Generate response
100
- with st.spinner("Analyzing documents and generating response..."):
101
- response_content, references = self.generate_response(prompt, context_chunks)
102
-
103
- # Add assistant message with references
104
- st.session_state.messages.append({
105
- "role": "assistant",
106
- "content": response_content,
107
- "references": references
108
- })
109
-
110
- # Store context chunks for future reference
111
- st.session_state.context_chunks = context_chunks
 
 
 
112
 
113
  def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]:
114
- """Generate response using Claude with improved context handling."""
115
  try:
116
- # Prepare context from chunks
117
- context = "\n".join([
118
- f"Document: {chunk['metadata']['title']}\n"
119
- f"Section: {chunk['text']}\n"
120
- f"Type: {chunk['metadata']['type']}\n"
121
- f"Jurisdiction: {chunk['metadata']['jurisdiction']}\n"
122
- for chunk in context_chunks
123
- ])
124
-
125
- # Generate system message using ontology
126
- system_message = self._generate_system_message(prompt, context_chunks)
127
-
128
- # Call Claude API
129
- message = self.client.messages.create(
130
- model="claude-3-sonnet-20240229",
131
  max_tokens=2000,
132
  temperature=0.7,
133
- messages=[
134
- {"role": "system", "content": system_message},
135
- {"role": "user", "content": f"Question: {prompt}\n\nContext:\n{context}"}
136
- ]
137
  )
138
-
139
  # Format references
140
  references_html = self._format_references(context_chunks)
141
-
142
- return message.content[0].text, references_html
143
-
144
- except Exception as e:
145
- st.error(f"Error generating response: {str(e)}")
146
- return "I apologize, but I encountered an error generating the response.", ""
147
 
148
- def __init__(self, case_manager, vector_store, document_processor):
149
- """Initialize ChatInterface with enhanced components."""
150
- self.case_manager = case_manager
151
- self.vector_store = vector_store
152
- self.document_processor = document_processor
153
- self.prompt_generator = LegalPromptGenerator()
154
-
155
- try:
156
- api_key = os.getenv("ANTHROPIC_API_KEY")
157
- if not api_key:
158
- st.error("Please set the ANTHROPIC_API_KEY environment variable.")
159
- st.stop()
160
- self.client = anthropic.Anthropic(api_key=api_key)
161
  except Exception as e:
162
- st.error(f"Error initializing Anthropic client: {str(e)}")
163
- st.stop()
164
-
165
- # Initialize session state
166
- if "messages" not in st.session_state:
167
- st.session_state.messages = []
168
- if "analyzed_documents" not in st.session_state:
169
- st.session_state.analyzed_documents = []
170
- if "context_chunks" not in st.session_state:
171
- st.session_state.context_chunks = []
172
- if "current_case" not in st.session_state:
173
- st.session_state.current_case = None
174
 
175
  def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]:
176
- """Generate messages for the Claude API with enhanced prompts."""
177
  # Get case metadata if available
178
  case_metadata = None
179
  if st.session_state.current_case:
180
  case_metadata = self.case_manager.get_case(st.session_state.current_case)
181
-
182
- # Generate enhanced system message
183
  system_message = self.prompt_generator.generate_system_message(
184
  context_chunks=context_chunks,
185
  query=prompt,
186
  case_metadata=case_metadata
187
  )
188
-
189
- # Prepare context from chunks
190
  context = "\n".join([
191
- f"Document: {chunk['metadata']['title']}\n"
192
  f"Section: {chunk['text']}\n"
193
- f"Type: {chunk['metadata']['type']}\n"
194
- f"Jurisdiction: {chunk['metadata']['jurisdiction']}\n"
195
  for chunk in context_chunks
196
  ])
197
-
198
- # Generate user message
199
  user_message = self.prompt_generator.generate_user_message(prompt, context)
200
-
201
- # Check if this is a follow-up question
202
  if st.session_state.messages:
203
  previous_query = next(
204
- (m["content"] for m in reversed(st.session_state.messages)
205
- if m["role"] == "user"),
206
  None
207
  )
208
  previous_response = next(
209
- (m["content"] for m in reversed(st.session_state.messages)
210
- if m["role"] == "assistant"),
211
  None
212
  )
213
  if previous_query and previous_response:
@@ -217,19 +179,19 @@ class ChatInterface:
217
  previous_response=previous_response,
218
  context_chunks=context_chunks
219
  )
220
-
221
  return [
222
  {"role": "system", "content": system_message},
223
  {"role": "user", "content": user_message}
224
  ]
225
 
226
  def _format_references(self, chunks: List[Dict]) -> str:
227
- """Format reference citations in HTML."""
228
  references = []
229
  for i, chunk in enumerate(chunks, 1):
230
  references.append(f"""
231
  <div class="document-chunk">
232
- <strong>Reference {i}:</strong> {chunk['metadata']['title']}
233
  <br/>
234
  <em>Section:</em> {chunk['text'][:200]}...
235
  </div>
@@ -237,7 +199,7 @@ class ChatInterface:
237
  return "\n".join(references)
238
 
239
  def add_analyzed_document(self, doc: Dict):
240
- """Add a document with improved metadata tracking."""
241
  doc['metadata']['added_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
242
  if doc not in st.session_state.analyzed_documents:
243
- st.session_state.analyzed_documents.append(doc)
 
5
  from datetime import datetime
6
  from utils.legal_prompt_generator import LegalPromptGenerator
7
 
8
+
9
  class ChatInterface:
10
+ def __init__(self, case_manager, vector_store, document_processor):
11
+ """Initialize ChatInterface with all required components."""
12
+ self.case_manager = case_manager
13
  self.vector_store = vector_store
14
  self.document_processor = document_processor
15
+ self.prompt_generator = LegalPromptGenerator()
16
+
17
  try:
18
  api_key = os.getenv("ANTHROPIC_API_KEY")
19
  if not api_key:
 
31
  st.session_state.analyzed_documents = []
32
  if "context_chunks" not in st.session_state:
33
  st.session_state.context_chunks = []
34
+ if "current_case" not in st.session_state:
35
+ st.session_state.current_case = None
36
 
37
  def render(self):
38
+ """Render the chat interface with document and context management."""
39
  st.markdown("""
40
  <style>
41
  .chat-message {
 
67
  </style>
68
  """, unsafe_allow_html=True)
69
 
70
+ # Display active documents in the sidebar
71
  with st.sidebar:
72
  st.subheader("📚 Active Documents")
73
  for doc in st.session_state.analyzed_documents:
 
75
  st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}")
76
  st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}")
77
 
78
+ # Display chat history
79
  for message in st.session_state.messages:
80
  message_class = "user-message" if message["role"] == "user" else "assistant-message"
81
  with st.container():
 
86
  </div>
87
  """, unsafe_allow_html=True)
88
 
89
+ # Chat input
90
  if prompt := st.chat_input("Ask about your documents..."):
91
  self._handle_chat_input(prompt)
92
 
93
  def _handle_chat_input(self, prompt: str):
94
+ """Process user input and generate a response."""
 
95
  st.session_state.messages.append({"role": "user", "content": prompt})
96
+
97
+ with st.spinner("Analyzing documents and generating a response..."):
98
+ try:
99
+ # Retrieve relevant document chunks
100
+ context_chunks = self.vector_store.similarity_search(
101
+ query=prompt,
102
+ k=5,
103
+ filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]}
104
+ )
105
+
106
+ # Generate the response
107
+ response, references = self.generate_response(prompt, context_chunks)
108
+
109
+ # Add assistant response
110
+ st.session_state.messages.append({
111
+ "role": "assistant",
112
+ "content": response,
113
+ "references": references
114
+ })
115
+
116
+ # Update context for future queries
117
+ st.session_state.context_chunks = context_chunks
118
+ except Exception as e:
119
+ st.error(f"Error generating response: {str(e)}")
120
 
121
  def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]:
122
+ """Generate a response using the LLM and LegalPromptGenerator."""
123
  try:
124
+ # Generate structured messages
125
+ messages = self._generate_messages(prompt, context_chunks)
126
+
127
+ # Call the LLM
128
+ response = self.client.messages.create(
129
+ model="claude-3",
 
 
 
 
 
 
 
 
 
130
  max_tokens=2000,
131
  temperature=0.7,
132
+ messages=messages
 
 
 
133
  )
134
+
135
  # Format references
136
  references_html = self._format_references(context_chunks)
137
+ return response.content[0].text, references_html
 
 
 
 
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  except Exception as e:
140
+ st.error(f"Error generating response: {str(e)}")
141
+ return "An error occurred while processing your query.", ""
 
 
 
 
 
 
 
 
 
 
142
 
143
  def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]:
144
+ """Generate structured messages for LLM input."""
145
  # Get case metadata if available
146
  case_metadata = None
147
  if st.session_state.current_case:
148
  case_metadata = self.case_manager.get_case(st.session_state.current_case)
149
+
150
+ # Generate system message
151
  system_message = self.prompt_generator.generate_system_message(
152
  context_chunks=context_chunks,
153
  query=prompt,
154
  case_metadata=case_metadata
155
  )
156
+
157
+ # Generate user message
158
  context = "\n".join([
159
+ f"Document: {chunk['metadata'].get('title', 'Untitled')}\n"
160
  f"Section: {chunk['text']}\n"
 
 
161
  for chunk in context_chunks
162
  ])
 
 
163
  user_message = self.prompt_generator.generate_user_message(prompt, context)
164
+
165
+ # Handle follow-up questions
166
  if st.session_state.messages:
167
  previous_query = next(
168
+ (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "user"),
 
169
  None
170
  )
171
  previous_response = next(
172
+ (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "assistant"),
 
173
  None
174
  )
175
  if previous_query and previous_response:
 
179
  previous_response=previous_response,
180
  context_chunks=context_chunks
181
  )
182
+
183
  return [
184
  {"role": "system", "content": system_message},
185
  {"role": "user", "content": user_message}
186
  ]
187
 
188
  def _format_references(self, chunks: List[Dict]) -> str:
189
+ """Format references as HTML for display."""
190
  references = []
191
  for i, chunk in enumerate(chunks, 1):
192
  references.append(f"""
193
  <div class="document-chunk">
194
+ <strong>Reference {i}:</strong> {chunk['metadata'].get('title', 'Untitled')}
195
  <br/>
196
  <em>Section:</em> {chunk['text'][:200]}...
197
  </div>
 
199
  return "\n".join(references)
200
 
201
  def add_analyzed_document(self, doc: Dict):
202
+ """Add a document to session state with metadata tracking."""
203
  doc['metadata']['added_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
204
  if doc not in st.session_state.analyzed_documents:
205
+ st.session_state.analyzed_documents.append(doc)