File size: 3,345 Bytes
3a8f169
d0da0f6
 
 
7345007
d0da0f6
 
7345007
d0da0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7345007
d0da0f6
 
 
 
 
 
 
 
 
 
 
 
7345007
d0da0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import tempfile
import os

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.callbacks.base import BaseCallbackHandler

# Global state
kb = None
retriever = None
qa = None

class StreamHandler(BaseCallbackHandler):
    def __init__(self, update_fn):
        self.text = ""
        self.update_fn = update_fn

    def on_llm_new_token(self, token: str, **kwargs):
        self.text += token
        self.update_fn(self.text)

def save_pdfs(pdf_list):
    paths = []
    for pdf in pdf_list:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
        tmp.write(pdf.read())
        tmp.close()
        paths.append(tmp.name)
    return paths

def create_kb(api_key, pdf_list):
    global retriever, qa
    try:
        pdf_paths = save_pdfs(pdf_list)
        docs = []
        for path in pdf_paths:
            loader = PyPDFLoader(path)
            docs.extend(loader.load())
        splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        chunks = splitter.split_documents(docs)
        embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=api_key)
        db = FAISS.from_documents(chunks, embeddings)
        retriever = db.as_retriever(search_kwargs={"k": 3})
        qa = RetrievalQA.from_chain_type(llm=None, retriever=retriever)  # llm passed later
        return "✅ Knowledge base created."
    except Exception as e:
        return f"❌ Error creating KB: {e}"

def ask_question(api_key, question, chat_history, set_stream):
    global retriever, qa
    if retriever is None:
        return chat_history, "❌ Create KB first."
    handler = StreamHandler(lambda txt: set_stream(txt))
    llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-pro-latest",
                                 google_api_key=api_key,
                                 streaming=True,
                                 callbacks=[handler])
    qa.llm = llm
    chat_history = chat_history or []
    chat_history.append({"role": "user", "content": question})
    result = qa.invoke({"query": question})
    chat_history.append({"role": "assistant", "content": handler.text})
    return chat_history, ""

with gr.Blocks() as demo:
    gr.Markdown("# 📚 Multi‑PDF RAG Chat with Gemini")

    with gr.Column():
        api_key = gr.Textbox(show_label=False, placeholder="Enter your Gemini API Key", type="password")
        pdfs = gr.File(file_types=[".pdf"], label="Upload PDFs", file_count="multiple")
        kb_status = gr.Textbox(label="Status")
        create_btn = gr.Button("▶️ Create Knowledge Base")

    create_btn.click(create_kb, inputs=[api_key, pdfs], outputs=kb_status)

    chatbot = gr.Chatbot(label="🧠 Assistant", type="messages")
    question = gr.Textbox(show_label=False, placeholder="Ask a question")
    stream_output = gr.State("")  # to capture stream text
    send = gr.Button("🔍 Ask")

    send.click(fn=ask_question,
               inputs=[api_key, question, chatbot, stream_output],
               outputs=[chatbot, ""])

demo.launch()