cryogenic22 commited on
Commit
62f0b1d
·
verified ·
1 Parent(s): 9d3f1ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -155
app.py CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
2
 
3
  import streamlit as st
4
  import os
 
 
5
  from pathlib import Path
6
- from typing import Dict, List, Any, Optional
7
  from datetime import datetime
8
  from dataclasses import dataclass
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -12,6 +14,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
12
  from langchain_community.chat_models import ChatOpenAI
13
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
14
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
 
15
  from utils.database import (
16
  create_connection,
17
  create_tables,
@@ -22,13 +25,15 @@ from utils.database import (
22
  initialize_qa_system,
23
  get_embeddings_model,
24
  verify_database_tables,
25
- handle_document_upload,
26
  create_collection,
27
  add_document_to_collection,
28
  get_recent_documents,
29
  save_chat_message,
30
  create_new_chat,
31
- get_chat_messages)
 
 
 
32
 
33
 
34
  @dataclass
@@ -44,84 +49,6 @@ class SessionState:
44
  reinitialize_chat: bool = False
45
 
46
 
47
- def initialize_session_state():
48
- """Initialize session state with default values."""
49
- defaults = SessionState()
50
- if 'initialized' not in st.session_state:
51
- # Setup data paths
52
- data_path = Path('/data' if os.path.exists('/data') else 'data')
53
- vector_store_path = data_path / 'vector_stores'
54
-
55
- # Create necessary directories
56
- data_path.mkdir(parents=True, exist_ok=True)
57
- vector_store_path.mkdir(parents=True, exist_ok=True)
58
-
59
- # Initialize session state
60
- st.session_state.update({
61
- 'show_collection_dialog': defaults.show_collection_dialog,
62
- 'selected_collection': defaults.selected_collection,
63
- 'chat_ready': defaults.chat_ready,
64
- 'messages': [] if defaults.messages is None else defaults.messages,
65
- 'current_chat_id': defaults.current_chat_id,
66
- 'vector_store': defaults.vector_store,
67
- 'qa_system': defaults.qa_system,
68
- 'reinitialize_chat': defaults.reinitialize_chat,
69
- 'initialized': True,
70
- 'data_path': data_path,
71
- 'vector_store_path': vector_store_path,
72
- 'show_explorer': False
73
- })
74
-
75
-
76
- def initialize_chat_system(collection_id=None) -> bool:
77
- """Initialize chat system with persistent vector stores."""
78
- try:
79
- # Get documents based on collection or all documents
80
- documents = (get_collection_documents(st.session_state.db_conn, collection_id)
81
- if collection_id else get_all_documents(st.session_state.db_conn))
82
-
83
- if not documents:
84
- st.error("No documents found.")
85
- return False
86
-
87
- # Initialize new vector store
88
- with st.spinner("Processing documents..."):
89
- embeddings = get_embeddings_model()
90
- text_splitter = RecursiveCharacterTextSplitter(
91
- chunk_size=500,
92
- chunk_overlap=50,
93
- length_function=len,
94
- )
95
-
96
- chunks = []
97
- for doc in documents:
98
- doc_chunks = text_splitter.split_text(doc['content'])
99
- chunks.extend([{
100
- 'content': chunk,
101
- 'metadata': {
102
- 'source': doc['name'],
103
- 'document_id': doc['id'],
104
- 'collection_id': collection_id
105
- }
106
- } for chunk in doc_chunks])
107
-
108
- # Create new vector store with allow_dangerous_deserialization
109
- vector_store = FAISS.from_texts(
110
- [chunk['content'] for chunk in chunks],
111
- embeddings,
112
- [chunk['metadata'] for chunk in chunks]
113
- )
114
-
115
- st.session_state.vector_store = vector_store
116
- st.session_state.qa_system = initialize_qa_system(vector_store)
117
- st.session_state.chat_ready = True
118
- return True
119
-
120
- except Exception as e:
121
- st.error(f"Error initializing chat system: {e}")
122
- return False
123
-
124
-
125
  def display_header():
126
  """Display the application header with navigation."""
127
  # Add custom CSS for header styling
@@ -188,67 +115,195 @@ def display_header():
188
  st.divider()
189
 
190
 
191
- def display_welcome_screen():
192
- """Display enhanced welcome screen with quick actions."""
193
- col1, col2 = st.columns([3, 2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- with col1:
196
- st.header("Quick Start")
197
 
198
- # Upload new documents with collection linking
199
- st.markdown("### Upload Documents")
200
- collection_id = None
201
- collections = get_collections(st.session_state.db_conn)
 
 
202
 
203
- if collections:
204
- selected_collection = st.selectbox(
205
- "Select Collection (Optional)",
206
- options=[("None", None)] + [(c["name"], c["id"]) for c in collections],
207
- format_func=lambda x: x[0]
208
- )
209
- collection_id = selected_collection[1] if selected_collection[0] != "None" else None
 
 
 
 
 
 
 
 
 
210
 
211
- # Add new collection button
212
- if st.button("Create New Collection", use_container_width=True):
213
- st.session_state.show_collection_dialog = True
214
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- uploaded_files = st.file_uploader(
217
- "Upload Documents",
218
- type=['pdf'],
219
- accept_multiple_files=True,
220
- help="Upload PDF documents to start analyzing"
221
- )
222
 
223
- if uploaded_files:
224
- with st.spinner("Processing documents..."):
225
- if handle_document_upload(uploaded_files, collection_id=collection_id):
226
- initialize_chat_system(collection_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  st.rerun()
228
-
229
  with col2:
230
- # Display existing collections
231
- st.header("Collections")
232
  if collections:
233
- for collection in collections:
234
- with st.expander(f"📁 {collection['name']} ({collection['doc_count']} documents)"):
235
- st.write(collection.get('description', ''))
236
- if st.button("Start Chat", key=f"chat_{collection['id']}", use_container_width=True):
237
- st.session_state.selected_collection = collection
238
- if initialize_chat_system(collection['id']):
239
- st.rerun()
240
-
241
- # Show recent documents
242
- st.header("Recent Documents")
243
- recent_docs = get_recent_documents(st.session_state.db_conn, limit=5)
244
- for doc in recent_docs:
245
- with st.expander(f"📄 {doc['name']}"):
246
- st.caption(f"Upload date: {doc['upload_date']}")
247
- if doc['collections']:
248
- st.caption(f"Collections: {', '.join(doc['collections'])}")
249
- if st.button("Start Chat", key=f"doc_{doc['id']}", use_container_width=True):
250
- if initialize_chat_system():
 
251
  st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
 
254
  def display_chat_interface():
@@ -294,7 +349,6 @@ def display_chat_interface():
294
 
295
  st.rerun()
296
 
297
- # Rest of the code remains the same...
298
 
299
  def main():
300
  """Main application function with improved state management."""
@@ -304,7 +358,7 @@ def main():
304
  initial_sidebar_state="collapsed"
305
  )
306
 
307
- # Initialize session state with paths
308
  initialize_session_state()
309
 
310
  # Initialize database connection
@@ -318,29 +372,13 @@ def main():
318
  # Display header
319
  display_header()
320
 
321
- # Show collection creation dialog if triggered
322
  if st.session_state.show_collection_dialog:
323
- with st.form("create_collection"):
324
- st.subheader("Create New Collection")
325
- name = st.text_input("Collection Name")
326
- description = st.text_area("Description")
327
-
328
- if st.form_submit_button("Create", use_container_width=True):
329
- if name:
330
- if create_collection(st.session_state.db_conn, name, description):
331
- st.success(f"Collection '{name}' created successfully!")
332
- st.session_state.show_collection_dialog = False
333
- st.rerun()
334
- else:
335
- st.error("Failed to create collection.")
336
- else:
337
- st.warning("Please enter a collection name.")
338
-
339
- # Display different views based on application state
340
- if st.session_state.chat_ready:
341
  display_chat_interface()
342
- elif st.session_state.show_explorer:
343
- display_document_chunks()
344
  else:
345
  display_welcome_screen()
346
 
 
2
 
3
  import streamlit as st
4
  import os
5
+ import json
6
+ import time
7
  from pathlib import Path
8
+ from typing import Dict, List, Any, Optional, Tuple
9
  from datetime import datetime
10
  from dataclasses import dataclass
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
14
  from langchain_community.chat_models import ChatOpenAI
15
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
16
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
17
+ import tempfile
18
  from utils.database import (
19
  create_connection,
20
  create_tables,
 
25
  initialize_qa_system,
26
  get_embeddings_model,
27
  verify_database_tables,
 
28
  create_collection,
29
  add_document_to_collection,
30
  get_recent_documents,
31
  save_chat_message,
32
  create_new_chat,
33
+ get_chat_messages,
34
+ get_document_tags,
35
+ add_document_tags,
36
+ delete_collection)
37
 
38
 
39
  @dataclass
 
49
  reinitialize_chat: bool = False
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def display_header():
53
  """Display the application header with navigation."""
54
  # Add custom CSS for header styling
 
115
  st.divider()
116
 
117
 
118
+ async def process_document(file_path: str, collection_id: Optional[int] = None) -> Tuple[List, str]:
119
+ """Process a document with automatic tagging."""
120
+ try:
121
+ # Load PDF
122
+ loader = PyPDFLoader(file_path)
123
+ documents = loader.load()
124
+
125
+ # Extract full content
126
+ full_content = "\n".join(doc.page_content for doc in documents)
127
+
128
+ # Generate tags
129
+ tags = await generate_document_tags(full_content)
130
+
131
+ # Create text splitter for chunks
132
+ text_splitter = RecursiveCharacterTextSplitter(
133
+ chunk_size=1000,
134
+ chunk_overlap=200,
135
+ length_function=len,
136
+ separators=["\n\n", "\n", " ", ""]
137
+ )
138
+
139
+ # Split documents into chunks
140
+ chunks = text_splitter.split_documents(documents)
141
+
142
+ # Add metadata to chunks
143
+ for chunk in chunks:
144
+ chunk.metadata.update({
145
+ 'collection_id': collection_id,
146
+ 'tags': tags
147
+ })
148
+
149
+ return chunks, full_content, tags
150
+
151
+ except Exception as e:
152
+ st.error(f"Error processing document: {e}")
153
+ return [], "", []
154
 
 
 
155
 
156
+ async def handle_document_upload(uploaded_files: List, collection_id: Optional[int] = None) -> bool:
157
+ """Handle document upload with progress tracking and auto-tagging."""
158
+ try:
159
+ progress_container = st.empty()
160
+ status_container = st.empty()
161
+ progress_bar = progress_container.progress(0)
162
 
163
+ # Initialize embeddings
164
+ embeddings = get_embeddings_model()
165
+ if not embeddings:
166
+ status_container.error("Failed to initialize embeddings model")
167
+ return False
168
+
169
+ progress_bar.progress(10)
170
+ all_chunks = []
171
+ documents = []
172
+
173
+ # Process each document
174
+ progress_per_file = 70 / len(uploaded_files)
175
+ current_progress = 10
176
+
177
+ for idx, uploaded_file in enumerate(uploaded_files):
178
+ status_container.info(f"Processing {uploaded_file.name}...")
179
 
180
+ # Create temporary file
181
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
182
+ tmp_file.write(uploaded_file.getvalue())
183
+ tmp_file.flush()
184
+
185
+ # Process document with tagging
186
+ chunks, content, tags = await process_document(tmp_file.name, collection_id)
187
+
188
+ # Store in database
189
+ doc_id = insert_document(st.session_state.db_conn, uploaded_file.name, content)
190
+ if not doc_id:
191
+ status_container.error(f"Failed to store document: {uploaded_file.name}")
192
+ continue
193
+
194
+ # Add tags
195
+ if tags:
196
+ add_document_tags(st.session_state.db_conn, doc_id, tags)
197
+
198
+ # Add to collection if specified
199
+ if collection_id:
200
+ add_document_to_collection(st.session_state.db_conn, doc_id, collection_id)
201
+
202
+ all_chunks.extend(chunks)
203
+ documents.append(content)
204
+
205
+ current_progress += progress_per_file
206
+ progress_bar.progress(int(current_progress))
207
+
208
+ # Initialize vector store
209
+ status_container.info("Creating document index...")
210
+ vector_store = FAISS.from_documents(all_chunks, embeddings)
211
+
212
+ st.session_state.vector_store = vector_store
213
+ st.session_state.qa_system = initialize_qa_system(vector_store)
214
+ st.session_state.chat_ready = True
215
+
216
+ progress_bar.progress(100)
217
+ status_container.success("Documents processed successfully!")
218
+
219
+ # Clean up progress display
220
+ time.sleep(2)
221
+ progress_container.empty()
222
+ status_container.empty()
223
+
224
+ return True
225
+
226
+ except Exception as e:
227
+ st.error(f"Error uploading documents: {e}")
228
+ return False
229
 
 
 
 
 
 
 
230
 
231
+ def display_collection_management():
232
+ """Display collection management interface."""
233
+ st.header("📁 Collection Management")
234
+
235
+ # Get existing collections
236
+ collections = get_collections(st.session_state.db_conn)
237
+
238
+ col1, col2 = st.columns([2, 1])
239
+
240
+ with col1:
241
+ # Create new collection form
242
+ with st.form("create_collection_form"):
243
+ st.subheader("Create New Collection")
244
+ name = st.text_input("Collection Name")
245
+ description = st.text_area("Description")
246
+ submit = st.form_submit_button("Create Collection", use_container_width=True)
247
+
248
+ if submit and name:
249
+ collection_id = create_collection(st.session_state.db_conn, name, description)
250
+ if collection_id:
251
+ st.success(f"Collection '{name}' created successfully!")
252
  st.rerun()
253
+
254
  with col2:
255
+ # Upload to existing collection
 
256
  if collections:
257
+ st.subheader("Add to Collection")
258
+ selected_collection = st.selectbox(
259
+ "Select Collection",
260
+ options=[(c["name"], c["id"]) for c in collections],
261
+ format_func=lambda x: x[0]
262
+ )
263
+
264
+ uploaded_files = st.file_uploader(
265
+ "Upload Documents",
266
+ type=['pdf'],
267
+ accept_multiple_files=True,
268
+ key="collection_uploader"
269
+ )
270
+
271
+ if uploaded_files:
272
+ collection_id = selected_collection[1]
273
+ with st.spinner("Processing documents..."):
274
+ if await handle_document_upload(uploaded_files, collection_id=collection_id):
275
+ st.success("Documents added to collection successfully!")
276
  st.rerun()
277
+
278
+ # Display existing collections
279
+ if collections:
280
+ st.markdown("### Existing Collections")
281
+ for collection in collections:
282
+ with st.expander(f"📁 {collection['name']} ({collection['doc_count']} documents)"):
283
+ col1, col2 = st.columns([3, 1])
284
+
285
+ with col1:
286
+ st.write(f"**Description:** {collection.get('description', 'No description')}")
287
+ st.write(f"**Created:** {collection['created_at']}")
288
+
289
+ # Display documents in collection
290
+ docs = get_collection_documents(st.session_state.db_conn, collection['id'])
291
+ if docs:
292
+ st.write("**Documents:**")
293
+ for doc in docs:
294
+ st.write(f"- {doc['name']}")
295
+ tags = get_document_tags(st.session_state.db_conn, doc['id'])
296
+ if tags:
297
+ st.write(f" Tags: {', '.join(tags)}")
298
+
299
+ with col2:
300
+ st.button("Start Chat", key=f"chat_{collection['id']}", use_container_width=True)
301
+ st.button("Upload Files", key=f"upload_{collection['id']}", use_container_width=True)
302
+ if st.button("Delete", key=f"delete_{collection['id']}", use_container_width=True):
303
+ if st.warning("Are you sure you want to delete this collection?"):
304
+ if delete_collection(st.session_state.db_conn, collection['id']):
305
+ st.success("Collection deleted successfully!")
306
+ st.rerun()
307
 
308
 
309
  def display_chat_interface():
 
349
 
350
  st.rerun()
351
 
 
352
 
353
  def main():
354
  """Main application function with improved state management."""
 
358
  initial_sidebar_state="collapsed"
359
  )
360
 
361
+ # Initialize session state
362
  initialize_session_state()
363
 
364
  # Initialize database connection
 
372
  # Display header
373
  display_header()
374
 
375
+ # Show collection management if triggered
376
  if st.session_state.show_collection_dialog:
377
+ display_collection_management()
378
+ # Display chat interface if ready
379
+ elif st.session_state.chat_ready:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  display_chat_interface()
381
+ # Show welcome screen
 
382
  else:
383
  display_welcome_screen()
384