File size: 3,503 Bytes
ac02361
 
 
 
 
 
 
 
 
97c6b1f
 
 
 
ac02361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import fitz  # PyMuPDF
import os
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
from groq import Groq

# Initialize Groq client
key = os.getenv("GROQ_API_KEY")
if not key:
    raise ValueError("No API key found")
groq_client = Groq(api_key=key)
model = "llama3-8b-8192"

embedder = SentenceTransformer('all-MiniLM-L6-v2')

# Global state
state = {
    "document_chunks": [],
    "metadata": [],
    "index": None,
    "embeddings": None
}

# Extract text from PDF using file path
def extract_text_from_pdf(file_path):
    doc = fitz.open(file_path)
    texts = []
    for i, page in enumerate(doc):
        text = page.get_text().strip()
        if text:
            texts.append({"text": text, "page": i + 1})
    return texts

# Process PDFs
def process_pdfs(files):
    state["document_chunks"] = []
    state["metadata"] = []

    for file in files:
        file_name = os.path.basename(file.name)
        chunks = extract_text_from_pdf(file.name)
        for chunk in chunks:
            state["document_chunks"].append(chunk['text'])
            state["metadata"].append({"file": file_name, "page": chunk['page']})

    embeddings = embedder.encode(state["document_chunks"], show_progress_bar=True)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(np.array(embeddings))
    state["index"] = index
    state["embeddings"] = embeddings

    return "βœ… Book(s) loaded successfully!"

# Retrieve top chunks
def retrieve_chunks(question, top_k=3):
    if not state["index"]:
        return []
    q_embedding = embedder.encode([question])
    D, I = state["index"].search(q_embedding, top_k)
    return [(state["document_chunks"][i], state["metadata"][i]) for i in I[0]]

# Generate answer with source references
def generate_answer(context, question):
    context_text = "\n\n".join(
        f"{chunk}\n\n[Source: {meta['file']}, Page: {meta['page']}]"
        for chunk, meta in context
    )
    prompt = f"""You are a helpful assistant. Use the context below to answer the question. 
Include the source references (file name and page number) in your answer.

Context:
{context_text}

Question:
{question}

Answer (with sources):"""

    response = groq_client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.2
    )
    return response.choices[0].message.content

# Chat function for ChatInterface
def chatbot_interface_fn(message, history):
    if not state["document_chunks"]:
        return "⚠️ Please upload PDF files first."
    context = retrieve_chunks(message)
    return generate_answer(context, message)

# Gradio UI
with gr.Blocks(title="RAG Chatbot") as demo:
    gr.Markdown("# πŸ“š Enhanced RAG Chatbot\nUpload books and chat naturally!")

    with gr.Row():
        pdf_input = gr.File(file_types=[".pdf"], file_count="multiple", label="πŸ“‚ Upload PDFs")
        upload_btn = gr.Button("Upload & Process PDFs")
        status = gr.Textbox(label="Status", interactive=False)

    upload_btn.click(process_pdfs, inputs=[pdf_input], outputs=[status])

    gr.ChatInterface(
        fn=chatbot_interface_fn,
        chatbot=gr.Chatbot(height=400, type="messages"),
        textbox=gr.Textbox(placeholder="Ask about the PDFs...", scale=7),
        title="πŸ“– PDF Chat",
        description="Ask questions based on uploaded PDF content.",
        submit_btn="Send"
    )

if __name__ == "__main__":
    demo.launch()