Upload 6 files
Browse files- app.py +118 -0
- chat_history.txt +0 -0
- embedding__model.txt +15 -0
- gemini_wrapper.py +135 -0
- rag.py +322 -0
- requirements.txt +10 -0
app.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from rag import RAGSystem
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from gtts import gTTS
|
| 6 |
+
from docx2pdf import convert
|
| 7 |
+
|
| 8 |
+
# Load environment variables from .env file
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
# Initialize the RAG system with default settings
|
| 12 |
+
pdf_dir = "material"
|
| 13 |
+
db_dir = "chroma_db"
|
| 14 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 15 |
+
rag_system = RAGSystem(pdf_dir=pdf_dir, gemini_api_key=gemini_api_key, db_directory=db_dir)
|
| 16 |
+
|
| 17 |
+
# Function to handle file upload and process the uploaded file
|
| 18 |
+
def upload_and_process(file):
|
| 19 |
+
if file is not None:
|
| 20 |
+
uploaded_file_path = file.name
|
| 21 |
+
ext = os.path.splitext(uploaded_file_path)[1].lower()
|
| 22 |
+
if ext == ".docx":
|
| 23 |
+
# Convert DOCX to PDF
|
| 24 |
+
pdf_path = uploaded_file_path.replace(".docx", ".pdf")
|
| 25 |
+
convert(uploaded_file_path, pdf_path)
|
| 26 |
+
rag_system.pdf_dir = os.path.dirname(pdf_path)
|
| 27 |
+
else:
|
| 28 |
+
rag_system.pdf_dir = os.path.dirname(uploaded_file_path)
|
| 29 |
+
rag_system.process_documents()
|
| 30 |
+
return "File uploaded and processed successfully."
|
| 31 |
+
return "No file uploaded."
|
| 32 |
+
|
| 33 |
+
# Updated function to handle user queries
|
| 34 |
+
def ask_query(query):
|
| 35 |
+
if query.strip():
|
| 36 |
+
response = rag_system.generate_response(query)
|
| 37 |
+
# Append the query and response to the conversation history
|
| 38 |
+
rag_system.conversation_history.append({"user": query, "system": response})
|
| 39 |
+
audio_path = text_to_speech(response)
|
| 40 |
+
return response, audio_path
|
| 41 |
+
return "Please enter a valid query.", None
|
| 42 |
+
|
| 43 |
+
# Function to convert text to speech and return audio file path
|
| 44 |
+
def text_to_speech(response):
|
| 45 |
+
tts = gTTS(response)
|
| 46 |
+
audio_path = "response_audio.mp3"
|
| 47 |
+
tts.save(audio_path)
|
| 48 |
+
return audio_path
|
| 49 |
+
|
| 50 |
+
# Function to clear the chat history
|
| 51 |
+
def clear_chat():
|
| 52 |
+
rag_system.conversation_history = []
|
| 53 |
+
return "Chat history cleared."
|
| 54 |
+
|
| 55 |
+
# Updated function to download the chat history
|
| 56 |
+
def download_chat():
|
| 57 |
+
chat_history = "\n".join([f"User: {entry['user']}\nSystem: {entry['system']}" for entry in rag_system.conversation_history])
|
| 58 |
+
file_path = "chat_history.txt"
|
| 59 |
+
with open(file_path, "w") as file:
|
| 60 |
+
file.write(chat_history)
|
| 61 |
+
return file_path
|
| 62 |
+
|
| 63 |
+
# --- Custom CSS for modern look ---
|
| 64 |
+
custom_css = '''
|
| 65 |
+
body { font-family: 'Roboto', 'Open Sans', Arial, sans-serif; }
|
| 66 |
+
.gradio-container { background: linear-gradient(135deg, #f8fafc 0%, #e0f7fa 100%); }
|
| 67 |
+
#rag-title { font-size: 2.2rem; font-weight: 700; color: #1e293b; letter-spacing: 1px; display: flex; align-items: center; gap: 0.5em; }
|
| 68 |
+
#rag-title img { height: 2.2rem; vertical-align: middle; }
|
| 69 |
+
.gr-box { border-radius: 12px !important; box-shadow: 0 2px 12px 0 rgba(16, 42, 67, 0.06); border: 1px solid #e0e7ef; }
|
| 70 |
+
.gr-button { background: #14b8a6; color: #fff; border-radius: 8px; font-weight: 600; font-size: 1rem; padding: 0.7em 1.5em; transition: background 0.2s, transform 0.2s; }
|
| 71 |
+
.gr-button:hover, .gr-button:focus { background: #0d9488; transform: scale(1.04); }
|
| 72 |
+
.gr-text-input, .gr-textbox { border-radius: 8px; border: 1.5px solid #cbd5e1; background: #fff; font-size: 1.05rem; }
|
| 73 |
+
.gr-text-input:focus, .gr-textbox:focus { border-color: #14b8a6; box-shadow: 0 0 0 2px #99f6e4; }
|
| 74 |
+
.gr-audio { border-radius: 8px; background: #f1f5f9; }
|
| 75 |
+
.gr-file { border-radius: 8px; border: 1.5px dashed #14b8a6; background: #f0fdfa; }
|
| 76 |
+
.gr-markdown { color: #334155; }
|
| 77 |
+
.fade-in { animation: fadeIn 0.7s; }
|
| 78 |
+
@keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } }
|
| 79 |
+
@media (max-width: 700px) {
|
| 80 |
+
#rag-title { font-size: 1.3rem; }
|
| 81 |
+
.gradio-container { padding: 0.5em; }
|
| 82 |
+
}
|
| 83 |
+
'''
|
| 84 |
+
|
| 85 |
+
# --- Gradio UI with modern design ---
|
| 86 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as ui:
|
| 87 |
+
# Branding row with logo and title
|
| 88 |
+
with gr.Row():
|
| 89 |
+
gr.Markdown("""
|
| 90 |
+
<div id='rag-title'>
|
| 91 |
+
<img src='https://img.icons8.com/color/48/000000/artificial-intelligence.png' alt='RAG Logo' />
|
| 92 |
+
RAG System UI
|
| 93 |
+
</div>
|
| 94 |
+
""", elem_id="rag-title")
|
| 95 |
+
|
| 96 |
+
with gr.Row(equal_height=True):
|
| 97 |
+
with gr.Column(scale=1, min_width=320):
|
| 98 |
+
file_input = gr.File(label="Upload PDF or DOCX", file_types=[".pdf", ".docx"], elem_classes=["gr-file"])
|
| 99 |
+
upload_button = gr.Button("Upload & Process File", elem_classes=["gr-button"])
|
| 100 |
+
with gr.Column(scale=2, min_width=400):
|
| 101 |
+
query_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...", lines=1, elem_classes=["gr-text-input"])
|
| 102 |
+
response_output = gr.Textbox(label="RAG Response", lines=6, interactive=False, elem_classes=["gr-textbox", "fade-in"])
|
| 103 |
+
play_button = gr.Audio(label="Play Response", interactive=False, elem_classes=["gr-audio"])
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
clear_button = gr.Button("Clear Chat History", elem_classes=["gr-button"])
|
| 107 |
+
download_button = gr.Button("Download Chat History", elem_classes=["gr-button"])
|
| 108 |
+
download_file = gr.File(label="Download Chat File", elem_classes=["gr-file"])
|
| 109 |
+
|
| 110 |
+
# Bind functions to UI components
|
| 111 |
+
upload_button.click(upload_and_process, inputs=file_input, outputs=None)
|
| 112 |
+
query_input.submit(ask_query, inputs=query_input, outputs=[response_output, play_button])
|
| 113 |
+
clear_button.click(clear_chat, inputs=None, outputs=None)
|
| 114 |
+
download_button.click(download_chat, inputs=None, outputs=download_file)
|
| 115 |
+
|
| 116 |
+
# Launch the Gradio app
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
ui.launch()
|
chat_history.txt
ADDED
|
File without changes
|
embedding__model.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
|
| 3 |
+
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 4 |
+
|
| 5 |
+
sentences = [
|
| 6 |
+
"That is a happy person",
|
| 7 |
+
"That is a happy dog",
|
| 8 |
+
"That is a very happy person",
|
| 9 |
+
"Today is a sunny day"
|
| 10 |
+
]
|
| 11 |
+
embeddings = model.encode(sentences)
|
| 12 |
+
|
| 13 |
+
similarities = model.similarity(embeddings, embeddings)
|
| 14 |
+
print(similarities.shape)
|
| 15 |
+
# [4, 4]
|
gemini_wrapper.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import google.generativeai as genai
|
| 2 |
+
|
| 3 |
+
class GoogleGeminiWrapper:
|
| 4 |
+
def __init__(self, api_key: str):
|
| 5 |
+
"""
|
| 6 |
+
Initialize the GoogleGeminiWrapper with the API key.
|
| 7 |
+
|
| 8 |
+
:param api_key: Your Google Gemini API key.
|
| 9 |
+
"""
|
| 10 |
+
self.api_key = api_key
|
| 11 |
+
genai.configure(api_key=self.api_key)
|
| 12 |
+
self.conversation_history = [] # For the new chat method
|
| 13 |
+
self.chat_session = None # To store the chat session for gemini
|
| 14 |
+
|
| 15 |
+
def ask(self, prompt: str, model: str = "gemini-2.0-flash", max_tokens: int = 150, temperature: float = 0.7) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Send a prompt to the Google Gemini model and get a response (single turn).
|
| 18 |
+
|
| 19 |
+
:param prompt: The input prompt to send to the model.
|
| 20 |
+
:param model: The model to use (default is "gemini-pro").
|
| 21 |
+
:param max_tokens: The maximum number of tokens to include in the response.
|
| 22 |
+
:param temperature: Sampling temperature (higher values mean more randomness).
|
| 23 |
+
:return: The response from the model as a string.
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
generation_config = {
|
| 27 |
+
"temperature": temperature,
|
| 28 |
+
"max_output_tokens": max_tokens,
|
| 29 |
+
}
|
| 30 |
+
model_instance = genai.GenerativeModel(model_name=model, generation_config=generation_config)
|
| 31 |
+
response = model_instance.generate_content(prompt)
|
| 32 |
+
return response.text.strip()
|
| 33 |
+
except Exception as e:
|
| 34 |
+
return f"An error occurred: {e}"
|
| 35 |
+
|
| 36 |
+
def start_chat_session(self, model: str = "gemini-2.0-flash", temperature: float = 0.7, max_tokens: int = 150):
|
| 37 |
+
"""
|
| 38 |
+
Starts a new chat session or continues an existing one.
|
| 39 |
+
"""
|
| 40 |
+
generation_config = {
|
| 41 |
+
"temperature": temperature,
|
| 42 |
+
"max_output_tokens": max_tokens,
|
| 43 |
+
}
|
| 44 |
+
model_instance = genai.GenerativeModel(model_name=model, generation_config=generation_config)
|
| 45 |
+
# For Gemini, conversation history is managed by the chat object itself.
|
| 46 |
+
# We re-initialize the chat session if one doesn't exist or if we want to start fresh.
|
| 47 |
+
# If you want to persist history across calls to `chat` without explicitly calling reset,
|
| 48 |
+
# you might initialize `self.chat_session` in `__init__` or when `chat` is first called.
|
| 49 |
+
self.chat_session = model_instance.start_chat(history=self.conversation_history)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def chat(self, prompt: str, model: str = "gemini-2.0-flash", max_tokens: int = 150, temperature: float = 0.7) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Send a prompt to the Google Gemini model, maintaining conversation history for context.
|
| 55 |
+
|
| 56 |
+
:param prompt: The input prompt to send to the model.
|
| 57 |
+
:param model: The model to use (default is "gemini-pro").
|
| 58 |
+
:param max_tokens: The maximum number of tokens to include in the response.
|
| 59 |
+
:param temperature: Sampling temperature (higher values mean more randomness).
|
| 60 |
+
:return: The response from the model as a string.
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
if self.chat_session is None:
|
| 64 |
+
self.start_chat_session(model=model, temperature=temperature, max_tokens=max_tokens)
|
| 65 |
+
|
| 66 |
+
response = self.chat_session.send_message(prompt)
|
| 67 |
+
assistant_response = response.text.strip()
|
| 68 |
+
|
| 69 |
+
# Gemini's chat session object updates its history internally.
|
| 70 |
+
# We can optionally also store it in our self.conversation_history if needed for other purposes
|
| 71 |
+
# or if we want to be able to reconstruct the chat session later.
|
| 72 |
+
# For simplicity here, we rely on the chat_session's internal history.
|
| 73 |
+
# To manually track:
|
| 74 |
+
# self.conversation_history.append({"role": "user", "parts": [prompt]})
|
| 75 |
+
# self.conversation_history.append({"role": "model", "parts": [assistant_response]})
|
| 76 |
+
|
| 77 |
+
return assistant_response
|
| 78 |
+
except Exception as e:
|
| 79 |
+
# Reset chat session on error to avoid issues with subsequent calls
|
| 80 |
+
self.chat_session = None
|
| 81 |
+
return f"An error occurred: {e}"
|
| 82 |
+
|
| 83 |
+
def reset_conversation(self):
|
| 84 |
+
"""
|
| 85 |
+
Reset the conversation history and the chat session.
|
| 86 |
+
"""
|
| 87 |
+
self.conversation_history = []
|
| 88 |
+
self.chat_session = None # Crucial for Gemini to start a fresh chat
|
| 89 |
+
|
| 90 |
+
def list_available_models(self):
|
| 91 |
+
"""
|
| 92 |
+
Lists available Gemini models.
|
| 93 |
+
:return: A list of available models.
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
print("Available Gemini Models:")
|
| 97 |
+
for m in genai.list_models():
|
| 98 |
+
if 'generateContent' in m.supported_generation_methods:
|
| 99 |
+
print(m.name)
|
| 100 |
+
return [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return f"An error occurred while listing models: {e}"
|
| 103 |
+
|
| 104 |
+
# Example usage (uncomment to test):
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
api_key = "AIzaSyBisxoehBz8UF0i9kX42f1V3jp-9RNq04g" # Replace with your actual key
|
| 107 |
+
wrapper = GoogleGeminiWrapper(api_key)
|
| 108 |
+
|
| 109 |
+
# Example 0: List available models
|
| 110 |
+
# print("\nListing available models...")
|
| 111 |
+
# available_models = wrapper.list_available_models()
|
| 112 |
+
# The function already prints, but you can use the returned list if needed
|
| 113 |
+
# print(available_models)
|
| 114 |
+
#
|
| 115 |
+
# Example 1: Simple one-off question
|
| 116 |
+
response_ask = wrapper.ask("What is the largest planet in our solar system?")
|
| 117 |
+
print(f"Ask response: {response_ask}")
|
| 118 |
+
#
|
| 119 |
+
# # Example 2: Conversation with history
|
| 120 |
+
# print("\nStarting chat conversation...")
|
| 121 |
+
# response1 = wrapper.chat("Hi, my name is Alex.")
|
| 122 |
+
# print(f"Chat response 1: {response1}")
|
| 123 |
+
#
|
| 124 |
+
# response2 = wrapper.chat("What is my name?")
|
| 125 |
+
# print(f"Chat response 2: {response2}") # Should remember "Alex"
|
| 126 |
+
#
|
| 127 |
+
# response3 = wrapper.chat("What was the first thing I asked you in this chat?")
|
| 128 |
+
# print(f"Chat response 3: {response3}")
|
| 129 |
+
#
|
| 130 |
+
# # Reset conversation history
|
| 131 |
+
# wrapper.reset_conversation()
|
| 132 |
+
# print("\nConversation reset.")
|
| 133 |
+
#
|
| 134 |
+
# response4 = wrapper.chat("Do you remember my name?")
|
| 135 |
+
# print(f"Chat response 4 (after reset): {response4}") # Should not remember "Alex"
|
rag.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
RAG (Retrieval Augmented Generation) System
|
| 4 |
+
-------------------------------------------
|
| 5 |
+
This module implements a RAG system that processes PDF documents,
|
| 6 |
+
uses ChromaDB as a vector database, sentence-transformers for embeddings,
|
| 7 |
+
and Google's Gemini as the main LLM. The system follows a conversational pattern.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Dict, Any, Optional
|
| 13 |
+
|
| 14 |
+
# Document processing
|
| 15 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 16 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 17 |
+
|
| 18 |
+
# Embeddings
|
| 19 |
+
from sentence_transformers import SentenceTransformer
|
| 20 |
+
|
| 21 |
+
# Vector database
|
| 22 |
+
import chromadb
|
| 23 |
+
from chromadb.utils import embedding_functions
|
| 24 |
+
|
| 25 |
+
# For Gemini LLM integration
|
| 26 |
+
from gemini_wrapper import GoogleGeminiWrapper
|
| 27 |
+
|
| 28 |
+
from gtts import gTTS
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
class RAGSystem:
|
| 35 |
+
"""
|
| 36 |
+
A Retrieval Augmented Generation system that processes PDF documents,
|
| 37 |
+
stores their embeddings in a vector database, and generates responses
|
| 38 |
+
using the Google Gemini model.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
pdf_dir: str,
|
| 44 |
+
gemini_api_key: str,
|
| 45 |
+
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 46 |
+
chunk_size: int = 1000,
|
| 47 |
+
chunk_overlap: int = 200,
|
| 48 |
+
db_directory: str = "./chroma_db"
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Initialize the RAG system.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
pdf_dir: Directory containing PDF documents
|
| 55 |
+
gemini_api_key: API key for Google Gemini
|
| 56 |
+
embedding_model_name: Name of the sentence-transformers model
|
| 57 |
+
chunk_size: Size of text chunks for splitting documents
|
| 58 |
+
chunk_overlap: Overlap between consecutive chunks
|
| 59 |
+
db_directory: Directory to store the ChromaDB database
|
| 60 |
+
"""
|
| 61 |
+
self.pdf_dir = pdf_dir
|
| 62 |
+
self.chunk_size = chunk_size
|
| 63 |
+
self.chunk_overlap = chunk_overlap
|
| 64 |
+
self.db_directory = db_directory
|
| 65 |
+
|
| 66 |
+
# Initialize the embedding model
|
| 67 |
+
logger.info(f"Loading embedding model: {embedding_model_name}")
|
| 68 |
+
self.embedding_model = SentenceTransformer(embedding_model_name)
|
| 69 |
+
|
| 70 |
+
# Initialize the text splitter
|
| 71 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 72 |
+
chunk_size=self.chunk_size,
|
| 73 |
+
chunk_overlap=self.chunk_overlap,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Initialize ChromaDB
|
| 77 |
+
logger.info(f"Initializing ChromaDB at {db_directory}")
|
| 78 |
+
self.client = chromadb.PersistentClient(path=db_directory)
|
| 79 |
+
|
| 80 |
+
# Create a custom embedding function that uses sentence-transformers
|
| 81 |
+
self.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 82 |
+
model_name=embedding_model_name
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Create or get the collection
|
| 86 |
+
self.collection = self.client.get_or_create_collection(
|
| 87 |
+
name="pdf_documents",
|
| 88 |
+
embedding_function=self.sentence_transformer_ef
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Initialize the Gemini LLM
|
| 92 |
+
logger.info("Initializing Google Gemini")
|
| 93 |
+
self.llm = GoogleGeminiWrapper(api_key=gemini_api_key)
|
| 94 |
+
|
| 95 |
+
# Load conversation history
|
| 96 |
+
self.conversation_history = []
|
| 97 |
+
|
| 98 |
+
def process_documents(self) -> None:
|
| 99 |
+
"""
|
| 100 |
+
Process all PDF documents in the specified directory,
|
| 101 |
+
split them into chunks, generate embeddings, and store in ChromaDB.
|
| 102 |
+
"""
|
| 103 |
+
logger.info(f"Processing documents from: {self.pdf_dir}")
|
| 104 |
+
|
| 105 |
+
# Check if documents are already processed
|
| 106 |
+
if self.collection.count() > 0:
|
| 107 |
+
logger.info(f"Found {self.collection.count()} existing document chunks in the database")
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
# Process each PDF file in the directory
|
| 111 |
+
pdf_files = [f for f in os.listdir(self.pdf_dir) if f.endswith('.pdf')]
|
| 112 |
+
if not pdf_files:
|
| 113 |
+
logger.warning(f"No PDF files found in {self.pdf_dir}")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
logger.info(f"Found {len(pdf_files)} PDF files")
|
| 117 |
+
|
| 118 |
+
doc_chunks = []
|
| 119 |
+
metadatas = []
|
| 120 |
+
ids = []
|
| 121 |
+
chunk_idx = 0
|
| 122 |
+
|
| 123 |
+
for pdf_file in pdf_files:
|
| 124 |
+
pdf_path = os.path.join(self.pdf_dir, pdf_file)
|
| 125 |
+
logger.info(f"Processing: {pdf_path}")
|
| 126 |
+
|
| 127 |
+
# Load PDF
|
| 128 |
+
loader = PyPDFLoader(pdf_path)
|
| 129 |
+
documents = loader.load()
|
| 130 |
+
|
| 131 |
+
# Split documents into chunks
|
| 132 |
+
chunks = self.text_splitter.split_documents(documents)
|
| 133 |
+
logger.info(f"Split {pdf_file} into {len(chunks)} chunks")
|
| 134 |
+
|
| 135 |
+
# Prepare data for ChromaDB
|
| 136 |
+
for chunk in chunks:
|
| 137 |
+
doc_chunks.append(chunk.page_content)
|
| 138 |
+
metadatas.append({
|
| 139 |
+
"source": pdf_file,
|
| 140 |
+
"page": chunk.metadata.get("page", 0),
|
| 141 |
+
})
|
| 142 |
+
ids.append(f"chunk_{chunk_idx}")
|
| 143 |
+
chunk_idx += 1
|
| 144 |
+
|
| 145 |
+
# Add documents to ChromaDB
|
| 146 |
+
if doc_chunks:
|
| 147 |
+
logger.info(f"Adding {len(doc_chunks)} chunks to ChromaDB")
|
| 148 |
+
self.collection.add(
|
| 149 |
+
documents=doc_chunks,
|
| 150 |
+
metadatas=metadatas,
|
| 151 |
+
ids=ids
|
| 152 |
+
)
|
| 153 |
+
logger.info("Documents successfully processed and stored")
|
| 154 |
+
else:
|
| 155 |
+
logger.warning("No document chunks were generated")
|
| 156 |
+
|
| 157 |
+
def retrieve_relevant_chunks(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
|
| 158 |
+
"""
|
| 159 |
+
Retrieve the k most relevant document chunks for a given query.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
query: The query text
|
| 163 |
+
k: Number of relevant chunks to retrieve
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
List of relevant document chunks with their metadata
|
| 167 |
+
"""
|
| 168 |
+
logger.info(f"Retrieving {k} relevant chunks for query: {query}")
|
| 169 |
+
results = self.collection.query(
|
| 170 |
+
query_texts=[query],
|
| 171 |
+
n_results=k
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
relevant_chunks = []
|
| 175 |
+
if results and results["documents"] and results["documents"][0]:
|
| 176 |
+
for i, doc in enumerate(results["documents"][0]):
|
| 177 |
+
relevant_chunks.append({
|
| 178 |
+
"content": doc,
|
| 179 |
+
"metadata": results["metadatas"][0][i] if results["metadatas"] and results["metadatas"][0] else {},
|
| 180 |
+
"id": results["ids"][0][i] if results["ids"] and results["ids"][0] else f"unknown_{i}"
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
return relevant_chunks
|
| 184 |
+
|
| 185 |
+
def generate_response(self, query: str, k: int = 3) -> str:
|
| 186 |
+
"""
|
| 187 |
+
Generate a response for a user query using RAG.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
query: User query
|
| 191 |
+
k: Number of relevant chunks to retrieve
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Generated response from the LLM
|
| 195 |
+
"""
|
| 196 |
+
# Retrieve relevant document chunks
|
| 197 |
+
relevant_chunks = self.retrieve_relevant_chunks(query, k=k)
|
| 198 |
+
|
| 199 |
+
if not relevant_chunks:
|
| 200 |
+
logger.warning("No relevant chunks found for the query")
|
| 201 |
+
return "I couldn't find relevant information to answer your question."
|
| 202 |
+
|
| 203 |
+
# Format context from retrieved chunks
|
| 204 |
+
context = "\n\n".join([f"Document {i+1} (from {chunk['metadata'].get('source', 'unknown')}, page {chunk['metadata'].get('page', 'unknown')}):\n{chunk['content']}"
|
| 205 |
+
for i, chunk in enumerate(relevant_chunks)])
|
| 206 |
+
|
| 207 |
+
# Create prompt for the LLM
|
| 208 |
+
prompt = f"""
|
| 209 |
+
You are a helpful assistant that answers questions based on the provided context.
|
| 210 |
+
|
| 211 |
+
CONTEXT:
|
| 212 |
+
{context}
|
| 213 |
+
|
| 214 |
+
QUESTION:
|
| 215 |
+
{query}
|
| 216 |
+
|
| 217 |
+
Please provide a comprehensive and accurate answer based only on the information in the provided context.
|
| 218 |
+
If the context doesn't contain enough information to answer the question, please say so.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
# Generate response using Gemini
|
| 222 |
+
response = self.llm.ask(prompt, max_tokens=500, temperature=0.3)
|
| 223 |
+
return response
|
| 224 |
+
|
| 225 |
+
def chat(self, user_input: str = None) -> Optional[str]:
|
| 226 |
+
"""
|
| 227 |
+
Conduct a conversation with the user using the RAG system.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
user_input: User's input. If None, starts a new conversation.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
System's response or None to exit
|
| 234 |
+
"""
|
| 235 |
+
if user_input is None:
|
| 236 |
+
# Initialize conversation
|
| 237 |
+
print("RAG System Initialized. Type 'exit' or 'quit' to end the conversation.")
|
| 238 |
+
user_input = input("You: ")
|
| 239 |
+
|
| 240 |
+
if user_input.lower() in ['exit', 'quit']:
|
| 241 |
+
print("Ending conversation. Goodbye!")
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
# Generate response using RAG
|
| 245 |
+
response = self.generate_response(user_input)
|
| 246 |
+
|
| 247 |
+
# Update conversation history
|
| 248 |
+
self.conversation_history.append({"user": user_input, "system": response})
|
| 249 |
+
|
| 250 |
+
return response
|
| 251 |
+
|
| 252 |
+
def interactive_session(self) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Start an interactive chat session with the RAG system.
|
| 255 |
+
"""
|
| 256 |
+
print("Welcome to the RAG System!")
|
| 257 |
+
print("Type 'exit' or 'quit' to end the conversation.")
|
| 258 |
+
|
| 259 |
+
while True:
|
| 260 |
+
user_input = input("\nYou: ")
|
| 261 |
+
|
| 262 |
+
if user_input.lower() in ['exit', 'quit']:
|
| 263 |
+
print("Ending conversation. Goodbye!")
|
| 264 |
+
break
|
| 265 |
+
|
| 266 |
+
response = self.generate_response(user_input)
|
| 267 |
+
print(f"\nRAG System: {response}")
|
| 268 |
+
|
| 269 |
+
# Function to convert text to speech
|
| 270 |
+
def text_to_speech(response):
|
| 271 |
+
tts = gTTS(response)
|
| 272 |
+
audio_path = "response_audio.mp3"
|
| 273 |
+
tts.save(audio_path)
|
| 274 |
+
return audio_path
|
| 275 |
+
|
| 276 |
+
def main():
|
| 277 |
+
"""
|
| 278 |
+
Main function to demonstrate the RAG system.
|
| 279 |
+
"""
|
| 280 |
+
# Attempt to get the Gemini API key from environment variable
|
| 281 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 282 |
+
|
| 283 |
+
if not gemini_api_key:
|
| 284 |
+
# If environment variable is not set or is empty, fallback to the hardcoded key
|
| 285 |
+
hardcoded_api_key = "AIzaSyBisxoehBz8UF0i9kX42f1V3jp-9RNq04g" # Your hardcoded key
|
| 286 |
+
# Check if the environment variable was truly not set (vs. set to an empty string)
|
| 287 |
+
# to decide if we should print the INFO message.
|
| 288 |
+
if os.getenv("GEMINI_API_KEY") is None: # More specific check for unset env variable
|
| 289 |
+
print("INFO: GEMINI_API_KEY environment variable not found. Using hardcoded API key from rag.py.")
|
| 290 |
+
gemini_api_key = hardcoded_api_key
|
| 291 |
+
|
| 292 |
+
# Final check: if the key is still not set (e.g. if hardcoded key was also empty or None)
|
| 293 |
+
if not gemini_api_key:
|
| 294 |
+
print("Error: Gemini API key is not set.")
|
| 295 |
+
print("Please set the GEMINI_API_KEY environment variable, or ensure it's correctly hardcoded in rag.py.")
|
| 296 |
+
print("To set as environment variable:")
|
| 297 |
+
print(" export GEMINI_API_KEY='your_api_key' # For Linux/macOS")
|
| 298 |
+
print(" set GEMINI_API_KEY=your_api_key # For Windows CMD")
|
| 299 |
+
print(" $env:GEMINI_API_KEY='your_api_key' # For Windows PowerShell")
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
# Set paths
|
| 303 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 304 |
+
pdf_dir = os.path.join(current_dir, "material")
|
| 305 |
+
db_dir = os.path.join(current_dir, "chroma_db")
|
| 306 |
+
|
| 307 |
+
# Initialize the RAG system
|
| 308 |
+
rag = RAGSystem(
|
| 309 |
+
pdf_dir=pdf_dir,
|
| 310 |
+
gemini_api_key=gemini_api_key,
|
| 311 |
+
db_directory=db_dir
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Process documents
|
| 315 |
+
rag.process_documents()
|
| 316 |
+
|
| 317 |
+
# Start interactive session
|
| 318 |
+
rag.interactive_session()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
python-dotenv
|
| 3 |
+
gtts
|
| 4 |
+
docx2pdf
|
| 5 |
+
langchain
|
| 6 |
+
langchain-community
|
| 7 |
+
sentence-transformers
|
| 8 |
+
chromadb
|
| 9 |
+
pypdf
|
| 10 |
+
google-generativeai
|