import streamlit as st import sqlite3 from pathlib import Path from typing import List, Dict, Optional from datetime import datetime from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.chat_models import ChatOpenAI from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage, AIMessage import tempfile import os class DocumentManager: def __init__(self, base_path: str = "/data"): """Initialize document manager with storage paths and database.""" self.base_path = Path(base_path) self.collections_path = self.base_path / "collections" self.db_path = self.base_path / "rfp_analysis.db" # Create necessary directories self.collections_path.mkdir(parents=True, exist_ok=True) # Initialize database self.conn = self._initialize_database() # Initialize embedding model self.embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # Text splitter for document processing self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, separators=["\n\n", "\n", " ", ""] ) def _initialize_database(self) -> sqlite3.Connection: """Initialize SQLite database with necessary tables.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create tables cursor.executescript(""" CREATE TABLE IF NOT EXISTS collections ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, collection_id INTEGER, name TEXT NOT NULL, file_path TEXT NOT NULL, upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (collection_id) REFERENCES collections (id) ); CREATE TABLE IF NOT EXISTS document_embeddings ( id INTEGER PRIMARY KEY AUTOINCREMENT, document_id INTEGER, embedding_path TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (document_id) REFERENCES documents (id) ); """) conn.commit() return conn def create_collection(self, name: str) -> int: """Create a new collection directory and database entry.""" cursor = self.conn.cursor() # Create collection in database cursor.execute( "INSERT INTO collections (name) VALUES (?)", (name,) ) collection_id = cursor.lastrowid # Create collection directory collection_path = self.collections_path / str(collection_id) collection_path.mkdir(exist_ok=True) self.conn.commit() return collection_id def upload_documents(self, files: List, collection_id: Optional[int] = None) -> List[int]: """Upload documents to a collection and process them.""" uploaded_ids = [] for file in files: # Save file to collection directory if collection_id: save_dir = self.collections_path / str(collection_id) else: save_dir = self.collections_path / "uncategorized" save_dir.mkdir(exist_ok=True) file_path = save_dir / file.name # Save file with open(file_path, "wb") as f: f.write(file.getvalue()) # Add to database cursor = self.conn.cursor() cursor.execute( """ INSERT INTO documents (collection_id, name, file_path) VALUES (?, ?, ?) """, (collection_id, file.name, str(file_path)) ) document_id = cursor.lastrowid uploaded_ids.append(document_id) # Process document embeddings self._process_document_embeddings(document_id, file_path) self.conn.commit() return uploaded_ids def _process_document_embeddings(self, document_id: int, file_path: str): """Process document and store embeddings.""" # Load and chunk document loader = PyPDFLoader(str(file_path)) pages = loader.load() chunks = self.text_splitter.split_documents(pages) # Create embeddings vector_store = FAISS.from_documents(chunks, self.embeddings) # Save embeddings embeddings_dir = self.base_path / "embeddings" embeddings_dir.mkdir(exist_ok=True) embedding_path = embeddings_dir / f"doc_{document_id}.faiss" vector_store.save_local(str(embedding_path)) # Store embedding path in database cursor = self.conn.cursor() cursor.execute( """ INSERT INTO document_embeddings (document_id, embedding_path) VALUES (?, ?) """, (document_id, str(embedding_path)) ) self.conn.commit() def get_collections(self) -> List[Dict]: """Get all collections with their documents.""" cursor = self.conn.cursor() cursor.execute(""" SELECT c.id, c.name, COUNT(d.id) as doc_count FROM collections c LEFT JOIN documents d ON c.id = d.collection_id GROUP BY c.id """) return [ { 'id': row[0], 'name': row[1], 'doc_count': row[2] } for row in cursor.fetchall() ] def get_collection_documents(self, collection_id: Optional[int] = None) -> List[Dict]: """Get documents in a collection or all documents if no collection specified.""" cursor = self.conn.cursor() if collection_id: cursor.execute(""" SELECT id, name, file_path, upload_date FROM documents WHERE collection_id = ? ORDER BY upload_date DESC """, (collection_id,)) else: cursor.execute(""" SELECT id, name, file_path, upload_date FROM documents ORDER BY upload_date DESC """) return [ { 'id': row[0], 'name': row[1], 'file_path': row[2], 'upload_date': row[3] } for row in cursor.fetchall() ] def initialize_chat(self, document_ids: List[int]) -> Optional[FAISS]: """Initialize chat by loading document embeddings.""" embeddings_list = [] cursor = self.conn.cursor() for doc_id in document_ids: cursor.execute( "SELECT embedding_path FROM document_embeddings WHERE document_id = ?", (doc_id,) ) result = cursor.fetchone() if result: embedding_path = result[0] if os.path.exists(embedding_path): embeddings_list.append(FAISS.load_local(embedding_path, self.embeddings)) if embeddings_list: # Merge all embeddings into one vector store combined_store = embeddings_list[0] for store in embeddings_list[1:]: combined_store.merge_from(store) return combined_store return None class ChatInterface: def __init__(self, vector_store: FAISS): """Initialize chat interface with vector store.""" self.vector_store = vector_store self.llm = ChatOpenAI(temperature=0.5, model_name="gpt-4") # Initialize prompt template self.prompt = ChatPromptTemplate.from_messages([ ("system", "You are an RFP analysis expert. Answer questions based on the provided context."), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}\n\nContext: {context}") ]) # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] def display(self): """Display chat interface.""" # Display chat history for message in st.session_state.messages: if isinstance(message, HumanMessage): with st.chat_message("user"): st.write(message.content) elif isinstance(message, AIMessage): with st.chat_message("assistant"): st.write(message.content) # Chat input if prompt := st.chat_input("Ask about your documents..."): with st.chat_message("user"): st.write(prompt) st.session_state.messages.append(HumanMessage(content=prompt)) # Get context from vector store docs = self.vector_store.similarity_search(prompt) context = "\n\n".join(doc.page_content for doc in docs) # Generate response response = self.llm(self.prompt.format( input=prompt, context=context, chat_history=st.session_state.messages )) with st.chat_message("assistant"): st.write(response.content) st.session_state.messages.append(AIMessage(content=response.content))