File size: 13,231 Bytes
6f70a16 2482b7b 6f70a16 b912b67 6f70a16 9356909 6f70a16 9356909 6f70a16 ffdb38f 6f70a16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | import os
import gradio as gr
from groq import Groq
import torch # For checking CUDA availability for embedding model
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader # For PDF loading
from langchain_community.embeddings import HuggingFaceEmbeddings # For open-source embeddings
from langchain_community.vectorstores import FAISS # For vector database
# Removed: from google.colab import userdata # This library is specific to Google Colab
import gc # For garbage collection, useful in Colab/Spaces
# --- Configuration & Global Variables ---
# IMPORTANT: Ensure your GROQ_API_KEY is set in Hugging Face Space's Repository Secrets!
# It will be directly available via os.environ.get()
# Groq LLM Model
GROQ_MODEL = "llama-3.3-70b-versatile" # A fast and capable open-source model available via Groq
# Embedding Model (Open-source, free, and efficient)
# Model: 'sentence-transformers/all-MiniLM-L6-v2' is a good balance of size and performance.
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Global state for Groq client, embedding model, and FAISS vector store
groq_client = None
embedding_model = None
# This will hold the FAISS vector store after a PDF is uploaded and processed
# It's initialized to None and will be updated via Gradio's State.
faiss_vector_store = None
llm_chat_history = [] # For maintaining conversational context with Groq
# --- Initialization Functions ---
def initialize_groq_client():
"""Initializes the Groq client from environment variable."""
global groq_client
try:
# Directly get the API key from environment variables.
# This works automatically when you set secrets in Hugging Face Spaces.
groq_api_key = os.environ.get("GROK_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it to your Hugging Face Space's Repository Secrets.")
groq_client = Groq(api_key=groq_api_key)
print("Groq client initialized successfully.")
except ValueError as ve:
print(f"ERROR: Groq client initialization failed: {ve}")
print("ACTION REQUIRED: Ensure 'GROQ_API_KEY' is set correctly in your Hugging Face Space's Repository Secrets.")
groq_client = None
except Exception as e:
print(f"ERROR: An unexpected error occurred during Groq client initialization: {e}")
groq_client = None
def initialize_embedding_model():
"""Initializes the HuggingFace embedding model."""
global embedding_model
try:
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME} on {EMBEDDING_DEVICE}...")
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={'device': EMBEDDING_DEVICE},
encode_kwargs={'normalize_embeddings': True} # Recommended for cosine similarity
)
print("Embedding model initialized successfully.")
except Exception as e:
print(f"ERROR: Embedding model initialization failed: {e}")
print("ACTION REQUIRED: Check network connection and ensure 'transformers' library dependencies are met. Consider Space GPU availability if using 'cuda'.")
embedding_model = None
# Initialize clients/models once when the app starts
initialize_groq_client()
initialize_embedding_model()
# --- PDF Processing and FAISS Indexing Function ---
def process_pdf_and_create_index(pdf_file: gr.File):
"""
Loads text from a PDF, chunks it, creates embeddings,
and builds a FAISS vector store.
Args:
pdf_file (gr.File): The Gradio File object containing the uploaded PDF.
Returns:
tuple: (status_message, FAISS_vector_store_object)
"""
global faiss_vector_store
global llm_chat_history # Clear chat history on new PDF upload
if embedding_model is None:
return "Error: Embedding model not loaded. Cannot process PDF.", None
if pdf_file is None:
return "Please upload a PDF file to process.", None
file_path = pdf_file.name # Get the temporary file path from Gradio
print(f"Processing PDF: {file_path}")
try:
# 1. Load PDF Document
loader = PyPDFLoader(file_path)
documents = loader.load()
print(f"Loaded {len(documents)} pages from PDF.")
# 2. Split Text into Chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
add_start_index=True,
)
chunks = text_splitter.split_documents(documents)
print(f"Split document into {len(chunks)} chunks.")
# 3. Create Embeddings and Build FAISS Index
print("Creating embeddings and building FAISS index... This may take a while.")
faiss_vector_store = FAISS.from_documents(chunks, embedding_model)
print("FAISS index created successfully!")
# Clear existing chat history for new document context
llm_chat_history.clear()
gc.collect() # Clean up memory
return "PDF processed successfully! You can now start chatting.", faiss_vector_store
except Exception as e:
print(f"ERROR during PDF processing: {e}")
faiss_vector_store = None # Reset store on error
return f"Error processing PDF: {e}. Please ensure it's a valid PDF.", None
# --- Chat Function with RAG ---
def chat_with_rag(user_query: str, chat_history: list, current_vector_store: FAISS):
"""
Generates a response using Groq, augmented by context retrieved from the FAISS vector store.
Args:
user_query (str): The current message from the user.
chat_history (list): Gradio's chat history (list of [user_text, bot_text] tuples).
current_vector_store (FAISS): The loaded FAISS index containing document embeddings.
Returns:
tuple: (updated_gradio_chat_history, bot_response_text)
"""
global llm_chat_history # Access the global LLM context history
if current_vector_store is None:
bot_response = "Please upload and process a PDF document first before asking questions."
llm_chat_history.append({"role": "user", "content": user_query})
llm_chat_history.append({"role": "assistant", "content": bot_response})
# Gradio's Chatbot handles updating the display history from this.
return chat_history + [[user_query, bot_response]], "" # Return updated Gradio history and empty text input
if groq_client is None:
bot_response = "Groq client not initialized. Cannot generate response. Check API key setup."
llm_chat_history.append({"role": "user", "content": user_query})
llm_chat_history.append({"role": "assistant", "content": bot_response})
return chat_history + [[user_query, bot_response]], ""
print(f"User Query: {user_query}")
try:
# 1. Retrieve relevant documents from FAISS
# Adjust k (number of results) based on how much context you need
retrieved_docs = current_vector_store.similarity_search(user_query, k=4)
context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
print(f"Retrieved Context:\n{context_text[:500]}...") # Print first 500 chars
# 2. Augment the user query with retrieved context for the LLM
# Ensure Groq LLM understands the context's role
augmented_query = (
f"Based on the following context, answer the question. "
f"If the answer is not in the context, state that you don't have enough information.\n\n"
f"Context:\n{context_text}\n\n"
f"Question: {user_query}"
)
# 3. Prepare messages for Groq API (including chat history)
# We need to build the 'messages' list for Groq, including conversation history.
# Gradio's chat_history is [[user, bot], [user, bot], ...]
groq_messages = []
for human_msg, ai_msg in chat_history:
# Only add to Groq's history if not empty
if human_msg:
groq_messages.append({"role": "user", "content": human_msg})
if ai_msg:
groq_messages.append({"role": "assistant", "content": ai_msg})
# Add the current augmented query as the latest user message
groq_messages.append({"role": "user", "content": augmented_query})
# 4. Generate response using Groq
chat_completion = groq_client.chat.completions.create(
messages=groq_messages,
model=GROQ_MODEL,
temperature=0.7,
max_tokens=1024,
top_p=1,
stop=None,
stream=False,
)
bot_response = chat_completion.choices[0].message.content
print(f"Groq Response: {bot_response}")
# Update global LLM chat history for next turn's context
llm_chat_history.append({"role": "user", "content": user_query}) # Store original query
llm_chat_history.append({"role": "assistant", "content": bot_response})
# Return the updated Gradio chat history and clear the text input
return chat_history + [[user_query, bot_response]], ""
except Exception as e:
error_message = f"An error occurred during RAG process: {e}. Please try again."
print(f"RAG Chat Error: {e}")
llm_chat_history.append({"role": "user", "content": user_query})
llm_chat_history.append({"role": "assistant", "content": error_message})
return chat_history + [[user_query, error_message]], ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="RAG PDF Chatbot") as demo:
gr.Markdown(
"""
# 📚 Sheikh's -- You can Chat with your PDF (RAG Application) 💬
Upload a PDF document, wait for it to process, and then ask questions about its content!
Powered by open-source models.
"""
)
# State to store the FAISS vector index (in-memory)
# This state will persist the index across chat turns within a session.
# It's initialized to None and updated by the PDF processing function.
vector_store_state = gr.State(faiss_vector_store)
with gr.Row():
with gr.Column(scale=1):
pdf_upload_input = gr.File(
label="Upload your PDF Document and free IK",
file_types=[".pdf"],
file_count="single"
)
process_pdf_btn = gr.Button("Process PDF 🚀")
pdf_process_status = gr.Textbox(
label="PDF Processing Status",
interactive=False,
lines=1
)
# Add a progress bar (useful for longer PDFs)
# You can connect a progress event to this if needed, for now just a placeholder
# progress_bar = gr.Progress(label="Processing Progress")
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation History",
value=[],
height=400,
show_copy_button=True
)
text_input = gr.Textbox(
label="Type your question here",
placeholder="Ask me about the PDF content...",
lines=3
)
with gr.Row():
submit_btn = gr.Button("Send Message ➡️")
clear_chat_btn = gr.Button("Clear Chat 🗑️")
# --- Event Handlers ---
# 1. When PDF is uploaded and "Process PDF" button is clicked
process_pdf_btn.click(
fn=process_pdf_and_create_index,
inputs=[pdf_upload_input],
outputs=[pdf_process_status, vector_store_state], # Update status and the state variable
# Add a loading indicator to the button itself
api_name="process_pdf"
)
# 2. When text is entered and "Send Message" button is clicked
submit_btn.click(
fn=chat_with_rag,
inputs=[text_input, chatbot, vector_store_state], # Pass query, current chat history, and the vector store state
outputs=[chatbot, text_input], # Update chat history and clear text input
api_name="send_message_button"
)
# 3. When text is entered and Enter key is pressed
text_input.submit(
fn=chat_with_rag,
inputs=[text_input, chatbot, vector_store_state], # Pass query, current chat history, and the vector store state
outputs=[chatbot, text_input], # Update chat history and clear text input
api_name="send_message_enter"
)
# 4. Clear chat button functionality
clear_chat_btn.click(
fn=lambda: ([], ""), # Clear chatbot display and text input box
inputs=[],
outputs=[chatbot, text_input],
queue=False
).success(
fn=lambda: llm_chat_history.clear(), # Clear the global LLM history list
inputs=[],
outputs=[]
)
# Launch the Gradio app
if __name__ == "__main__":
# For Hugging Face Spaces deployment, `share=True` is not needed as it's automatically public.
# The default demo.launch() will work.
demo.launch()
|