Ismetdh commited on
Commit
f5c61b7
·
verified ·
1 Parent(s): 2ae4087

Update app.py

Browse files

Remove comments

Files changed (1) hide show
  1. app.py +11 -103
app.py CHANGED
@@ -1,39 +1,28 @@
1
  import streamlit as st
2
- import pdfplumber # For PDF extraction
3
- import docx # For DOCX extraction
4
  import os
5
  import re
6
  import numpy as np
7
- import google.generativeai as palm # For embedding generation
8
  from sklearn.metrics.pairwise import cosine_similarity
9
  import logging
10
  import time
11
  import uuid
12
  import json
13
- # Firebase integration imports
14
  import firebase_admin
15
  from firebase_admin import credentials, firestore
16
 
17
- # -------------------------
18
- # Firebase Initialization using Firestore
19
- # -------------------------
20
  def init_firebase():
21
  if not firebase_admin._apps:
22
- # Replace with the path to your Firebase service account key JSON file.
23
  data = json.loads(os.getenv("FIREBASE_CRED"))
24
-
25
  cred = credentials.Certificate(data)
26
- # No databaseURL is provided because we're using Firestore.
27
  firebase_admin.initialize_app(cred)
28
 
29
  init_firebase()
30
- # Create a Firestore client
31
  fs_client = firestore.client()
32
 
33
  def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None):
34
- """
35
- Save a complete conversation (user question + assistant answer + feedback) as a single document.
36
- """
37
  conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations")
38
  data = {
39
  "user_question": user_question,
@@ -41,32 +30,21 @@ def save_conversation_to_firestore(session_id, user_question, assistant_answer,
41
  "feedback": feedback,
42
  "timestamp": firestore.SERVER_TIMESTAMP
43
  }
44
- # Add a new document with an auto-generated ID.
45
  doc_ref = conv_ref.add(data)
46
- # doc_ref returns a tuple (write_result, document_reference)
47
  return doc_ref[1].id
48
 
49
- # -------------------------
50
- # Firestore Helper Functions
51
- # -------------------------
52
  def save_message_to_firestore(session_id, role, content, feedback=None):
53
- """
54
- Save a message to Firestore under sessions/{session_id}/messages.
55
- """
56
  messages_ref = fs_client.collection("sessions").document(session_id).collection("messages")
57
  data = {
58
  "role": role,
59
  "content": content,
60
  "feedback": feedback,
61
- "timestamp": firestore.SERVER_TIMESTAMP # Server will set the timestamp
62
  }
63
- # Add a new document with an auto-generated ID.
64
  doc_ref = messages_ref.add(data)
65
- # doc_ref returns a tuple (write_result, document_reference)
66
  return doc_ref[1].id
67
 
68
  def handle_feedback(feedback_val):
69
- # Update Firestore and update local conversation history
70
  update_feedback_in_firestore(
71
  st.session_state.session_id,
72
  st.session_state.latest_conversation_id,
@@ -74,11 +52,7 @@ def handle_feedback(feedback_val):
74
  )
75
  st.session_state.conversations[-1]["feedback"] = feedback_val
76
 
77
-
78
  def fetch_messages_from_firestore(session_id):
79
- """
80
- Fetch all messages for the given session from Firestore, ordered by timestamp.
81
- """
82
  messages_ref = fs_client.collection("sessions").document(session_id).collection("messages")
83
  docs = messages_ref.order_by("timestamp").stream()
84
  messages = []
@@ -89,42 +63,27 @@ def fetch_messages_from_firestore(session_id):
89
  return messages
90
 
91
  def update_feedback_in_firestore(session_id, conversation_id, feedback):
92
- """
93
- Update the feedback field for a conversation document.
94
- """
95
  conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id)
96
  conv_doc.update({"feedback": feedback})
97
 
98
- # -------------------------
99
- # Configuration
100
- # -------------------------
101
  class Config:
102
  CHUNK_WORDS = 300
103
- EMBEDDING_MODEL = "models/text-embedding-004" # Update as needed.
104
  TOP_N = 3
105
  SYSTEM_PROMPT = (
106
  "You are a helpful assistant. Answer the question using the provided context. "
107
  )
108
- GENERATION_MODEL = "models/gemini-1.5-flash"
109
 
110
- # -------------------------
111
- # API Key and Initialization for Generative AI
112
- # -------------------------
113
  API_KEY = os.getenv("GOOGLE_API_KEY")
114
  if not API_KEY:
115
  st.error("Google API key is not configured.")
116
  st.stop()
117
  palm.configure(api_key=API_KEY)
118
 
119
- # -------------------------
120
- # Logging Configuration
121
- # -------------------------
122
  logging.basicConfig(level=logging.INFO)
123
  logger = logging.getLogger(__name__)
124
 
125
- # -------------------------
126
- # Cached Embedding Function
127
- # -------------------------
128
  @st.cache_data(show_spinner=True)
129
  def generate_embedding_cached(text: str) -> list:
130
  logger.info("Calling API for embedding generation. Text snippet: %s", text[:50])
@@ -137,7 +96,7 @@ def generate_embedding_cached(text: str) -> list:
137
  if "embedding" not in response or not response["embedding"]:
138
  logger.error("No embedding returned from API.")
139
  st.error("No embedding returned. Please verify your API settings and input text.")
140
- return [0.0] * 768 # Fallback: list of zeros
141
  embedding = np.array(response["embedding"])
142
  if embedding.ndim == 2:
143
  embedding = embedding.flatten()
@@ -155,9 +114,6 @@ def generate_embedding(text: str) -> np.ndarray:
155
  embedding_list = generate_embedding_cached(text)
156
  return np.array(embedding_list)
157
 
158
- # -------------------------
159
- # File Handling
160
- # -------------------------
161
  def extract_text_from_file(uploaded_file) -> str:
162
  file_name = uploaded_file.name.lower()
163
  if file_name.endswith(".txt"):
@@ -180,16 +136,12 @@ def extract_text_from_file(uploaded_file) -> str:
180
  else:
181
  raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.")
182
 
183
- # -------------------------
184
- # Chunking the Document
185
- # -------------------------
186
  def chunk_text(text: str) -> list[str]:
187
  max_words = Config.CHUNK_WORDS
188
  paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
189
  chunks = []
190
  current_chunk = ""
191
  current_word_count = 0
192
-
193
  for paragraph in paragraphs:
194
  para_word_count = len(paragraph.split())
195
  if para_word_count > max_words:
@@ -221,21 +173,15 @@ def chunk_text(text: str) -> list[str]:
221
  else:
222
  current_chunk += paragraph + "\n\n"
223
  current_word_count += para_word_count
224
-
225
  if current_chunk:
226
  chunks.append(current_chunk.strip())
227
  return chunks
228
 
229
- # -------------------------
230
- # Process Document (Extract, Chunk, Embed)
231
- # -------------------------
232
  def process_document(uploaded_file) -> None:
233
  try:
234
- # Clear only document-related keys.
235
  keys_to_clear = ["document_text", "document_chunks", "document_embeddings"]
236
  for key in keys_to_clear:
237
  st.session_state.pop(key, None)
238
-
239
  file_text = extract_text_from_file(uploaded_file)
240
  if not file_text.strip():
241
  logger.error("Uploaded file contains no valid text.")
@@ -264,15 +210,11 @@ def process_document(uploaded_file) -> None:
264
  logger.error("Document processing failed: %s", e)
265
  st.error(f"An error occurred while processing the document: {e}")
266
 
267
- # -------------------------
268
- # Retrieve Relevant Chunks
269
- # -------------------------
270
  def search_query(query: str) -> list[tuple[str, float]]:
271
  if "document_embeddings" not in st.session_state or len(st.session_state["document_embeddings"]) == 0:
272
  logger.error("No valid document embeddings found in session state.")
273
  st.error("No valid document embeddings found. Please upload a valid document.")
274
  return []
275
-
276
  query_embedding = generate_embedding(query)
277
  if np.all(query_embedding == 0):
278
  logger.error("Query embedding is a zero vector.")
@@ -285,9 +227,6 @@ def search_query(query: str) -> list[tuple[str, float]]:
285
  results = [(st.session_state["document_chunks"][i], similarities[i]) for i in top_indices]
286
  return results
287
 
288
- # -------------------------
289
- # Generate Answer from LLM (RAG)
290
- # -------------------------
291
  def generate_answer(user_query: str, context: str) -> str:
292
  prompt = (
293
  f"System: {Config.SYSTEM_PROMPT}\n\n"
@@ -306,79 +245,48 @@ def generate_answer(user_query: str, context: str) -> str:
306
  st.error("Failed to generate answer. Please check your input and try again.")
307
  return "I'm sorry, I encountered an error generating a response."
308
 
309
- # -------------------------
310
- # Chat Interface
311
- # -------------------------
312
  def chat_app():
313
- # Initialize conversation history and session ID if not already set.
314
  if "conversations" not in st.session_state:
315
- st.session_state.conversations = [] # Each element is a dict with keys: user_question, assistant_answer, (optionally) feedback
316
  if "session_id" not in st.session_state:
317
  st.session_state.session_id = str(uuid.uuid4())
318
-
319
- # Display past conversations
320
  for conv in st.session_state.conversations:
321
- # Display the user's question
322
  with st.chat_message("user"):
323
  st.write(conv.get("user_question", ""))
324
- # Display the assistant's answer
325
  with st.chat_message("assistant"):
326
  st.write(conv.get("assistant_answer", ""))
327
- # Optionally, display feedback if available
328
  if conv.get("feedback"):
329
  st.markdown(f"**Feedback:** {conv['feedback']}")
330
-
331
- # Get new user input
332
  user_input = st.chat_input("Type your message here")
333
  if user_input:
334
- # Display the user input immediately.
335
  with st.chat_message("user"):
336
  st.write(user_input)
337
-
338
- # Retrieve relevant document chunks from the processed document.
339
  results = search_query(user_input)
340
  context = "\n\n".join([chunk for chunk, score in results]) if results else ""
341
-
342
- # Generate the assistant's answer using the retrieved context.
343
  answer = generate_answer(user_input, context)
344
  with st.chat_message("assistant"):
345
  st.write(answer)
346
-
347
- # Save the whole conversation (user question + assistant answer) as one document.
348
  conversation_id = save_conversation_to_firestore(
349
  st.session_state.session_id,
350
  user_question=user_input,
351
  assistant_answer=answer
352
  )
353
  st.session_state.latest_conversation_id = conversation_id
354
-
355
- # Append the conversation to session state (for UI history)
356
  st.session_state.conversations.append({
357
  "user_question": user_input,
358
  "assistant_answer": answer,
359
  })
360
-
361
- # Instead of a radio button, show two buttons for like/dislike.
362
- # Only show these buttons if the latest conversation has not yet been rated.
363
  if "feedback" not in st.session_state.conversations[-1]:
364
- col1, col2,col3,col4,col5,col6,col7,col8,col9,col10 = st.columns(10)
365
- col1.button("👍", key=f"feedback_like_{len(st.session_state.conversations)}",
366
- on_click=handle_feedback, args=("positive",))
367
- col2.button("👎", key=f"feedback_dislike_{len(st.session_state.conversations)}",
368
- on_click=handle_feedback, args=("negative",))
369
 
370
- # -------------------------
371
- # Main Application (Streamlit)
372
- # -------------------------
373
  def main():
374
  st.title("Code : Beta")
375
-
376
  st.sidebar.header("Upload Document")
377
  uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"])
378
- # Process the document only if uploaded and not already processed.
379
  if uploaded_file and not st.session_state.get("doc_processed", False):
380
  process_document(uploaded_file)
381
-
382
  if "document_text" in st.session_state:
383
  chat_app()
384
  else:
 
1
  import streamlit as st
2
+ import pdfplumber
3
+ import docx
4
  import os
5
  import re
6
  import numpy as np
7
+ import google.generativeai as palm
8
  from sklearn.metrics.pairwise import cosine_similarity
9
  import logging
10
  import time
11
  import uuid
12
  import json
 
13
  import firebase_admin
14
  from firebase_admin import credentials, firestore
15
 
 
 
 
16
  def init_firebase():
17
  if not firebase_admin._apps:
 
18
  data = json.loads(os.getenv("FIREBASE_CRED"))
 
19
  cred = credentials.Certificate(data)
 
20
  firebase_admin.initialize_app(cred)
21
 
22
  init_firebase()
 
23
  fs_client = firestore.client()
24
 
25
  def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None):
 
 
 
26
  conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations")
27
  data = {
28
  "user_question": user_question,
 
30
  "feedback": feedback,
31
  "timestamp": firestore.SERVER_TIMESTAMP
32
  }
 
33
  doc_ref = conv_ref.add(data)
 
34
  return doc_ref[1].id
35
 
 
 
 
36
  def save_message_to_firestore(session_id, role, content, feedback=None):
 
 
 
37
  messages_ref = fs_client.collection("sessions").document(session_id).collection("messages")
38
  data = {
39
  "role": role,
40
  "content": content,
41
  "feedback": feedback,
42
+ "timestamp": firestore.SERVER_TIMESTAMP
43
  }
 
44
  doc_ref = messages_ref.add(data)
 
45
  return doc_ref[1].id
46
 
47
  def handle_feedback(feedback_val):
 
48
  update_feedback_in_firestore(
49
  st.session_state.session_id,
50
  st.session_state.latest_conversation_id,
 
52
  )
53
  st.session_state.conversations[-1]["feedback"] = feedback_val
54
 
 
55
  def fetch_messages_from_firestore(session_id):
 
 
 
56
  messages_ref = fs_client.collection("sessions").document(session_id).collection("messages")
57
  docs = messages_ref.order_by("timestamp").stream()
58
  messages = []
 
63
  return messages
64
 
65
  def update_feedback_in_firestore(session_id, conversation_id, feedback):
 
 
 
66
  conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id)
67
  conv_doc.update({"feedback": feedback})
68
 
 
 
 
69
  class Config:
70
  CHUNK_WORDS = 300
71
+ EMBEDDING_MODEL = "models/text-embedding-004"
72
  TOP_N = 3
73
  SYSTEM_PROMPT = (
74
  "You are a helpful assistant. Answer the question using the provided context. "
75
  )
76
+ GENERATION_MODEL = "models/gemini-1.5-flash"
77
 
 
 
 
78
  API_KEY = os.getenv("GOOGLE_API_KEY")
79
  if not API_KEY:
80
  st.error("Google API key is not configured.")
81
  st.stop()
82
  palm.configure(api_key=API_KEY)
83
 
 
 
 
84
  logging.basicConfig(level=logging.INFO)
85
  logger = logging.getLogger(__name__)
86
 
 
 
 
87
  @st.cache_data(show_spinner=True)
88
  def generate_embedding_cached(text: str) -> list:
89
  logger.info("Calling API for embedding generation. Text snippet: %s", text[:50])
 
96
  if "embedding" not in response or not response["embedding"]:
97
  logger.error("No embedding returned from API.")
98
  st.error("No embedding returned. Please verify your API settings and input text.")
99
+ return [0.0] * 768
100
  embedding = np.array(response["embedding"])
101
  if embedding.ndim == 2:
102
  embedding = embedding.flatten()
 
114
  embedding_list = generate_embedding_cached(text)
115
  return np.array(embedding_list)
116
 
 
 
 
117
  def extract_text_from_file(uploaded_file) -> str:
118
  file_name = uploaded_file.name.lower()
119
  if file_name.endswith(".txt"):
 
136
  else:
137
  raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.")
138
 
 
 
 
139
  def chunk_text(text: str) -> list[str]:
140
  max_words = Config.CHUNK_WORDS
141
  paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
142
  chunks = []
143
  current_chunk = ""
144
  current_word_count = 0
 
145
  for paragraph in paragraphs:
146
  para_word_count = len(paragraph.split())
147
  if para_word_count > max_words:
 
173
  else:
174
  current_chunk += paragraph + "\n\n"
175
  current_word_count += para_word_count
 
176
  if current_chunk:
177
  chunks.append(current_chunk.strip())
178
  return chunks
179
 
 
 
 
180
  def process_document(uploaded_file) -> None:
181
  try:
 
182
  keys_to_clear = ["document_text", "document_chunks", "document_embeddings"]
183
  for key in keys_to_clear:
184
  st.session_state.pop(key, None)
 
185
  file_text = extract_text_from_file(uploaded_file)
186
  if not file_text.strip():
187
  logger.error("Uploaded file contains no valid text.")
 
210
  logger.error("Document processing failed: %s", e)
211
  st.error(f"An error occurred while processing the document: {e}")
212
 
 
 
 
213
  def search_query(query: str) -> list[tuple[str, float]]:
214
  if "document_embeddings" not in st.session_state or len(st.session_state["document_embeddings"]) == 0:
215
  logger.error("No valid document embeddings found in session state.")
216
  st.error("No valid document embeddings found. Please upload a valid document.")
217
  return []
 
218
  query_embedding = generate_embedding(query)
219
  if np.all(query_embedding == 0):
220
  logger.error("Query embedding is a zero vector.")
 
227
  results = [(st.session_state["document_chunks"][i], similarities[i]) for i in top_indices]
228
  return results
229
 
 
 
 
230
  def generate_answer(user_query: str, context: str) -> str:
231
  prompt = (
232
  f"System: {Config.SYSTEM_PROMPT}\n\n"
 
245
  st.error("Failed to generate answer. Please check your input and try again.")
246
  return "I'm sorry, I encountered an error generating a response."
247
 
 
 
 
248
  def chat_app():
 
249
  if "conversations" not in st.session_state:
250
+ st.session_state.conversations = []
251
  if "session_id" not in st.session_state:
252
  st.session_state.session_id = str(uuid.uuid4())
 
 
253
  for conv in st.session_state.conversations:
 
254
  with st.chat_message("user"):
255
  st.write(conv.get("user_question", ""))
 
256
  with st.chat_message("assistant"):
257
  st.write(conv.get("assistant_answer", ""))
 
258
  if conv.get("feedback"):
259
  st.markdown(f"**Feedback:** {conv['feedback']}")
 
 
260
  user_input = st.chat_input("Type your message here")
261
  if user_input:
 
262
  with st.chat_message("user"):
263
  st.write(user_input)
 
 
264
  results = search_query(user_input)
265
  context = "\n\n".join([chunk for chunk, score in results]) if results else ""
 
 
266
  answer = generate_answer(user_input, context)
267
  with st.chat_message("assistant"):
268
  st.write(answer)
 
 
269
  conversation_id = save_conversation_to_firestore(
270
  st.session_state.session_id,
271
  user_question=user_input,
272
  assistant_answer=answer
273
  )
274
  st.session_state.latest_conversation_id = conversation_id
 
 
275
  st.session_state.conversations.append({
276
  "user_question": user_input,
277
  "assistant_answer": answer,
278
  })
 
 
 
279
  if "feedback" not in st.session_state.conversations[-1]:
280
+ col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10)
281
+ col1.button("👍", key=f"feedback_like_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("positive",))
282
+ col2.button("👎", key=f"feedback_dislike_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("negative",))
 
 
283
 
 
 
 
284
  def main():
285
  st.title("Code : Beta")
 
286
  st.sidebar.header("Upload Document")
287
  uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"])
 
288
  if uploaded_file and not st.session_state.get("doc_processed", False):
289
  process_document(uploaded_file)
 
290
  if "document_text" in st.session_state:
291
  chat_app()
292
  else: