|
|
import streamlit as st |
|
|
import sqlite3 |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Optional |
|
|
from datetime import datetime |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
class DocumentManager: |
|
|
def __init__(self, base_path: str = "/data"): |
|
|
"""Initialize document manager with storage paths and database.""" |
|
|
self.base_path = Path(base_path) |
|
|
self.collections_path = self.base_path / "collections" |
|
|
self.db_path = self.base_path / "rfp_analysis.db" |
|
|
|
|
|
|
|
|
self.collections_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.conn = self._initialize_database() |
|
|
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
|
) |
|
|
|
|
|
|
|
|
self.text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, |
|
|
chunk_overlap=200, |
|
|
length_function=len, |
|
|
separators=["\n\n", "\n", " ", ""] |
|
|
) |
|
|
|
|
|
def _initialize_database(self) -> sqlite3.Connection: |
|
|
"""Initialize SQLite database with necessary tables.""" |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.executescript(""" |
|
|
CREATE TABLE IF NOT EXISTS collections ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
name TEXT NOT NULL UNIQUE, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
); |
|
|
|
|
|
CREATE TABLE IF NOT EXISTS documents ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
collection_id INTEGER, |
|
|
name TEXT NOT NULL, |
|
|
file_path TEXT NOT NULL, |
|
|
upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
FOREIGN KEY (collection_id) REFERENCES collections (id) |
|
|
); |
|
|
|
|
|
CREATE TABLE IF NOT EXISTS document_embeddings ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
document_id INTEGER, |
|
|
embedding_path TEXT NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
FOREIGN KEY (document_id) REFERENCES documents (id) |
|
|
); |
|
|
""") |
|
|
|
|
|
conn.commit() |
|
|
return conn |
|
|
|
|
|
def create_collection(self, name: str) -> int: |
|
|
"""Create a new collection directory and database entry.""" |
|
|
cursor = self.conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute( |
|
|
"INSERT INTO collections (name) VALUES (?)", |
|
|
(name,) |
|
|
) |
|
|
collection_id = cursor.lastrowid |
|
|
|
|
|
|
|
|
collection_path = self.collections_path / str(collection_id) |
|
|
collection_path.mkdir(exist_ok=True) |
|
|
|
|
|
self.conn.commit() |
|
|
return collection_id |
|
|
|
|
|
def upload_documents(self, files: List, collection_id: Optional[int] = None) -> List[int]: |
|
|
"""Upload documents to a collection and process them.""" |
|
|
uploaded_ids = [] |
|
|
|
|
|
for file in files: |
|
|
|
|
|
if collection_id: |
|
|
save_dir = self.collections_path / str(collection_id) |
|
|
else: |
|
|
save_dir = self.collections_path / "uncategorized" |
|
|
|
|
|
save_dir.mkdir(exist_ok=True) |
|
|
file_path = save_dir / file.name |
|
|
|
|
|
|
|
|
with open(file_path, "wb") as f: |
|
|
f.write(file.getvalue()) |
|
|
|
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute( |
|
|
""" |
|
|
INSERT INTO documents (collection_id, name, file_path) |
|
|
VALUES (?, ?, ?) |
|
|
""", |
|
|
(collection_id, file.name, str(file_path)) |
|
|
) |
|
|
document_id = cursor.lastrowid |
|
|
uploaded_ids.append(document_id) |
|
|
|
|
|
|
|
|
self._process_document_embeddings(document_id, file_path) |
|
|
|
|
|
self.conn.commit() |
|
|
|
|
|
return uploaded_ids |
|
|
|
|
|
def _process_document_embeddings(self, document_id: int, file_path: str): |
|
|
"""Process document and store embeddings.""" |
|
|
|
|
|
loader = PyPDFLoader(str(file_path)) |
|
|
pages = loader.load() |
|
|
chunks = self.text_splitter.split_documents(pages) |
|
|
|
|
|
|
|
|
vector_store = FAISS.from_documents(chunks, self.embeddings) |
|
|
|
|
|
|
|
|
embeddings_dir = self.base_path / "embeddings" |
|
|
embeddings_dir.mkdir(exist_ok=True) |
|
|
embedding_path = embeddings_dir / f"doc_{document_id}.faiss" |
|
|
vector_store.save_local(str(embedding_path)) |
|
|
|
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute( |
|
|
""" |
|
|
INSERT INTO document_embeddings (document_id, embedding_path) |
|
|
VALUES (?, ?) |
|
|
""", |
|
|
(document_id, str(embedding_path)) |
|
|
) |
|
|
self.conn.commit() |
|
|
|
|
|
def get_collections(self) -> List[Dict]: |
|
|
"""Get all collections with their documents.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
SELECT |
|
|
c.id, |
|
|
c.name, |
|
|
COUNT(d.id) as doc_count |
|
|
FROM collections c |
|
|
LEFT JOIN documents d ON c.id = d.collection_id |
|
|
GROUP BY c.id |
|
|
""") |
|
|
|
|
|
return [ |
|
|
{ |
|
|
'id': row[0], |
|
|
'name': row[1], |
|
|
'doc_count': row[2] |
|
|
} |
|
|
for row in cursor.fetchall() |
|
|
] |
|
|
|
|
|
def get_collection_documents(self, collection_id: Optional[int] = None) -> List[Dict]: |
|
|
"""Get documents in a collection or all documents if no collection specified.""" |
|
|
cursor = self.conn.cursor() |
|
|
|
|
|
if collection_id: |
|
|
cursor.execute(""" |
|
|
SELECT id, name, file_path, upload_date |
|
|
FROM documents |
|
|
WHERE collection_id = ? |
|
|
ORDER BY upload_date DESC |
|
|
""", (collection_id,)) |
|
|
else: |
|
|
cursor.execute(""" |
|
|
SELECT id, name, file_path, upload_date |
|
|
FROM documents |
|
|
ORDER BY upload_date DESC |
|
|
""") |
|
|
|
|
|
return [ |
|
|
{ |
|
|
'id': row[0], |
|
|
'name': row[1], |
|
|
'file_path': row[2], |
|
|
'upload_date': row[3] |
|
|
} |
|
|
for row in cursor.fetchall() |
|
|
] |
|
|
|
|
|
def initialize_chat(self, document_ids: List[int]) -> Optional[FAISS]: |
|
|
"""Initialize chat by loading document embeddings.""" |
|
|
embeddings_list = [] |
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
for doc_id in document_ids: |
|
|
cursor.execute( |
|
|
"SELECT embedding_path FROM document_embeddings WHERE document_id = ?", |
|
|
(doc_id,) |
|
|
) |
|
|
result = cursor.fetchone() |
|
|
if result: |
|
|
embedding_path = result[0] |
|
|
if os.path.exists(embedding_path): |
|
|
embeddings_list.append(FAISS.load_local(embedding_path, self.embeddings)) |
|
|
|
|
|
if embeddings_list: |
|
|
|
|
|
combined_store = embeddings_list[0] |
|
|
for store in embeddings_list[1:]: |
|
|
combined_store.merge_from(store) |
|
|
return combined_store |
|
|
|
|
|
return None |
|
|
|
|
|
class ChatInterface: |
|
|
def __init__(self, vector_store: FAISS): |
|
|
"""Initialize chat interface with vector store.""" |
|
|
self.vector_store = vector_store |
|
|
self.llm = ChatOpenAI(temperature=0.5, model_name="gpt-4") |
|
|
|
|
|
|
|
|
self.prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", "You are an RFP analysis expert. Answer questions based on the provided context."), |
|
|
MessagesPlaceholder(variable_name="chat_history"), |
|
|
("human", "{input}\n\nContext: {context}") |
|
|
]) |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
def display(self): |
|
|
"""Display chat interface.""" |
|
|
|
|
|
for message in st.session_state.messages: |
|
|
if isinstance(message, HumanMessage): |
|
|
with st.chat_message("user"): |
|
|
st.write(message.content) |
|
|
elif isinstance(message, AIMessage): |
|
|
with st.chat_message("assistant"): |
|
|
st.write(message.content) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask about your documents..."): |
|
|
with st.chat_message("user"): |
|
|
st.write(prompt) |
|
|
st.session_state.messages.append(HumanMessage(content=prompt)) |
|
|
|
|
|
|
|
|
docs = self.vector_store.similarity_search(prompt) |
|
|
context = "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
|
|
|
response = self.llm(self.prompt.format( |
|
|
input=prompt, |
|
|
context=context, |
|
|
chat_history=st.session_state.messages |
|
|
)) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
st.write(response.content) |
|
|
st.session_state.messages.append(AIMessage(content=response.content)) |