| import os |
| import gradio as gr |
| from groq import Groq |
| import torch |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.document_loaders import PyPDFLoader |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain_community.vectorstores import FAISS |
| |
| import gc |
|
|
| |
| |
| |
|
|
| |
| GROQ_MODEL = "llama-3.3-70b-versatile" |
|
|
| |
| |
| EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
| EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| groq_client = None |
| embedding_model = None |
| |
| |
| faiss_vector_store = None |
| llm_chat_history = [] |
|
|
| |
| def initialize_groq_client(): |
| """Initializes the Groq client from environment variable.""" |
| global groq_client |
| try: |
| |
| |
| groq_api_key = os.environ.get("GROK_API_KEY") |
|
|
| if not groq_api_key: |
| raise ValueError("GROQ_API_KEY environment variable is not set. Please add it to your Hugging Face Space's Repository Secrets.") |
| groq_client = Groq(api_key=groq_api_key) |
| print("Groq client initialized successfully.") |
| except ValueError as ve: |
| print(f"ERROR: Groq client initialization failed: {ve}") |
| print("ACTION REQUIRED: Ensure 'GROQ_API_KEY' is set correctly in your Hugging Face Space's Repository Secrets.") |
| groq_client = None |
| except Exception as e: |
| print(f"ERROR: An unexpected error occurred during Groq client initialization: {e}") |
| groq_client = None |
|
|
| def initialize_embedding_model(): |
| """Initializes the HuggingFace embedding model.""" |
| global embedding_model |
| try: |
| print(f"Loading embedding model: {EMBEDDING_MODEL_NAME} on {EMBEDDING_DEVICE}...") |
| embedding_model = HuggingFaceEmbeddings( |
| model_name=EMBEDDING_MODEL_NAME, |
| model_kwargs={'device': EMBEDDING_DEVICE}, |
| encode_kwargs={'normalize_embeddings': True} |
| ) |
| print("Embedding model initialized successfully.") |
| except Exception as e: |
| print(f"ERROR: Embedding model initialization failed: {e}") |
| print("ACTION REQUIRED: Check network connection and ensure 'transformers' library dependencies are met. Consider Space GPU availability if using 'cuda'.") |
| embedding_model = None |
|
|
| |
| initialize_groq_client() |
| initialize_embedding_model() |
|
|
| |
| def process_pdf_and_create_index(pdf_file: gr.File): |
| """ |
| Loads text from a PDF, chunks it, creates embeddings, |
| and builds a FAISS vector store. |
| |
| Args: |
| pdf_file (gr.File): The Gradio File object containing the uploaded PDF. |
| |
| Returns: |
| tuple: (status_message, FAISS_vector_store_object) |
| """ |
| global faiss_vector_store |
| global llm_chat_history |
|
|
| if embedding_model is None: |
| return "Error: Embedding model not loaded. Cannot process PDF.", None |
|
|
| if pdf_file is None: |
| return "Please upload a PDF file to process.", None |
|
|
| file_path = pdf_file.name |
| print(f"Processing PDF: {file_path}") |
|
|
| try: |
| |
| loader = PyPDFLoader(file_path) |
| documents = loader.load() |
| print(f"Loaded {len(documents)} pages from PDF.") |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200, |
| length_function=len, |
| add_start_index=True, |
| ) |
| chunks = text_splitter.split_documents(documents) |
| print(f"Split document into {len(chunks)} chunks.") |
|
|
| |
| print("Creating embeddings and building FAISS index... This may take a while.") |
| faiss_vector_store = FAISS.from_documents(chunks, embedding_model) |
| print("FAISS index created successfully!") |
|
|
| |
| llm_chat_history.clear() |
| gc.collect() |
|
|
| return "PDF processed successfully! You can now start chatting.", faiss_vector_store |
|
|
| except Exception as e: |
| print(f"ERROR during PDF processing: {e}") |
| faiss_vector_store = None |
| return f"Error processing PDF: {e}. Please ensure it's a valid PDF.", None |
|
|
| |
| def chat_with_rag(user_query: str, chat_history: list, current_vector_store: FAISS): |
| """ |
| Generates a response using Groq, augmented by context retrieved from the FAISS vector store. |
| |
| Args: |
| user_query (str): The current message from the user. |
| chat_history (list): Gradio's chat history (list of [user_text, bot_text] tuples). |
| current_vector_store (FAISS): The loaded FAISS index containing document embeddings. |
| |
| Returns: |
| tuple: (updated_gradio_chat_history, bot_response_text) |
| """ |
| global llm_chat_history |
|
|
| if current_vector_store is None: |
| bot_response = "Please upload and process a PDF document first before asking questions." |
| llm_chat_history.append({"role": "user", "content": user_query}) |
| llm_chat_history.append({"role": "assistant", "content": bot_response}) |
| |
| return chat_history + [[user_query, bot_response]], "" |
|
|
| if groq_client is None: |
| bot_response = "Groq client not initialized. Cannot generate response. Check API key setup." |
| llm_chat_history.append({"role": "user", "content": user_query}) |
| llm_chat_history.append({"role": "assistant", "content": bot_response}) |
| return chat_history + [[user_query, bot_response]], "" |
|
|
| print(f"User Query: {user_query}") |
|
|
| try: |
| |
| |
| retrieved_docs = current_vector_store.similarity_search(user_query, k=4) |
| context_text = "\n\n".join([doc.page_content for doc in retrieved_docs]) |
| print(f"Retrieved Context:\n{context_text[:500]}...") |
|
|
| |
| |
| augmented_query = ( |
| f"Based on the following context, answer the question. " |
| f"If the answer is not in the context, state that you don't have enough information.\n\n" |
| f"Context:\n{context_text}\n\n" |
| f"Question: {user_query}" |
| ) |
|
|
| |
| |
| |
| groq_messages = [] |
| for human_msg, ai_msg in chat_history: |
| |
| if human_msg: |
| groq_messages.append({"role": "user", "content": human_msg}) |
| if ai_msg: |
| groq_messages.append({"role": "assistant", "content": ai_msg}) |
| |
| |
| groq_messages.append({"role": "user", "content": augmented_query}) |
|
|
| |
| chat_completion = groq_client.chat.completions.create( |
| messages=groq_messages, |
| model=GROQ_MODEL, |
| temperature=0.7, |
| max_tokens=1024, |
| top_p=1, |
| stop=None, |
| stream=False, |
| ) |
| bot_response = chat_completion.choices[0].message.content |
| print(f"Groq Response: {bot_response}") |
|
|
| |
| llm_chat_history.append({"role": "user", "content": user_query}) |
| llm_chat_history.append({"role": "assistant", "content": bot_response}) |
|
|
| |
| return chat_history + [[user_query, bot_response]], "" |
|
|
| except Exception as e: |
| error_message = f"An error occurred during RAG process: {e}. Please try again." |
| print(f"RAG Chat Error: {e}") |
| llm_chat_history.append({"role": "user", "content": user_query}) |
| llm_chat_history.append({"role": "assistant", "content": error_message}) |
| return chat_history + [[user_query, error_message]], "" |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), title="RAG PDF Chatbot") as demo: |
| gr.Markdown( |
| """ |
| # 📚 Sheikh's -- You can Chat with your PDF (RAG Application) 💬 |
| Upload a PDF document, wait for it to process, and then ask questions about its content! |
| Powered by open-source models. |
| """ |
| ) |
|
|
| |
| |
| |
| vector_store_state = gr.State(faiss_vector_store) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| pdf_upload_input = gr.File( |
| label="Upload your PDF Document and free IK", |
| file_types=[".pdf"], |
| file_count="single" |
| ) |
| process_pdf_btn = gr.Button("Process PDF 🚀") |
| pdf_process_status = gr.Textbox( |
| label="PDF Processing Status", |
| interactive=False, |
| lines=1 |
| ) |
| |
| |
| |
|
|
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot( |
| label="Conversation History", |
| value=[], |
| height=400, |
| show_copy_button=True |
| ) |
| text_input = gr.Textbox( |
| label="Type your question here", |
| placeholder="Ask me about the PDF content...", |
| lines=3 |
| ) |
| with gr.Row(): |
| submit_btn = gr.Button("Send Message ➡️") |
| clear_chat_btn = gr.Button("Clear Chat 🗑️") |
|
|
| |
| |
| process_pdf_btn.click( |
| fn=process_pdf_and_create_index, |
| inputs=[pdf_upload_input], |
| outputs=[pdf_process_status, vector_store_state], |
| |
| api_name="process_pdf" |
| ) |
|
|
| |
| submit_btn.click( |
| fn=chat_with_rag, |
| inputs=[text_input, chatbot, vector_store_state], |
| outputs=[chatbot, text_input], |
| api_name="send_message_button" |
| ) |
|
|
| |
| text_input.submit( |
| fn=chat_with_rag, |
| inputs=[text_input, chatbot, vector_store_state], |
| outputs=[chatbot, text_input], |
| api_name="send_message_enter" |
| ) |
|
|
| |
| clear_chat_btn.click( |
| fn=lambda: ([], ""), |
| inputs=[], |
| outputs=[chatbot, text_input], |
| queue=False |
| ).success( |
| fn=lambda: llm_chat_history.clear(), |
| inputs=[], |
| outputs=[] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| |
| |
| demo.launch() |
|
|