RFP_Agent / doc_manager.py
cryogenic22's picture
Create doc_manager.py
c878d9c verified
import streamlit as st
import sqlite3
from pathlib import Path
from typing import List, Dict, Optional
from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
import tempfile
import os
class DocumentManager:
def __init__(self, base_path: str = "/data"):
"""Initialize document manager with storage paths and database."""
self.base_path = Path(base_path)
self.collections_path = self.base_path / "collections"
self.db_path = self.base_path / "rfp_analysis.db"
# Create necessary directories
self.collections_path.mkdir(parents=True, exist_ok=True)
# Initialize database
self.conn = self._initialize_database()
# Initialize embedding model
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Text splitter for document processing
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
def _initialize_database(self) -> sqlite3.Connection:
"""Initialize SQLite database with necessary tables."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create tables
cursor.executescript("""
CREATE TABLE IF NOT EXISTS collections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
collection_id INTEGER,
name TEXT NOT NULL,
file_path TEXT NOT NULL,
upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (collection_id) REFERENCES collections (id)
);
CREATE TABLE IF NOT EXISTS document_embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
document_id INTEGER,
embedding_path TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (document_id) REFERENCES documents (id)
);
""")
conn.commit()
return conn
def create_collection(self, name: str) -> int:
"""Create a new collection directory and database entry."""
cursor = self.conn.cursor()
# Create collection in database
cursor.execute(
"INSERT INTO collections (name) VALUES (?)",
(name,)
)
collection_id = cursor.lastrowid
# Create collection directory
collection_path = self.collections_path / str(collection_id)
collection_path.mkdir(exist_ok=True)
self.conn.commit()
return collection_id
def upload_documents(self, files: List, collection_id: Optional[int] = None) -> List[int]:
"""Upload documents to a collection and process them."""
uploaded_ids = []
for file in files:
# Save file to collection directory
if collection_id:
save_dir = self.collections_path / str(collection_id)
else:
save_dir = self.collections_path / "uncategorized"
save_dir.mkdir(exist_ok=True)
file_path = save_dir / file.name
# Save file
with open(file_path, "wb") as f:
f.write(file.getvalue())
# Add to database
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO documents (collection_id, name, file_path)
VALUES (?, ?, ?)
""",
(collection_id, file.name, str(file_path))
)
document_id = cursor.lastrowid
uploaded_ids.append(document_id)
# Process document embeddings
self._process_document_embeddings(document_id, file_path)
self.conn.commit()
return uploaded_ids
def _process_document_embeddings(self, document_id: int, file_path: str):
"""Process document and store embeddings."""
# Load and chunk document
loader = PyPDFLoader(str(file_path))
pages = loader.load()
chunks = self.text_splitter.split_documents(pages)
# Create embeddings
vector_store = FAISS.from_documents(chunks, self.embeddings)
# Save embeddings
embeddings_dir = self.base_path / "embeddings"
embeddings_dir.mkdir(exist_ok=True)
embedding_path = embeddings_dir / f"doc_{document_id}.faiss"
vector_store.save_local(str(embedding_path))
# Store embedding path in database
cursor = self.conn.cursor()
cursor.execute(
"""
INSERT INTO document_embeddings (document_id, embedding_path)
VALUES (?, ?)
""",
(document_id, str(embedding_path))
)
self.conn.commit()
def get_collections(self) -> List[Dict]:
"""Get all collections with their documents."""
cursor = self.conn.cursor()
cursor.execute("""
SELECT
c.id,
c.name,
COUNT(d.id) as doc_count
FROM collections c
LEFT JOIN documents d ON c.id = d.collection_id
GROUP BY c.id
""")
return [
{
'id': row[0],
'name': row[1],
'doc_count': row[2]
}
for row in cursor.fetchall()
]
def get_collection_documents(self, collection_id: Optional[int] = None) -> List[Dict]:
"""Get documents in a collection or all documents if no collection specified."""
cursor = self.conn.cursor()
if collection_id:
cursor.execute("""
SELECT id, name, file_path, upload_date
FROM documents
WHERE collection_id = ?
ORDER BY upload_date DESC
""", (collection_id,))
else:
cursor.execute("""
SELECT id, name, file_path, upload_date
FROM documents
ORDER BY upload_date DESC
""")
return [
{
'id': row[0],
'name': row[1],
'file_path': row[2],
'upload_date': row[3]
}
for row in cursor.fetchall()
]
def initialize_chat(self, document_ids: List[int]) -> Optional[FAISS]:
"""Initialize chat by loading document embeddings."""
embeddings_list = []
cursor = self.conn.cursor()
for doc_id in document_ids:
cursor.execute(
"SELECT embedding_path FROM document_embeddings WHERE document_id = ?",
(doc_id,)
)
result = cursor.fetchone()
if result:
embedding_path = result[0]
if os.path.exists(embedding_path):
embeddings_list.append(FAISS.load_local(embedding_path, self.embeddings))
if embeddings_list:
# Merge all embeddings into one vector store
combined_store = embeddings_list[0]
for store in embeddings_list[1:]:
combined_store.merge_from(store)
return combined_store
return None
class ChatInterface:
def __init__(self, vector_store: FAISS):
"""Initialize chat interface with vector store."""
self.vector_store = vector_store
self.llm = ChatOpenAI(temperature=0.5, model_name="gpt-4")
# Initialize prompt template
self.prompt = ChatPromptTemplate.from_messages([
("system", "You are an RFP analysis expert. Answer questions based on the provided context."),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}\n\nContext: {context}")
])
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
def display(self):
"""Display chat interface."""
# Display chat history
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.write(message.content)
elif isinstance(message, AIMessage):
with st.chat_message("assistant"):
st.write(message.content)
# Chat input
if prompt := st.chat_input("Ask about your documents..."):
with st.chat_message("user"):
st.write(prompt)
st.session_state.messages.append(HumanMessage(content=prompt))
# Get context from vector store
docs = self.vector_store.similarity_search(prompt)
context = "\n\n".join(doc.page_content for doc in docs)
# Generate response
response = self.llm(self.prompt.format(
input=prompt,
context=context,
chat_history=st.session_state.messages
))
with st.chat_message("assistant"):
st.write(response.content)
st.session_state.messages.append(AIMessage(content=response.content))