shahzad4894 commited on
Commit
b4021de
Β·
verified Β·
1 Parent(s): 39270a1

add the chatbot code

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +847 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,849 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
4
+ from langchain_community.vectorstores import SupabaseVectorStore
5
+ from langchain.chains import RetrievalQA
6
+ from supabase import create_client
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.agents import Tool, create_react_agent
9
+ from langchain.tools.retriever import create_retriever_tool
10
+ from langchain.memory import ConversationSummaryBufferMemory
11
+ from langchain.agents import AgentExecutor
12
+ from langchain.schema import HumanMessage, AIMessage
13
+ from langchain.cache import InMemoryCache
14
+ from langchain.globals import set_llm_cache
15
+ from langchain.retrievers import ContextualCompressionRetriever
16
+ from langchain.retrievers.document_compressors import LLMChainExtractor
17
+ import uuid
18
+ from datetime import datetime
19
+ import json
20
+ import time
21
+ from collections import defaultdict
22
+ from tenacity import retry, stop_after_attempt, wait_exponential
23
 
24
+ # Page configuration
25
+ st.set_page_config(
26
+ page_title="AI Document Assistant",
27
+ page_icon="πŸ€–",
28
+ layout="wide",
29
+ initial_sidebar_state="expanded"
30
+ )
31
+
32
+ # Enable LLM caching for faster responses
33
+ set_llm_cache(InMemoryCache())
34
+
35
+ # Custom CSS for professional design
36
+ st.markdown("""
37
+ <style>
38
+ /* Import clean font */
39
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
40
+
41
+ /* Global styles */
42
+ * {
43
+ font-family: 'Inter', sans-serif;
44
+ }
45
+
46
+ /* Remove default padding/margins */
47
+ .main > div {
48
+ padding-top: 2rem;
49
+ }
50
+
51
+ /* Header styling */
52
+ .header-container {
53
+ background: #ffffff;
54
+ border-bottom: 1px solid #e5e7eb;
55
+ padding: 1.5rem 0;
56
+ margin-bottom: 0;
57
+ position: sticky;
58
+ top: 0;
59
+ z-index: 100;
60
+ }
61
+
62
+ .header-title {
63
+ font-size: 1.5rem;
64
+ font-weight: 600;
65
+ color: #111827;
66
+ margin: 0;
67
+ }
68
+
69
+ .header-subtitle {
70
+ color: #6b7280;
71
+ font-size: 0.875rem;
72
+ margin: 0.25rem 0 0 0;
73
+ }
74
+
75
+ /* Sidebar styling */
76
+ .css-1d391kg {
77
+ background-color: #f9fafb;
78
+ }
79
+
80
+ .sidebar-title {
81
+ font-weight: 600;
82
+ color: #374151;
83
+ margin-bottom: 1rem;
84
+ }
85
+
86
+ /* Session buttons */
87
+ .session-btn {
88
+ background: white;
89
+ border: 1px solid #e5e7eb;
90
+ border-radius: 8px;
91
+ padding: 12px;
92
+ margin: 6px 0;
93
+ width: 100%;
94
+ text-align: left;
95
+ cursor: pointer;
96
+ transition: all 0.2s;
97
+ color: #374151;
98
+ }
99
+
100
+ .session-btn:hover {
101
+ border-color: #3b82f6;
102
+ background: #f8fafc;
103
+ }
104
+
105
+ .session-btn.active {
106
+ background: #eff6ff;
107
+ border-color: #3b82f6;
108
+ color: #1d4ed8;
109
+ }
110
+
111
+ /* Chat container */
112
+ .chat-container {
113
+ background: #ffffff;
114
+ border: 1px solid #e5e7eb;
115
+ border-radius: 12px;
116
+ height: 500px;
117
+ overflow-y: auto;
118
+ padding: 1rem;
119
+ margin-bottom: 1rem;
120
+ }
121
+
122
+ /* Message styling */
123
+ .message {
124
+ margin-bottom: 1rem;
125
+ display: flex;
126
+ }
127
+
128
+ .message.user {
129
+ justify-content: flex-end;
130
+ }
131
+
132
+ .message-content {
133
+ max-width: 70%;
134
+ padding: 12px 16px;
135
+ border-radius: 12px;
136
+ line-height: 1.5;
137
+ }
138
+
139
+ .message.user .message-content {
140
+ background: #3b82f6;
141
+ color: white;
142
+ border-bottom-right-radius: 4px;
143
+ }
144
+
145
+ .message.bot .message-content {
146
+ background: #f3f4f6;
147
+ color: #374151;
148
+ border: 1px solid #e5e7eb;
149
+ border-bottom-left-radius: 4px;
150
+ }
151
+
152
+ .message-label {
153
+ font-size: 0.75rem;
154
+ font-weight: 500;
155
+ margin-bottom: 4px;
156
+ opacity: 0.7;
157
+ }
158
+
159
+ .message-tools {
160
+ font-size: 0.75rem;
161
+ opacity: 0.6;
162
+ margin-top: 4px;
163
+ }
164
+
165
+ /* Input area */
166
+ .input-container {
167
+ background: white;
168
+ border: 1px solid #e5e7eb;
169
+ border-radius: 12px;
170
+ padding: 1rem;
171
+ }
172
+
173
+ /* Buttons */
174
+ .stButton > button {
175
+ background: #3b82f6;
176
+ color: white;
177
+ border: none;
178
+ border-radius: 8px;
179
+ font-weight: 500;
180
+ padding: 0.5rem 1rem;
181
+ transition: background 0.2s;
182
+ }
183
+
184
+ .stButton > button:hover {
185
+ background: #2563eb;
186
+ }
187
+
188
+ /* Status indicators */
189
+ .status {
190
+ font-size: 0.875rem;
191
+ padding: 4px 8px;
192
+ border-radius: 6px;
193
+ font-weight: 500;
194
+ }
195
+
196
+ .status.connected {
197
+ background: #dcfce7;
198
+ color: #166534;
199
+ }
200
+
201
+ .status.error {
202
+ background: #fee2e2;
203
+ color: #dc2626;
204
+ }
205
+
206
+ /* Thinking indicator */
207
+ .thinking {
208
+ background: #f3f4f6;
209
+ padding: 8px 12px;
210
+ border-radius: 8px;
211
+ color: #6b7280;
212
+ font-size: 0.875rem;
213
+ margin-bottom: 1rem;
214
+ display: inline-block;
215
+ }
216
+
217
+ /* Hide streamlit elements */
218
+ #MainMenu {visibility: hidden;}
219
+ footer {visibility: hidden;}
220
+ header {visibility: hidden;}
221
+
222
+ /* Custom scrollbar */
223
+ .chat-container::-webkit-scrollbar {
224
+ width: 6px;
225
+ }
226
+
227
+ .chat-container::-webkit-scrollbar-track {
228
+ background: #f1f5f9;
229
+ border-radius: 3px;
230
+ }
231
+
232
+ .chat-container::-webkit-scrollbar-thumb {
233
+ background: #cbd5e1;
234
+ border-radius: 3px;
235
+ }
236
+
237
+ .chat-container::-webkit-scrollbar-thumb:hover {
238
+ background: #94a3b8;
239
+ }
240
+ </style>
241
+ """, unsafe_allow_html=True)
242
+
243
+ # Rate Limiter Class
244
+ class RateLimiter:
245
+ def __init__(self, max_requests=10, time_window=60):
246
+ self.requests = defaultdict(list)
247
+ self.max_requests = max_requests
248
+ self.time_window = time_window
249
+
250
+ def check_limit(self, session_id):
251
+ now = time.time()
252
+ # Clean old requests
253
+ self.requests[session_id] = [
254
+ t for t in self.requests[session_id]
255
+ if now - t < self.time_window
256
+ ]
257
+
258
+ if len(self.requests[session_id]) >= self.max_requests:
259
+ return False, f"Rate limit exceeded. Please wait before sending more messages."
260
+
261
+ self.requests[session_id].append(now)
262
+ return True, ""
263
+
264
+ # Initialize session state
265
+ if 'initialized' not in st.session_state:
266
+ st.session_state.initialized = False
267
+ st.session_state.agent_executor = None
268
+ st.session_state.chat_sessions = {}
269
+ st.session_state.current_session_id = None
270
+ st.session_state.connection_status = "Not Connected"
271
+ st.session_state.sidebar_collapsed = False
272
+ st.session_state.rate_limiter = RateLimiter(max_requests=20, time_window=60)
273
+ st.session_state.supabase = None
274
+
275
+ # Keys configuration
276
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
277
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
278
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
279
+
280
+ def validate_input(user_input: str) -> tuple:
281
+ """Validate user input"""
282
+ if not user_input or len(user_input.strip()) < 3:
283
+ return False, "Query too short. Please provide more details (at least 3 characters)."
284
+
285
+ if len(user_input) > 2000:
286
+ return False, "Query too long. Please keep it under 2000 characters."
287
+
288
+ # Check for potential dangerous patterns
289
+ dangerous_patterns = ['__import__', 'exec(', 'eval(', 'os.system', 'subprocess']
290
+ if any(pattern in user_input.lower() for pattern in dangerous_patterns):
291
+ return False, "Invalid input detected. Please rephrase your question."
292
+
293
+ return True, ""
294
+
295
+ def save_session_to_db(session_id, session_data):
296
+ """Save session to Supabase"""
297
+ try:
298
+ if st.session_state.supabase is None:
299
+ return
300
+
301
+ # Prepare messages for JSON serialization
302
+ messages_json = []
303
+ for msg in session_data['messages']:
304
+ msg_copy = msg.copy()
305
+ if 'timestamp' in msg_copy:
306
+ msg_copy['timestamp'] = msg_copy['timestamp'].isoformat()
307
+ messages_json.append(msg_copy)
308
+
309
+ st.session_state.supabase.table('chat_sessions').upsert({
310
+ 'id': session_id,
311
+ 'name': session_data['name'],
312
+ 'created_at': session_data['created_at'].isoformat(),
313
+ 'messages': json.dumps(messages_json),
314
+ 'updated_at': datetime.now().isoformat()
315
+ }).execute()
316
+ except Exception as e:
317
+ st.warning(f"Could not save session to database: {str(e)}")
318
+
319
+ def load_sessions_from_db():
320
+ """Load all sessions from database"""
321
+ try:
322
+ if st.session_state.supabase is None:
323
+ return {}
324
+
325
+ response = st.session_state.supabase.table('chat_sessions').select('*').order('created_at', desc=True).execute()
326
+
327
+ sessions = {}
328
+ for session in response.data:
329
+ session_id = session['id']
330
+ messages = json.loads(session['messages']) if session['messages'] else []
331
+
332
+ # Convert timestamp strings back to datetime
333
+ for msg in messages:
334
+ if 'timestamp' in msg and isinstance(msg['timestamp'], str):
335
+ msg['timestamp'] = datetime.fromisoformat(msg['timestamp'])
336
+
337
+ sessions[session_id] = {
338
+ 'id': session_id,
339
+ 'name': session['name'],
340
+ 'created_at': datetime.fromisoformat(session['created_at']),
341
+ 'messages': messages,
342
+ 'session_memory': [],
343
+ 'history': []
344
+ }
345
+
346
+ # Rebuild session memory from messages
347
+ for msg in messages:
348
+ if msg['type'] == 'user':
349
+ sessions[session_id]['session_memory'].append(HumanMessage(content=msg['content']))
350
+ else:
351
+ sessions[session_id]['session_memory'].append(AIMessage(content=msg['content']))
352
+
353
+ return sessions
354
+ except Exception as e:
355
+ st.warning(f"Could not load sessions from database: {str(e)}")
356
+ return {}
357
+
358
+ @st.cache_resource
359
+ def initialize_agent():
360
+ """Initialize the LangChain agent with caching"""
361
+ try:
362
+ # Connect to Supabase
363
+ supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
364
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
365
+
366
+ # Reconnect to existing vector store
367
+ vector_store = SupabaseVectorStore(
368
+ client=supabase,
369
+ embedding=embeddings,
370
+ table_name="documents"
371
+ )
372
+
373
+ # LLM setup with streaming
374
+ llm = ChatOpenAI(
375
+ model="gpt-4o-mini",
376
+ temperature=0,
377
+ openai_api_key=OPENAI_API_KEY,
378
+ streaming=False
379
+ )
380
+
381
+ # Create base retriever with better search parameters
382
+ base_retriever = vector_store.as_retriever(
383
+ search_type="similarity",
384
+ search_kwargs={
385
+ "k": 5,
386
+
387
+ }
388
+ )
389
+
390
+ # Add contextual compression for better retrieval
391
+ compressor = LLMChainExtractor.from_llm(llm)
392
+ compression_retriever = ContextualCompressionRetriever(
393
+ base_compressor=compressor,
394
+ base_retriever=base_retriever
395
+ )
396
+
397
+ # QA Chain for better answers
398
+ qa_chain = RetrievalQA.from_chain_type(
399
+ llm=llm,
400
+ chain_type="stuff",
401
+ retriever=base_retriever,
402
+ return_source_documents=True
403
+ )
404
+
405
+ def qa_with_sources(query):
406
+ """Question answering with source tracking"""
407
+ try:
408
+ result = qa_chain.invoke({"query": query})
409
+ return result["result"]
410
+ except Exception as e:
411
+ return f"Error retrieving information: {str(e)}"
412
+
413
+ # Document search tool
414
+ Retriver_tool = Tool(
415
+ name="document_search",
416
+ func=qa_with_sources,
417
+ description="Search and answer questions based on uploaded documents. Use this for ANY question about companies, acquisitions, financial data, or specific information that might be in the documents.",
418
+ )
419
+
420
+ # General QA tool
421
+ def general_qa(query):
422
+ """General question answering"""
423
+ try:
424
+ return llm.invoke(query).content
425
+ except Exception as e:
426
+ return f"Error: {str(e)}"
427
+
428
+ qa_tool = Tool(
429
+ name="general_question",
430
+ func=general_qa,
431
+ description="Answer general knowledge questions NOT related to the uploaded documents.",
432
+ )
433
+
434
+ # Summary tool
435
+ def summarize_text(text):
436
+ """Summarize text"""
437
+ try:
438
+ prompt = f"Summarize the following concisely:\n\n{text}"
439
+ return llm.invoke(prompt).content
440
+ except Exception as e:
441
+ return f"Error: {str(e)}"
442
+
443
+ summary_tool = Tool(
444
+ name="summarize",
445
+ func=summarize_text,
446
+ description="Summarize text or information.",
447
+ )
448
+
449
+ # Explanation tool
450
+ def explain_concept(concept):
451
+ """Explain concepts"""
452
+ try:
453
+ prompt = f"Explain clearly:\n\n{concept}"
454
+ return llm.invoke(prompt).content
455
+ except Exception as e:
456
+ return f"Error: {str(e)}"
457
+
458
+ explanation_tool = Tool(
459
+ name="explain",
460
+ func=explain_concept,
461
+ description="Explain concepts or ideas in detail.",
462
+ )
463
+
464
+ tools = [Retriver_tool, qa_tool, summary_tool, explanation_tool]
465
+ tool_names = ", ".join([tool.name for tool in tools])
466
+
467
+ # Custom ReAct prompt
468
+ react_prompt = PromptTemplate.from_template(
469
+ """Answer the following question as best you can. You have access to the following tools:
470
+
471
+ {tools}
472
+
473
+ Use this format STRICTLY:
474
+
475
+ Thought: Think about what needs to be done
476
+ Action: The exact tool name from [{tool_names}]
477
+ Action Input: The specific input for the tool
478
+ Observation: The result of the action
479
+ ... (repeat Thought/Action/Action Input/Observation as needed)
480
+ Thought: I now know the final answer
481
+ Final Answer: Provide a clear, complete answer
482
+
483
+ IMPORTANT:
484
+ 1. For questions about documents, companies, or data, ALWAYS use "document_search" FIRST
485
+ 2. Action Input should be the question/text only - no quotes or special formatting
486
+ 3. Always provide a Final Answer
487
+
488
+ Previous conversation:
489
+ {chat_history}
490
+
491
+ Question: {input}
492
+ {agent_scratchpad}"""
493
+ ).partial(
494
+ tools="\n".join([f"{tool.name}: {tool.description}" for tool in tools]),
495
+ tool_names=tool_names
496
+ )
497
+
498
+ # Create agent
499
+ custom_agent = create_react_agent(llm=llm, tools=tools, prompt=react_prompt)
500
+
501
+ return custom_agent, tools, supabase, "Connected Successfully"
502
+
503
+ except Exception as e:
504
+ return None, None, None, f"Connection Error: {str(e)}"
505
+ def create_new_session():
506
+ """Create a new chat session"""
507
+ session_id = str(uuid.uuid4())
508
+ session_name = f"Chat {len(st.session_state.chat_sessions) + 1}"
509
+
510
+ # Initialize session data
511
+ st.session_state.chat_sessions[session_id] = {
512
+ "id": session_id,
513
+ "name": session_name,
514
+ "created_at": datetime.now(),
515
+ "messages": [],
516
+ "session_memory": [],
517
+ "history": []
518
+ }
519
+
520
+ st.session_state.current_session_id = session_id
521
+
522
+ # Save to database
523
+ save_session_to_db(session_id, st.session_state.chat_sessions[session_id])
524
+
525
+ return session_id
526
+
527
+ def get_recent_context(session_data, max_messages=10):
528
+ """Get only recent messages to avoid context overflow"""
529
+ recent_messages = session_data["session_memory"][-max_messages*2:] if len(session_data["session_memory"]) > max_messages*2 else session_data["session_memory"]
530
+ return recent_messages
531
+
532
+ def get_agent_executor_for_session(session_id):
533
+ """Get agent executor with session-specific memory"""
534
+ if not st.session_state.initialized:
535
+ return None
536
+
537
+ session_data = st.session_state.chat_sessions[session_id]
538
+
539
+ # Get recent context to avoid overwhelming the model
540
+ recent_memory = get_recent_context(session_data, max_messages=8)
541
+
542
+ # Create summary buffer memory for this session
543
+ memory = ConversationSummaryBufferMemory(
544
+ llm=ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY),
545
+ memory_key="chat_history",
546
+ return_messages=True,
547
+ output_key="output",
548
+ max_token_limit=1000
549
+ )
550
+
551
+ # Restore recent session memory
552
+ memory.chat_memory.messages = recent_memory
553
+
554
+ # Create agent executor
555
+ agent_executor = AgentExecutor(
556
+ agent=st.session_state.agent,
557
+ tools=st.session_state.tools,
558
+ memory=memory,
559
+ verbose=True,
560
+ handle_parsing_errors="Check your output and make sure it follows the correct format.",
561
+ return_intermediate_steps=True,
562
+ max_iterations=5,
563
+ max_execution_time=45,
564
+ )
565
+
566
+ return agent_executor
567
+
568
+
569
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
570
+ def get_agent_response(agent_executor, user_input):
571
+ """Get response with retry logic"""
572
+ return agent_executor.invoke({"input": user_input})
573
+
574
+ def get_response_with_fallback(agent_executor, user_input):
575
+ """Try multiple strategies if initial response fails"""
576
+ try:
577
+ # Primary attempt
578
+ return get_agent_response(agent_executor, user_input)
579
+ except Exception as e1:
580
+ st.warning(f"Primary attempt failed, trying simplified approach...")
581
+ try:
582
+ # Fallback 1: Try with simpler prompt
583
+ simplified_input = f"Please answer briefly: {user_input}"
584
+ return agent_executor.invoke({"input": simplified_input})
585
+ except Exception as e2:
586
+ st.warning(f"Simplified approach failed, using direct LLM...")
587
+ try:
588
+ # Fallback 2: Direct LLM call without tools
589
+ llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY)
590
+ response_content = llm.invoke(user_input).content
591
+ return {"output": response_content, "intermediate_steps": []}
592
+ except Exception as e3:
593
+ raise Exception(f"All attempts failed: {str(e3)}")
594
+
595
+ def track_metrics(session_data):
596
+ """Track conversation metrics"""
597
+ total_messages = len(session_data["messages"])
598
+ user_messages = sum(1 for m in session_data["messages"] if m["type"] == "user")
599
+ bot_messages = total_messages - user_messages
600
+
601
+ # Calculate session duration
602
+ if session_data["messages"]:
603
+ first_msg = session_data["messages"][0]["timestamp"]
604
+ last_msg = session_data["messages"][-1]["timestamp"]
605
+ duration = (last_msg - first_msg).seconds
606
+ else:
607
+ duration = 0
608
+
609
+ return {
610
+ "total_messages": total_messages,
611
+ "user_messages": user_messages,
612
+ "bot_messages": bot_messages,
613
+ "session_duration": duration
614
+ }
615
+
616
+ def main():
617
+ # Header
618
+ st.markdown("""
619
+ <div class="header-container">
620
+ <h1 class="header-title">πŸ€– AI Document Assistant</h1>
621
+ <p class="header-subtitle">Intelligent document analysis powered by LangChain</p>
622
+ </div>
623
+ """, unsafe_allow_html=True)
624
+
625
+ # Initialize agent if not done
626
+ if not st.session_state.initialized:
627
+ with st.spinner("Initializing AI Agent..."):
628
+ agent, tools, supabase, status = initialize_agent()
629
+ if agent is not None:
630
+ st.session_state.agent = agent
631
+ st.session_state.tools = tools
632
+ st.session_state.supabase = supabase
633
+ st.session_state.connection_status = status
634
+ st.session_state.initialized = True
635
+
636
+ # Load existing sessions from database
637
+ loaded_sessions = load_sessions_from_db()
638
+ if loaded_sessions:
639
+ st.session_state.chat_sessions = loaded_sessions
640
+ st.session_state.current_session_id = list(loaded_sessions.keys())[0]
641
+ else:
642
+ st.session_state.connection_status = status
643
+
644
+ # Sidebar for session management
645
+ with st.sidebar:
646
+ st.markdown('<p class="sidebar-title">πŸ’¬ Chat Sessions</p>', unsafe_allow_html=True)
647
+
648
+ # Connection status
649
+ status_class = "connected" if st.session_state.connection_status == "Connected Successfully" else "error"
650
+ status_text = "🟒 Connected" if status_class == "connected" else f"πŸ”΄ {st.session_state.connection_status}"
651
+ st.markdown(f'<div class="status {status_class}">{status_text}</div>', unsafe_allow_html=True)
652
+
653
+ st.markdown("---")
654
+
655
+ # New session button
656
+ if st.button("+ New Chat", use_container_width=True):
657
+ create_new_session()
658
+ st.rerun()
659
+
660
+ # Display sessions
661
+ if st.session_state.chat_sessions:
662
+ for session_id, session_data in st.session_state.chat_sessions.items():
663
+ is_active = session_id == st.session_state.current_session_id
664
+
665
+ if st.button(
666
+ f"{session_data['name']}\n{len(session_data['messages'])} messages",
667
+ key=f"session_{session_id}",
668
+ use_container_width=True
669
+ ):
670
+ st.session_state.current_session_id = session_id
671
+ st.rerun()
672
+
673
+ # Session actions
674
+ if st.session_state.current_session_id:
675
+ st.markdown("---")
676
+
677
+ # Rename session
678
+ new_name = st.text_input(
679
+ "Session Name:",
680
+ value=st.session_state.chat_sessions[st.session_state.current_session_id]["name"]
681
+ )
682
+ if st.button("Save Name", key="save_name"):
683
+ st.session_state.chat_sessions[st.session_state.current_session_id]["name"] = new_name
684
+ save_session_to_db(st.session_state.current_session_id, st.session_state.chat_sessions[st.session_state.current_session_id])
685
+ st.success("Name updated!")
686
+ st.rerun()
687
+
688
+ # Show session metrics
689
+ if st.session_state.chat_sessions[st.session_state.current_session_id]["messages"]:
690
+ metrics = track_metrics(st.session_state.chat_sessions[st.session_state.current_session_id])
691
+ st.markdown("---")
692
+ st.markdown("**Session Stats:**")
693
+ st.text(f"πŸ“Š Messages: {metrics['total_messages']}")
694
+ st.text(f"⏱️ Duration: {metrics['session_duration']}s")
695
+
696
+ # Delete session
697
+ if len(st.session_state.chat_sessions) > 1:
698
+ st.markdown("---")
699
+ if st.button("πŸ—‘οΈ Delete Chat", key="delete_session"):
700
+ # Delete from database
701
+ if st.session_state.supabase:
702
+ try:
703
+ st.session_state.supabase.table('chat_sessions').delete().eq('id', st.session_state.current_session_id).execute()
704
+ except:
705
+ pass
706
+
707
+ del st.session_state.chat_sessions[st.session_state.current_session_id]
708
+ st.session_state.current_session_id = list(st.session_state.chat_sessions.keys())[0]
709
+ st.rerun()
710
+
711
+ # Main content
712
+ if not st.session_state.initialized:
713
+ st.error("⚠️ Agent initialization failed. Please check your configuration.")
714
+ return
715
+
716
+ # Create default session if none exists
717
+ if not st.session_state.chat_sessions:
718
+ create_new_session()
719
+
720
+ # Ensure current session exists
721
+ if st.session_state.current_session_id not in st.session_state.chat_sessions:
722
+ st.session_state.current_session_id = list(st.session_state.chat_sessions.keys())[0]
723
+
724
+ current_session = st.session_state.chat_sessions[st.session_state.current_session_id]
725
+
726
+ # Chat messages display
727
+ if current_session["messages"]:
728
+ for message in current_session["messages"]:
729
+ if message["type"] == "user":
730
+ st.markdown(f'''
731
+ <div class="message user">
732
+ <div class="message-content">
733
+ <div class="message-label">You</div>
734
+ {message["content"]}
735
+ </div>
736
+ </div>
737
+ ''', unsafe_allow_html=True)
738
+ else:
739
+ tools_info = ""
740
+ if message.get('tools_used'):
741
+ tools_info = f'<div class="message-tools">πŸ”§ Tools: {", ".join(message["tools_used"])}</div>'
742
+
743
+ sources_info = ""
744
+ if message.get('sources'):
745
+ sources_info = f'<div class="message-tools">πŸ“š Sources: {len(message["sources"])} documents</div>'
746
+
747
+ st.markdown(f'''
748
+ <div class="message bot">
749
+ <div class="message-content">
750
+ <div class="message-label">Assistant</div>
751
+ {message["content"]}
752
+ {tools_info}
753
+ {sources_info}
754
+
755
+ </div>
756
+ </div>
757
+ ''', unsafe_allow_html=True)
758
+ else:
759
+ st.markdown("""
760
+ <div style="text-align: center; color: #6b7280; padding: 2rem;">
761
+ πŸ‘‹ Start a conversation by asking a question about your documents
762
+ </div>
763
+ """, unsafe_allow_html=True)
764
+
765
+ # Input area
766
+ with st.form("chat_form", clear_on_submit=True):
767
+ col1, col2 = st.columns([5, 1])
768
+ with col1:
769
+ user_input = st.text_area(
770
+ "Message",
771
+ placeholder="Ask a question about your documents...",
772
+ height=80,
773
+ label_visibility="collapsed"
774
+ )
775
+ with col2:
776
+ st.markdown("<div style='height: 20px;'></div>", unsafe_allow_html=True)
777
+ submit_button = st.form_submit_button("Send", use_container_width=True)
778
+
779
+ # Process user input
780
+ if submit_button and user_input.strip():
781
+ # Validate input
782
+ is_valid, error_msg = validate_input(user_input)
783
+ if not is_valid:
784
+ st.error(error_msg)
785
+ return
786
+
787
+ # Check rate limit
788
+ can_proceed, rate_limit_msg = st.session_state.rate_limiter.check_limit(st.session_state.current_session_id)
789
+ if not can_proceed:
790
+ st.error(rate_limit_msg)
791
+ return
792
+
793
+ # Add user message to session
794
+ user_message = {
795
+ "type": "user",
796
+ "content": user_input,
797
+ "timestamp": datetime.now()
798
+ }
799
+ current_session["messages"].append(user_message)
800
+ current_session["session_memory"].append(HumanMessage(content=user_input))
801
+
802
+ # Show thinking indicator
803
+ thinking_placeholder = st.empty()
804
+ thinking_placeholder.markdown('<div class="thinking">πŸ€” Thinking...</div>', unsafe_allow_html=True)
805
+
806
+ try:
807
+ # Get agent executor for current session
808
+ agent_executor = get_agent_executor_for_session(st.session_state.current_session_id)
809
+
810
+ # Get response from agent with fallback
811
+ response = get_response_with_fallback(agent_executor, user_input)
812
+ answer = response["output"]
813
+
814
+ # Extract tools used
815
+ tools_used = []
816
+ if "intermediate_steps" in response:
817
+ for step in response["intermediate_steps"]:
818
+ if len(step) > 0 and hasattr(step[0], 'tool'):
819
+ tools_used.append(step[0].tool)
820
+ tools_used = list(set(tools_used)) if tools_used else []
821
+
822
+ # Add bot message to session
823
+ bot_message = {
824
+ "type": "bot",
825
+ "content": answer,
826
+ "timestamp": datetime.now(),
827
+ "tools_used": tools_used
828
+ }
829
+ current_session["messages"].append(bot_message)
830
+ current_session["session_memory"].append(AIMessage(content=answer))
831
+
832
+ # Save session to database
833
+ save_session_to_db(st.session_state.current_session_id, current_session)
834
+
835
+ except Exception as e:
836
+ error_message = {
837
+ "type": "bot",
838
+ "content": f"❌ I encountered an error processing your request. Please try rephrasing your question or try again later.\n\nError: {str(e)}",
839
+ "timestamp": datetime.now()
840
+ }
841
+ current_session["messages"].append(error_message)
842
+ current_session["session_memory"].append(AIMessage(content=error_message["content"]))
843
+
844
+ finally:
845
+ thinking_placeholder.empty()
846
+ st.rerun()
847
+
848
+ if __name__ == "__main__":
849
+ main()