Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -178,7 +178,7 @@ def get_embeddings_model():
|
|
| 178 |
|
| 179 |
# QA System Initialization (qa_system.py)
|
| 180 |
@st.cache_resource
|
| 181 |
-
def initialize_qa_system(
|
| 182 |
try:
|
| 183 |
qa_pipeline = RetrievalQA.from_chain_type(
|
| 184 |
llm=pipeline(
|
|
@@ -186,7 +186,7 @@ def initialize_qa_system(vector_store):
|
|
| 186 |
model="gpt-4",
|
| 187 |
api_key=st.secrets["OPENAI_API_KEY"],
|
| 188 |
prompt_template="Extract the specific details relevant to the query accurately from the document without adding additional information that is not present in the text. Provide concise, clear responses that stay within the boundaries of the document's content."),
|
| 189 |
-
retriever=
|
| 190 |
)
|
| 191 |
return qa_pipeline
|
| 192 |
except Exception as e:
|
|
@@ -205,6 +205,11 @@ def main():
|
|
| 205 |
create_tables(conn)
|
| 206 |
else:
|
| 207 |
st.error("Error! Cannot create the database connection.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
# Dashboard Overview Tab
|
| 210 |
st.sidebar.markdown("<h2 style='color: #1E3A8A;'>Dashboard Overview</h2>", unsafe_allow_html=True)
|
|
@@ -234,78 +239,39 @@ def main():
|
|
| 234 |
documents_in_db = cursor.fetchall()
|
| 235 |
|
| 236 |
if documents_in_db:
|
| 237 |
-
|
|
|
|
| 238 |
"Select documents to include in the search:",
|
| 239 |
options=[doc[0] for doc in documents_in_db],
|
| 240 |
-
format_func=lambda doc_id: next(doc[1] for doc in documents_in_db if doc[0] == doc_id)
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
if selected_doc_ids:
|
| 244 |
selected_documents = []
|
|
|
|
| 245 |
for doc_id in selected_doc_ids:
|
| 246 |
-
cursor.execute("SELECT content FROM documents WHERE id = ?", (doc_id,))
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
|
| 249 |
# Initialize FAISS and Store Embeddings for Selected Documents
|
| 250 |
embeddings = get_embeddings_model()
|
| 251 |
if embeddings:
|
| 252 |
-
vector_store = initialize_faiss(embeddings, selected_documents,
|
| 253 |
if vector_store:
|
| 254 |
st.sidebar.success("Embeddings for selected documents stored successfully.", icon="📁")
|
| 255 |
|
| 256 |
# Initialize QA System for Selected Documents
|
| 257 |
qa_system = initialize_qa_system(vector_store)
|
| 258 |
if qa_system:
|
| 259 |
-
#
|
| 260 |
-
user_query = st.text_input("Enter your query about the RFPs:", placeholder="e.g., What are the evaluation criteria?", label_visibility='visible')
|
| 261 |
-
if user_query:
|
| 262 |
-
st.markdown("<p style='color: #1E3A8A;'>Retrieving answer...</p>", unsafe_allow_html=True)
|
| 263 |
-
try:
|
| 264 |
-
response, source_documents = qa_system.run(user_query, return_source_documents=True)
|
| 265 |
-
st.markdown("<h4 style='color: #1E3A8A;'>Answer:</h4>", unsafe_allow_html=True)
|
| 266 |
-
st.write(response)
|
| 267 |
-
|
| 268 |
-
# Store Query and Response in Database
|
| 269 |
-
with conn:
|
| 270 |
-
for doc in source_documents:
|
| 271 |
-
source_name = doc.metadata["source"]
|
| 272 |
-
document_id = conn.execute("SELECT id FROM documents WHERE name = ?", (source_name,)).fetchone()
|
| 273 |
-
if document_id:
|
| 274 |
-
conn.execute("INSERT INTO queries (query, response, document_id) VALUES (?, ?, ?)", (user_query, response, document_id[0]))
|
| 275 |
-
|
| 276 |
-
# Display Source Information
|
| 277 |
-
st.markdown("<h4 style='color: #1E3A8A;'>Sources:</h4>", unsafe_allow_html=True)
|
| 278 |
-
for doc in source_documents:
|
| 279 |
-
source_name = doc.metadata["source"]
|
| 280 |
-
matched_text = doc.page_content
|
| 281 |
-
st.write(f"- Source Document: {source_name}")
|
| 282 |
-
# Display the matching text with highlighting
|
| 283 |
-
for idx, page_content in enumerate(document_pages[document_names.index(source_name)]):
|
| 284 |
-
if matched_text in page_content:
|
| 285 |
-
highlighted_content = re.sub(re.escape(matched_text), f"<mark>{matched_text}</mark>", page_content)
|
| 286 |
-
st.write(f" - Page {idx + 1}: {highlighted_content}")
|
| 287 |
|
| 288 |
-
except Exception as e:
|
| 289 |
-
st.error(f"Error generating response: {e}")
|
| 290 |
except Exception as e:
|
| 291 |
st.error(f"Error retrieving documents from database: {e}")
|
| 292 |
|
| 293 |
-
#
|
| 294 |
-
st.markdown("<h2 style='color: #1E3A8A;'>Upload RFP Documents</h2>", unsafe_allow_html=True)
|
| 295 |
-
uploaded_documents = st.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)
|
| 296 |
-
if uploaded_documents:
|
| 297 |
-
st.write(f"Uploaded {len(uploaded_documents)} documents.")
|
| 298 |
-
all_texts, document_names, document_pages = upload_and_parse_documents(uploaded_documents)
|
| 299 |
-
if all_texts:
|
| 300 |
-
# Store Documents in Database
|
| 301 |
-
if conn is not None:
|
| 302 |
-
try:
|
| 303 |
-
with conn:
|
| 304 |
-
for doc, doc_name in zip(all_texts, document_names):
|
| 305 |
-
conn.execute("INSERT INTO documents (name, content) VALUES (?, ?)", (doc_name, doc))
|
| 306 |
-
st.success("Documents uploaded and parsed successfully.", icon="✅")
|
| 307 |
-
except Exception as e:
|
| 308 |
-
st.error(f"Error saving documents to database: {e}")
|
| 309 |
|
| 310 |
# URL Input Section
|
| 311 |
st.markdown("<h2 style='color: #1E3A8A;'>Or Provide a URL</h2>", unsafe_allow_html=True)
|
|
|
|
| 178 |
|
| 179 |
# QA System Initialization (qa_system.py)
|
| 180 |
@st.cache_resource
|
| 181 |
+
def initialize_qa_system(_vector_store): # Add a leading underscore to 'vector_store'
|
| 182 |
try:
|
| 183 |
qa_pipeline = RetrievalQA.from_chain_type(
|
| 184 |
llm=pipeline(
|
|
|
|
| 186 |
model="gpt-4",
|
| 187 |
api_key=st.secrets["OPENAI_API_KEY"],
|
| 188 |
prompt_template="Extract the specific details relevant to the query accurately from the document without adding additional information that is not present in the text. Provide concise, clear responses that stay within the boundaries of the document's content."),
|
| 189 |
+
retriever=_vector_store.as_retriever() # Use '_vector_store' here as well
|
| 190 |
)
|
| 191 |
return qa_pipeline
|
| 192 |
except Exception as e:
|
|
|
|
| 205 |
create_tables(conn)
|
| 206 |
else:
|
| 207 |
st.error("Error! Cannot create the database connection.")
|
| 208 |
+
# ... (other imports and functions) ...
|
| 209 |
+
|
| 210 |
+
# Streamlit App Interface (app.py)
|
| 211 |
+
def main():
|
| 212 |
+
# ... (other code) ...
|
| 213 |
|
| 214 |
# Dashboard Overview Tab
|
| 215 |
st.sidebar.markdown("<h2 style='color: #1E3A8A;'>Dashboard Overview</h2>", unsafe_allow_html=True)
|
|
|
|
| 239 |
documents_in_db = cursor.fetchall()
|
| 240 |
|
| 241 |
if documents_in_db:
|
| 242 |
+
# Use st.multiselect instead of st.selectbox
|
| 243 |
+
selected_doc_ids = st.sidebar.multiselect(
|
| 244 |
"Select documents to include in the search:",
|
| 245 |
options=[doc[0] for doc in documents_in_db],
|
| 246 |
+
format_func=lambda doc_id: next(doc[1] for doc in documents_in_db if doc[0] == doc_id),
|
| 247 |
+
default=[doc[0] for doc in documents_in_db] # Select all documents by default
|
| 248 |
)
|
| 249 |
|
| 250 |
if selected_doc_ids:
|
| 251 |
selected_documents = []
|
| 252 |
+
selected_doc_names = [] # Also keep track of the document names
|
| 253 |
for doc_id in selected_doc_ids:
|
| 254 |
+
cursor.execute("SELECT content, name FROM documents WHERE id = ?", (doc_id,))
|
| 255 |
+
result = cursor.fetchone()
|
| 256 |
+
selected_documents.append(result[0])
|
| 257 |
+
selected_doc_names.append(result[1]) # Add the name to the list
|
| 258 |
|
| 259 |
# Initialize FAISS and Store Embeddings for Selected Documents
|
| 260 |
embeddings = get_embeddings_model()
|
| 261 |
if embeddings:
|
| 262 |
+
vector_store = initialize_faiss(embeddings, selected_documents, selected_doc_names) # Use selected_doc_names here
|
| 263 |
if vector_store:
|
| 264 |
st.sidebar.success("Embeddings for selected documents stored successfully.", icon="📁")
|
| 265 |
|
| 266 |
# Initialize QA System for Selected Documents
|
| 267 |
qa_system = initialize_qa_system(vector_store)
|
| 268 |
if qa_system:
|
| 269 |
+
# ... (rest of your query processing code) ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
|
|
|
|
|
|
| 271 |
except Exception as e:
|
| 272 |
st.error(f"Error retrieving documents from database: {e}")
|
| 273 |
|
| 274 |
+
# ... (rest of your code) ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
# URL Input Section
|
| 277 |
st.markdown("<h2 style='color: #1E3A8A;'>Or Provide a URL</h2>", unsafe_allow_html=True)
|