Mateo4 commited on
Commit
a57a185
·
verified ·
1 Parent(s): 3c813e8

Update app.py

Browse files

prompt become static , pdf loaded by default

Files changed (1) hide show
  1. app.py +107 -106
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import time
3
- import fitz
4
  import faiss
5
  import pickle
6
  import numpy as np
@@ -16,21 +16,20 @@ from sentence_transformers import SentenceTransformer
16
  import gradio as gr
17
 
18
  # Define the ML_prompt (as it was in your notebook)
 
19
  ML_prompt = """
20
  نقش ات:
21
  تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی
22
  این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره
23
  منبع درس کتاب بیشاپ هست
24
-
25
  لحن صحبت کردن ات:
26
  تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند
27
  """
28
- # api_key = os.getenv("google_api_key")
29
 
30
  class GeminiRAG:
31
  def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash",
32
  embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model
33
- instruction_prompt: str = ML_prompt,
34
  vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence
35
 
36
  if not api_key:
@@ -62,19 +61,26 @@ class GeminiRAG:
62
  self.load_vectorstore()
63
 
64
  def _split_into_sentences(self, text: str) -> List[str]:
 
65
  sentences = re.split(r'(?<=[.!?])\s+', text)
66
  return [s.strip() for s in sentences if s.strip()]
67
 
68
  def load_document(self, pdf_path: str) -> List[str]:
69
- doc = fitz.open(pdf_path)
70
- page_contents = []
71
- for page_num in range(len(doc)):
72
- page = doc.load_page(page_num)
73
- text = page.get_text()
74
- if text.strip():
75
- page_contents.append(text.strip())
76
- doc.close()
77
- return page_contents
 
 
 
 
 
 
78
 
79
  def add_document(self, parent_chunks: List[str]):
80
  new_sentence_chunks = []
@@ -107,7 +113,7 @@ class GeminiRAG:
107
 
108
  retrieved_parent_doc_indices = set()
109
  for idx in I[0]:
110
- if idx < len(self.sentence_chunks):
111
  parent_idx = self.sentence_to_parent_map[idx]
112
  retrieved_parent_doc_indices.add(parent_idx)
113
 
@@ -115,7 +121,7 @@ class GeminiRAG:
115
  sorted_parent_indices = sorted(list(retrieved_parent_doc_indices))
116
 
117
  for parent_idx in sorted_parent_indices:
118
- if parent_idx < len(self.parent_documents):
119
  context_parts.append(self.parent_documents[parent_idx])
120
 
121
  context = "\n\n---\\n\\n".join(context_parts)
@@ -123,17 +129,15 @@ class GeminiRAG:
123
  if not context.strip():
124
  return "No relevant information found in the knowledge base."
125
 
 
126
  prompt = f"""
127
- ### instruction prompt : (explanation : this text is your guideline don't mention it on response)
128
- {self.instruction_prompt}
129
-
130
- Use the following context to answer the question.\n
131
- Context:\n
132
- {context}\n
133
-
134
- Question: {query}\n
135
-
136
- Answer:"""
137
 
138
  for attempt in range(3):
139
  try:
@@ -142,28 +146,44 @@ class GeminiRAG:
142
  except InternalServerError as e:
143
  print(f"Error: {e}. Retrying in 5 seconds...")
144
  time.sleep(5)
145
- raise Exception("Failed to generate after 3 retries.")
 
 
 
146
 
147
  def save_vectorstore(self):
148
- faiss.write_index(self.index, self.vectorstore_faiss_path)
149
- with open(self.vectorstore_data_path, "wb") as f:
150
- pickle.dump({
151
- 'sentence_chunks': self.sentence_chunks,
152
- 'parent_documents': self.parent_documents,
153
- 'sentence_to_parent_map': self.sentence_to_parent_map
154
- }, f)
155
- print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}")
 
 
 
156
 
157
  def load_vectorstore(self):
158
  if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path):
159
- self.index = faiss.read_index(self.vectorstore_faiss_path)
160
- with open(self.vectorstore_data_path, "rb") as f:
161
- data = pickle.load(f)
162
- self.sentence_chunks = data['sentence_chunks']
163
- self.parent_documents = data['parent_documents']
164
- self.sentence_to_parent_map = data['sentence_to_parent_map']
165
- print("📦 Loaded vectorstore.")
166
- return True
 
 
 
 
 
 
 
 
 
 
167
  print("ℹ️ No saved vectorstore found.")
168
  return False
169
 
@@ -172,86 +192,69 @@ class GeminiRAG:
172
  # Get API key from environment variable
173
  api_key = os.getenv("google_api_key")
174
  if not api_key:
175
- raise ValueError("GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.")
 
176
 
177
  # Initialize the RAG system globally for the Gradio app
178
- rag_instance = GeminiRAG(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def respond(
181
  message: str,
182
  history: list[list[str]], # Gradio Chatbot history format
183
- system_message: str, # From additional_inputs
184
  max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency)
185
  temperature: float, # From additional_inputs (not directly used by RAG)
186
  top_p: float, # From additional_inputs (not directly used by RAG)
187
  ):
188
- # The `system_message` from Gradio can be used to dynamically update the RAG's instruction prompt
189
- # For this example, we'll keep the ML_prompt fixed, but you could add logic here:
190
- # rag_instance.instruction_prompt = system_message
 
 
 
191
 
192
  try:
193
- # Call your RAG system's ask_question method
194
- # The top_k parameter can be exposed in Gradio's additional_inputs if needed
195
  response = rag_instance.ask_question(message)
196
- # Gradio ChatInterface expects a generator for streaming or a direct string for non-streaming
197
- yield response # Yield the full response, as ask_question does not stream token by token
198
  except Exception as e:
199
  yield f"❌ An error occurred: {e}"
200
 
201
- def upload_and_process_documents(files):
202
- if not files:
203
- return "Please upload PDF files to process."
204
-
205
- # Re-initialize RAG instance to clear previous data and rebuild with new documents
206
- # This is a simple approach; for more complex scenarios, you might want to append
207
- # or manage different knowledge bases.
208
- print("Rebuilding knowledge base with new documents...")
209
- try:
210
- # Re-initialize to clear previous data
211
- global rag_instance
212
- rag_instance = GeminiRAG(api_key=api_key)
213
- except Exception as e:
214
- return f"Error re-initializing RAG: {e}"
215
-
216
- success_count = 0
217
- error_files = []
218
- for file_obj in files:
219
- file_path = file_obj.name # Gradio passes a NamedTemporaryFile object
220
- print(f"Processing {file_path}")
221
- try:
222
- chunks = rag_instance.load_document(file_path)
223
- rag_instance.add_document(chunks)
224
- success_count += 1
225
- except Exception as e:
226
- error_files.append(f"{os.path.basename(file_path)}: {e}")
227
-
228
- rag_instance.save_vectorstore()
229
-
230
- status_message = f"Successfully loaded and embedded {success_count} document(s)."
231
- if error_files:
232
- status_message += f"\nErrors occurred with: {'; '.join(error_files)}"
233
- return status_message
234
-
235
-
236
  # Define the Gradio ChatInterface
237
  with gr.Blocks() as demo:
238
  gr.Markdown("# Gemini RAG Chatbot for ML Theory")
239
- gr.Markdown("Upload your PDF documents, and then ask questions about the content. Ensure your `GEMINI_API_KEY` is set as a Space Secret.")
240
-
241
- with gr.Row():
242
- file_output = gr.Textbox(label="Upload Status", interactive=False)
243
- upload_button = gr.UploadButton(
244
- label="Upload PDF Documents",
245
- file_types=["pdf"],
246
- file_count="multiple"
247
- )
248
- upload_button.upload(upload_and_process_documents, inputs=upload_button, outputs=file_output)
249
-
250
- # The ChatInterface component simplifies the chat UI setup
251
  chat_interface_component = gr.ChatInterface(
252
  respond,
253
  additional_inputs=[
254
- gr.Textbox(value=ML_prompt, label="System message", info="This sets the fixed role for the AI."),
255
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."),
256
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."),
257
  gr.Slider(
@@ -265,16 +268,14 @@ with gr.Blocks() as demo:
265
  ],
266
  chatbot=gr.Chatbot(height=400),
267
  textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7),
268
- # clear_btn="Clear Chat",
269
  submit_btn="Send",
270
- # Set examples for quick testing
271
  examples=[
272
- ["درمورد boosting بهم بگو", ML_prompt, 512, 0.7, 0.95],
273
- ["انواع رگرسیون را توضیح بده", ML_prompt, 512, 0.7, 0.95],
274
- ["شبکه های عصبی چیستند؟", ML_prompt, 512, 0.7, 0.95]
275
  ]
276
  )
277
-
278
 
279
 
280
  if __name__ == "__main__":
 
1
  import os
2
  import time
3
+ import fitz # PyMuPDF
4
  import faiss
5
  import pickle
6
  import numpy as np
 
16
  import gradio as gr
17
 
18
  # Define the ML_prompt (as it was in your notebook)
19
+ # This prompt will now be hardcoded and not exposed to the user
20
  ML_prompt = """
21
  نقش ات:
22
  تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی
23
  این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره
24
  منبع درس کتاب بیشاپ هست
 
25
  لحن صحبت کردن ات:
26
  تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند
27
  """
 
28
 
29
  class GeminiRAG:
30
  def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash",
31
  embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model
32
+ instruction_prompt: str = ML_prompt, # Prompt is passed here
33
  vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence
34
 
35
  if not api_key:
 
61
  self.load_vectorstore()
62
 
63
  def _split_into_sentences(self, text: str) -> List[str]:
64
+ # Improved sentence splitting for better chunking
65
  sentences = re.split(r'(?<=[.!?])\s+', text)
66
  return [s.strip() for s in sentences if s.strip()]
67
 
68
  def load_document(self, pdf_path: str) -> List[str]:
69
+ print(f"Loading document from: {pdf_path}")
70
+ try:
71
+ doc = fitz.open(pdf_path)
72
+ page_contents = []
73
+ for page_num in range(len(doc)):
74
+ page = doc.load_page(page_num)
75
+ text = page.get_text()
76
+ if text.strip():
77
+ page_contents.append(text.strip())
78
+ doc.close()
79
+ print(f"Successfully extracted {len(page_contents)} pages from {pdf_path}")
80
+ return page_contents
81
+ except Exception as e:
82
+ print(f"Error loading PDF {pdf_path}: {e}")
83
+ raise # Re-raise the exception to be caught higher up
84
 
85
  def add_document(self, parent_chunks: List[str]):
86
  new_sentence_chunks = []
 
113
 
114
  retrieved_parent_doc_indices = set()
115
  for idx in I[0]:
116
+ if idx < len(self.sentence_chunks): # Ensure index is within bounds
117
  parent_idx = self.sentence_to_parent_map[idx]
118
  retrieved_parent_doc_indices.add(parent_idx)
119
 
 
121
  sorted_parent_indices = sorted(list(retrieved_parent_doc_indices))
122
 
123
  for parent_idx in sorted_parent_indices:
124
+ if parent_idx < len(self.parent_documents): # Ensure index is within bounds
125
  context_parts.append(self.parent_documents[parent_idx])
126
 
127
  context = "\n\n---\\n\\n".join(context_parts)
 
129
  if not context.strip():
130
  return "No relevant information found in the knowledge base."
131
 
132
+ # The instruction prompt is now self.instruction_prompt which is set at init
133
  prompt = f"""
134
+ ### instruction prompt : (explanation : this text is your guideline don't mention it on response)
135
+ {self.instruction_prompt}
136
+ Use the following context to answer the question.\n
137
+ Context:\n
138
+ {context}\n
139
+ Question: {query}\n
140
+ Answer:"""
 
 
 
141
 
142
  for attempt in range(3):
143
  try:
 
146
  except InternalServerError as e:
147
  print(f"Error: {e}. Retrying in 5 seconds...")
148
  time.sleep(5)
149
+ except Exception as e: # Catch other potential errors from API call
150
+ print(f"An unexpected error occurred during API call: {e}. Retrying in 5 seconds...")
151
+ time.sleep(5)
152
+ raise Exception("Failed to generate after 3 retries due to persistent errors.")
153
 
154
  def save_vectorstore(self):
155
+ try:
156
+ faiss.write_index(self.index, self.vectorstore_faiss_path)
157
+ with open(self.vectorstore_data_path, "wb") as f:
158
+ pickle.dump({
159
+ 'sentence_chunks': self.sentence_chunks,
160
+ 'parent_documents': self.parent_documents,
161
+ 'sentence_to_parent_map': self.sentence_to_parent_map
162
+ }, f)
163
+ print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}")
164
+ except Exception as e:
165
+ print(f"Error saving vectorstore: {e}")
166
 
167
  def load_vectorstore(self):
168
  if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path):
169
+ try:
170
+ self.index = faiss.read_index(self.vectorstore_faiss_path)
171
+ with open(self.vectorstore_data_path, "rb") as f:
172
+ data = pickle.load(f)
173
+ self.sentence_chunks = data['sentence_chunks']
174
+ self.parent_documents = data['parent_documents']
175
+ self.sentence_to_parent_map = data['sentence_to_parent_map']
176
+ print("📦 Loaded vectorstore.")
177
+ return True
178
+ except Exception as e:
179
+ print(f"Error loading vectorstore: {e}")
180
+ # If loading fails, it's better to start fresh
181
+ self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
182
+ self.sentence_chunks = []
183
+ self.parent_documents = []
184
+ self.sentence_to_parent_map = []
185
+ print("⚠️ Failed to load vectorstore, initializing a new one.")
186
+ return False
187
  print("ℹ️ No saved vectorstore found.")
188
  return False
189
 
 
192
  # Get API key from environment variable
193
  api_key = os.getenv("google_api_key")
194
  if not api_key:
195
+ print("Warning: GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.")
196
+
197
 
198
  # Initialize the RAG system globally for the Gradio app
199
+ # The ML_prompt is passed during initialization and is then part of the rag_instance state
200
+ rag_instance = GeminiRAG(api_key=api_key, instruction_prompt=ML_prompt) # Pass the prompt here
201
+
202
+ # --- Load the predefined PDF at startup ---
203
+ PDF_PATH = "MLT.pdf" # Assumes MLT.pdf is in the same directory as this script, or specify full path
204
+ VECTORSTORE_BUILT_FLAG = os.path.join(rag_instance.vectorstore_dir, "vectorstore_built_flag.txt")
205
+
206
+
207
+ if not rag_instance.load_vectorstore(): # Try to load existing
208
+ print(f"Attempting to load and process {PDF_PATH}...")
209
+ if os.path.exists(PDF_PATH):
210
+ try:
211
+ chunks = rag_instance.load_document(PDF_PATH)
212
+ if chunks:
213
+ rag_instance.add_document(chunks)
214
+ rag_instance.save_vectorstore()
215
+ with open(VECTORSTORE_BUILT_FLAG, "w") as f:
216
+ f.write("Vectorstore built successfully.")
217
+ print("Initial PDF processed and vectorstore saved.")
218
+ else:
219
+ print(f"Warning: No text extracted from {PDF_PATH}. Please check the PDF content.")
220
+ except Exception as e:
221
+ print(f"Fatal Error: Could not process {PDF_PATH} at startup: {e}")
222
+ else:
223
+ print(f"Error: {PDF_PATH} not found. Please ensure the PDF file is in the correct directory.")
224
+
225
 
226
  def respond(
227
  message: str,
228
  history: list[list[str]], # Gradio Chatbot history format
229
+ # Removed system_message from inputs as it's no longer user-configurable
230
  max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency)
231
  temperature: float, # From additional_inputs (not directly used by RAG)
232
  top_p: float, # From additional_inputs (not directly used by RAG)
233
  ):
234
+ # The instruction prompt is now handled internally by rag_instance
235
+ # No need to access a system_message input here
236
+
237
+ if not rag_instance.sentence_chunks:
238
+ yield "Knowledge base is empty. Please ensure the PDF was loaded correctly at startup."
239
+ return
240
 
241
  try:
 
 
242
  response = rag_instance.ask_question(message)
243
+ yield response
 
244
  except Exception as e:
245
  yield f"❌ An error occurred: {e}"
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  # Define the Gradio ChatInterface
248
  with gr.Blocks() as demo:
249
  gr.Markdown("# Gemini RAG Chatbot for ML Theory")
250
+ gr.Markdown(f"This chatbot is powered by {PDF_PATH}. Ensure your `GEMINI_API_KEY` is set as a Space Secret.")
251
+
252
+ # No file upload section anymore
253
+
 
 
 
 
 
 
 
 
254
  chat_interface_component = gr.ChatInterface(
255
  respond,
256
  additional_inputs=[
257
+ # Removed the Textbox for system_message
258
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."),
259
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."),
260
  gr.Slider(
 
268
  ],
269
  chatbot=gr.Chatbot(height=400),
270
  textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7),
 
271
  submit_btn="Send",
272
+ # Update examples as the system_message input is no longer present
273
  examples=[
274
+ ["درمورد boosting بهم بگو", 512, 0.7, 0.95],
275
+ ["انواع رگرسیون را توضیح بده", 512, 0.7, 0.95],
276
+ ["شبکه های عصبی چیستند؟", 512, 0.7, 0.95]
277
  ]
278
  )
 
279
 
280
 
281
  if __name__ == "__main__":