File size: 10,051 Bytes
c878d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import streamlit as st
import sqlite3
from pathlib import Path
from typing import List, Dict, Optional
from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
import tempfile
import os

class DocumentManager:
    def __init__(self, base_path: str = "/data"):
        """Initialize document manager with storage paths and database."""
        self.base_path = Path(base_path)
        self.collections_path = self.base_path / "collections"
        self.db_path = self.base_path / "rfp_analysis.db"
        
        # Create necessary directories
        self.collections_path.mkdir(parents=True, exist_ok=True)
        
        # Initialize database
        self.conn = self._initialize_database()
        
        # Initialize embedding model
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )
        
        # Text splitter for document processing
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
            separators=["\n\n", "\n", " ", ""]
        )

    def _initialize_database(self) -> sqlite3.Connection:
        """Initialize SQLite database with necessary tables."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # Create tables
        cursor.executescript("""
            CREATE TABLE IF NOT EXISTS collections (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL UNIQUE,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            );
            
            CREATE TABLE IF NOT EXISTS documents (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                collection_id INTEGER,
                name TEXT NOT NULL,
                file_path TEXT NOT NULL,
                upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (collection_id) REFERENCES collections (id)
            );
            
            CREATE TABLE IF NOT EXISTS document_embeddings (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                document_id INTEGER,
                embedding_path TEXT NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (document_id) REFERENCES documents (id)
            );
        """)
        
        conn.commit()
        return conn

    def create_collection(self, name: str) -> int:
        """Create a new collection directory and database entry."""
        cursor = self.conn.cursor()
        
        # Create collection in database
        cursor.execute(
            "INSERT INTO collections (name) VALUES (?)",
            (name,)
        )
        collection_id = cursor.lastrowid
        
        # Create collection directory
        collection_path = self.collections_path / str(collection_id)
        collection_path.mkdir(exist_ok=True)
        
        self.conn.commit()
        return collection_id

    def upload_documents(self, files: List, collection_id: Optional[int] = None) -> List[int]:
        """Upload documents to a collection and process them."""
        uploaded_ids = []
        
        for file in files:
            # Save file to collection directory
            if collection_id:
                save_dir = self.collections_path / str(collection_id)
            else:
                save_dir = self.collections_path / "uncategorized"
            
            save_dir.mkdir(exist_ok=True)
            file_path = save_dir / file.name
            
            # Save file
            with open(file_path, "wb") as f:
                f.write(file.getvalue())
            
            # Add to database
            cursor = self.conn.cursor()
            cursor.execute(
                """
                INSERT INTO documents (collection_id, name, file_path)
                VALUES (?, ?, ?)
                """,
                (collection_id, file.name, str(file_path))
            )
            document_id = cursor.lastrowid
            uploaded_ids.append(document_id)
            
            # Process document embeddings
            self._process_document_embeddings(document_id, file_path)
            
            self.conn.commit()
        
        return uploaded_ids

    def _process_document_embeddings(self, document_id: int, file_path: str):
        """Process document and store embeddings."""
        # Load and chunk document
        loader = PyPDFLoader(str(file_path))
        pages = loader.load()
        chunks = self.text_splitter.split_documents(pages)
        
        # Create embeddings
        vector_store = FAISS.from_documents(chunks, self.embeddings)
        
        # Save embeddings
        embeddings_dir = self.base_path / "embeddings"
        embeddings_dir.mkdir(exist_ok=True)
        embedding_path = embeddings_dir / f"doc_{document_id}.faiss"
        vector_store.save_local(str(embedding_path))
        
        # Store embedding path in database
        cursor = self.conn.cursor()
        cursor.execute(
            """
            INSERT INTO document_embeddings (document_id, embedding_path)
            VALUES (?, ?)
            """,
            (document_id, str(embedding_path))
        )
        self.conn.commit()

    def get_collections(self) -> List[Dict]:
        """Get all collections with their documents."""
        cursor = self.conn.cursor()
        cursor.execute("""
            SELECT 
                c.id,
                c.name,
                COUNT(d.id) as doc_count
            FROM collections c
            LEFT JOIN documents d ON c.id = d.collection_id
            GROUP BY c.id
        """)
        
        return [
            {
                'id': row[0],
                'name': row[1],
                'doc_count': row[2]
            }
            for row in cursor.fetchall()
        ]

    def get_collection_documents(self, collection_id: Optional[int] = None) -> List[Dict]:
        """Get documents in a collection or all documents if no collection specified."""
        cursor = self.conn.cursor()
        
        if collection_id:
            cursor.execute("""
                SELECT id, name, file_path, upload_date
                FROM documents
                WHERE collection_id = ?
                ORDER BY upload_date DESC
            """, (collection_id,))
        else:
            cursor.execute("""
                SELECT id, name, file_path, upload_date
                FROM documents
                ORDER BY upload_date DESC
            """)
        
        return [
            {
                'id': row[0],
                'name': row[1],
                'file_path': row[2],
                'upload_date': row[3]
            }
            for row in cursor.fetchall()
        ]

    def initialize_chat(self, document_ids: List[int]) -> Optional[FAISS]:
        """Initialize chat by loading document embeddings."""
        embeddings_list = []
        
        cursor = self.conn.cursor()
        for doc_id in document_ids:
            cursor.execute(
                "SELECT embedding_path FROM document_embeddings WHERE document_id = ?",
                (doc_id,)
            )
            result = cursor.fetchone()
            if result:
                embedding_path = result[0]
                if os.path.exists(embedding_path):
                    embeddings_list.append(FAISS.load_local(embedding_path, self.embeddings))
        
        if embeddings_list:
            # Merge all embeddings into one vector store
            combined_store = embeddings_list[0]
            for store in embeddings_list[1:]:
                combined_store.merge_from(store)
            return combined_store
        
        return None

class ChatInterface:
    def __init__(self, vector_store: FAISS):
        """Initialize chat interface with vector store."""
        self.vector_store = vector_store
        self.llm = ChatOpenAI(temperature=0.5, model_name="gpt-4")
        
        # Initialize prompt template
        self.prompt = ChatPromptTemplate.from_messages([
            ("system", "You are an RFP analysis expert. Answer questions based on the provided context."),
            MessagesPlaceholder(variable_name="chat_history"),
            ("human", "{input}\n\nContext: {context}")
        ])
        
        # Initialize chat history
        if "messages" not in st.session_state:
            st.session_state.messages = []

    def display(self):
        """Display chat interface."""
        # Display chat history
        for message in st.session_state.messages:
            if isinstance(message, HumanMessage):
                with st.chat_message("user"):
                    st.write(message.content)
            elif isinstance(message, AIMessage):
                with st.chat_message("assistant"):
                    st.write(message.content)

        # Chat input
        if prompt := st.chat_input("Ask about your documents..."):
            with st.chat_message("user"):
                st.write(prompt)
                st.session_state.messages.append(HumanMessage(content=prompt))

            # Get context from vector store
            docs = self.vector_store.similarity_search(prompt)
            context = "\n\n".join(doc.page_content for doc in docs)

            # Generate response
            response = self.llm(self.prompt.format(
                input=prompt,
                context=context,
                chat_history=st.session_state.messages
            ))

            with st.chat_message("assistant"):
                st.write(response.content)
                st.session_state.messages.append(AIMessage(content=response.content))