Yatheshr's picture
Update app.py
d0da0f6 verified
import gradio as gr
import tempfile
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.callbacks.base import BaseCallbackHandler
# Global state
kb = None
retriever = None
qa = None
class StreamHandler(BaseCallbackHandler):
def __init__(self, update_fn):
self.text = ""
self.update_fn = update_fn
def on_llm_new_token(self, token: str, **kwargs):
self.text += token
self.update_fn(self.text)
def save_pdfs(pdf_list):
paths = []
for pdf in pdf_list:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
tmp.write(pdf.read())
tmp.close()
paths.append(tmp.name)
return paths
def create_kb(api_key, pdf_list):
global retriever, qa
try:
pdf_paths = save_pdfs(pdf_list)
docs = []
for path in pdf_paths:
loader = PyPDFLoader(path)
docs.extend(loader.load())
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_documents(docs)
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=api_key)
db = FAISS.from_documents(chunks, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 3})
qa = RetrievalQA.from_chain_type(llm=None, retriever=retriever) # llm passed later
return "✅ Knowledge base created."
except Exception as e:
return f"❌ Error creating KB: {e}"
def ask_question(api_key, question, chat_history, set_stream):
global retriever, qa
if retriever is None:
return chat_history, "❌ Create KB first."
handler = StreamHandler(lambda txt: set_stream(txt))
llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-pro-latest",
google_api_key=api_key,
streaming=True,
callbacks=[handler])
qa.llm = llm
chat_history = chat_history or []
chat_history.append({"role": "user", "content": question})
result = qa.invoke({"query": question})
chat_history.append({"role": "assistant", "content": handler.text})
return chat_history, ""
with gr.Blocks() as demo:
gr.Markdown("# 📚 Multi‑PDF RAG Chat with Gemini")
with gr.Column():
api_key = gr.Textbox(show_label=False, placeholder="Enter your Gemini API Key", type="password")
pdfs = gr.File(file_types=[".pdf"], label="Upload PDFs", file_count="multiple")
kb_status = gr.Textbox(label="Status")
create_btn = gr.Button("▶️ Create Knowledge Base")
create_btn.click(create_kb, inputs=[api_key, pdfs], outputs=kb_status)
chatbot = gr.Chatbot(label="🧠 Assistant", type="messages")
question = gr.Textbox(show_label=False, placeholder="Ask a question")
stream_output = gr.State("") # to capture stream text
send = gr.Button("🔍 Ask")
send.click(fn=ask_question,
inputs=[api_key, question, chatbot, stream_output],
outputs=[chatbot, ""])
demo.launch()