Ismetdh commited on
Commit
98d0333
·
verified ·
1 Parent(s): 8849b0e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pdfplumber # For PDF extraction
3
+ import docx # For DOCX extraction
4
+ import os
5
+ import re
6
+ import numpy as np
7
+ import google.generativeai as palm # For embedding generation
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import logging
10
+ import time
11
+ import uuid
12
+ import json
13
+ # Firebase integration imports
14
+ import firebase_admin
15
+ from firebase_admin import credentials, firestore
16
+
17
+ # -------------------------
18
+ # Firebase Initialization using Firestore
19
+ # -------------------------
20
+ def init_firebase():
21
+ if not firebase_admin._apps:
22
+ # Replace with the path to your Firebase service account key JSON file.
23
+ data = json.loads(os.getenv("FIREBASE_CRED"))
24
+
25
+ cred = credentials.Certificate(data)
26
+ # No databaseURL is provided because we're using Firestore.
27
+ firebase_admin.initialize_app(cred)
28
+
29
+ init_firebase()
30
+ # Create a Firestore client
31
+ fs_client = firestore.client()
32
+
33
+ def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None):
34
+ """
35
+ Save a complete conversation (user question + assistant answer + feedback) as a single document.
36
+ """
37
+ conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations")
38
+ data = {
39
+ "user_question": user_question,
40
+ "assistant_answer": assistant_answer,
41
+ "feedback": feedback,
42
+ "timestamp": firestore.SERVER_TIMESTAMP
43
+ }
44
+ # Add a new document with an auto-generated ID.
45
+ doc_ref = conv_ref.add(data)
46
+ # doc_ref returns a tuple (write_result, document_reference)
47
+ return doc_ref[1].id
48
+
49
+ # -------------------------
50
+ # Firestore Helper Functions
51
+ # -------------------------
52
+ def save_message_to_firestore(session_id, role, content, feedback=None):
53
+ """
54
+ Save a message to Firestore under sessions/{session_id}/messages.
55
+ """
56
+ messages_ref = fs_client.collection("sessions").document(session_id).collection("messages")
57
+ data = {
58
+ "role": role,
59
+ "content": content,
60
+ "feedback": feedback,
61
+ "timestamp": firestore.SERVER_TIMESTAMP # Server will set the timestamp
62
+ }
63
+ # Add a new document with an auto-generated ID.
64
+ doc_ref = messages_ref.add(data)
65
+ # doc_ref returns a tuple (write_result, document_reference)
66
+ return doc_ref[1].id
67
+
68
+ def handle_feedback(feedback_val):
69
+ # Update Firestore and update local conversation history
70
+ update_feedback_in_firestore(
71
+ st.session_state.session_id,
72
+ st.session_state.latest_conversation_id,
73
+ feedback_val
74
+ )
75
+ st.session_state.conversations[-1]["feedback"] = feedback_val
76
+
77
+
78
+ def fetch_messages_from_firestore(session_id):
79
+ """
80
+ Fetch all messages for the given session from Firestore, ordered by timestamp.
81
+ """
82
+ messages_ref = fs_client.collection("sessions").document(session_id).collection("messages")
83
+ docs = messages_ref.order_by("timestamp").stream()
84
+ messages = []
85
+ for doc in docs:
86
+ data = doc.to_dict()
87
+ data["id"] = doc.id
88
+ messages.append(data)
89
+ return messages
90
+
91
+ def update_feedback_in_firestore(session_id, conversation_id, feedback):
92
+ """
93
+ Update the feedback field for a conversation document.
94
+ """
95
+ conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id)
96
+ conv_doc.update({"feedback": feedback})
97
+
98
+ # -------------------------
99
+ # Configuration
100
+ # -------------------------
101
+ class Config:
102
+ CHUNK_WORDS = 300
103
+ EMBEDDING_MODEL = "models/text-embedding-004" # Update as needed.
104
+ TOP_N = 3
105
+ SYSTEM_PROMPT = (
106
+ "You are a helpful assistant. Answer the question using the provided context. "
107
+ )
108
+ GENERATION_MODEL = "models/gemini-1.5-flash"
109
+
110
+ # -------------------------
111
+ # API Key and Initialization for Generative AI
112
+ # -------------------------
113
+ API_KEY = os.getenv("GOOGLE_API_KEY")
114
+ if not API_KEY:
115
+ st.error("Google API key is not configured.")
116
+ st.stop()
117
+ palm.configure(api_key=API_KEY)
118
+
119
+ # -------------------------
120
+ # Logging Configuration
121
+ # -------------------------
122
+ logging.basicConfig(level=logging.INFO)
123
+ logger = logging.getLogger(__name__)
124
+
125
+ # -------------------------
126
+ # Cached Embedding Function
127
+ # -------------------------
128
+ @st.cache_data(show_spinner=True)
129
+ def generate_embedding_cached(text: str) -> list:
130
+ logger.info("Calling API for embedding generation. Text snippet: %s", text[:50])
131
+ try:
132
+ response = palm.embed_content(
133
+ model=Config.EMBEDDING_MODEL,
134
+ content=text,
135
+ task_type="retrieval_document"
136
+ )
137
+ if "embedding" not in response or not response["embedding"]:
138
+ logger.error("No embedding returned from API.")
139
+ st.error("No embedding returned. Please verify your API settings and input text.")
140
+ return [0.0] * 768 # Fallback: list of zeros
141
+ embedding = np.array(response["embedding"])
142
+ if embedding.ndim == 2:
143
+ embedding = embedding.flatten()
144
+ elif embedding.ndim > 2:
145
+ logger.error("Embedding has more than 2 dimensions.")
146
+ st.error("Invalid embedding dimensions. Please check the API response.")
147
+ return [0.0] * 768
148
+ return embedding.tolist()
149
+ except Exception as e:
150
+ logger.error("Embedding generation failed: %s", e)
151
+ st.error(f"Embedding generation failed: {e}")
152
+ return [0.0] * 768
153
+
154
+ def generate_embedding(text: str) -> np.ndarray:
155
+ embedding_list = generate_embedding_cached(text)
156
+ return np.array(embedding_list)
157
+
158
+ # -------------------------
159
+ # File Handling
160
+ # -------------------------
161
+ def extract_text_from_file(uploaded_file) -> str:
162
+ file_name = uploaded_file.name.lower()
163
+ if file_name.endswith(".txt"):
164
+ logger.info("Processing TXT file.")
165
+ return uploaded_file.read().decode("utf-8")
166
+ elif file_name.endswith(".pdf"):
167
+ logger.info("Processing PDF file.")
168
+ with pdfplumber.open(uploaded_file) as pdf:
169
+ text = "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()])
170
+ if not text:
171
+ logger.error("PDF extraction returned empty text.")
172
+ return text
173
+ elif file_name.endswith(".docx"):
174
+ logger.info("Processing DOCX file.")
175
+ doc = docx.Document(uploaded_file)
176
+ text = "\n".join([para.text for para in doc.paragraphs])
177
+ if not text:
178
+ logger.error("DOCX extraction returned empty text.")
179
+ return text
180
+ else:
181
+ raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.")
182
+
183
+ # -------------------------
184
+ # Chunking the Document
185
+ # -------------------------
186
+ def chunk_text(text: str) -> list[str]:
187
+ max_words = Config.CHUNK_WORDS
188
+ paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
189
+ chunks = []
190
+ current_chunk = ""
191
+ current_word_count = 0
192
+
193
+ for paragraph in paragraphs:
194
+ para_word_count = len(paragraph.split())
195
+ if para_word_count > max_words:
196
+ if current_chunk:
197
+ chunks.append(current_chunk.strip())
198
+ current_chunk = ""
199
+ current_word_count = 0
200
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
201
+ temp_chunk = ""
202
+ temp_word_count = 0
203
+ for sentence in sentences:
204
+ sentence_word_count = len(sentence.split())
205
+ if temp_word_count + sentence_word_count > max_words:
206
+ if temp_chunk:
207
+ chunks.append(temp_chunk.strip())
208
+ temp_chunk = sentence + " "
209
+ temp_word_count = sentence_word_count
210
+ else:
211
+ temp_chunk += sentence + " "
212
+ temp_word_count += sentence_word_count
213
+ if temp_chunk:
214
+ chunks.append(temp_chunk.strip())
215
+ else:
216
+ if current_word_count + para_word_count > max_words:
217
+ if current_chunk:
218
+ chunks.append(current_chunk.strip())
219
+ current_chunk = paragraph + "\n\n"
220
+ current_word_count = para_word_count
221
+ else:
222
+ current_chunk += paragraph + "\n\n"
223
+ current_word_count += para_word_count
224
+
225
+ if current_chunk:
226
+ chunks.append(current_chunk.strip())
227
+ return chunks
228
+
229
+ # -------------------------
230
+ # Process Document (Extract, Chunk, Embed)
231
+ # -------------------------
232
+ def process_document(uploaded_file) -> None:
233
+ try:
234
+ # Clear only document-related keys.
235
+ keys_to_clear = ["document_text", "document_chunks", "document_embeddings"]
236
+ for key in keys_to_clear:
237
+ st.session_state.pop(key, None)
238
+
239
+ file_text = extract_text_from_file(uploaded_file)
240
+ if not file_text.strip():
241
+ logger.error("Uploaded file contains no valid text.")
242
+ st.error("The uploaded file contains no valid text.")
243
+ return
244
+ chunks = chunk_text(file_text)
245
+ if not chunks:
246
+ logger.error("No chunks generated from text.")
247
+ st.error("Failed to split text into chunks.")
248
+ return
249
+ embeddings = [generate_embedding(chunk) for chunk in chunks]
250
+ if all(np.all(embedding == 0) for embedding in embeddings):
251
+ logger.error("All embeddings are zero vectors.")
252
+ st.error("Failed to generate valid embeddings.")
253
+ return
254
+ st.session_state.update({
255
+ "document_text": file_text,
256
+ "document_chunks": chunks,
257
+ "document_embeddings": embeddings
258
+ })
259
+ if not st.session_state.get("doc_processed", False):
260
+ message_placeholder = st.empty()
261
+ message_placeholder.success("Document processing complete! You can now start chatting.")
262
+ st.session_state.doc_processed = True
263
+ except Exception as e:
264
+ logger.error("Document processing failed: %s", e)
265
+ st.error(f"An error occurred while processing the document: {e}")
266
+
267
+ # -------------------------
268
+ # Retrieve Relevant Chunks
269
+ # -------------------------
270
+ def search_query(query: str) -> list[tuple[str, float]]:
271
+ if "document_embeddings" not in st.session_state or len(st.session_state["document_embeddings"]) == 0:
272
+ logger.error("No valid document embeddings found in session state.")
273
+ st.error("No valid document embeddings found. Please upload a valid document.")
274
+ return []
275
+
276
+ query_embedding = generate_embedding(query)
277
+ if np.all(query_embedding == 0):
278
+ logger.error("Query embedding is a zero vector.")
279
+ st.error("Failed to generate a valid query embedding.")
280
+ return []
281
+ query_embedding = query_embedding.reshape(1, -1)
282
+ doc_embeddings = np.vstack(st.session_state["document_embeddings"])
283
+ similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
284
+ top_indices = np.argsort(similarities)[-Config.TOP_N:][::-1]
285
+ results = [(st.session_state["document_chunks"][i], similarities[i]) for i in top_indices]
286
+ return results
287
+
288
+ # -------------------------
289
+ # Generate Answer from LLM (RAG)
290
+ # -------------------------
291
+ def generate_answer(user_query: str, context: str) -> str:
292
+ prompt = (
293
+ f"System: {Config.SYSTEM_PROMPT}\n\n"
294
+ f"Context:\n{context}\n\n"
295
+ f"User: {user_query}\nAssistant:"
296
+ )
297
+ try:
298
+ model = palm.GenerativeModel(Config.GENERATION_MODEL)
299
+ response = model.generate_content(prompt)
300
+ if hasattr(response, "text"):
301
+ return response.text
302
+ else:
303
+ return response
304
+ except Exception as e:
305
+ logger.error("Failed to generate answer: %s", e)
306
+ st.error("Failed to generate answer. Please check your input and try again.")
307
+ return "I'm sorry, I encountered an error generating a response."
308
+
309
+ # -------------------------
310
+ # Chat Interface
311
+ # -------------------------
312
+ def chat_app():
313
+ # Initialize conversation history and session ID if not already set.
314
+ if "conversations" not in st.session_state:
315
+ st.session_state.conversations = [] # Each element is a dict with keys: user_question, assistant_answer, (optionally) feedback
316
+ if "session_id" not in st.session_state:
317
+ st.session_state.session_id = str(uuid.uuid4())
318
+
319
+ # Display past conversations
320
+ for conv in st.session_state.conversations:
321
+ # Display the user's question
322
+ with st.chat_message("user"):
323
+ st.write(conv.get("user_question", ""))
324
+ # Display the assistant's answer
325
+ with st.chat_message("assistant"):
326
+ st.write(conv.get("assistant_answer", ""))
327
+ # Optionally, display feedback if available
328
+ if conv.get("feedback"):
329
+ st.markdown(f"**Feedback:** {conv['feedback']}")
330
+
331
+ # Get new user input
332
+ user_input = st.chat_input("Type your message here")
333
+ if user_input:
334
+ # Display the user input immediately.
335
+ with st.chat_message("user"):
336
+ st.write(user_input)
337
+
338
+ # Retrieve relevant document chunks from the processed document.
339
+ results = search_query(user_input)
340
+ context = "\n\n".join([chunk for chunk, score in results]) if results else ""
341
+
342
+ # Generate the assistant's answer using the retrieved context.
343
+ answer = generate_answer(user_input, context)
344
+ with st.chat_message("assistant"):
345
+ st.write(answer)
346
+
347
+ # Save the whole conversation (user question + assistant answer) as one document.
348
+ conversation_id = save_conversation_to_firestore(
349
+ st.session_state.session_id,
350
+ user_question=user_input,
351
+ assistant_answer=answer
352
+ )
353
+ st.session_state.latest_conversation_id = conversation_id
354
+
355
+ # Append the conversation to session state (for UI history)
356
+ st.session_state.conversations.append({
357
+ "user_question": user_input,
358
+ "assistant_answer": answer,
359
+ })
360
+
361
+ # Instead of a radio button, show two buttons for like/dislike.
362
+ # Only show these buttons if the latest conversation has not yet been rated.
363
+ if "feedback" not in st.session_state.conversations[-1]:
364
+ col1, col2,col3,col4,col5,col6,col7,col8,col9,col10 = st.columns(10)
365
+ col1.button("👍", key=f"feedback_like_{len(st.session_state.conversations)}",
366
+ on_click=handle_feedback, args=("positive",))
367
+ col2.button("👎", key=f"feedback_dislike_{len(st.session_state.conversations)}",
368
+ on_click=handle_feedback, args=("negative",))
369
+
370
+ # -------------------------
371
+ # Main Application (Streamlit)
372
+ # -------------------------
373
+ def main():
374
+ st.title("Code : Beta")
375
+
376
+ st.sidebar.header("Upload Document")
377
+ uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"])
378
+ # Process the document only if uploaded and not already processed.
379
+ if uploaded_file and not st.session_state.get("doc_processed", False):
380
+ process_document(uploaded_file)
381
+
382
+ if "document_text" in st.session_state:
383
+ chat_app()
384
+ else:
385
+ st.info("Please upload and process a document from the sidebar to start chatting.")
386
+
387
+ if __name__ == "__main__":
388
+ main()