ash2203 commited on
Commit
0590ae6
·
verified ·
1 Parent(s): 8e6c014

Update app.py

Browse files

changed pinecone to chromadb to avoid limits

Files changed (1) hide show
  1. app.py +109 -103
app.py CHANGED
@@ -12,10 +12,7 @@ from langchain_openai import OpenAIEmbeddings
12
  from langchain_core.runnables import RunnablePassthrough
13
  from langchain_community.retrievers import BM25Retriever
14
  from langchain.retrievers import EnsembleRetriever
15
- from langchain_community.retrievers import PineconeHybridSearchRetriever
16
- from langchain_pinecone import PineconeVectorStore
17
- from pinecone import Pinecone, ServerlessSpec
18
- from pinecone import PineconeApiException, NotFoundException
19
  import shutil
20
  import uuid
21
 
@@ -23,7 +20,7 @@ from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
  # Set page configuration
26
- st.set_page_config(page_title="Document Analyzer", layout="wide", )
27
 
28
  st.title("📚 Document Analyzer")
29
 
@@ -41,46 +38,65 @@ if 'initialized' not in st.session_state:
41
  st.session_state.initialized = False
42
  if 'processing' not in st.session_state:
43
  st.session_state.processing = False
44
- if 'last_processed_files' not in st.session_state:
45
- st.session_state.last_processed_files = set()
46
- if 'chat_history' not in st.session_state:
47
- st.session_state.chat_history = []
48
  if 'chat_enabled' not in st.session_state:
49
  st.session_state.chat_enabled = False
50
  if 'session_id' not in st.session_state:
51
  # Generate a unique session ID using UUID
52
  st.session_state.session_id = str(uuid.uuid4())[:8]
53
 
54
- def get_session_index_name():
55
- """Get unique index name for current session"""
56
- base_name = "docdb" # Using a short base name to leave room for the unique identifier
57
- unique_id = st.session_state.session_id
58
- # Combine base name with unique ID, ensuring total length is under 45 chars
59
- return f"{base_name}-{unique_id}" # This will be like "docdb-12345678"
60
 
61
- def cleanup_pinecone_index():
62
- """Clean up existing Pinecone index for the current session"""
63
  try:
64
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
65
- index_name = get_session_index_name()
66
- if index_name in pc.list_indexes().names():
67
- pc.delete_index(index_name)
68
  except Exception as e:
69
- print(f"Error cleaning up index: {str(e)}") # Log error internally
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  if not st.session_state.initialized:
 
 
 
72
  # Clear everything only on first run or page refresh
73
  if os.path.exists("data"):
74
  shutil.rmtree("data")
75
  os.makedirs("data")
 
 
 
 
 
76
  st.session_state.uploaded_files = {}
77
  st.session_state.previous_files = set()
78
- st.session_state.vectorstore = None
79
- st.session_state.retriever = None
80
  st.session_state.initialized = True
81
-
82
- # Clean up any existing index
83
- cleanup_pinecone_index()
84
 
85
  def save_uploaded_file(uploaded_file):
86
  """Save uploaded file to the data directory"""
@@ -105,15 +121,15 @@ def save_uploaded_file(uploaded_file):
105
  return None
106
 
107
  def process_documents(uploaded_files_dict):
108
- """Process documents and store in Pinecone"""
109
  warning_placeholder = st.empty()
110
  warning_placeholder.warning("⚠️ Document processing in progress. Please wait before adding or removing files.")
111
  success_placeholder = st.empty()
112
 
113
  try:
114
  with st.spinner('Processing documents...'):
115
- # Clean up existing index before processing
116
- cleanup_pinecone_index()
117
 
118
  docs = []
119
  # Process each file
@@ -152,27 +168,14 @@ def process_documents(uploaded_files_dict):
152
  # Initialize embeddings
153
  embed_func = OpenAIEmbeddings(model='text-embedding-3-small', dimensions=512)
154
 
155
- # Initialize Pinecone
156
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
157
- index_name = get_session_index_name()
158
-
159
  try:
160
- pc.create_index(
161
- name=index_name,
162
- dimension=512,
163
- metric='cosine',
164
- spec=ServerlessSpec(cloud='aws', region='us-east-1')
165
- )
166
-
167
- # Wait for index to be ready
168
- while not pc.describe_index(index_name).status['ready']:
169
- time.sleep(1)
170
-
171
- pc_index = pc.Index(index_name)
172
-
173
  # Create vectorstore and add documents
174
- vectorstore = PineconeVectorStore(index=pc_index, embedding=embed_func)
175
- vectorstore.add_documents(documents=chunks)
 
 
 
 
176
 
177
  st.session_state.chat_enabled = True
178
  success_placeholder.success('Documents processed successfully!')
@@ -180,8 +183,8 @@ def process_documents(uploaded_files_dict):
180
  success_placeholder.empty() # Clear the success message
181
  return True
182
 
183
- except PineconeApiException as e:
184
- print(f"Pinecone API error: {str(e)}") # Log error internally
185
  st.warning("Unable to process documents at the moment. Please try again.")
186
  st.session_state.chat_enabled = False
187
  return False
@@ -195,38 +198,9 @@ def process_documents(uploaded_files_dict):
195
  warning_placeholder.empty()
196
 
197
  def doc2str(docs):
198
- return "\n\n".join(doc for doc in docs)
199
-
200
- def format_reranked_docs(pc, retriever, question):
201
- """Rerank documents using Pinecone's reranking model"""
202
- # Get relevant docs and ensure they're not empty
203
- relevant_docs = [doc.page_content for doc in retriever.invoke(question) if doc.page_content.strip()]
204
-
205
- if not relevant_docs:
206
- return "I don't have enough context to answer this question."
207
-
208
- try:
209
- # Format documents for reranking
210
- formatted_docs = [{"text": doc} for doc in relevant_docs]
211
-
212
- reranked_docs = pc.inference.rerank(
213
- model="pinecone-rerank-v0",
214
- query=question,
215
- documents=formatted_docs,
216
- top_n=3,
217
- return_documents=True
218
- )
219
-
220
- # Extract text from reranked documents
221
- final_docs = [d.document["text"] for d in reranked_docs.data]
222
- context = "\n\n".join(final_docs)
223
- return context
224
- except Exception as e:
225
- print(f"Error during reranking: {str(e)}") # Log error internally
226
- # Fallback to using retrieved docs without reranking
227
- return "\n\n".join(relevant_docs[:3])
228
 
229
- def run_chatbot(retriever, pc, llm):
230
  """Run the chatbot with the given components"""
231
  # Initialize chat prompt
232
  prompt = ChatPromptTemplate.from_template("""
@@ -245,9 +219,9 @@ def run_chatbot(retriever, pc, llm):
245
 
246
  {question}""")
247
 
248
- # Create the QA chain with reranking
249
  qa_chain = (
250
- RunnablePassthrough.assign(context=lambda input: format_reranked_docs(pc, retriever, input["question"]))
251
  | prompt
252
  | llm
253
  | StrOutputParser()
@@ -305,8 +279,14 @@ def process_and_chat():
305
  # Check for removed files
306
  files_to_remove = set(st.session_state.uploaded_files.keys()) - current_uploaded_filenames
307
  if files_to_remove:
308
- # Clean up index when files are removed
309
- cleanup_pinecone_index()
 
 
 
 
 
 
310
  for file_name in files_to_remove:
311
  # Remove file from session state
312
  if file_name in st.session_state.uploaded_files:
@@ -323,6 +303,12 @@ def process_and_chat():
323
  for file in uploaded_files:
324
  # Only process files that haven't been uploaded before
325
  if file.name not in st.session_state.uploaded_files:
 
 
 
 
 
 
326
  file_path = save_uploaded_file(file)
327
  if file_path: # Only add to session state if file was saved successfully
328
  st.session_state.uploaded_files[file.name] = {
@@ -336,45 +322,66 @@ def process_and_chat():
336
 
337
  # If files have changed (added or removed), reset chat and process documents
338
  if current_files != st.session_state.previous_files or files_to_remove:
339
- # Reset chat state
340
- st.session_state.chat_enabled = False
341
- if "messages" in st.session_state:
342
- del st.session_state.messages
343
-
344
  st.session_state.previous_files = current_files
345
 
346
  if current_files:
347
- st.session_state.processing = True
348
  # Process documents and enable chat if successful
349
  if process_documents(st.session_state.uploaded_files):
350
  st.session_state.chat_enabled = True
351
  st.session_state.processing = False
352
  else:
353
  st.warning('Please upload a file to continue')
 
354
 
355
  # If files exist and chat is enabled, show chat interface
356
  if current_files and st.session_state.chat_enabled:
357
  try:
358
  # Initialize components for chat
359
  llm = ChatGroq(temperature=0, model_name="llama-3.3-70b-versatile", groq_api_key=os.getenv("GROQ_API_KEY"), max_tokens=8000)
360
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
361
- index_name = get_session_index_name()
362
- pc_index = pc.Index(index_name)
363
 
364
  # Create vectorstore
365
  embed_func = OpenAIEmbeddings(model='text-embedding-3-small', dimensions=512)
366
- vectorstore = PineconeVectorStore(index=pc_index, embedding=embed_func)
 
 
 
 
367
 
368
  # Create retrievers
369
  vectorstore_retriever = vectorstore.as_retriever(
370
- search_type="similarity_score_threshold",
371
- search_kwargs={"k": 5, "score_threshold": 0.6},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  )
373
 
374
  # Run chatbot with fresh components
375
- run_chatbot(vectorstore_retriever, pc, llm)
376
- except NotFoundException:
377
- st.error("Vector database not found. Please try uploading your documents again.")
 
378
  st.session_state.chat_enabled = False
379
  # Clear the previous files to force reprocessing
380
  st.session_state.previous_files = set()
@@ -382,5 +389,4 @@ def process_and_chat():
382
  del st.session_state.messages
383
 
384
  # Call the main function
385
- process_and_chat()
386
-
 
12
  from langchain_core.runnables import RunnablePassthrough
13
  from langchain_community.retrievers import BM25Retriever
14
  from langchain.retrievers import EnsembleRetriever
15
+ from langchain_chroma import Chroma
 
 
 
16
  import shutil
17
  import uuid
18
 
 
20
  load_dotenv()
21
 
22
  # Set page configuration
23
+ st.set_page_config(page_title="Document Analyzer", layout="wide")
24
 
25
  st.title("📚 Document Analyzer")
26
 
 
38
  st.session_state.initialized = False
39
  if 'processing' not in st.session_state:
40
  st.session_state.processing = False
 
 
 
 
41
  if 'chat_enabled' not in st.session_state:
42
  st.session_state.chat_enabled = False
43
  if 'session_id' not in st.session_state:
44
  # Generate a unique session ID using UUID
45
  st.session_state.session_id = str(uuid.uuid4())[:8]
46
 
47
+ def get_chroma_directory():
48
+ """Get unique directory name for current session's ChromaDB"""
49
+ base_dir = "vectorstores"
50
+ if not os.path.exists(base_dir):
51
+ os.makedirs(base_dir)
52
+ return os.path.join(base_dir, f"chroma_db_{st.session_state.session_id}")
53
 
54
+ def cleanup_chroma_db():
55
+ """Clean up existing ChromaDB for the current session"""
56
  try:
57
+ chroma_dir = get_chroma_directory()
58
+ if os.path.exists(chroma_dir):
59
+ shutil.rmtree(chroma_dir)
 
60
  except Exception as e:
61
+ print(f"Error cleaning up ChromaDB: {str(e)}") # Log error internally
62
+
63
+ def cleanup_old_vectorstores():
64
+ """Clean up vector stores that are older than 24 hours"""
65
+ try:
66
+ base_dir = "vectorstores"
67
+ if not os.path.exists(base_dir):
68
+ return
69
+
70
+ current_time = time.time()
71
+ one_day_in_seconds = 24 * 60 * 60
72
+
73
+ # Get all directories in vectorstores
74
+ for dir_name in os.listdir(base_dir):
75
+ dir_path = os.path.join(base_dir, dir_name)
76
+ if os.path.isdir(dir_path):
77
+ # Get directory's last modification time
78
+ last_modified = os.path.getmtime(dir_path)
79
+ if current_time - last_modified > one_day_in_seconds:
80
+ shutil.rmtree(dir_path)
81
+ except Exception as e:
82
+ print(f"Error cleaning up old vector stores: {str(e)}") # Log error internally
83
 
84
  if not st.session_state.initialized:
85
+ # Clean up old vector stores first
86
+ cleanup_old_vectorstores()
87
+
88
  # Clear everything only on first run or page refresh
89
  if os.path.exists("data"):
90
  shutil.rmtree("data")
91
  os.makedirs("data")
92
+
93
+ # Clear vectorstores directory for current session
94
+ if os.path.exists("vectorstores"):
95
+ os.makedirs("vectorstores", exist_ok=True)
96
+
97
  st.session_state.uploaded_files = {}
98
  st.session_state.previous_files = set()
 
 
99
  st.session_state.initialized = True
 
 
 
100
 
101
  def save_uploaded_file(uploaded_file):
102
  """Save uploaded file to the data directory"""
 
121
  return None
122
 
123
  def process_documents(uploaded_files_dict):
124
+ """Process documents and store in ChromaDB"""
125
  warning_placeholder = st.empty()
126
  warning_placeholder.warning("⚠️ Document processing in progress. Please wait before adding or removing files.")
127
  success_placeholder = st.empty()
128
 
129
  try:
130
  with st.spinner('Processing documents...'):
131
+ # Clean up existing ChromaDB before processing
132
+ cleanup_chroma_db()
133
 
134
  docs = []
135
  # Process each file
 
168
  # Initialize embeddings
169
  embed_func = OpenAIEmbeddings(model='text-embedding-3-small', dimensions=512)
170
 
 
 
 
 
171
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  # Create vectorstore and add documents
173
+ vectorstore = Chroma.from_documents(
174
+ collection_name="collection",
175
+ documents=chunks,
176
+ embedding=embed_func,
177
+ persist_directory=get_chroma_directory()
178
+ )
179
 
180
  st.session_state.chat_enabled = True
181
  success_placeholder.success('Documents processed successfully!')
 
183
  success_placeholder.empty() # Clear the success message
184
  return True
185
 
186
+ except Exception as e:
187
+ print(f"ChromaDB error: {str(e)}") # Log error internally
188
  st.warning("Unable to process documents at the moment. Please try again.")
189
  st.session_state.chat_enabled = False
190
  return False
 
198
  warning_placeholder.empty()
199
 
200
  def doc2str(docs):
201
+ return "\n\n".join(doc.page_content for doc in docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def run_chatbot(retriever, llm):
204
  """Run the chatbot with the given components"""
205
  # Initialize chat prompt
206
  prompt = ChatPromptTemplate.from_template("""
 
219
 
220
  {question}""")
221
 
222
+ # Create the QA chain
223
  qa_chain = (
224
+ RunnablePassthrough.assign(context=lambda input: doc2str(retriever.invoke(input["question"])))
225
  | prompt
226
  | llm
227
  | StrOutputParser()
 
279
  # Check for removed files
280
  files_to_remove = set(st.session_state.uploaded_files.keys()) - current_uploaded_filenames
281
  if files_to_remove:
282
+ # Set processing state immediately
283
+ st.session_state.processing = True
284
+ st.session_state.chat_enabled = False
285
+ if "messages" in st.session_state:
286
+ del st.session_state.messages
287
+
288
+ # Clean up ChromaDB when files are removed
289
+ cleanup_chroma_db()
290
  for file_name in files_to_remove:
291
  # Remove file from session state
292
  if file_name in st.session_state.uploaded_files:
 
303
  for file in uploaded_files:
304
  # Only process files that haven't been uploaded before
305
  if file.name not in st.session_state.uploaded_files:
306
+ # Set processing state immediately when new file is detected
307
+ st.session_state.processing = True
308
+ st.session_state.chat_enabled = False
309
+ if "messages" in st.session_state:
310
+ del st.session_state.messages
311
+
312
  file_path = save_uploaded_file(file)
313
  if file_path: # Only add to session state if file was saved successfully
314
  st.session_state.uploaded_files[file.name] = {
 
322
 
323
  # If files have changed (added or removed), reset chat and process documents
324
  if current_files != st.session_state.previous_files or files_to_remove:
 
 
 
 
 
325
  st.session_state.previous_files = current_files
326
 
327
  if current_files:
 
328
  # Process documents and enable chat if successful
329
  if process_documents(st.session_state.uploaded_files):
330
  st.session_state.chat_enabled = True
331
  st.session_state.processing = False
332
  else:
333
  st.warning('Please upload a file to continue')
334
+ st.session_state.processing = False
335
 
336
  # If files exist and chat is enabled, show chat interface
337
  if current_files and st.session_state.chat_enabled:
338
  try:
339
  # Initialize components for chat
340
  llm = ChatGroq(temperature=0, model_name="llama-3.3-70b-versatile", groq_api_key=os.getenv("GROQ_API_KEY"), max_tokens=8000)
 
 
 
341
 
342
  # Create vectorstore
343
  embed_func = OpenAIEmbeddings(model='text-embedding-3-small', dimensions=512)
344
+ vectorstore = Chroma(
345
+ collection_name="collection",
346
+ embedding_function=embed_func,
347
+ persist_directory=get_chroma_directory()
348
+ )
349
 
350
  # Create retrievers
351
  vectorstore_retriever = vectorstore.as_retriever(
352
+ search_kwargs={"k": 3}
353
+ )
354
+
355
+ # Create keyword retriever
356
+ text_splitter = RecursiveCharacterTextSplitter(
357
+ chunk_size=1500,
358
+ chunk_overlap=400,
359
+ length_function=len
360
+ )
361
+ docs = []
362
+ for file_info in st.session_state.uploaded_files.values():
363
+ if file_info["path"].endswith(".pdf"):
364
+ docs.extend(PyMuPDFLoader(file_info["path"]).load())
365
+ elif file_info["path"].endswith(".txt"):
366
+ docs.extend(TextLoader(file_info["path"]).load())
367
+ elif file_info["path"].endswith(".docx"):
368
+ docs.extend(Docx2txtLoader(file_info["path"]).load())
369
+
370
+ chunks = text_splitter.split_documents(docs)
371
+ keyword_retriever = BM25Retriever.from_documents(chunks)
372
+ keyword_retriever.k = 3
373
+
374
+ # Combine retrievers
375
+ ensemble_retriever = EnsembleRetriever(
376
+ retrievers=[vectorstore_retriever, keyword_retriever],
377
+ weights=[0.5, 0.5]
378
  )
379
 
380
  # Run chatbot with fresh components
381
+ run_chatbot(ensemble_retriever, llm)
382
+ except Exception as e:
383
+ print(f"Chat interface error: {str(e)}") # Log error internally
384
+ st.warning("Please try uploading your documents again.")
385
  st.session_state.chat_enabled = False
386
  # Clear the previous files to force reprocessing
387
  st.session_state.previous_files = set()
 
389
  del st.session_state.messages
390
 
391
  # Call the main function
392
+ process_and_chat()