Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tempfile | |
| import os | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| # Global state | |
| kb = None | |
| retriever = None | |
| qa = None | |
| class StreamHandler(BaseCallbackHandler): | |
| def __init__(self, update_fn): | |
| self.text = "" | |
| self.update_fn = update_fn | |
| def on_llm_new_token(self, token: str, **kwargs): | |
| self.text += token | |
| self.update_fn(self.text) | |
| def save_pdfs(pdf_list): | |
| paths = [] | |
| for pdf in pdf_list: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| tmp.write(pdf.read()) | |
| tmp.close() | |
| paths.append(tmp.name) | |
| return paths | |
| def create_kb(api_key, pdf_list): | |
| global retriever, qa | |
| try: | |
| pdf_paths = save_pdfs(pdf_list) | |
| docs = [] | |
| for path in pdf_paths: | |
| loader = PyPDFLoader(path) | |
| docs.extend(loader.load()) | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| chunks = splitter.split_documents(docs) | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=api_key) | |
| db = FAISS.from_documents(chunks, embeddings) | |
| retriever = db.as_retriever(search_kwargs={"k": 3}) | |
| qa = RetrievalQA.from_chain_type(llm=None, retriever=retriever) # llm passed later | |
| return "✅ Knowledge base created." | |
| except Exception as e: | |
| return f"❌ Error creating KB: {e}" | |
| def ask_question(api_key, question, chat_history, set_stream): | |
| global retriever, qa | |
| if retriever is None: | |
| return chat_history, "❌ Create KB first." | |
| handler = StreamHandler(lambda txt: set_stream(txt)) | |
| llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-pro-latest", | |
| google_api_key=api_key, | |
| streaming=True, | |
| callbacks=[handler]) | |
| qa.llm = llm | |
| chat_history = chat_history or [] | |
| chat_history.append({"role": "user", "content": question}) | |
| result = qa.invoke({"query": question}) | |
| chat_history.append({"role": "assistant", "content": handler.text}) | |
| return chat_history, "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 📚 Multi‑PDF RAG Chat with Gemini") | |
| with gr.Column(): | |
| api_key = gr.Textbox(show_label=False, placeholder="Enter your Gemini API Key", type="password") | |
| pdfs = gr.File(file_types=[".pdf"], label="Upload PDFs", file_count="multiple") | |
| kb_status = gr.Textbox(label="Status") | |
| create_btn = gr.Button("▶️ Create Knowledge Base") | |
| create_btn.click(create_kb, inputs=[api_key, pdfs], outputs=kb_status) | |
| chatbot = gr.Chatbot(label="🧠 Assistant", type="messages") | |
| question = gr.Textbox(show_label=False, placeholder="Ask a question") | |
| stream_output = gr.State("") # to capture stream text | |
| send = gr.Button("🔍 Ask") | |
| send.click(fn=ask_question, | |
| inputs=[api_key, question, chatbot, stream_output], | |
| outputs=[chatbot, ""]) | |
| demo.launch() | |