cryogenic22 commited on
Commit
a010f3e
·
verified ·
1 Parent(s): 75a6a91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -53
app.py CHANGED
@@ -178,7 +178,7 @@ def get_embeddings_model():
178
 
179
  # QA System Initialization (qa_system.py)
180
  @st.cache_resource
181
- def initialize_qa_system(vector_store):
182
  try:
183
  qa_pipeline = RetrievalQA.from_chain_type(
184
  llm=pipeline(
@@ -186,7 +186,7 @@ def initialize_qa_system(vector_store):
186
  model="gpt-4",
187
  api_key=st.secrets["OPENAI_API_KEY"],
188
  prompt_template="Extract the specific details relevant to the query accurately from the document without adding additional information that is not present in the text. Provide concise, clear responses that stay within the boundaries of the document's content."),
189
- retriever=vector_store.as_retriever()
190
  )
191
  return qa_pipeline
192
  except Exception as e:
@@ -205,6 +205,11 @@ def main():
205
  create_tables(conn)
206
  else:
207
  st.error("Error! Cannot create the database connection.")
 
 
 
 
 
208
 
209
  # Dashboard Overview Tab
210
  st.sidebar.markdown("<h2 style='color: #1E3A8A;'>Dashboard Overview</h2>", unsafe_allow_html=True)
@@ -234,78 +239,39 @@ def main():
234
  documents_in_db = cursor.fetchall()
235
 
236
  if documents_in_db:
237
- selected_doc_ids = st.sidebar.multiselect(
 
238
  "Select documents to include in the search:",
239
  options=[doc[0] for doc in documents_in_db],
240
- format_func=lambda doc_id: next(doc[1] for doc in documents_in_db if doc[0] == doc_id)
 
241
  )
242
 
243
  if selected_doc_ids:
244
  selected_documents = []
 
245
  for doc_id in selected_doc_ids:
246
- cursor.execute("SELECT content FROM documents WHERE id = ?", (doc_id,))
247
- selected_documents.append(cursor.fetchone()[0])
 
 
248
 
249
  # Initialize FAISS and Store Embeddings for Selected Documents
250
  embeddings = get_embeddings_model()
251
  if embeddings:
252
- vector_store = initialize_faiss(embeddings, selected_documents, [doc[1] for doc in documents_in_db if doc[0] in selected_doc_ids])
253
  if vector_store:
254
  st.sidebar.success("Embeddings for selected documents stored successfully.", icon="📁")
255
 
256
  # Initialize QA System for Selected Documents
257
  qa_system = initialize_qa_system(vector_store)
258
  if qa_system:
259
- # Query Input
260
- user_query = st.text_input("Enter your query about the RFPs:", placeholder="e.g., What are the evaluation criteria?", label_visibility='visible')
261
- if user_query:
262
- st.markdown("<p style='color: #1E3A8A;'>Retrieving answer...</p>", unsafe_allow_html=True)
263
- try:
264
- response, source_documents = qa_system.run(user_query, return_source_documents=True)
265
- st.markdown("<h4 style='color: #1E3A8A;'>Answer:</h4>", unsafe_allow_html=True)
266
- st.write(response)
267
-
268
- # Store Query and Response in Database
269
- with conn:
270
- for doc in source_documents:
271
- source_name = doc.metadata["source"]
272
- document_id = conn.execute("SELECT id FROM documents WHERE name = ?", (source_name,)).fetchone()
273
- if document_id:
274
- conn.execute("INSERT INTO queries (query, response, document_id) VALUES (?, ?, ?)", (user_query, response, document_id[0]))
275
-
276
- # Display Source Information
277
- st.markdown("<h4 style='color: #1E3A8A;'>Sources:</h4>", unsafe_allow_html=True)
278
- for doc in source_documents:
279
- source_name = doc.metadata["source"]
280
- matched_text = doc.page_content
281
- st.write(f"- Source Document: {source_name}")
282
- # Display the matching text with highlighting
283
- for idx, page_content in enumerate(document_pages[document_names.index(source_name)]):
284
- if matched_text in page_content:
285
- highlighted_content = re.sub(re.escape(matched_text), f"<mark>{matched_text}</mark>", page_content)
286
- st.write(f" - Page {idx + 1}: {highlighted_content}")
287
 
288
- except Exception as e:
289
- st.error(f"Error generating response: {e}")
290
  except Exception as e:
291
  st.error(f"Error retrieving documents from database: {e}")
292
 
293
- # Document Upload Section
294
- st.markdown("<h2 style='color: #1E3A8A;'>Upload RFP Documents</h2>", unsafe_allow_html=True)
295
- uploaded_documents = st.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)
296
- if uploaded_documents:
297
- st.write(f"Uploaded {len(uploaded_documents)} documents.")
298
- all_texts, document_names, document_pages = upload_and_parse_documents(uploaded_documents)
299
- if all_texts:
300
- # Store Documents in Database
301
- if conn is not None:
302
- try:
303
- with conn:
304
- for doc, doc_name in zip(all_texts, document_names):
305
- conn.execute("INSERT INTO documents (name, content) VALUES (?, ?)", (doc_name, doc))
306
- st.success("Documents uploaded and parsed successfully.", icon="✅")
307
- except Exception as e:
308
- st.error(f"Error saving documents to database: {e}")
309
 
310
  # URL Input Section
311
  st.markdown("<h2 style='color: #1E3A8A;'>Or Provide a URL</h2>", unsafe_allow_html=True)
 
178
 
179
  # QA System Initialization (qa_system.py)
180
  @st.cache_resource
181
+ def initialize_qa_system(_vector_store): # Add a leading underscore to 'vector_store'
182
  try:
183
  qa_pipeline = RetrievalQA.from_chain_type(
184
  llm=pipeline(
 
186
  model="gpt-4",
187
  api_key=st.secrets["OPENAI_API_KEY"],
188
  prompt_template="Extract the specific details relevant to the query accurately from the document without adding additional information that is not present in the text. Provide concise, clear responses that stay within the boundaries of the document's content."),
189
+ retriever=_vector_store.as_retriever() # Use '_vector_store' here as well
190
  )
191
  return qa_pipeline
192
  except Exception as e:
 
205
  create_tables(conn)
206
  else:
207
  st.error("Error! Cannot create the database connection.")
208
+ # ... (other imports and functions) ...
209
+
210
+ # Streamlit App Interface (app.py)
211
+ def main():
212
+ # ... (other code) ...
213
 
214
  # Dashboard Overview Tab
215
  st.sidebar.markdown("<h2 style='color: #1E3A8A;'>Dashboard Overview</h2>", unsafe_allow_html=True)
 
239
  documents_in_db = cursor.fetchall()
240
 
241
  if documents_in_db:
242
+ # Use st.multiselect instead of st.selectbox
243
+ selected_doc_ids = st.sidebar.multiselect(
244
  "Select documents to include in the search:",
245
  options=[doc[0] for doc in documents_in_db],
246
+ format_func=lambda doc_id: next(doc[1] for doc in documents_in_db if doc[0] == doc_id),
247
+ default=[doc[0] for doc in documents_in_db] # Select all documents by default
248
  )
249
 
250
  if selected_doc_ids:
251
  selected_documents = []
252
+ selected_doc_names = [] # Also keep track of the document names
253
  for doc_id in selected_doc_ids:
254
+ cursor.execute("SELECT content, name FROM documents WHERE id = ?", (doc_id,))
255
+ result = cursor.fetchone()
256
+ selected_documents.append(result[0])
257
+ selected_doc_names.append(result[1]) # Add the name to the list
258
 
259
  # Initialize FAISS and Store Embeddings for Selected Documents
260
  embeddings = get_embeddings_model()
261
  if embeddings:
262
+ vector_store = initialize_faiss(embeddings, selected_documents, selected_doc_names) # Use selected_doc_names here
263
  if vector_store:
264
  st.sidebar.success("Embeddings for selected documents stored successfully.", icon="📁")
265
 
266
  # Initialize QA System for Selected Documents
267
  qa_system = initialize_qa_system(vector_store)
268
  if qa_system:
269
+ # ... (rest of your query processing code) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
 
 
271
  except Exception as e:
272
  st.error(f"Error retrieving documents from database: {e}")
273
 
274
+ # ... (rest of your code) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  # URL Input Section
277
  st.markdown("<h2 style='color: #1E3A8A;'>Or Provide a URL</h2>", unsafe_allow_html=True)