Ismetdh commited on
Commit
330a910
·
verified ·
1 Parent(s): 0463792

Update app.py

Browse files

Make it able to accept multiple files. Now the user can delete the chat history and attached files as well

Files changed (1) hide show
  1. app.py +85 -38
app.py CHANGED
@@ -5,14 +5,15 @@ import os
5
  import re
6
  import numpy as np
7
  import google.generativeai as palm
8
- from sklearn.metrics.pairwise import cosine_similarity
9
  import logging
10
  import time
11
  import uuid
12
  import json
13
  import firebase_admin
14
  from firebase_admin import credentials, firestore
 
15
 
 
16
  def init_firebase():
17
  if not firebase_admin._apps:
18
  data = json.loads(os.getenv("FIREBASE_CRED"))
@@ -68,10 +69,11 @@ def update_feedback_in_firestore(session_id, conversation_id, feedback):
68
 
69
  class Config:
70
  CHUNK_WORDS = 300
71
- EMBEDDING_MODEL = "models/text-embedding-004"
72
  TOP_N = 5
73
  SYSTEM_PROMPT = (
74
- "You are a helpful assistant. Answer the question using the provided context below. Answer based on your knowledge if the context given is not enough."
 
75
  )
76
  GENERATION_MODEL = "models/gemini-1.5-flash"
77
 
@@ -179,9 +181,6 @@ def chunk_text(text: str) -> list[str]:
179
 
180
  def process_document(uploaded_file) -> None:
181
  try:
182
- keys_to_clear = ["document_text", "document_chunks", "document_embeddings"]
183
- for key in keys_to_clear:
184
- st.session_state.pop(key, None)
185
  file_text = extract_text_from_file(uploaded_file)
186
  if not file_text.strip():
187
  logger.error("Uploaded file contains no valid text.")
@@ -197,21 +196,34 @@ def process_document(uploaded_file) -> None:
197
  logger.error("All embeddings are zero vectors.")
198
  st.error("Failed to generate valid embeddings.")
199
  return
200
- st.session_state.update({
 
201
  "document_text": file_text,
202
  "document_chunks": chunks,
203
- "document_embeddings": embeddings
204
- })
205
- if not st.session_state.get("doc_processed", False):
206
- message_placeholder = st.empty()
207
- message_placeholder.success("Document processing complete! You can now start chatting.")
208
- st.session_state.doc_processed = True
 
209
  except Exception as e:
210
  logger.error("Document processing failed: %s", e)
211
  st.error(f"An error occurred while processing the document: {e}")
212
 
 
 
 
 
 
 
 
 
 
 
 
213
  def search_query(query: str) -> list[tuple[str, float]]:
214
- if "document_embeddings" not in st.session_state or len(st.session_state["document_embeddings"]) == 0:
215
  logger.error("No valid document embeddings found in session state.")
216
  st.error("No valid document embeddings found. Please upload a valid document.")
217
  return []
@@ -221,10 +233,15 @@ def search_query(query: str) -> list[tuple[str, float]]:
221
  st.error("Failed to generate a valid query embedding.")
222
  return []
223
  query_embedding = query_embedding.reshape(1, -1)
224
- doc_embeddings = np.vstack(st.session_state["document_embeddings"])
 
 
 
 
 
225
  similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
226
  top_indices = np.argsort(similarities)[-Config.TOP_N:][::-1]
227
- results = [(st.session_state["document_chunks"][i], similarities[i]) for i in top_indices]
228
  return results
229
 
230
  def generate_answer(user_query: str, context: str) -> str:
@@ -276,34 +293,64 @@ def chat_app():
276
  "user_question": user_input,
277
  "assistant_answer": answer,
278
  })
279
- if "feedback" not in st.session_state.conversations[-1]:
280
- col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10)
281
- col1.button("👍", key=f"feedback_like_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("positive",))
282
- col2.button("👎", key=f"feedback_dislike_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("negative",))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  def main():
285
  st.title("Chat with your files")
286
- st.sidebar.header("Upload Document")
287
- uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"])
288
- if uploaded_file and not st.session_state.get("doc_processed", False):
289
- process_document(uploaded_file)
290
- if "document_text" in st.session_state:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  chat_app()
292
  else:
293
- st.info("Please upload and process a document from the sidebar to start chatting.")
 
294
  st.markdown(
295
- """
296
- <div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;">
297
- Made by Danny.<br>
298
- Your questions, our response as well as your feedback will be saved for evaluation purposes.
299
- </div>
300
-
301
-
302
-
303
- """,
304
- unsafe_allow_html=True
305
- )
306
-
307
 
308
  if __name__ == "__main__":
309
  main()
 
5
  import re
6
  import numpy as np
7
  import google.generativeai as palm
 
8
  import logging
9
  import time
10
  import uuid
11
  import json
12
  import firebase_admin
13
  from firebase_admin import credentials, firestore
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
 
16
+ # Initialize Firebase
17
  def init_firebase():
18
  if not firebase_admin._apps:
19
  data = json.loads(os.getenv("FIREBASE_CRED"))
 
69
 
70
  class Config:
71
  CHUNK_WORDS = 300
72
+ EMBEDDING_MODEL = "models/gemini-embedding-exp-03-07"
73
  TOP_N = 5
74
  SYSTEM_PROMPT = (
75
+ "You are a helpful assistant. Answer the question using the provided context below. "
76
+ "Answer based on your knowledge if the context given is not enough."
77
  )
78
  GENERATION_MODEL = "models/gemini-1.5-flash"
79
 
 
181
 
182
  def process_document(uploaded_file) -> None:
183
  try:
 
 
 
184
  file_text = extract_text_from_file(uploaded_file)
185
  if not file_text.strip():
186
  logger.error("Uploaded file contains no valid text.")
 
196
  logger.error("All embeddings are zero vectors.")
197
  st.error("Failed to generate valid embeddings.")
198
  return
199
+ doc_entry = {
200
+ "file_name": uploaded_file.name,
201
  "document_text": file_text,
202
  "document_chunks": chunks,
203
+ "document_embeddings": embeddings,
204
+ }
205
+ if "documents" not in st.session_state:
206
+ st.session_state["documents"] = []
207
+ st.session_state.documents.append(doc_entry)
208
+ st.session_state.doc_processed = True
209
+ st.success(f"Document '{uploaded_file.name}' processing complete! You can now start chatting.")
210
  except Exception as e:
211
  logger.error("Document processing failed: %s", e)
212
  st.error(f"An error occurred while processing the document: {e}")
213
 
214
+ def clear_documents():
215
+ # Clear attached documents and chat messages from session state.
216
+ if "documents" in st.session_state:
217
+ del st.session_state["documents"]
218
+ if "conversations" in st.session_state:
219
+ del st.session_state["conversations"]
220
+ # Update the dynamic key for the file uploader to force reinitialization.
221
+ st.session_state["uploaded_files_key"] = str(uuid.uuid4())
222
+ st.session_state.doc_processed = False
223
+ st.success("All documents and chat messages have been cleared.")
224
+
225
  def search_query(query: str) -> list[tuple[str, float]]:
226
+ if "documents" not in st.session_state or len(st.session_state["documents"]) == 0:
227
  logger.error("No valid document embeddings found in session state.")
228
  st.error("No valid document embeddings found. Please upload a valid document.")
229
  return []
 
233
  st.error("Failed to generate a valid query embedding.")
234
  return []
235
  query_embedding = query_embedding.reshape(1, -1)
236
+ all_chunks = []
237
+ all_embeddings = []
238
+ for doc in st.session_state.documents:
239
+ all_chunks.extend(doc["document_chunks"])
240
+ all_embeddings.extend(doc["document_embeddings"])
241
+ doc_embeddings = np.vstack(all_embeddings)
242
  similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
243
  top_indices = np.argsort(similarities)[-Config.TOP_N:][::-1]
244
+ results = [(all_chunks[i], similarities[i]) for i in top_indices]
245
  return results
246
 
247
  def generate_answer(user_query: str, context: str) -> str:
 
293
  "user_question": user_input,
294
  "assistant_answer": answer,
295
  })
296
+ col1, col2 ,col3,col4,col5= st.columns(5)
297
+ col1.button("👍", key=f"feedback_like_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("positive",))
298
+ col2.button("👎", key=f"feedback_dislike_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("negative",))
299
+
300
+ # Define the clear confirmation dialog using st.dialog decorator.
301
+ @st.dialog("Confirm Clear")
302
+ def clear_confirm_dialog():
303
+ st.write("This will erase all attached documents and chat history. Do you want to proceed?")
304
+ col1, col2 = st.columns(2)
305
+ with col1:
306
+ if st.button("Confirm Clear"):
307
+ clear_documents()
308
+ st.success("Documents and chat history have been cleared.")
309
+ st.rerun()
310
+ with col2:
311
+ if st.button("Cancel"):
312
+ st.write("Operation cancelled.")
313
+ st.rerun()
314
 
315
  def main():
316
  st.title("Chat with your files")
317
+ st.sidebar.header("Upload Documents")
318
+
319
+ # Ensure a dynamic key for the file uploader exists.
320
+ if "uploaded_files_key" not in st.session_state:
321
+ st.session_state["uploaded_files_key"] = str(uuid.uuid4())
322
+
323
+ # File uploader using the dynamic key.
324
+ uploaded_files = st.sidebar.file_uploader(
325
+ "Upload (.txt, .pdf, .docx)",
326
+ type=["txt", "pdf", "docx"],
327
+ accept_multiple_files=True,
328
+ key=st.session_state["uploaded_files_key"]
329
+ )
330
+ if uploaded_files:
331
+ for file in uploaded_files:
332
+ process_document(file)
333
+
334
+ # Show the clear button if either documents, conversations exist or if files are uploaded.
335
+ if (("documents" in st.session_state and st.session_state.documents) or
336
+ ("conversations" in st.session_state and st.session_state.conversations) or
337
+ (uploaded_files is not None and len(uploaded_files) > 0)):
338
+ if st.sidebar.button("Clear Documents & Chat History"):
339
+ clear_confirm_dialog() # Call the dialog function.
340
+
341
+ if st.session_state.get("doc_processed", False):
342
  chat_app()
343
  else:
344
+ st.info("Please upload and process at least one document from the sidebar to start chatting.")
345
+
346
  st.markdown(
347
+ """
348
+ <div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;">
349
+ Your questions, our response as well as your feedback will be saved for evaluation purposes.
350
+ </div>
351
+ """,
352
+ unsafe_allow_html=True
353
+ )
 
 
 
 
 
354
 
355
  if __name__ == "__main__":
356
  main()