Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import fitz # PyMuPDF | |
| import faiss | |
| import pickle | |
| import numpy as np | |
| from typing import List, Dict | |
| import re | |
| import google.generativeai as genai | |
| from google.api_core.exceptions import InternalServerError | |
| from sentence_transformers import SentenceTransformer | |
| # Import gradio for the web interface | |
| import gradio as gr | |
| # Define the ML_prompt (as it was in your notebook) | |
| # This prompt will now be hardcoded and not exposed to the user | |
| ML_prompt = """ | |
| نقش ات: | |
| تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی | |
| این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره | |
| منبع درس کتاب بیشاپ هست | |
| لحن صحبت کردن ات: | |
| تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند | |
| """ | |
| class GeminiRAG: | |
| def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash", | |
| embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model | |
| instruction_prompt: str = ML_prompt, # Prompt is passed here | |
| vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence | |
| if not api_key: | |
| raise ValueError("API key is missing.") | |
| self.instruction_prompt = instruction_prompt | |
| self.vectorstore_dir = vectorstore_dir | |
| self.vectorstore_faiss_path = os.path.join(self.vectorstore_dir, "faiss_index.index") | |
| self.vectorstore_data_path = os.path.join(self.vectorstore_dir, "faiss_data.pkl") | |
| # Ensure vectorstore directory exists | |
| os.makedirs(self.vectorstore_dir, exist_ok=True) | |
| # Setup Gemini | |
| genai.configure(api_key=api_key) | |
| self.model = genai.GenerativeModel(model_name=model_name) | |
| # Setup Embedder | |
| self.embedder = SentenceTransformer(embed_model_name) | |
| # FAISS index and storage for sentence chunks and their parent documents | |
| embedding_dim = self.embedder.get_sentence_embedding_dimension() # Get embedding dimension | |
| self.index = faiss.IndexFlatL2(embedding_dim) | |
| self.sentence_chunks: List[str] = [] | |
| self.parent_documents: List[str] = [] | |
| self.sentence_to_parent_map: List[int] = [] | |
| # Load existing vector store if available | |
| self.load_vectorstore() | |
| def _split_into_sentences(self, text: str) -> List[str]: | |
| # Improved sentence splitting for better chunking | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def load_document(self, pdf_path: str) -> List[str]: | |
| print(f"Loading document from: {pdf_path}") | |
| try: | |
| doc = fitz.open(pdf_path) | |
| page_contents = [] | |
| for page_num in range(len(doc)): | |
| page = doc.load_page(page_num) | |
| text = page.get_text() | |
| if text.strip(): | |
| page_contents.append(text.strip()) | |
| doc.close() | |
| print(f"Successfully extracted {len(page_contents)} pages from {pdf_path}") | |
| return page_contents | |
| except Exception as e: | |
| print(f"Error loading PDF {pdf_path}: {e}") | |
| raise # Re-raise the exception to be caught higher up | |
| def add_document(self, parent_chunks: List[str]): | |
| new_sentence_chunks = [] | |
| new_sentence_to_parent_map = [] | |
| current_parent_doc_index = len(self.parent_documents) | |
| for parent_chunk in parent_chunks: | |
| self.parent_documents.append(parent_chunk) | |
| sentences = self._split_into_sentences(parent_chunk) | |
| for sentence in sentences: | |
| new_sentence_chunks.append(sentence) | |
| new_sentence_to_parent_map.append(current_parent_doc_index) | |
| current_parent_doc_index += 1 | |
| if new_sentence_chunks: | |
| embeddings = self.embedder.encode(new_sentence_chunks, batch_size=32, convert_to_numpy=True) | |
| self.index.add(np.array(embeddings)) | |
| self.sentence_chunks.extend(new_sentence_chunks) | |
| self.sentence_to_parent_map.extend(new_sentence_to_parent_map) | |
| print(f"Added {len(new_sentence_chunks)} sentence chunks from {len(parent_chunks)} parent documents.") | |
| else: | |
| print("No new sentence chunks to add.") | |
| def ask_question(self, query: str, top_k: int = 5) -> str: | |
| if not self.sentence_chunks or not self.parent_documents: | |
| return "Knowledge base is empty. Please load documents first." | |
| query_emb = self.embedder.encode([query], convert_to_numpy=True) | |
| D, I = self.index.search(np.array(query_emb), top_k) | |
| retrieved_parent_doc_indices = set() | |
| for idx in I[0]: | |
| if idx < len(self.sentence_chunks): # Ensure index is within bounds | |
| parent_idx = self.sentence_to_parent_map[idx] | |
| retrieved_parent_doc_indices.add(parent_idx) | |
| context_parts = [] | |
| sorted_parent_indices = sorted(list(retrieved_parent_doc_indices)) | |
| for parent_idx in sorted_parent_indices: | |
| if parent_idx < len(self.parent_documents): # Ensure index is within bounds | |
| context_parts.append(self.parent_documents[parent_idx]) | |
| context = "\n\n---\\n\\n".join(context_parts) | |
| if not context.strip(): | |
| return "No relevant information found in the knowledge base." | |
| # The instruction prompt is now self.instruction_prompt which is set at init | |
| prompt = f""" | |
| ### instruction prompt : (explanation : this text is your guideline don't mention it on response) | |
| {self.instruction_prompt} | |
| Use the following context to answer the question.\n | |
| Context:\n | |
| {context}\n | |
| Question: {query}\n | |
| Answer:""" | |
| for attempt in range(3): | |
| try: | |
| response = self.model.generate_content(prompt) | |
| return response.text | |
| except InternalServerError as e: | |
| print(f"Error: {e}. Retrying in 5 seconds...") | |
| time.sleep(5) | |
| except Exception as e: # Catch other potential errors from API call | |
| print(f"An unexpected error occurred during API call: {e}. Retrying in 5 seconds...") | |
| time.sleep(5) | |
| raise Exception("Failed to generate after 3 retries due to persistent errors.") | |
| def save_vectorstore(self): | |
| try: | |
| faiss.write_index(self.index, self.vectorstore_faiss_path) | |
| with open(self.vectorstore_data_path, "wb") as f: | |
| pickle.dump({ | |
| 'sentence_chunks': self.sentence_chunks, | |
| 'parent_documents': self.parent_documents, | |
| 'sentence_to_parent_map': self.sentence_to_parent_map | |
| }, f) | |
| print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}") | |
| except Exception as e: | |
| print(f"Error saving vectorstore: {e}") | |
| def load_vectorstore(self): | |
| if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path): | |
| try: | |
| self.index = faiss.read_index(self.vectorstore_faiss_path) | |
| with open(self.vectorstore_data_path, "rb") as f: | |
| data = pickle.load(f) | |
| self.sentence_chunks = data['sentence_chunks'] | |
| self.parent_documents = data['parent_documents'] | |
| self.sentence_to_parent_map = data['sentence_to_parent_map'] | |
| print("📦 Loaded vectorstore.") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading vectorstore: {e}") | |
| # If loading fails, it's better to start fresh | |
| self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension()) | |
| self.sentence_chunks = [] | |
| self.parent_documents = [] | |
| self.sentence_to_parent_map = [] | |
| print("⚠️ Failed to load vectorstore, initializing a new one.") | |
| return False | |
| print("ℹ️ No saved vectorstore found.") | |
| return False | |
| # --- Gradio Interface Setup --- | |
| # Get API key from environment variable | |
| api_key = os.getenv("google_api_key") | |
| if not api_key: | |
| print("Warning: GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.") | |
| # Initialize the RAG system globally for the Gradio app | |
| # The ML_prompt is passed during initialization and is then part of the rag_instance state | |
| rag_instance = GeminiRAG(api_key=api_key, instruction_prompt=ML_prompt) # Pass the prompt here | |
| # --- Load the predefined PDF at startup --- | |
| PDF_PATH = "MLT.pdf" # Assumes MLT.pdf is in the same directory as this script, or specify full path | |
| VECTORSTORE_BUILT_FLAG = os.path.join(rag_instance.vectorstore_dir, "vectorstore_built_flag.txt") | |
| if not rag_instance.load_vectorstore(): # Try to load existing | |
| print(f"Attempting to load and process {PDF_PATH}...") | |
| if os.path.exists(PDF_PATH): | |
| try: | |
| chunks = rag_instance.load_document(PDF_PATH) | |
| if chunks: | |
| rag_instance.add_document(chunks) | |
| rag_instance.save_vectorstore() | |
| with open(VECTORSTORE_BUILT_FLAG, "w") as f: | |
| f.write("Vectorstore built successfully.") | |
| print("Initial PDF processed and vectorstore saved.") | |
| else: | |
| print(f"Warning: No text extracted from {PDF_PATH}. Please check the PDF content.") | |
| except Exception as e: | |
| print(f"Fatal Error: Could not process {PDF_PATH} at startup: {e}") | |
| else: | |
| print(f"Error: {PDF_PATH} not found. Please ensure the PDF file is in the correct directory.") | |
| def respond( | |
| message: str, | |
| history: list[list[str]], # Gradio Chatbot history format | |
| # Removed system_message from inputs as it's no longer user-configurable | |
| max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency) | |
| temperature: float, # From additional_inputs (not directly used by RAG) | |
| top_p: float, # From additional_inputs (not directly used by RAG) | |
| ): | |
| # The instruction prompt is now handled internally by rag_instance | |
| # No need to access a system_message input here | |
| if not rag_instance.sentence_chunks: | |
| yield "Knowledge base is empty. Please ensure the PDF was loaded correctly at startup." | |
| return | |
| try: | |
| response = rag_instance.ask_question(message) | |
| yield response | |
| except Exception as e: | |
| yield f"❌ An error occurred: {e}" | |
| # Define the Gradio ChatInterface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Gemini RAG Chatbot for ML Theory") | |
| gr.Markdown(f"This chatbot is powered by {PDF_PATH}. Ensure your `GEMINI_API_KEY` is set as a Space Secret.") | |
| # No file upload section anymore | |
| chat_interface_component = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| # Removed the Textbox for system_message | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| info="Not directly used by RAG model." | |
| ), | |
| ], | |
| chatbot=gr.Chatbot(height=400), | |
| textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7), | |
| submit_btn="Send", | |
| # Update examples as the system_message input is no longer present | |
| examples=[ | |
| ["درمورد boosting بهم بگو", 512, 0.7, 0.95], | |
| ["انواع رگرسیون را توضیح بده", 512, 0.7, 0.95], | |
| ["شبکه های عصبی چیستند؟", 512, 0.7, 0.95] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |