Mateo4 commited on
Commit
d02a876
·
verified ·
1 Parent(s): 2d1eba3

Update app.py

Browse files

GeminiRAG added to app.py

Files changed (1) hide show
  1. app.py +258 -47
app.py CHANGED
@@ -1,64 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
+ import time
3
+ import fitz
4
+ import faiss
5
+ import pickle
6
+ import numpy as np
7
+ from typing import List, Dict
8
+ import re
9
+
10
+ import google.generativeai as genai
11
+ from google.api_core.exceptions import InternalServerError
12
+
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ # Import gradio for the web interface
16
  import gradio as gr
 
17
 
18
+ # Define the ML_prompt (as it was in your notebook)
19
+ ML_prompt = """
20
+ نقش ات:
21
+ تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی
22
+ این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره
23
+ منبع درس کتاب بیشاپ هست
24
+
25
+ لحن صحبت کردن ات:
26
+ تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند
27
  """
 
 
 
28
 
29
+ class GeminiRAG:
30
+ def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash",
31
+ embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model
32
+ instruction_prompt: str = ML_prompt,
33
+ vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence
34
+
35
+ if not api_key:
36
+ raise ValueError("API key is missing.")
37
+
38
+ self.instruction_prompt = instruction_prompt
39
+ self.vectorstore_dir = vectorstore_dir
40
+ self.vectorstore_faiss_path = os.path.join(self.vectorstore_dir, "faiss_index.index")
41
+ self.vectorstore_data_path = os.path.join(self.vectorstore_dir, "faiss_data.pkl")
42
+
43
+ # Ensure vectorstore directory exists
44
+ os.makedirs(self.vectorstore_dir, exist_ok=True)
45
+
46
+ # Setup Gemini
47
+ genai.configure(api_key=api_key)
48
+ self.model = genai.GenerativeModel(model_name=model_name)
49
+
50
+ # Setup Embedder
51
+ self.embedder = SentenceTransformer(embed_model_name)
52
+
53
+ # FAISS index and storage for sentence chunks and their parent documents
54
+ embedding_dim = self.embedder.get_sentence_embedding_dimension() # Get embedding dimension
55
+ self.index = faiss.IndexFlatL2(embedding_dim)
56
+ self.sentence_chunks: List[str] = []
57
+ self.parent_documents: List[str] = []
58
+ self.sentence_to_parent_map: List[int] = []
59
+
60
+ # Load existing vector store if available
61
+ self.load_vectorstore()
62
+
63
+ def _split_into_sentences(self, text: str) -> List[str]:
64
+ sentences = re.split(r'(?<=[.!?])\s+', text)
65
+ return [s.strip() for s in sentences if s.strip()]
66
+
67
+ def load_document(self, pdf_path: str) -> List[str]:
68
+ doc = fitz.open(pdf_path)
69
+ page_contents = []
70
+ for page_num in range(len(doc)):
71
+ page = doc.load_page(page_num)
72
+ text = page.get_text()
73
+ if text.strip():
74
+ page_contents.append(text.strip())
75
+ doc.close()
76
+ return page_contents
77
+
78
+ def add_document(self, parent_chunks: List[str]):
79
+ new_sentence_chunks = []
80
+ new_sentence_to_parent_map = []
81
+ current_parent_doc_index = len(self.parent_documents)
82
+
83
+ for parent_chunk in parent_chunks:
84
+ self.parent_documents.append(parent_chunk)
85
+ sentences = self._split_into_sentences(parent_chunk)
86
+ for sentence in sentences:
87
+ new_sentence_chunks.append(sentence)
88
+ new_sentence_to_parent_map.append(current_parent_doc_index)
89
+ current_parent_doc_index += 1
90
+
91
+ if new_sentence_chunks:
92
+ embeddings = self.embedder.encode(new_sentence_chunks, batch_size=32, convert_to_numpy=True)
93
+ self.index.add(np.array(embeddings))
94
+ self.sentence_chunks.extend(new_sentence_chunks)
95
+ self.sentence_to_parent_map.extend(new_sentence_to_parent_map)
96
+ print(f"Added {len(new_sentence_chunks)} sentence chunks from {len(parent_chunks)} parent documents.")
97
+ else:
98
+ print("No new sentence chunks to add.")
99
+
100
+ def ask_question(self, query: str, top_k: int = 5) -> str:
101
+ if not self.sentence_chunks or not self.parent_documents:
102
+ return "Knowledge base is empty. Please load documents first."
103
+
104
+ query_emb = self.embedder.encode([query], convert_to_numpy=True)
105
+ D, I = self.index.search(np.array(query_emb), top_k)
106
+
107
+ retrieved_parent_doc_indices = set()
108
+ for idx in I[0]:
109
+ if idx < len(self.sentence_chunks):
110
+ parent_idx = self.sentence_to_parent_map[idx]
111
+ retrieved_parent_doc_indices.add(parent_idx)
112
+
113
+ context_parts = []
114
+ sorted_parent_indices = sorted(list(retrieved_parent_doc_indices))
115
+
116
+ for parent_idx in sorted_parent_indices:
117
+ if parent_idx < len(self.parent_documents):
118
+ context_parts.append(self.parent_documents[parent_idx])
119
+
120
+ context = "\n\n---\\n\\n".join(context_parts)
121
+
122
+ if not context.strip():
123
+ return "No relevant information found in the knowledge base."
124
+
125
+ prompt = f"""
126
+ ### instruction prompt : (explanation : this text is your guideline don't mention it on response)
127
+ {self.instruction_prompt}
128
+
129
+ Use the following context to answer the question.\n
130
+ Context:\n
131
+ {context}\n
132
+
133
+ Question: {query}\n
134
+
135
+ Answer:"""
136
+
137
+ for attempt in range(3):
138
+ try:
139
+ response = self.model.generate_content(prompt)
140
+ return response.text
141
+ except InternalServerError as e:
142
+ print(f"Error: {e}. Retrying in 5 seconds...")
143
+ time.sleep(5)
144
+ raise Exception("Failed to generate after 3 retries.")
145
+
146
+ def save_vectorstore(self):
147
+ faiss.write_index(self.index, self.vectorstore_faiss_path)
148
+ with open(self.vectorstore_data_path, "wb") as f:
149
+ pickle.dump({
150
+ 'sentence_chunks': self.sentence_chunks,
151
+ 'parent_documents': self.parent_documents,
152
+ 'sentence_to_parent_map': self.sentence_to_parent_map
153
+ }, f)
154
+ print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}")
155
+
156
+ def load_vectorstore(self):
157
+ if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path):
158
+ self.index = faiss.read_index(self.vectorstore_faiss_path)
159
+ with open(self.vectorstore_data_path, "rb") as f:
160
+ data = pickle.load(f)
161
+ self.sentence_chunks = data['sentence_chunks']
162
+ self.parent_documents = data['parent_documents']
163
+ self.sentence_to_parent_map = data['sentence_to_parent_map']
164
+ print("📦 Loaded vectorstore.")
165
+ return True
166
+ print("ℹ️ No saved vectorstore found.")
167
+ return False
168
+
169
+ # --- Gradio Interface Setup ---
170
+
171
+ # Get API key from environment variable
172
+ api_key = os.getenv("GEMINI_API_KEY")
173
+ if not api_key:
174
+ raise ValueError("GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.")
175
+
176
+ # Initialize the RAG system globally for the Gradio app
177
+ rag_instance = GeminiRAG(api_key=api_key)
178
 
179
  def respond(
180
+ message: str,
181
+ history: list[list[str]], # Gradio Chatbot history format
182
+ system_message: str, # From additional_inputs
183
+ max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency)
184
+ temperature: float, # From additional_inputs (not directly used by RAG)
185
+ top_p: float, # From additional_inputs (not directly used by RAG)
186
  ):
187
+ # The `system_message` from Gradio can be used to dynamically update the RAG's instruction prompt
188
+ # For this example, we'll keep the ML_prompt fixed, but you could add logic here:
189
+ # rag_instance.instruction_prompt = system_message
190
 
191
+ try:
192
+ # Call your RAG system's ask_question method
193
+ # The top_k parameter can be exposed in Gradio's additional_inputs if needed
194
+ response = rag_instance.ask_question(message)
195
+ # Gradio ChatInterface expects a generator for streaming or a direct string for non-streaming
196
+ yield response # Yield the full response, as ask_question does not stream token by token
197
+ except Exception as e:
198
+ yield f"❌ An error occurred: {e}"
199
 
200
+ def upload_and_process_documents(files):
201
+ if not files:
202
+ return "Please upload PDF files to process."
203
 
204
+ # Re-initialize RAG instance to clear previous data and rebuild with new documents
205
+ # This is a simple approach; for more complex scenarios, you might want to append
206
+ # or manage different knowledge bases.
207
+ print("Rebuilding knowledge base with new documents...")
208
+ try:
209
+ # Re-initialize to clear previous data
210
+ global rag_instance
211
+ rag_instance = GeminiRAG(api_key=api_key)
212
+ except Exception as e:
213
+ return f"Error re-initializing RAG: {e}"
214
 
215
+ success_count = 0
216
+ error_files = []
217
+ for file_obj in files:
218
+ file_path = file_obj.name # Gradio passes a NamedTemporaryFile object
219
+ print(f"Processing {file_path}")
220
+ try:
221
+ chunks = rag_instance.load_document(file_path)
222
+ rag_instance.add_document(chunks)
223
+ success_count += 1
224
+ except Exception as e:
225
+ error_files.append(f"{os.path.basename(file_path)}: {e}")
226
 
227
+ rag_instance.save_vectorstore()
 
228
 
229
+ status_message = f"Successfully loaded and embedded {success_count} document(s)."
230
+ if error_files:
231
+ status_message += f"\nErrors occurred with: {'; '.join(error_files)}"
232
+ return status_message
233
 
234
+
235
+ # Define the Gradio ChatInterface
236
+ with gr.Blocks() as demo:
237
+ gr.Markdown("# Gemini RAG Chatbot for ML Theory")
238
+ gr.Markdown("Upload your PDF documents, and then ask questions about the content. Ensure your `GEMINI_API_KEY` is set as a Space Secret.")
239
+
240
+ with gr.Row():
241
+ file_output = gr.Textbox(label="Upload Status", interactive=False)
242
+ upload_button = gr.UploadButton(
243
+ label="Upload PDF Documents",
244
+ file_types=["pdf"],
245
+ file_count="multiple"
246
+ )
247
+ upload_button.upload(upload_and_process_documents, inputs=upload_button, outputs=file_output)
248
+
249
+ # The ChatInterface component simplifies the chat UI setup
250
+ chat_interface_component = gr.ChatInterface(
251
+ respond,
252
+ additional_inputs=[
253
+ gr.Textbox(value=ML_prompt, label="System message", info="This sets the fixed role for the AI."), # Keep ML_prompt fixed
254
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."),
255
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."),
256
+ gr.Slider(
257
+ minimum=0.1,
258
+ maximum=1.0,
259
+ value=0.95,
260
+ step=0.05,
261
+ label="Top-p (nucleus sampling)",
262
+ info="Not directly used by RAG model."
263
+ ),
264
+ ],
265
+ chatbot=gr.Chatbot(height=400),
266
+ textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7),
267
+ clear_btn="Clear Chat",
268
+ submit_btn="Send",
269
+ # Set examples for quick testing
270
+ examples=["درمورد boosting بهم بگو", "انواع رگرسیون را توضیح بده", "شبکه های عصبی چیستند؟"]
271
+ )
272
 
273
 
274
  if __name__ == "__main__":
275
+ demo.launch()