Spaces:
Runtime error
Runtime error
File size: 4,366 Bytes
b6e8184 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import os
import tempfile
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import ConversationalRetrievalChain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
# Persistent vectorstore directory
PERSIST_DIRECTORY = "./chroma_db"
# Store chat histories
chat_histories = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in chat_histories:
chat_histories[session_id] = ChatMessageHistory()
return chat_histories[session_id]
def process_files(api_key, model_name, session_id, files, question):
if not api_key:
return "Please enter your Groq API key"
# Initialize LLM
llm = ChatGroq(groq_api_key=api_key, model_name=model_name)
# Process PDFs
documents = []
for file in files:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(file)
tmp_path = tmp.name
loader = PyPDFLoader(tmp_path)
docs = loader.load()
documents.extend(docs)
os.unlink(tmp_path) # Clean up temp file
# Split and embed
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
splits = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(splits, embedding=embeddings, persist_directory=PERSIST_DIRECTORY)
retriever = vectorstore.as_retriever()
# Setup RAG chain
rag_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
return_source_documents=True,
)
# Get chat history
chat_history = get_session_history(session_id)
# Process question
response = rag_chain({
"question": question,
"chat_history": [(msg.content if msg.type == "human" else "Assistant: " + msg.content)
for msg in chat_history.messages]
})
# Update history
chat_history.add_user_message(question)
chat_history.add_ai_message(response["answer"])
# Format response
output = f"**Assistant:** {response['answer']}\n\n---\n**Chat History:**\n"
for msg in chat_history.messages[-6:]: # Show last 3 exchanges
output += f"{msg.type.capitalize()}: {msg.content}\n"
return output
# Gradio Interface
with gr.Blocks(title="RAG PDF Chat") as demo:
gr.Markdown("## 📚 Conversational RAG with PDF Uploads")
with gr.Row():
with gr.Column(scale=1):
api_key = gr.Textbox(
label="Groq API Key",
type="password",
placeholder="Enter your API key"
)
model = gr.Dropdown(
label="LLM Model",
choices=[
"qwen-2.5-32b",
"deepseek-r1-distill-llama-70b",
"gemma2-9b-it",
"mixtral-8x7b-32768",
"llama-3.3-70b-versatile",
"Gemma2-9b-It"
],
value="mixtral-8x7b-32768"
)
session_id = gr.Textbox(
label="Session ID",
value="default_session"
)
with gr.Column(scale=2):
file_input = gr.File(
label="Upload PDFs",
file_types=[".pdf"],
file_count="multiple"
)
question = gr.Textbox(
label="Your Question",
placeholder="Ask about the uploaded documents..."
)
submit_btn = gr.Button("Submit")
output = gr.Markdown()
submit_btn.click(
fn=process_files,
inputs=[api_key, model, session_id, file_input, question],
outputs=output
)
if __name__ == "__main__":
demo.launch(share=True) |