File size: 4,366 Bytes
b6e8184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import os
import tempfile

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import ConversationalRetrievalChain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory

# Persistent vectorstore directory
PERSIST_DIRECTORY = "./chroma_db"

# Store chat histories
chat_histories = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in chat_histories:
        chat_histories[session_id] = ChatMessageHistory()
    return chat_histories[session_id]

def process_files(api_key, model_name, session_id, files, question):
    if not api_key:
        return "Please enter your Groq API key"
    
    # Initialize LLM
    llm = ChatGroq(groq_api_key=api_key, model_name=model_name)
    
    # Process PDFs
    documents = []
    for file in files:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
            tmp.write(file)
            tmp_path = tmp.name
        
        loader = PyPDFLoader(tmp_path)
        docs = loader.load()
        documents.extend(docs)
        os.unlink(tmp_path)  # Clean up temp file

    # Split and embed
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
    splits = text_splitter.split_documents(documents)
    
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    vectorstore = Chroma.from_documents(splits, embedding=embeddings, persist_directory=PERSIST_DIRECTORY)
    retriever = vectorstore.as_retriever()

    # Setup RAG chain
    rag_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        return_source_documents=True,
    )

    # Get chat history
    chat_history = get_session_history(session_id)
    
    # Process question
    response = rag_chain({
        "question": question,
        "chat_history": [(msg.content if msg.type == "human" else "Assistant: " + msg.content)
                         for msg in chat_history.messages]
    })

    # Update history
    chat_history.add_user_message(question)
    chat_history.add_ai_message(response["answer"])
    
    # Format response
    output = f"**Assistant:** {response['answer']}\n\n---\n**Chat History:**\n"
    for msg in chat_history.messages[-6:]:  # Show last 3 exchanges
        output += f"{msg.type.capitalize()}: {msg.content}\n"
    
    return output

# Gradio Interface
with gr.Blocks(title="RAG PDF Chat") as demo:
    gr.Markdown("## 📚 Conversational RAG with PDF Uploads")
    
    with gr.Row():
        with gr.Column(scale=1):
            api_key = gr.Textbox(
                label="Groq API Key", 
                type="password",
                placeholder="Enter your API key"
            )
            model = gr.Dropdown(
                label="LLM Model",
                choices=[
                    "qwen-2.5-32b", 
                    "deepseek-r1-distill-llama-70b", 
                    "gemma2-9b-it",
                    "mixtral-8x7b-32768", 
                    "llama-3.3-70b-versatile", 
                    "Gemma2-9b-It"
                ],
                value="mixtral-8x7b-32768"
            )
            session_id = gr.Textbox(
                label="Session ID",
                value="default_session"
            )
            
        with gr.Column(scale=2):
            file_input = gr.File(
                label="Upload PDFs",
                file_types=[".pdf"],
                file_count="multiple"
            )
            question = gr.Textbox(
                label="Your Question",
                placeholder="Ask about the uploaded documents..."
            )
            submit_btn = gr.Button("Submit")
            output = gr.Markdown()
    
    submit_btn.click(
        fn=process_files,
        inputs=[api_key, model, session_id, file_input, question],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch(share=True)