File size: 3,523 Bytes
06d67dc
bb90063
bf4298a
 
893a06e
e539d12
 
 
bb90063
e539d12
 
06d67dc
e539d12
bf4298a
e539d12
 
bf4298a
e539d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4298a
 
 
 
 
 
 
06d67dc
893a06e
06d67dc
893a06e
 
 
06d67dc
bf4298a
893a06e
 
 
06d67dc
893a06e
 
e539d12
893a06e
 
 
 
06d67dc
893a06e
 
e539d12
893a06e
 
bf4298a
 
bb90063
bf4298a
 
 
06d67dc
 
 
 
bf4298a
06d67dc
 
bf4298a
bb90063
 
bf4298a
bb90063
bf4298a
06d67dc
bf4298a
06d67dc
 
bf4298a
06d67dc
bf4298a
06d67dc
bf4298a
06d67dc
bf4298a
 
06d67dc
 
bf4298a
bb90063
06d67dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import gradio as gr
import tempfile
from typing import List, Optional
import shutil
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
from langchain.llms.base import LLM
from groq import Groq

# ---- Custom GroqLLM class using LangChain LLM base ----
class GroqLLM(LLM):
    model: str = "llama3-8b-8192"
    api_key: str = os.environ.get("GROQ_API_KEY")  # Load from HF secrets
    temperature: float = 0.7

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        client = Groq(api_key=self.api_key)
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        response = client.chat.completions.create(
            model=self.model,
            messages=messages,
            temperature=self.temperature,
        )
        return response.choices[0].message.content

    @property
    def _llm_type(self) -> str:
        return "groq-llm"

# Global cache for vectorstore
rag_context = {"retriever": None}

# ---- Step 1: Upload & Embed PDF ----
def process_pdf(file):
    if file is None:
        return "❌ Please upload a PDF."

    # Save uploaded file to a real file path
    with tempfile.TemporaryDirectory() as temp_dir:
        # Gradio provides file path directly via file.name
        temp_pdf_path = os.path.join(temp_dir, "uploaded.pdf")
        shutil.copy(file.name, temp_pdf_path)

        # Load and split PDF
        try:
            loader = PyPDFLoader(temp_pdf_path)
            documents = loader.load()

            text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
            chunks = text_splitter.split_documents(documents)

            # Create embeddings
            embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
            vectorstore = Chroma.from_documents(chunks, embedding, persist_directory=temp_dir)
            vectorstore.persist()

            rag_context["retriever"] = vectorstore.as_retriever()
            return "βœ… PDF processed and ready. Ask your questions!"

        except Exception as e:
            return f"❌ Failed to load PDF: {e}"

# ---- Step 2: Ask questions to the RAG chain ----
def ask_question(query):
    retriever = rag_context.get("retriever")
    if retriever is None:
        return "❌ Please upload and process a PDF first."

    llm = GroqLLM()
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        return_source_documents=True
    )

    result = qa_chain({"query": query})
    answer = result["result"]
    return f"### Answer:\n{answer}"

# ---- Gradio UI ----
with gr.Blocks() as demo:
    gr.Markdown("# πŸ“š RAG Chatbot with Groq & LangChain\nUpload a PDF, then ask questions about it!")

    with gr.Row():
        pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_btn = gr.Button("Process PDF")
    upload_status = gr.Textbox(label="Status", interactive=False)

    upload_btn.click(process_pdf, inputs=pdf_input, outputs=upload_status)

    query_input = gr.Textbox(label="Ask a question")
    ask_btn = gr.Button("Get Answer")
    answer_output = gr.Markdown()

    ask_btn.click(ask_question, inputs=query_input, outputs=answer_output)

demo.launch()