Spaces:
Build error
Build error
Update utils/database.py
Browse files- utils/database.py +76 -5
utils/database.py
CHANGED
|
@@ -14,6 +14,9 @@ from langchain.chat_models import ChatOpenAI
|
|
| 14 |
from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
|
| 15 |
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
| 16 |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
import os
|
| 19 |
import streamlit as st
|
|
@@ -76,7 +79,28 @@ def create_tables(conn):
|
|
| 76 |
except Error as e:
|
| 77 |
st.error(f"Error: {e}")
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
def get_documents(conn):
|
| 82 |
"""Retrieve all documents from the database.
|
|
@@ -179,10 +203,50 @@ def handle_document_upload(uploaded_files):
|
|
| 179 |
return
|
| 180 |
progress_bar.progress(10)
|
| 181 |
|
| 182 |
-
|
|
|
|
| 183 |
documents = []
|
| 184 |
document_names = []
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
# Calculate progress steps per file
|
| 187 |
progress_per_file = 70 / len(uploaded_files) # 70% of progress for file processing
|
| 188 |
current_progress = 10
|
|
@@ -281,6 +345,7 @@ def handle_document_upload(uploaded_files):
|
|
| 281 |
st.session_state.vector_store = None
|
| 282 |
st.session_state.qa_system = None
|
| 283 |
st.session_state.chat_ready = False
|
|
|
|
| 284 |
|
| 285 |
finally:
|
| 286 |
# Clean up progress display after 5 seconds if successful
|
|
@@ -333,16 +398,22 @@ def display_vector_store_info():
|
|
| 333 |
|
| 334 |
|
| 335 |
def initialize_qa_system(vector_store):
|
| 336 |
-
"""Initialize QA system with
|
| 337 |
try:
|
| 338 |
llm = ChatOpenAI(
|
| 339 |
temperature=0.5,
|
| 340 |
model_name="gpt-4",
|
|
|
|
| 341 |
api_key=os.environ.get("OPENAI_API_KEY")
|
| 342 |
)
|
| 343 |
|
| 344 |
-
#
|
| 345 |
-
retriever = vector_store.as_retriever(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
# Create a template that enforces clean formatting
|
| 348 |
prompt = ChatPromptTemplate.from_messages([
|
|
|
|
| 14 |
from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
|
| 15 |
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
| 16 |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
| 17 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 18 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 19 |
+
from langchain.vectorstores import FAISS
|
| 20 |
|
| 21 |
import os
|
| 22 |
import streamlit as st
|
|
|
|
| 79 |
except Error as e:
|
| 80 |
st.error(f"Error: {e}")
|
| 81 |
|
| 82 |
+
|
| 83 |
+
def process_document(file_path):
|
| 84 |
+
"""Process a PDF document with proper chunking."""
|
| 85 |
+
# Load PDF
|
| 86 |
+
loader = PyPDFLoader(file_path)
|
| 87 |
+
documents = loader.load()
|
| 88 |
+
|
| 89 |
+
# Create text splitter
|
| 90 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 91 |
+
chunk_size=1000,
|
| 92 |
+
chunk_overlap=200,
|
| 93 |
+
length_function=len,
|
| 94 |
+
separators=["\n\n", "\n", " ", ""]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Split documents into chunks
|
| 98 |
+
chunks = text_splitter.split_documents(documents)
|
| 99 |
+
|
| 100 |
+
# Extract text content for database storage
|
| 101 |
+
full_content = "\n".join(doc.page_content for doc in documents)
|
| 102 |
+
|
| 103 |
+
return chunks, full_content
|
| 104 |
|
| 105 |
def get_documents(conn):
|
| 106 |
"""Retrieve all documents from the database.
|
|
|
|
| 203 |
return
|
| 204 |
progress_bar.progress(10)
|
| 205 |
|
| 206 |
+
# Process documents
|
| 207 |
+
all_chunks = []
|
| 208 |
documents = []
|
| 209 |
document_names = []
|
| 210 |
|
| 211 |
+
progress_per_file = 70 / len(uploaded_files)
|
| 212 |
+
current_progress = 10
|
| 213 |
+
|
| 214 |
+
for idx, uploaded_file in enumerate(uploaded_files):
|
| 215 |
+
file_name = uploaded_file.name
|
| 216 |
+
status_container.info(f"🔄 Processing document {idx + 1}/{len(uploaded_files)}: {file_name}")
|
| 217 |
+
|
| 218 |
+
# Create temporary file
|
| 219 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
| 220 |
+
tmp_file.write(uploaded_file.getvalue())
|
| 221 |
+
tmp_file.flush()
|
| 222 |
+
|
| 223 |
+
# Process document with chunking
|
| 224 |
+
chunks, content = process_document(tmp_file.name)
|
| 225 |
+
|
| 226 |
+
# Store in database
|
| 227 |
+
doc_id = insert_document(st.session_state.db_conn, file_name, content)
|
| 228 |
+
if not doc_id:
|
| 229 |
+
status_container.error(f"❌ Failed to store document: {file_name}")
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
# Add chunks with metadata
|
| 233 |
+
for chunk in chunks:
|
| 234 |
+
chunk.metadata["source"] = file_name
|
| 235 |
+
all_chunks.extend(chunks)
|
| 236 |
+
|
| 237 |
+
documents.append(content)
|
| 238 |
+
document_names.append(file_name)
|
| 239 |
+
|
| 240 |
+
current_progress += progress_per_file
|
| 241 |
+
progress_bar.progress(int(current_progress))
|
| 242 |
+
|
| 243 |
+
# Initialize vector store with chunks instead of full documents
|
| 244 |
+
status_container.info("🔄 Initializing vector store...")
|
| 245 |
+
vector_store = FAISS.from_documents(
|
| 246 |
+
all_chunks,
|
| 247 |
+
embeddings
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
# Calculate progress steps per file
|
| 251 |
progress_per_file = 70 / len(uploaded_files) # 70% of progress for file processing
|
| 252 |
current_progress = 10
|
|
|
|
| 345 |
st.session_state.vector_store = None
|
| 346 |
st.session_state.qa_system = None
|
| 347 |
st.session_state.chat_ready = False
|
| 348 |
+
except Exception as e:
|
| 349 |
|
| 350 |
finally:
|
| 351 |
# Clean up progress display after 5 seconds if successful
|
|
|
|
| 398 |
|
| 399 |
|
| 400 |
def initialize_qa_system(vector_store):
|
| 401 |
+
"""Initialize QA system with optimized retrieval."""
|
| 402 |
try:
|
| 403 |
llm = ChatOpenAI(
|
| 404 |
temperature=0.5,
|
| 405 |
model_name="gpt-4",
|
| 406 |
+
max_tokens=4000, # Explicitly set max tokens
|
| 407 |
api_key=os.environ.get("OPENAI_API_KEY")
|
| 408 |
)
|
| 409 |
|
| 410 |
+
# Optimize retriever settings
|
| 411 |
+
retriever = vector_store.as_retriever(
|
| 412 |
+
search_kwargs={
|
| 413 |
+
"k": 3, # Retrieve fewer, more relevant chunks
|
| 414 |
+
"fetch_k": 5 # Consider more candidates before selecting top k
|
| 415 |
+
}
|
| 416 |
+
)
|
| 417 |
|
| 418 |
# Create a template that enforces clean formatting
|
| 419 |
prompt = ChatPromptTemplate.from_messages([
|