Spaces:
Paused
Paused
| # utils/database.py | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain_core.messages import ( | |
| HumanMessage, | |
| AIMessage, | |
| SystemMessage, | |
| BaseMessage | |
| ) | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent | |
| from langchain.agents.format_scratchpad.tools import format_to_tool_messages | |
| from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser | |
| from utils.document_chunker import DocumentChunker | |
| import os | |
| import streamlit as st | |
| import sqlite3 | |
| import traceback | |
| import time | |
| import io | |
| import tempfile | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from sqlite3 import Error | |
| def create_connection(db_file): | |
| """Create a database connection to the SQLite database.""" | |
| conn = None | |
| try: | |
| conn = sqlite3.connect(db_file) | |
| return conn | |
| except Error as e: | |
| st.error("Failed to connect to database. Please try again or contact support.") | |
| return None | |
| # Add this function to your database.py file | |
| def get_db_connection(): | |
| """Get a thread-safe database connection.""" | |
| try: | |
| data_dir = Path("data") | |
| data_dir.mkdir(exist_ok=True) | |
| db_path = data_dir / 'rfp_analysis.db' | |
| # Create new connection | |
| conn = sqlite3.connect(str(db_path)) | |
| # Create tables if they don't exist | |
| create_tables(conn) | |
| return conn | |
| except Exception as e: | |
| st.error(f"Database connection error: {str(e)}") | |
| return None | |
| def create_tables(conn): | |
| """Create necessary tables in the database.""" | |
| try: | |
| sql_create_documents_table = ''' | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| name TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| ''' | |
| sql_create_queries_table = ''' | |
| CREATE TABLE IF NOT EXISTS queries ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| query TEXT NOT NULL, | |
| response TEXT NOT NULL, | |
| document_id INTEGER, | |
| query_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (document_id) REFERENCES documents (id) | |
| ); | |
| ''' | |
| sql_create_annotations_table = ''' | |
| CREATE TABLE IF NOT EXISTS annotations ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| document_id INTEGER NOT NULL, | |
| annotation TEXT NOT NULL, | |
| page_number INTEGER, | |
| annotation_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (document_id) REFERENCES documents (id) | |
| ); | |
| ''' | |
| conn.execute(sql_create_documents_table) | |
| conn.execute(sql_create_queries_table) | |
| conn.execute(sql_create_annotations_table) | |
| except Error as e: | |
| st.error(f"Error: {e}") | |
| def insert_document(name, content): | |
| """Insert a document with thread-safe connection.""" | |
| try: | |
| conn = get_db_connection() | |
| if conn is None: | |
| return None | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO documents (name, content) VALUES (?, ?)", | |
| (name, content) | |
| ) | |
| conn.commit() | |
| doc_id = cursor.lastrowid | |
| conn.close() | |
| return doc_id | |
| except Exception as e: | |
| st.error(f"Error inserting document: {str(e)}") | |
| if conn: | |
| conn.rollback() | |
| conn.close() | |
| return None | |
| def get_documents(conn): | |
| """Retrieve all documents from the database. | |
| Args: | |
| conn: SQLite database connection | |
| Returns: | |
| tuple: (list of document contents, list of document names) | |
| """ | |
| try: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT content, name FROM documents") | |
| results = cursor.fetchall() | |
| if not results: | |
| return [], [] | |
| # Separate contents and names | |
| document_contents = [row[0] for row in results] | |
| document_names = [row[1] for row in results] | |
| return document_contents, document_names | |
| except Error as e: | |
| st.error(f"Error retrieving documents: {e}") | |
| return [], [] | |
| def insert_document(conn, name, content): | |
| """Insert a new document into the database. | |
| Args: | |
| conn: SQLite database connection | |
| name (str): Name of the document | |
| content (str): Content of the document | |
| Returns: | |
| int: ID of the inserted document, or None if insertion failed | |
| """ | |
| try: | |
| cursor = conn.cursor() | |
| sql = '''INSERT INTO documents (name, content) | |
| VALUES (?, ?)''' | |
| cursor.execute(sql, (name, content)) | |
| conn.commit() | |
| return cursor.lastrowid | |
| except Error as e: | |
| st.error(f"Error inserting document: {e}") | |
| return None | |
| def verify_vector_store(vector_store): | |
| """Verify that the vector store has documents loaded. | |
| Args: | |
| vector_store: FAISS vector store instance | |
| Returns: | |
| bool: True if vector store is properly initialized with documents | |
| """ | |
| try: | |
| # Try to perform a simple similarity search | |
| test_results = vector_store.similarity_search("test", k=1) | |
| return len(test_results) > 0 | |
| except Exception as e: | |
| st.error(f"Vector store verification failed: {e}") | |
| return False | |
| def handle_document_upload(uploaded_files): | |
| """Handle document upload with improved chunking and progress tracking.""" | |
| # Initialize containers first - before any processing | |
| progress_container = st.empty() | |
| status_container = st.empty() | |
| details_container = st.empty() | |
| progress_bar = progress_container.progress(0) | |
| try: | |
| # Initialize session state variables | |
| if 'qa_system' not in st.session_state: | |
| st.session_state.qa_system = None | |
| if 'vector_store' not in st.session_state: | |
| st.session_state.vector_store = None | |
| # Initialize persistence manager | |
| persistence = PersistenceManager() | |
| # Generate a session ID based on timestamp and files | |
| session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| # Initialize embeddings (10% progress) | |
| status_container.info("🔄 Initializing embeddings model...") | |
| embeddings = get_embeddings_model() | |
| if not embeddings: | |
| status_container.error("❌ Failed to initialize embeddings model") | |
| return | |
| progress_bar.progress(10) | |
| # Initialize document chunker | |
| chunker = DocumentChunker( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| max_tokens_per_chunk=2000 | |
| ) | |
| # Process documents | |
| document_pairs = [] # List to store (content, filename) pairs | |
| progress_per_file = 70 / len(uploaded_files) | |
| current_progress = 10 | |
| for idx, uploaded_file in enumerate(uploaded_files): | |
| file_name = uploaded_file.name | |
| status_container.info(f"🔄 Processing document {idx + 1}/{len(uploaded_files)}: {file_name}") | |
| details_container.text(f"📄 Current file: {file_name}") | |
| # Create temporary file for PDF processing | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_file.flush() | |
| try: | |
| # Load PDF content | |
| loader = PyPDFLoader(tmp_file.name) | |
| pdf_documents = loader.load() | |
| content = "\n".join(doc.page_content for doc in pdf_documents) | |
| # Store original content in database | |
| doc_id = insert_document(st.session_state.db_conn, file_name, content) | |
| if not doc_id: | |
| status_container.error(f"❌ Failed to store document: {file_name}") | |
| continue | |
| document_pairs.append((content, file_name)) | |
| finally: | |
| # Ensure temporary file is cleaned up | |
| try: | |
| os.unlink(tmp_file.name) | |
| except Exception as e: | |
| st.warning(f"Could not delete temporary file: {e}") | |
| current_progress += progress_per_file | |
| progress_bar.progress(int(current_progress)) | |
| if not document_pairs: | |
| status_container.error("❌ No documents were successfully processed") | |
| return | |
| # Chunk documents (80% progress) | |
| status_container.info("🔄 Chunking documents...") | |
| details_container.text("📑 Splitting documents into manageable chunks...") | |
| chunks, chunk_metadatas = chunker.process_documents(document_pairs) | |
| if not chunks: | |
| status_container.error("❌ Failed to chunk documents") | |
| return | |
| progress_bar.progress(80) | |
| # Save chunks for persistence | |
| persistence.save_chunks(chunks, chunk_metadatas, session_id) | |
| # Initialize vector store (90% progress) | |
| status_container.info("🔄 Initializing vector store...") | |
| details_container.text("🔍 Creating vector embeddings...") | |
| vector_store = initialize_faiss(embeddings, chunks, chunk_metadatas) | |
| if not vector_store: | |
| status_container.error("❌ Failed to initialize vector store") | |
| return | |
| # Save vector store and update session state | |
| persistence.save_vector_store(vector_store, session_id) | |
| st.session_state.vector_store = vector_store | |
| st.session_state.current_session_id = session_id | |
| progress_bar.progress(90) | |
| # Initialize QA system (100% progress) | |
| status_container.info("🔄 Setting up QA system...") | |
| qa_system = initialize_qa_system(vector_store) | |
| if not qa_system: | |
| status_container.error("❌ Failed to initialize QA system") | |
| return | |
| st.session_state.qa_system = qa_system | |
| progress_bar.progress(100) | |
| # Success message | |
| status_container.success("✅ Documents processed successfully!") | |
| details_container.markdown(f""" | |
| 🎉 **Ready to chat!** | |
| - Documents processed: {len(document_pairs)} | |
| - Total chunks created: {len(chunks)} | |
| - Average chunk size: {sum(len(chunk) for chunk in chunks) / len(chunks):.0f} characters | |
| - Vector store initialized and saved | |
| - QA system ready | |
| - Session ID: {session_id} | |
| You can now start asking questions about your documents! | |
| """) | |
| st.balloons() | |
| st.session_state.chat_ready = True | |
| except Exception as e: | |
| status_container.error(f"❌ Error processing documents: {str(e)}") | |
| details_container.error(traceback.format_exc()) | |
| st.session_state.vector_store = None | |
| st.session_state.qa_system = None | |
| st.session_state.chat_ready = False | |
| finally: | |
| # Clean up progress display after successful processing | |
| if st.session_state.get('qa_system') is not None: | |
| time.sleep(5) | |
| progress_container.empty() | |
| def display_vector_store_info(): | |
| """Display information about the current vector store state.""" | |
| if 'vector_store' not in st.session_state: | |
| st.info("ℹ️ No documents loaded yet.") | |
| return | |
| try: | |
| # Get the vector store from session state | |
| vector_store = st.session_state.vector_store | |
| # Get basic stats | |
| test_query = vector_store.similarity_search("test", k=1) | |
| doc_count = len(test_query) | |
| # Create an expander for detailed info | |
| with st.expander("📊 Knowledge Base Status"): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric( | |
| label="Documents Loaded", | |
| value=doc_count | |
| ) | |
| with col2: | |
| st.metric( | |
| label="System Status", | |
| value="Ready" if verify_vector_store(vector_store) else "Not Ready" | |
| ) | |
| # Display sample queries | |
| if verify_vector_store(vector_store): | |
| st.markdown("### 🔍 Sample Document Snippets") | |
| sample_docs = vector_store.similarity_search("", k=3) | |
| for i, doc in enumerate(sample_docs, 1): | |
| with st.container(): | |
| st.markdown(f"**Snippet {i}:**") | |
| st.text(doc.page_content[:200] + "...") | |
| except Exception as e: | |
| st.error(f"Error displaying vector store info: {e}") | |
| st.error(traceback.format_exc()) | |
| def initialize_qa_system(vector_store): | |
| """Initialize QA system with proper chat handling.""" | |
| try: | |
| llm = ChatOpenAI( | |
| temperature=0.5, | |
| model_name="gpt-4", | |
| api_key=os.environ.get("OPENAI_API_KEY") | |
| ) | |
| # Create retriever function | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 2}) | |
| # Create a template that enforces clean formatting | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert consultant specializing in analyzing Request for Proposal (RFP) documents. Your goal is to assist users by providing clear, concise, and professional insights based on the content provided. Please adhere to the following guidelines when crafting your responses: | |
| Begin with a summary that highlights the key findings or answers the main query. | |
| Structured Format: Use clear and descriptive section headers to organize the information logically. | |
| Bullet Points: Utilize bullet points for lists or complex information to enhance readability. | |
| Source Attribution: Cite specific sections or page numbers from the RFP document when referencing information. | |
| Professional Formatting: Maintain a clean and professional layout using Markdown formatting where appropriate (e.g., headings, bold, italics). | |
| Focused Content: Keep your responses concise and directly related to the user's query, avoiding unnecessary information. | |
| Scope Awareness: If a query falls outside the provided information or context, politely acknowledge this and suggest consulting the relevant sections or additional sources. | |
| Confidentiality: Respect the confidentiality of the information provided and avoid sharing any sensitive data beyond the scope of the query. | |
| Tone and Language: Use formal and professional language, ensuring clarity and precision in your responses. | |
| Accuracy: Double-check all information for accuracy and completeness before providing it to the user. | |
| """), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}\n\nContext: {context}") | |
| ]) | |
| def get_chat_history(inputs): | |
| chat_history = inputs.get("chat_history", []) | |
| if not isinstance(chat_history, list): | |
| return [] | |
| return [msg for msg in chat_history if isinstance(msg, BaseMessage)] | |
| def get_context(inputs): | |
| docs = retriever.get_relevant_documents(inputs["input"]) | |
| context_parts = [] | |
| for doc in docs: | |
| source = doc.metadata.get('source', 'Unknown source') | |
| context_parts.append(f"\nFrom {source}:\n{doc.page_content}") | |
| return "\n".join(context_parts) | |
| chain = ( | |
| { | |
| "context": get_context, | |
| "chat_history": get_chat_history, | |
| "input": lambda x: x["input"] | |
| } | |
| | prompt | |
| | llm | |
| ) | |
| return chain | |
| except Exception as e: | |
| st.error(f"Error initializing QA system: {e}") | |
| return None | |
| # FAISS vector store initialization | |
| def initialize_faiss(embeddings, documents, document_names): | |
| """Initialize FAISS vector store.""" | |
| try: | |
| from langchain.vectorstores import FAISS | |
| vector_store = FAISS.from_texts( | |
| documents, | |
| embeddings, | |
| metadatas=[{"source": name} for name in document_names], | |
| ) | |
| return vector_store | |
| except Exception as e: | |
| st.error(f"Error initializing FAISS: {e}") | |
| return None | |
| # Embeddings model retrieval | |
| def get_embeddings_model(): | |
| """Get the embeddings model.""" | |
| try: | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
| return embeddings | |
| except Exception as e: | |
| st.error(f"Error loading embeddings model: {e}") | |
| return None |