import os import time import fitz # PyMuPDF import faiss import pickle import numpy as np from typing import List, Dict import re import google.generativeai as genai from google.api_core.exceptions import InternalServerError from sentence_transformers import SentenceTransformer # Import gradio for the web interface import gradio as gr # Define the ML_prompt (as it was in your notebook) # This prompt will now be hardcoded and not exposed to the user ML_prompt = """ نقش ات: تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره منبع درس کتاب بیشاپ هست لحن صحبت کردن ات: تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند """ class GeminiRAG: def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash", embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model instruction_prompt: str = ML_prompt, # Prompt is passed here vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence if not api_key: raise ValueError("API key is missing.") self.instruction_prompt = instruction_prompt self.vectorstore_dir = vectorstore_dir self.vectorstore_faiss_path = os.path.join(self.vectorstore_dir, "faiss_index.index") self.vectorstore_data_path = os.path.join(self.vectorstore_dir, "faiss_data.pkl") # Ensure vectorstore directory exists os.makedirs(self.vectorstore_dir, exist_ok=True) # Setup Gemini genai.configure(api_key=api_key) self.model = genai.GenerativeModel(model_name=model_name) # Setup Embedder self.embedder = SentenceTransformer(embed_model_name) # FAISS index and storage for sentence chunks and their parent documents embedding_dim = self.embedder.get_sentence_embedding_dimension() # Get embedding dimension self.index = faiss.IndexFlatL2(embedding_dim) self.sentence_chunks: List[str] = [] self.parent_documents: List[str] = [] self.sentence_to_parent_map: List[int] = [] # Load existing vector store if available self.load_vectorstore() def _split_into_sentences(self, text: str) -> List[str]: # Improved sentence splitting for better chunking sentences = re.split(r'(?<=[.!?])\s+', text) return [s.strip() for s in sentences if s.strip()] def load_document(self, pdf_path: str) -> List[str]: print(f"Loading document from: {pdf_path}") try: doc = fitz.open(pdf_path) page_contents = [] for page_num in range(len(doc)): page = doc.load_page(page_num) text = page.get_text() if text.strip(): page_contents.append(text.strip()) doc.close() print(f"Successfully extracted {len(page_contents)} pages from {pdf_path}") return page_contents except Exception as e: print(f"Error loading PDF {pdf_path}: {e}") raise # Re-raise the exception to be caught higher up def add_document(self, parent_chunks: List[str]): new_sentence_chunks = [] new_sentence_to_parent_map = [] current_parent_doc_index = len(self.parent_documents) for parent_chunk in parent_chunks: self.parent_documents.append(parent_chunk) sentences = self._split_into_sentences(parent_chunk) for sentence in sentences: new_sentence_chunks.append(sentence) new_sentence_to_parent_map.append(current_parent_doc_index) current_parent_doc_index += 1 if new_sentence_chunks: embeddings = self.embedder.encode(new_sentence_chunks, batch_size=32, convert_to_numpy=True) self.index.add(np.array(embeddings)) self.sentence_chunks.extend(new_sentence_chunks) self.sentence_to_parent_map.extend(new_sentence_to_parent_map) print(f"Added {len(new_sentence_chunks)} sentence chunks from {len(parent_chunks)} parent documents.") else: print("No new sentence chunks to add.") def ask_question(self, query: str, top_k: int = 5) -> str: if not self.sentence_chunks or not self.parent_documents: return "Knowledge base is empty. Please load documents first." query_emb = self.embedder.encode([query], convert_to_numpy=True) D, I = self.index.search(np.array(query_emb), top_k) retrieved_parent_doc_indices = set() for idx in I[0]: if idx < len(self.sentence_chunks): # Ensure index is within bounds parent_idx = self.sentence_to_parent_map[idx] retrieved_parent_doc_indices.add(parent_idx) context_parts = [] sorted_parent_indices = sorted(list(retrieved_parent_doc_indices)) for parent_idx in sorted_parent_indices: if parent_idx < len(self.parent_documents): # Ensure index is within bounds context_parts.append(self.parent_documents[parent_idx]) context = "\n\n---\\n\\n".join(context_parts) if not context.strip(): return "No relevant information found in the knowledge base." # The instruction prompt is now self.instruction_prompt which is set at init prompt = f""" ### instruction prompt : (explanation : this text is your guideline don't mention it on response) {self.instruction_prompt} Use the following context to answer the question.\n Context:\n {context}\n Question: {query}\n Answer:""" for attempt in range(3): try: response = self.model.generate_content(prompt) return response.text except InternalServerError as e: print(f"Error: {e}. Retrying in 5 seconds...") time.sleep(5) except Exception as e: # Catch other potential errors from API call print(f"An unexpected error occurred during API call: {e}. Retrying in 5 seconds...") time.sleep(5) raise Exception("Failed to generate after 3 retries due to persistent errors.") def save_vectorstore(self): try: faiss.write_index(self.index, self.vectorstore_faiss_path) with open(self.vectorstore_data_path, "wb") as f: pickle.dump({ 'sentence_chunks': self.sentence_chunks, 'parent_documents': self.parent_documents, 'sentence_to_parent_map': self.sentence_to_parent_map }, f) print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}") except Exception as e: print(f"Error saving vectorstore: {e}") def load_vectorstore(self): if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path): try: self.index = faiss.read_index(self.vectorstore_faiss_path) with open(self.vectorstore_data_path, "rb") as f: data = pickle.load(f) self.sentence_chunks = data['sentence_chunks'] self.parent_documents = data['parent_documents'] self.sentence_to_parent_map = data['sentence_to_parent_map'] print("📦 Loaded vectorstore.") return True except Exception as e: print(f"Error loading vectorstore: {e}") # If loading fails, it's better to start fresh self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension()) self.sentence_chunks = [] self.parent_documents = [] self.sentence_to_parent_map = [] print("⚠️ Failed to load vectorstore, initializing a new one.") return False print("ℹ️ No saved vectorstore found.") return False # --- Gradio Interface Setup --- # Get API key from environment variable api_key = os.getenv("google_api_key") if not api_key: print("Warning: GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.") # Initialize the RAG system globally for the Gradio app # The ML_prompt is passed during initialization and is then part of the rag_instance state rag_instance = GeminiRAG(api_key=api_key, instruction_prompt=ML_prompt) # Pass the prompt here # --- Load the predefined PDF at startup --- PDF_PATH = "MLT.pdf" # Assumes MLT.pdf is in the same directory as this script, or specify full path VECTORSTORE_BUILT_FLAG = os.path.join(rag_instance.vectorstore_dir, "vectorstore_built_flag.txt") if not rag_instance.load_vectorstore(): # Try to load existing print(f"Attempting to load and process {PDF_PATH}...") if os.path.exists(PDF_PATH): try: chunks = rag_instance.load_document(PDF_PATH) if chunks: rag_instance.add_document(chunks) rag_instance.save_vectorstore() with open(VECTORSTORE_BUILT_FLAG, "w") as f: f.write("Vectorstore built successfully.") print("Initial PDF processed and vectorstore saved.") else: print(f"Warning: No text extracted from {PDF_PATH}. Please check the PDF content.") except Exception as e: print(f"Fatal Error: Could not process {PDF_PATH} at startup: {e}") else: print(f"Error: {PDF_PATH} not found. Please ensure the PDF file is in the correct directory.") def respond( message: str, history: list[list[str]], # Gradio Chatbot history format # Removed system_message from inputs as it's no longer user-configurable max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency) temperature: float, # From additional_inputs (not directly used by RAG) top_p: float, # From additional_inputs (not directly used by RAG) ): # The instruction prompt is now handled internally by rag_instance # No need to access a system_message input here if not rag_instance.sentence_chunks: yield "Knowledge base is empty. Please ensure the PDF was loaded correctly at startup." return try: response = rag_instance.ask_question(message) yield response except Exception as e: yield f"❌ An error occurred: {e}" # Define the Gradio ChatInterface with gr.Blocks() as demo: gr.Markdown("# Gemini RAG Chatbot for ML Theory") gr.Markdown(f"This chatbot is powered by {PDF_PATH}. Ensure your `GEMINI_API_KEY` is set as a Space Secret.") # No file upload section anymore chat_interface_component = gr.ChatInterface( respond, additional_inputs=[ # Removed the Textbox for system_message gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", info="Not directly used by RAG model." ), ], chatbot=gr.Chatbot(height=400), textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7), submit_btn="Send", # Update examples as the system_message input is no longer present examples=[ ["درمورد boosting بهم بگو", 512, 0.7, 0.95], ["انواع رگرسیون را توضیح بده", 512, 0.7, 0.95], ["شبکه های عصبی چیستند؟", 512, 0.7, 0.95] ] ) if __name__ == "__main__": demo.launch()