Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import tempfile | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| # Persistent vectorstore directory | |
| PERSIST_DIRECTORY = "./chroma_db" | |
| # Store chat histories | |
| chat_histories = {} | |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in chat_histories: | |
| chat_histories[session_id] = ChatMessageHistory() | |
| return chat_histories[session_id] | |
| def process_files(api_key, model_name, session_id, files, question): | |
| if not api_key: | |
| return "Please enter your Groq API key" | |
| # Initialize LLM | |
| llm = ChatGroq(groq_api_key=api_key, model_name=model_name) | |
| # Process PDFs | |
| documents = [] | |
| for file in files: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(file) | |
| tmp_path = tmp.name | |
| loader = PyPDFLoader(tmp_path) | |
| docs = loader.load() | |
| documents.extend(docs) | |
| os.unlink(tmp_path) # Clean up temp file | |
| # Split and embed | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500) | |
| splits = text_splitter.split_documents(documents) | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectorstore = Chroma.from_documents(splits, embedding=embeddings, persist_directory=PERSIST_DIRECTORY) | |
| retriever = vectorstore.as_retriever() | |
| # Setup RAG chain | |
| rag_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| return_source_documents=True, | |
| ) | |
| # Get chat history | |
| chat_history = get_session_history(session_id) | |
| # Process question | |
| response = rag_chain({ | |
| "question": question, | |
| "chat_history": [(msg.content if msg.type == "human" else "Assistant: " + msg.content) | |
| for msg in chat_history.messages] | |
| }) | |
| # Update history | |
| chat_history.add_user_message(question) | |
| chat_history.add_ai_message(response["answer"]) | |
| # Format response | |
| output = f"**Assistant:** {response['answer']}\n\n---\n**Chat History:**\n" | |
| for msg in chat_history.messages[-6:]: # Show last 3 exchanges | |
| output += f"{msg.type.capitalize()}: {msg.content}\n" | |
| return output | |
| # Gradio Interface | |
| with gr.Blocks(title="RAG PDF Chat") as demo: | |
| gr.Markdown("## π Conversational RAG with PDF Uploads") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| api_key = gr.Textbox( | |
| label="Groq API Key", | |
| type="password", | |
| placeholder="Enter your API key" | |
| ) | |
| model = gr.Dropdown( | |
| label="LLM Model", | |
| choices=[ | |
| "qwen-2.5-32b", | |
| "deepseek-r1-distill-llama-70b", | |
| "gemma2-9b-it", | |
| "mixtral-8x7b-32768", | |
| "llama-3.3-70b-versatile", | |
| "Gemma2-9b-It" | |
| ], | |
| value="mixtral-8x7b-32768" | |
| ) | |
| session_id = gr.Textbox( | |
| label="Session ID", | |
| value="default_session" | |
| ) | |
| with gr.Column(scale=2): | |
| file_input = gr.File( | |
| label="Upload PDFs", | |
| file_types=[".pdf"], | |
| file_count="multiple" | |
| ) | |
| question = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask about the uploaded documents..." | |
| ) | |
| submit_btn = gr.Button("Submit") | |
| output = gr.Markdown() | |
| submit_btn.click( | |
| fn=process_files, | |
| inputs=[api_key, model, session_id, file_input, question], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |