cryogenic22 commited on
Commit
10c877a
·
verified ·
1 Parent(s): 3ad6907

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +76 -5
utils/database.py CHANGED
@@ -14,6 +14,9 @@ from langchain.chat_models import ChatOpenAI
14
  from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
15
  from langchain.agents.format_scratchpad.tools import format_to_tool_messages
16
  from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
 
 
 
17
 
18
  import os
19
  import streamlit as st
@@ -76,7 +79,28 @@ def create_tables(conn):
76
  except Error as e:
77
  st.error(f"Error: {e}")
78
 
79
- # Add this function to your database.py file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def get_documents(conn):
82
  """Retrieve all documents from the database.
@@ -179,10 +203,50 @@ def handle_document_upload(uploaded_files):
179
  return
180
  progress_bar.progress(10)
181
 
182
- # Process documents
 
183
  documents = []
184
  document_names = []
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  # Calculate progress steps per file
187
  progress_per_file = 70 / len(uploaded_files) # 70% of progress for file processing
188
  current_progress = 10
@@ -281,6 +345,7 @@ def handle_document_upload(uploaded_files):
281
  st.session_state.vector_store = None
282
  st.session_state.qa_system = None
283
  st.session_state.chat_ready = False
 
284
 
285
  finally:
286
  # Clean up progress display after 5 seconds if successful
@@ -333,16 +398,22 @@ def display_vector_store_info():
333
 
334
 
335
  def initialize_qa_system(vector_store):
336
- """Initialize QA system with proper chat handling."""
337
  try:
338
  llm = ChatOpenAI(
339
  temperature=0.5,
340
  model_name="gpt-4",
 
341
  api_key=os.environ.get("OPENAI_API_KEY")
342
  )
343
 
344
- # Create retriever function
345
- retriever = vector_store.as_retriever(search_kwargs={"k": 2})
 
 
 
 
 
346
 
347
  # Create a template that enforces clean formatting
348
  prompt = ChatPromptTemplate.from_messages([
 
14
  from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
15
  from langchain.agents.format_scratchpad.tools import format_to_tool_messages
16
  from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+ from langchain_community.document_loaders import PyPDFLoader
19
+ from langchain.vectorstores import FAISS
20
 
21
  import os
22
  import streamlit as st
 
79
  except Error as e:
80
  st.error(f"Error: {e}")
81
 
82
+
83
+ def process_document(file_path):
84
+ """Process a PDF document with proper chunking."""
85
+ # Load PDF
86
+ loader = PyPDFLoader(file_path)
87
+ documents = loader.load()
88
+
89
+ # Create text splitter
90
+ text_splitter = RecursiveCharacterTextSplitter(
91
+ chunk_size=1000,
92
+ chunk_overlap=200,
93
+ length_function=len,
94
+ separators=["\n\n", "\n", " ", ""]
95
+ )
96
+
97
+ # Split documents into chunks
98
+ chunks = text_splitter.split_documents(documents)
99
+
100
+ # Extract text content for database storage
101
+ full_content = "\n".join(doc.page_content for doc in documents)
102
+
103
+ return chunks, full_content
104
 
105
  def get_documents(conn):
106
  """Retrieve all documents from the database.
 
203
  return
204
  progress_bar.progress(10)
205
 
206
+ # Process documents
207
+ all_chunks = []
208
  documents = []
209
  document_names = []
210
 
211
+ progress_per_file = 70 / len(uploaded_files)
212
+ current_progress = 10
213
+
214
+ for idx, uploaded_file in enumerate(uploaded_files):
215
+ file_name = uploaded_file.name
216
+ status_container.info(f"🔄 Processing document {idx + 1}/{len(uploaded_files)}: {file_name}")
217
+
218
+ # Create temporary file
219
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
220
+ tmp_file.write(uploaded_file.getvalue())
221
+ tmp_file.flush()
222
+
223
+ # Process document with chunking
224
+ chunks, content = process_document(tmp_file.name)
225
+
226
+ # Store in database
227
+ doc_id = insert_document(st.session_state.db_conn, file_name, content)
228
+ if not doc_id:
229
+ status_container.error(f"❌ Failed to store document: {file_name}")
230
+ continue
231
+
232
+ # Add chunks with metadata
233
+ for chunk in chunks:
234
+ chunk.metadata["source"] = file_name
235
+ all_chunks.extend(chunks)
236
+
237
+ documents.append(content)
238
+ document_names.append(file_name)
239
+
240
+ current_progress += progress_per_file
241
+ progress_bar.progress(int(current_progress))
242
+
243
+ # Initialize vector store with chunks instead of full documents
244
+ status_container.info("🔄 Initializing vector store...")
245
+ vector_store = FAISS.from_documents(
246
+ all_chunks,
247
+ embeddings
248
+ )
249
+
250
  # Calculate progress steps per file
251
  progress_per_file = 70 / len(uploaded_files) # 70% of progress for file processing
252
  current_progress = 10
 
345
  st.session_state.vector_store = None
346
  st.session_state.qa_system = None
347
  st.session_state.chat_ready = False
348
+ except Exception as e:
349
 
350
  finally:
351
  # Clean up progress display after 5 seconds if successful
 
398
 
399
 
400
  def initialize_qa_system(vector_store):
401
+ """Initialize QA system with optimized retrieval."""
402
  try:
403
  llm = ChatOpenAI(
404
  temperature=0.5,
405
  model_name="gpt-4",
406
+ max_tokens=4000, # Explicitly set max tokens
407
  api_key=os.environ.get("OPENAI_API_KEY")
408
  )
409
 
410
+ # Optimize retriever settings
411
+ retriever = vector_store.as_retriever(
412
+ search_kwargs={
413
+ "k": 3, # Retrieve fewer, more relevant chunks
414
+ "fetch_k": 5 # Consider more candidates before selecting top k
415
+ }
416
+ )
417
 
418
  # Create a template that enforces clean formatting
419
  prompt = ChatPromptTemplate.from_messages([