forever-sheikh's picture
Update app.py
2482b7b verified
import os
import gradio as gr
from groq import Groq
import torch # For checking CUDA availability for embedding model
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader # For PDF loading
from langchain_community.embeddings import HuggingFaceEmbeddings # For open-source embeddings
from langchain_community.vectorstores import FAISS # For vector database
# Removed: from google.colab import userdata # This library is specific to Google Colab
import gc # For garbage collection, useful in Colab/Spaces
# --- Configuration & Global Variables ---
# IMPORTANT: Ensure your GROQ_API_KEY is set in Hugging Face Space's Repository Secrets!
# It will be directly available via os.environ.get()
# Groq LLM Model
GROQ_MODEL = "llama-3.3-70b-versatile" # A fast and capable open-source model available via Groq
# Embedding Model (Open-source, free, and efficient)
# Model: 'sentence-transformers/all-MiniLM-L6-v2' is a good balance of size and performance.
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Global state for Groq client, embedding model, and FAISS vector store
groq_client = None
embedding_model = None
# This will hold the FAISS vector store after a PDF is uploaded and processed
# It's initialized to None and will be updated via Gradio's State.
faiss_vector_store = None
llm_chat_history = [] # For maintaining conversational context with Groq
# --- Initialization Functions ---
def initialize_groq_client():
"""Initializes the Groq client from environment variable."""
global groq_client
try:
# Directly get the API key from environment variables.
# This works automatically when you set secrets in Hugging Face Spaces.
groq_api_key = os.environ.get("GROK_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it to your Hugging Face Space's Repository Secrets.")
groq_client = Groq(api_key=groq_api_key)
print("Groq client initialized successfully.")
except ValueError as ve:
print(f"ERROR: Groq client initialization failed: {ve}")
print("ACTION REQUIRED: Ensure 'GROQ_API_KEY' is set correctly in your Hugging Face Space's Repository Secrets.")
groq_client = None
except Exception as e:
print(f"ERROR: An unexpected error occurred during Groq client initialization: {e}")
groq_client = None
def initialize_embedding_model():
"""Initializes the HuggingFace embedding model."""
global embedding_model
try:
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME} on {EMBEDDING_DEVICE}...")
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={'device': EMBEDDING_DEVICE},
encode_kwargs={'normalize_embeddings': True} # Recommended for cosine similarity
)
print("Embedding model initialized successfully.")
except Exception as e:
print(f"ERROR: Embedding model initialization failed: {e}")
print("ACTION REQUIRED: Check network connection and ensure 'transformers' library dependencies are met. Consider Space GPU availability if using 'cuda'.")
embedding_model = None
# Initialize clients/models once when the app starts
initialize_groq_client()
initialize_embedding_model()
# --- PDF Processing and FAISS Indexing Function ---
def process_pdf_and_create_index(pdf_file: gr.File):
"""
Loads text from a PDF, chunks it, creates embeddings,
and builds a FAISS vector store.
Args:
pdf_file (gr.File): The Gradio File object containing the uploaded PDF.
Returns:
tuple: (status_message, FAISS_vector_store_object)
"""
global faiss_vector_store
global llm_chat_history # Clear chat history on new PDF upload
if embedding_model is None:
return "Error: Embedding model not loaded. Cannot process PDF.", None
if pdf_file is None:
return "Please upload a PDF file to process.", None
file_path = pdf_file.name # Get the temporary file path from Gradio
print(f"Processing PDF: {file_path}")
try:
# 1. Load PDF Document
loader = PyPDFLoader(file_path)
documents = loader.load()
print(f"Loaded {len(documents)} pages from PDF.")
# 2. Split Text into Chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
add_start_index=True,
)
chunks = text_splitter.split_documents(documents)
print(f"Split document into {len(chunks)} chunks.")
# 3. Create Embeddings and Build FAISS Index
print("Creating embeddings and building FAISS index... This may take a while.")
faiss_vector_store = FAISS.from_documents(chunks, embedding_model)
print("FAISS index created successfully!")
# Clear existing chat history for new document context
llm_chat_history.clear()
gc.collect() # Clean up memory
return "PDF processed successfully! You can now start chatting.", faiss_vector_store
except Exception as e:
print(f"ERROR during PDF processing: {e}")
faiss_vector_store = None # Reset store on error
return f"Error processing PDF: {e}. Please ensure it's a valid PDF.", None
# --- Chat Function with RAG ---
def chat_with_rag(user_query: str, chat_history: list, current_vector_store: FAISS):
"""
Generates a response using Groq, augmented by context retrieved from the FAISS vector store.
Args:
user_query (str): The current message from the user.
chat_history (list): Gradio's chat history (list of [user_text, bot_text] tuples).
current_vector_store (FAISS): The loaded FAISS index containing document embeddings.
Returns:
tuple: (updated_gradio_chat_history, bot_response_text)
"""
global llm_chat_history # Access the global LLM context history
if current_vector_store is None:
bot_response = "Please upload and process a PDF document first before asking questions."
llm_chat_history.append({"role": "user", "content": user_query})
llm_chat_history.append({"role": "assistant", "content": bot_response})
# Gradio's Chatbot handles updating the display history from this.
return chat_history + [[user_query, bot_response]], "" # Return updated Gradio history and empty text input
if groq_client is None:
bot_response = "Groq client not initialized. Cannot generate response. Check API key setup."
llm_chat_history.append({"role": "user", "content": user_query})
llm_chat_history.append({"role": "assistant", "content": bot_response})
return chat_history + [[user_query, bot_response]], ""
print(f"User Query: {user_query}")
try:
# 1. Retrieve relevant documents from FAISS
# Adjust k (number of results) based on how much context you need
retrieved_docs = current_vector_store.similarity_search(user_query, k=4)
context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
print(f"Retrieved Context:\n{context_text[:500]}...") # Print first 500 chars
# 2. Augment the user query with retrieved context for the LLM
# Ensure Groq LLM understands the context's role
augmented_query = (
f"Based on the following context, answer the question. "
f"If the answer is not in the context, state that you don't have enough information.\n\n"
f"Context:\n{context_text}\n\n"
f"Question: {user_query}"
)
# 3. Prepare messages for Groq API (including chat history)
# We need to build the 'messages' list for Groq, including conversation history.
# Gradio's chat_history is [[user, bot], [user, bot], ...]
groq_messages = []
for human_msg, ai_msg in chat_history:
# Only add to Groq's history if not empty
if human_msg:
groq_messages.append({"role": "user", "content": human_msg})
if ai_msg:
groq_messages.append({"role": "assistant", "content": ai_msg})
# Add the current augmented query as the latest user message
groq_messages.append({"role": "user", "content": augmented_query})
# 4. Generate response using Groq
chat_completion = groq_client.chat.completions.create(
messages=groq_messages,
model=GROQ_MODEL,
temperature=0.7,
max_tokens=1024,
top_p=1,
stop=None,
stream=False,
)
bot_response = chat_completion.choices[0].message.content
print(f"Groq Response: {bot_response}")
# Update global LLM chat history for next turn's context
llm_chat_history.append({"role": "user", "content": user_query}) # Store original query
llm_chat_history.append({"role": "assistant", "content": bot_response})
# Return the updated Gradio chat history and clear the text input
return chat_history + [[user_query, bot_response]], ""
except Exception as e:
error_message = f"An error occurred during RAG process: {e}. Please try again."
print(f"RAG Chat Error: {e}")
llm_chat_history.append({"role": "user", "content": user_query})
llm_chat_history.append({"role": "assistant", "content": error_message})
return chat_history + [[user_query, error_message]], ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="RAG PDF Chatbot") as demo:
gr.Markdown(
"""
# 📚 Sheikh's -- You can Chat with your PDF (RAG Application) 💬
Upload a PDF document, wait for it to process, and then ask questions about its content!
Powered by open-source models.
"""
)
# State to store the FAISS vector index (in-memory)
# This state will persist the index across chat turns within a session.
# It's initialized to None and updated by the PDF processing function.
vector_store_state = gr.State(faiss_vector_store)
with gr.Row():
with gr.Column(scale=1):
pdf_upload_input = gr.File(
label="Upload your PDF Document and free IK",
file_types=[".pdf"],
file_count="single"
)
process_pdf_btn = gr.Button("Process PDF 🚀")
pdf_process_status = gr.Textbox(
label="PDF Processing Status",
interactive=False,
lines=1
)
# Add a progress bar (useful for longer PDFs)
# You can connect a progress event to this if needed, for now just a placeholder
# progress_bar = gr.Progress(label="Processing Progress")
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation History",
value=[],
height=400,
show_copy_button=True
)
text_input = gr.Textbox(
label="Type your question here",
placeholder="Ask me about the PDF content...",
lines=3
)
with gr.Row():
submit_btn = gr.Button("Send Message ➡️")
clear_chat_btn = gr.Button("Clear Chat 🗑️")
# --- Event Handlers ---
# 1. When PDF is uploaded and "Process PDF" button is clicked
process_pdf_btn.click(
fn=process_pdf_and_create_index,
inputs=[pdf_upload_input],
outputs=[pdf_process_status, vector_store_state], # Update status and the state variable
# Add a loading indicator to the button itself
api_name="process_pdf"
)
# 2. When text is entered and "Send Message" button is clicked
submit_btn.click(
fn=chat_with_rag,
inputs=[text_input, chatbot, vector_store_state], # Pass query, current chat history, and the vector store state
outputs=[chatbot, text_input], # Update chat history and clear text input
api_name="send_message_button"
)
# 3. When text is entered and Enter key is pressed
text_input.submit(
fn=chat_with_rag,
inputs=[text_input, chatbot, vector_store_state], # Pass query, current chat history, and the vector store state
outputs=[chatbot, text_input], # Update chat history and clear text input
api_name="send_message_enter"
)
# 4. Clear chat button functionality
clear_chat_btn.click(
fn=lambda: ([], ""), # Clear chatbot display and text input box
inputs=[],
outputs=[chatbot, text_input],
queue=False
).success(
fn=lambda: llm_chat_history.clear(), # Clear the global LLM history list
inputs=[],
outputs=[]
)
# Launch the Gradio app
if __name__ == "__main__":
# For Hugging Face Spaces deployment, `share=True` is not needed as it's automatically public.
# The default demo.launch() will work.
demo.launch()