ml_assistant / app.py
Mateo4's picture
Update app.py
a57a185 verified
import os
import time
import fitz # PyMuPDF
import faiss
import pickle
import numpy as np
from typing import List, Dict
import re
import google.generativeai as genai
from google.api_core.exceptions import InternalServerError
from sentence_transformers import SentenceTransformer
# Import gradio for the web interface
import gradio as gr
# Define the ML_prompt (as it was in your notebook)
# This prompt will now be hardcoded and not exposed to the user
ML_prompt = """
نقش ات:
تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی
این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره
منبع درس کتاب بیشاپ هست
لحن صحبت کردن ات:
تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند
"""
class GeminiRAG:
def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash",
embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model
instruction_prompt: str = ML_prompt, # Prompt is passed here
vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence
if not api_key:
raise ValueError("API key is missing.")
self.instruction_prompt = instruction_prompt
self.vectorstore_dir = vectorstore_dir
self.vectorstore_faiss_path = os.path.join(self.vectorstore_dir, "faiss_index.index")
self.vectorstore_data_path = os.path.join(self.vectorstore_dir, "faiss_data.pkl")
# Ensure vectorstore directory exists
os.makedirs(self.vectorstore_dir, exist_ok=True)
# Setup Gemini
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name=model_name)
# Setup Embedder
self.embedder = SentenceTransformer(embed_model_name)
# FAISS index and storage for sentence chunks and their parent documents
embedding_dim = self.embedder.get_sentence_embedding_dimension() # Get embedding dimension
self.index = faiss.IndexFlatL2(embedding_dim)
self.sentence_chunks: List[str] = []
self.parent_documents: List[str] = []
self.sentence_to_parent_map: List[int] = []
# Load existing vector store if available
self.load_vectorstore()
def _split_into_sentences(self, text: str) -> List[str]:
# Improved sentence splitting for better chunking
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
def load_document(self, pdf_path: str) -> List[str]:
print(f"Loading document from: {pdf_path}")
try:
doc = fitz.open(pdf_path)
page_contents = []
for page_num in range(len(doc)):
page = doc.load_page(page_num)
text = page.get_text()
if text.strip():
page_contents.append(text.strip())
doc.close()
print(f"Successfully extracted {len(page_contents)} pages from {pdf_path}")
return page_contents
except Exception as e:
print(f"Error loading PDF {pdf_path}: {e}")
raise # Re-raise the exception to be caught higher up
def add_document(self, parent_chunks: List[str]):
new_sentence_chunks = []
new_sentence_to_parent_map = []
current_parent_doc_index = len(self.parent_documents)
for parent_chunk in parent_chunks:
self.parent_documents.append(parent_chunk)
sentences = self._split_into_sentences(parent_chunk)
for sentence in sentences:
new_sentence_chunks.append(sentence)
new_sentence_to_parent_map.append(current_parent_doc_index)
current_parent_doc_index += 1
if new_sentence_chunks:
embeddings = self.embedder.encode(new_sentence_chunks, batch_size=32, convert_to_numpy=True)
self.index.add(np.array(embeddings))
self.sentence_chunks.extend(new_sentence_chunks)
self.sentence_to_parent_map.extend(new_sentence_to_parent_map)
print(f"Added {len(new_sentence_chunks)} sentence chunks from {len(parent_chunks)} parent documents.")
else:
print("No new sentence chunks to add.")
def ask_question(self, query: str, top_k: int = 5) -> str:
if not self.sentence_chunks or not self.parent_documents:
return "Knowledge base is empty. Please load documents first."
query_emb = self.embedder.encode([query], convert_to_numpy=True)
D, I = self.index.search(np.array(query_emb), top_k)
retrieved_parent_doc_indices = set()
for idx in I[0]:
if idx < len(self.sentence_chunks): # Ensure index is within bounds
parent_idx = self.sentence_to_parent_map[idx]
retrieved_parent_doc_indices.add(parent_idx)
context_parts = []
sorted_parent_indices = sorted(list(retrieved_parent_doc_indices))
for parent_idx in sorted_parent_indices:
if parent_idx < len(self.parent_documents): # Ensure index is within bounds
context_parts.append(self.parent_documents[parent_idx])
context = "\n\n---\\n\\n".join(context_parts)
if not context.strip():
return "No relevant information found in the knowledge base."
# The instruction prompt is now self.instruction_prompt which is set at init
prompt = f"""
### instruction prompt : (explanation : this text is your guideline don't mention it on response)
{self.instruction_prompt}
Use the following context to answer the question.\n
Context:\n
{context}\n
Question: {query}\n
Answer:"""
for attempt in range(3):
try:
response = self.model.generate_content(prompt)
return response.text
except InternalServerError as e:
print(f"Error: {e}. Retrying in 5 seconds...")
time.sleep(5)
except Exception as e: # Catch other potential errors from API call
print(f"An unexpected error occurred during API call: {e}. Retrying in 5 seconds...")
time.sleep(5)
raise Exception("Failed to generate after 3 retries due to persistent errors.")
def save_vectorstore(self):
try:
faiss.write_index(self.index, self.vectorstore_faiss_path)
with open(self.vectorstore_data_path, "wb") as f:
pickle.dump({
'sentence_chunks': self.sentence_chunks,
'parent_documents': self.parent_documents,
'sentence_to_parent_map': self.sentence_to_parent_map
}, f)
print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}")
except Exception as e:
print(f"Error saving vectorstore: {e}")
def load_vectorstore(self):
if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path):
try:
self.index = faiss.read_index(self.vectorstore_faiss_path)
with open(self.vectorstore_data_path, "rb") as f:
data = pickle.load(f)
self.sentence_chunks = data['sentence_chunks']
self.parent_documents = data['parent_documents']
self.sentence_to_parent_map = data['sentence_to_parent_map']
print("📦 Loaded vectorstore.")
return True
except Exception as e:
print(f"Error loading vectorstore: {e}")
# If loading fails, it's better to start fresh
self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
self.sentence_chunks = []
self.parent_documents = []
self.sentence_to_parent_map = []
print("⚠️ Failed to load vectorstore, initializing a new one.")
return False
print("ℹ️ No saved vectorstore found.")
return False
# --- Gradio Interface Setup ---
# Get API key from environment variable
api_key = os.getenv("google_api_key")
if not api_key:
print("Warning: GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.")
# Initialize the RAG system globally for the Gradio app
# The ML_prompt is passed during initialization and is then part of the rag_instance state
rag_instance = GeminiRAG(api_key=api_key, instruction_prompt=ML_prompt) # Pass the prompt here
# --- Load the predefined PDF at startup ---
PDF_PATH = "MLT.pdf" # Assumes MLT.pdf is in the same directory as this script, or specify full path
VECTORSTORE_BUILT_FLAG = os.path.join(rag_instance.vectorstore_dir, "vectorstore_built_flag.txt")
if not rag_instance.load_vectorstore(): # Try to load existing
print(f"Attempting to load and process {PDF_PATH}...")
if os.path.exists(PDF_PATH):
try:
chunks = rag_instance.load_document(PDF_PATH)
if chunks:
rag_instance.add_document(chunks)
rag_instance.save_vectorstore()
with open(VECTORSTORE_BUILT_FLAG, "w") as f:
f.write("Vectorstore built successfully.")
print("Initial PDF processed and vectorstore saved.")
else:
print(f"Warning: No text extracted from {PDF_PATH}. Please check the PDF content.")
except Exception as e:
print(f"Fatal Error: Could not process {PDF_PATH} at startup: {e}")
else:
print(f"Error: {PDF_PATH} not found. Please ensure the PDF file is in the correct directory.")
def respond(
message: str,
history: list[list[str]], # Gradio Chatbot history format
# Removed system_message from inputs as it's no longer user-configurable
max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency)
temperature: float, # From additional_inputs (not directly used by RAG)
top_p: float, # From additional_inputs (not directly used by RAG)
):
# The instruction prompt is now handled internally by rag_instance
# No need to access a system_message input here
if not rag_instance.sentence_chunks:
yield "Knowledge base is empty. Please ensure the PDF was loaded correctly at startup."
return
try:
response = rag_instance.ask_question(message)
yield response
except Exception as e:
yield f"❌ An error occurred: {e}"
# Define the Gradio ChatInterface
with gr.Blocks() as demo:
gr.Markdown("# Gemini RAG Chatbot for ML Theory")
gr.Markdown(f"This chatbot is powered by {PDF_PATH}. Ensure your `GEMINI_API_KEY` is set as a Space Secret.")
# No file upload section anymore
chat_interface_component = gr.ChatInterface(
respond,
additional_inputs=[
# Removed the Textbox for system_message
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
info="Not directly used by RAG model."
),
],
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7),
submit_btn="Send",
# Update examples as the system_message input is no longer present
examples=[
["درمورد boosting بهم بگو", 512, 0.7, 0.95],
["انواع رگرسیون را توضیح بده", 512, 0.7, 0.95],
["شبکه های عصبی چیستند؟", 512, 0.7, 0.95]
]
)
if __name__ == "__main__":
demo.launch()