File size: 12,319 Bytes
d02a876
 
a57a185
d02a876
 
 
 
 
 
 
 
 
 
 
 
31f9a1d
 
d02a876
a57a185
d02a876
 
 
 
 
 
 
31f9a1d
 
d02a876
 
 
a57a185
d02a876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57a185
d02a876
 
 
 
a57a185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d02a876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57a185
d02a876
 
 
 
 
 
 
a57a185
d02a876
 
 
 
 
 
 
a57a185
d02a876
a57a185
 
 
 
 
 
 
d02a876
 
 
 
 
 
 
 
a57a185
 
 
 
d02a876
 
a57a185
 
 
 
 
 
 
 
 
 
 
d02a876
 
 
a57a185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d02a876
 
 
 
 
 
121f275
d02a876
a57a185
 
d02a876
 
a57a185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31f9a1d
 
d02a876
 
a57a185
d02a876
 
 
31f9a1d
a57a185
 
 
 
 
 
31f9a1d
d02a876
 
a57a185
d02a876
 
31f9a1d
d02a876
 
 
a57a185
 
 
 
d02a876
 
 
a57a185
d02a876
 
 
 
 
 
 
 
 
 
 
 
 
 
a57a185
aaa61f6
a57a185
 
 
aaa61f6
d02a876
31f9a1d
 
 
d02a876
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import time
import fitz # PyMuPDF
import faiss
import pickle
import numpy as np
from typing import List, Dict
import re

import google.generativeai as genai
from google.api_core.exceptions import InternalServerError

from sentence_transformers import SentenceTransformer

# Import gradio for the web interface
import gradio as gr

# Define the ML_prompt (as it was in your notebook)
# This prompt will now be hardcoded and not exposed to the user
ML_prompt = """
نقش ات:
تو دستیار هوش مصنوعی من برای امتحان یادگیری ماشین هستی
این امتحان تمرکز روی مفاهیم تیوری یادگیری ماشین داره
منبع درس کتاب بیشاپ هست
لحن صحبت کردن ات:
تو استاد دانشگاه هستی و کسایی که باهات چت می کنن دانشجوهات اند
"""

class GeminiRAG:
    def __init__(self, api_key: str, model_name: str = "models/gemini-2.0-flash",
                 embed_model_name: str = "all-MiniLM-L6-v2", # Using a common SentenceTransformer model
                 instruction_prompt: str = ML_prompt, # Prompt is passed here
                 vectorstore_dir: str = "vectorstore"): # Use a directory within the app for persistence

        if not api_key:
            raise ValueError("API key is missing.")

        self.instruction_prompt = instruction_prompt
        self.vectorstore_dir = vectorstore_dir
        self.vectorstore_faiss_path = os.path.join(self.vectorstore_dir, "faiss_index.index")
        self.vectorstore_data_path = os.path.join(self.vectorstore_dir, "faiss_data.pkl")

        # Ensure vectorstore directory exists
        os.makedirs(self.vectorstore_dir, exist_ok=True)

        # Setup Gemini
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel(model_name=model_name)

        # Setup Embedder
        self.embedder = SentenceTransformer(embed_model_name)

        # FAISS index and storage for sentence chunks and their parent documents
        embedding_dim = self.embedder.get_sentence_embedding_dimension() # Get embedding dimension
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.sentence_chunks: List[str] = []
        self.parent_documents: List[str] = []
        self.sentence_to_parent_map: List[int] = []

        # Load existing vector store if available
        self.load_vectorstore()

    def _split_into_sentences(self, text: str) -> List[str]:
        # Improved sentence splitting for better chunking
        sentences = re.split(r'(?<=[.!?])\s+', text)
        return [s.strip() for s in sentences if s.strip()]

    def load_document(self, pdf_path: str) -> List[str]:
        print(f"Loading document from: {pdf_path}")
        try:
            doc = fitz.open(pdf_path)
            page_contents = []
            for page_num in range(len(doc)):
                page = doc.load_page(page_num)
                text = page.get_text()
                if text.strip():
                    page_contents.append(text.strip())
            doc.close()
            print(f"Successfully extracted {len(page_contents)} pages from {pdf_path}")
            return page_contents
        except Exception as e:
            print(f"Error loading PDF {pdf_path}: {e}")
            raise # Re-raise the exception to be caught higher up

    def add_document(self, parent_chunks: List[str]):
        new_sentence_chunks = []
        new_sentence_to_parent_map = []
        current_parent_doc_index = len(self.parent_documents)

        for parent_chunk in parent_chunks:
            self.parent_documents.append(parent_chunk)
            sentences = self._split_into_sentences(parent_chunk)
            for sentence in sentences:
                new_sentence_chunks.append(sentence)
                new_sentence_to_parent_map.append(current_parent_doc_index)
            current_parent_doc_index += 1

        if new_sentence_chunks:
            embeddings = self.embedder.encode(new_sentence_chunks, batch_size=32, convert_to_numpy=True)
            self.index.add(np.array(embeddings))
            self.sentence_chunks.extend(new_sentence_chunks)
            self.sentence_to_parent_map.extend(new_sentence_to_parent_map)
            print(f"Added {len(new_sentence_chunks)} sentence chunks from {len(parent_chunks)} parent documents.")
        else:
            print("No new sentence chunks to add.")

    def ask_question(self, query: str, top_k: int = 5) -> str:
        if not self.sentence_chunks or not self.parent_documents:
            return "Knowledge base is empty. Please load documents first."

        query_emb = self.embedder.encode([query], convert_to_numpy=True)
        D, I = self.index.search(np.array(query_emb), top_k)

        retrieved_parent_doc_indices = set()
        for idx in I[0]:
            if idx < len(self.sentence_chunks): # Ensure index is within bounds
                parent_idx = self.sentence_to_parent_map[idx]
                retrieved_parent_doc_indices.add(parent_idx)

        context_parts = []
        sorted_parent_indices = sorted(list(retrieved_parent_doc_indices))

        for parent_idx in sorted_parent_indices:
            if parent_idx < len(self.parent_documents): # Ensure index is within bounds
                context_parts.append(self.parent_documents[parent_idx])

        context = "\n\n---\\n\\n".join(context_parts)

        if not context.strip():
            return "No relevant information found in the knowledge base."

        # The instruction prompt is now self.instruction_prompt which is set at init
        prompt = f"""
             ### instruction prompt : (explanation : this text is your guideline don't mention it on response)
             {self.instruction_prompt}
             Use the following context to answer the question.\n
             Context:\n
             {context}\n
             Question: {query}\n
             Answer:"""

        for attempt in range(3):
            try:
                response = self.model.generate_content(prompt)
                return response.text
            except InternalServerError as e:
                print(f"Error: {e}. Retrying in 5 seconds...")
                time.sleep(5)
            except Exception as e: # Catch other potential errors from API call
                print(f"An unexpected error occurred during API call: {e}. Retrying in 5 seconds...")
                time.sleep(5)
        raise Exception("Failed to generate after 3 retries due to persistent errors.")

    def save_vectorstore(self):
        try:
            faiss.write_index(self.index, self.vectorstore_faiss_path)
            with open(self.vectorstore_data_path, "wb") as f:
                pickle.dump({
                    'sentence_chunks': self.sentence_chunks,
                    'parent_documents': self.parent_documents,
                    'sentence_to_parent_map': self.sentence_to_parent_map
                }, f)
            print(f"Vectorstore saved to {self.vectorstore_faiss_path} and {self.vectorstore_data_path}")
        except Exception as e:
            print(f"Error saving vectorstore: {e}")

    def load_vectorstore(self):
        if os.path.exists(self.vectorstore_faiss_path) and os.path.exists(self.vectorstore_data_path):
            try:
                self.index = faiss.read_index(self.vectorstore_faiss_path)
                with open(self.vectorstore_data_path, "rb") as f:
                    data = pickle.load(f)
                    self.sentence_chunks = data['sentence_chunks']
                    self.parent_documents = data['parent_documents']
                    self.sentence_to_parent_map = data['sentence_to_parent_map']
                print("📦 Loaded vectorstore.")
                return True
            except Exception as e:
                print(f"Error loading vectorstore: {e}")
                # If loading fails, it's better to start fresh
                self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
                self.sentence_chunks = []
                self.parent_documents = []
                self.sentence_to_parent_map = []
                print("⚠️ Failed to load vectorstore, initializing a new one.")
                return False
        print("ℹ️ No saved vectorstore found.")
        return False

# --- Gradio Interface Setup ---

# Get API key from environment variable
api_key = os.getenv("google_api_key")
if not api_key:
    print("Warning: GEMINI_API_KEY environment variable not set. Please set it in Hugging Face Space secrets.")


# Initialize the RAG system globally for the Gradio app
# The ML_prompt is passed during initialization and is then part of the rag_instance state
rag_instance = GeminiRAG(api_key=api_key, instruction_prompt=ML_prompt) # Pass the prompt here

# --- Load the predefined PDF at startup ---
PDF_PATH = "MLT.pdf" # Assumes MLT.pdf is in the same directory as this script, or specify full path
VECTORSTORE_BUILT_FLAG = os.path.join(rag_instance.vectorstore_dir, "vectorstore_built_flag.txt")


if not rag_instance.load_vectorstore(): # Try to load existing
    print(f"Attempting to load and process {PDF_PATH}...")
    if os.path.exists(PDF_PATH):
        try:
            chunks = rag_instance.load_document(PDF_PATH)
            if chunks:
                rag_instance.add_document(chunks)
                rag_instance.save_vectorstore()
                with open(VECTORSTORE_BUILT_FLAG, "w") as f:
                    f.write("Vectorstore built successfully.")
                print("Initial PDF processed and vectorstore saved.")
            else:
                print(f"Warning: No text extracted from {PDF_PATH}. Please check the PDF content.")
        except Exception as e:
            print(f"Fatal Error: Could not process {PDF_PATH} at startup: {e}")
    else:
        print(f"Error: {PDF_PATH} not found. Please ensure the PDF file is in the correct directory.")


def respond(
    message: str,
    history: list[list[str]], # Gradio Chatbot history format
    # Removed system_message from inputs as it's no longer user-configurable
    max_tokens: int, # From additional_inputs (not directly used by RAG but kept for interface consistency)
    temperature: float, # From additional_inputs (not directly used by RAG)
    top_p: float, # From additional_inputs (not directly used by RAG)
):
    # The instruction prompt is now handled internally by rag_instance
    # No need to access a system_message input here

    if not rag_instance.sentence_chunks:
        yield "Knowledge base is empty. Please ensure the PDF was loaded correctly at startup."
        return

    try:
        response = rag_instance.ask_question(message)
        yield response
    except Exception as e:
        yield f"❌ An error occurred: {e}"

# Define the Gradio ChatInterface
with gr.Blocks() as demo:
    gr.Markdown("# Gemini RAG Chatbot for ML Theory")
    gr.Markdown(f"This chatbot is powered by {PDF_PATH}. Ensure your `GEMINI_API_KEY` is set as a Space Secret.")

    # No file upload section anymore

    chat_interface_component = gr.ChatInterface(
        respond,
        additional_inputs=[
            # Removed the Textbox for system_message
            gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", info="Not directly used by RAG model."),
            gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature", info="Not directly used by RAG model."),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p (nucleus sampling)",
                info="Not directly used by RAG model."
            ),
        ],
        chatbot=gr.Chatbot(height=400),
        textbox=gr.Textbox(placeholder="Ask me about Machine Learning Theory!", container=False, scale=7),
        submit_btn="Send",
        # Update examples as the system_message input is no longer present
        examples=[
            ["درمورد boosting بهم بگو", 512, 0.7, 0.95],
            ["انواع رگرسیون را توضیح بده", 512, 0.7, 0.95],
            ["شبکه های عصبی چیستند؟", 512, 0.7, 0.95]
        ]
    )


if __name__ == "__main__":
    demo.launch()