RFP_Analyzer_Agent / utils /database.py
cryogenic22's picture
Update utils/database.py
14499b7 verified
# utils/database.py
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import (
HumanMessage,
AIMessage,
SystemMessage,
BaseMessage
)
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from utils.document_chunker import DocumentChunker
import os
import streamlit as st
import sqlite3
import traceback
import time
import io
import tempfile
from langchain_community.document_loaders import PyPDFLoader
from sqlite3 import Error
def create_connection(db_file):
"""Create a database connection to the SQLite database."""
conn = None
try:
conn = sqlite3.connect(db_file)
return conn
except Error as e:
st.error("Failed to connect to database. Please try again or contact support.")
return None
# Add this function to your database.py file
def get_db_connection():
"""Get a thread-safe database connection."""
try:
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
db_path = data_dir / 'rfp_analysis.db'
# Create new connection
conn = sqlite3.connect(str(db_path))
# Create tables if they don't exist
create_tables(conn)
return conn
except Exception as e:
st.error(f"Database connection error: {str(e)}")
return None
def create_tables(conn):
"""Create necessary tables in the database."""
try:
sql_create_documents_table = '''
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
content TEXT NOT NULL,
upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
'''
sql_create_queries_table = '''
CREATE TABLE IF NOT EXISTS queries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query TEXT NOT NULL,
response TEXT NOT NULL,
document_id INTEGER,
query_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (document_id) REFERENCES documents (id)
);
'''
sql_create_annotations_table = '''
CREATE TABLE IF NOT EXISTS annotations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
document_id INTEGER NOT NULL,
annotation TEXT NOT NULL,
page_number INTEGER,
annotation_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (document_id) REFERENCES documents (id)
);
'''
conn.execute(sql_create_documents_table)
conn.execute(sql_create_queries_table)
conn.execute(sql_create_annotations_table)
except Error as e:
st.error(f"Error: {e}")
def insert_document(name, content):
"""Insert a document with thread-safe connection."""
try:
conn = get_db_connection()
if conn is None:
return None
cursor = conn.cursor()
cursor.execute(
"INSERT INTO documents (name, content) VALUES (?, ?)",
(name, content)
)
conn.commit()
doc_id = cursor.lastrowid
conn.close()
return doc_id
except Exception as e:
st.error(f"Error inserting document: {str(e)}")
if conn:
conn.rollback()
conn.close()
return None
def get_documents(conn):
"""Retrieve all documents from the database.
Args:
conn: SQLite database connection
Returns:
tuple: (list of document contents, list of document names)
"""
try:
cursor = conn.cursor()
cursor.execute("SELECT content, name FROM documents")
results = cursor.fetchall()
if not results:
return [], []
# Separate contents and names
document_contents = [row[0] for row in results]
document_names = [row[1] for row in results]
return document_contents, document_names
except Error as e:
st.error(f"Error retrieving documents: {e}")
return [], []
def insert_document(conn, name, content):
"""Insert a new document into the database.
Args:
conn: SQLite database connection
name (str): Name of the document
content (str): Content of the document
Returns:
int: ID of the inserted document, or None if insertion failed
"""
try:
cursor = conn.cursor()
sql = '''INSERT INTO documents (name, content)
VALUES (?, ?)'''
cursor.execute(sql, (name, content))
conn.commit()
return cursor.lastrowid
except Error as e:
st.error(f"Error inserting document: {e}")
return None
def verify_vector_store(vector_store):
"""Verify that the vector store has documents loaded.
Args:
vector_store: FAISS vector store instance
Returns:
bool: True if vector store is properly initialized with documents
"""
try:
# Try to perform a simple similarity search
test_results = vector_store.similarity_search("test", k=1)
return len(test_results) > 0
except Exception as e:
st.error(f"Vector store verification failed: {e}")
return False
def handle_document_upload(uploaded_files):
"""Handle document upload with improved chunking and progress tracking."""
# Initialize containers first - before any processing
progress_container = st.empty()
status_container = st.empty()
details_container = st.empty()
progress_bar = progress_container.progress(0)
try:
# Initialize session state variables
if 'qa_system' not in st.session_state:
st.session_state.qa_system = None
if 'vector_store' not in st.session_state:
st.session_state.vector_store = None
# Initialize persistence manager
persistence = PersistenceManager()
# Generate a session ID based on timestamp and files
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# Initialize embeddings (10% progress)
status_container.info("🔄 Initializing embeddings model...")
embeddings = get_embeddings_model()
if not embeddings:
status_container.error("❌ Failed to initialize embeddings model")
return
progress_bar.progress(10)
# Initialize document chunker
chunker = DocumentChunker(
chunk_size=1000,
chunk_overlap=200,
max_tokens_per_chunk=2000
)
# Process documents
document_pairs = [] # List to store (content, filename) pairs
progress_per_file = 70 / len(uploaded_files)
current_progress = 10
for idx, uploaded_file in enumerate(uploaded_files):
file_name = uploaded_file.name
status_container.info(f"🔄 Processing document {idx + 1}/{len(uploaded_files)}: {file_name}")
details_container.text(f"📄 Current file: {file_name}")
# Create temporary file for PDF processing
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_file.flush()
try:
# Load PDF content
loader = PyPDFLoader(tmp_file.name)
pdf_documents = loader.load()
content = "\n".join(doc.page_content for doc in pdf_documents)
# Store original content in database
doc_id = insert_document(st.session_state.db_conn, file_name, content)
if not doc_id:
status_container.error(f"❌ Failed to store document: {file_name}")
continue
document_pairs.append((content, file_name))
finally:
# Ensure temporary file is cleaned up
try:
os.unlink(tmp_file.name)
except Exception as e:
st.warning(f"Could not delete temporary file: {e}")
current_progress += progress_per_file
progress_bar.progress(int(current_progress))
if not document_pairs:
status_container.error("❌ No documents were successfully processed")
return
# Chunk documents (80% progress)
status_container.info("🔄 Chunking documents...")
details_container.text("📑 Splitting documents into manageable chunks...")
chunks, chunk_metadatas = chunker.process_documents(document_pairs)
if not chunks:
status_container.error("❌ Failed to chunk documents")
return
progress_bar.progress(80)
# Save chunks for persistence
persistence.save_chunks(chunks, chunk_metadatas, session_id)
# Initialize vector store (90% progress)
status_container.info("🔄 Initializing vector store...")
details_container.text("🔍 Creating vector embeddings...")
vector_store = initialize_faiss(embeddings, chunks, chunk_metadatas)
if not vector_store:
status_container.error("❌ Failed to initialize vector store")
return
# Save vector store and update session state
persistence.save_vector_store(vector_store, session_id)
st.session_state.vector_store = vector_store
st.session_state.current_session_id = session_id
progress_bar.progress(90)
# Initialize QA system (100% progress)
status_container.info("🔄 Setting up QA system...")
qa_system = initialize_qa_system(vector_store)
if not qa_system:
status_container.error("❌ Failed to initialize QA system")
return
st.session_state.qa_system = qa_system
progress_bar.progress(100)
# Success message
status_container.success("✅ Documents processed successfully!")
details_container.markdown(f"""
🎉 **Ready to chat!**
- Documents processed: {len(document_pairs)}
- Total chunks created: {len(chunks)}
- Average chunk size: {sum(len(chunk) for chunk in chunks) / len(chunks):.0f} characters
- Vector store initialized and saved
- QA system ready
- Session ID: {session_id}
You can now start asking questions about your documents!
""")
st.balloons()
st.session_state.chat_ready = True
except Exception as e:
status_container.error(f"❌ Error processing documents: {str(e)}")
details_container.error(traceback.format_exc())
st.session_state.vector_store = None
st.session_state.qa_system = None
st.session_state.chat_ready = False
finally:
# Clean up progress display after successful processing
if st.session_state.get('qa_system') is not None:
time.sleep(5)
progress_container.empty()
def display_vector_store_info():
"""Display information about the current vector store state."""
if 'vector_store' not in st.session_state:
st.info("ℹ️ No documents loaded yet.")
return
try:
# Get the vector store from session state
vector_store = st.session_state.vector_store
# Get basic stats
test_query = vector_store.similarity_search("test", k=1)
doc_count = len(test_query)
# Create an expander for detailed info
with st.expander("📊 Knowledge Base Status"):
col1, col2 = st.columns(2)
with col1:
st.metric(
label="Documents Loaded",
value=doc_count
)
with col2:
st.metric(
label="System Status",
value="Ready" if verify_vector_store(vector_store) else "Not Ready"
)
# Display sample queries
if verify_vector_store(vector_store):
st.markdown("### 🔍 Sample Document Snippets")
sample_docs = vector_store.similarity_search("", k=3)
for i, doc in enumerate(sample_docs, 1):
with st.container():
st.markdown(f"**Snippet {i}:**")
st.text(doc.page_content[:200] + "...")
except Exception as e:
st.error(f"Error displaying vector store info: {e}")
st.error(traceback.format_exc())
def initialize_qa_system(vector_store):
"""Initialize QA system with proper chat handling."""
try:
llm = ChatOpenAI(
temperature=0.5,
model_name="gpt-4",
api_key=os.environ.get("OPENAI_API_KEY")
)
# Create retriever function
retriever = vector_store.as_retriever(search_kwargs={"k": 2})
# Create a template that enforces clean formatting
prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert consultant specializing in analyzing Request for Proposal (RFP) documents. Your goal is to assist users by providing clear, concise, and professional insights based on the content provided. Please adhere to the following guidelines when crafting your responses:
Begin with a summary that highlights the key findings or answers the main query.
Structured Format: Use clear and descriptive section headers to organize the information logically.
Bullet Points: Utilize bullet points for lists or complex information to enhance readability.
Source Attribution: Cite specific sections or page numbers from the RFP document when referencing information.
Professional Formatting: Maintain a clean and professional layout using Markdown formatting where appropriate (e.g., headings, bold, italics).
Focused Content: Keep your responses concise and directly related to the user's query, avoiding unnecessary information.
Scope Awareness: If a query falls outside the provided information or context, politely acknowledge this and suggest consulting the relevant sections or additional sources.
Confidentiality: Respect the confidentiality of the information provided and avoid sharing any sensitive data beyond the scope of the query.
Tone and Language: Use formal and professional language, ensuring clarity and precision in your responses.
Accuracy: Double-check all information for accuracy and completeness before providing it to the user.
"""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}\n\nContext: {context}")
])
def get_chat_history(inputs):
chat_history = inputs.get("chat_history", [])
if not isinstance(chat_history, list):
return []
return [msg for msg in chat_history if isinstance(msg, BaseMessage)]
def get_context(inputs):
docs = retriever.get_relevant_documents(inputs["input"])
context_parts = []
for doc in docs:
source = doc.metadata.get('source', 'Unknown source')
context_parts.append(f"\nFrom {source}:\n{doc.page_content}")
return "\n".join(context_parts)
chain = (
{
"context": get_context,
"chat_history": get_chat_history,
"input": lambda x: x["input"]
}
| prompt
| llm
)
return chain
except Exception as e:
st.error(f"Error initializing QA system: {e}")
return None
# FAISS vector store initialization
def initialize_faiss(embeddings, documents, document_names):
"""Initialize FAISS vector store."""
try:
from langchain.vectorstores import FAISS
vector_store = FAISS.from_texts(
documents,
embeddings,
metadatas=[{"source": name} for name in document_names],
)
return vector_store
except Exception as e:
st.error(f"Error initializing FAISS: {e}")
return None
# Embeddings model retrieval
@st.cache_resource
def get_embeddings_model():
"""Get the embeddings model."""
try:
from langchain.embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=model_name)
return embeddings
except Exception as e:
st.error(f"Error loading embeddings model: {e}")
return None