Spaces:
Runtime error
Runtime error
| # app.py | |
| # Hugging Face Space: PDF Q&A (RAG) with Gemini 2.5 Flash | |
| # - Upload PDFs, index with FAISS, ask questions answered by Gemini. | |
| # - Uses document-specific splitters (Markdown/Python/JS) + generic fallback. | |
| # | |
| # IMPORTANT: In your Space, set Settings → Variables and secrets: | |
| # Name: GEMINI_API_KEY Value: <your real key> | |
| import os | |
| import io | |
| import numpy as np | |
| import gradio as gr | |
| # PDF parsing | |
| from pypdf import PdfReader | |
| # ✅ LangChain 1.x splitters live in a separate package now | |
| from langchain_text_splitters import ( | |
| RecursiveCharacterTextSplitter, | |
| MarkdownTextSplitter, | |
| PythonCodeTextSplitter, | |
| Language, | |
| ) | |
| # FAISS vector store (community package in LC 1.x) | |
| from langchain_community.vectorstores import FAISS | |
| # ---------------------------- | |
| # Gemini wrappers | |
| # ---------------------------- | |
| class GeminiEmbeddings: | |
| """Minimal embedding wrapper that works with either google-genai (new) or google-generativeai (legacy).""" | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| self._client = None | |
| self._legacy = None | |
| self._init_clients() | |
| def _init_clients(self): | |
| # Preferred: new "from google import genai" | |
| try: | |
| from google import genai | |
| self._client = genai.Client(api_key=self.api_key) | |
| except Exception: | |
| self._client = None | |
| # Fallback: legacy | |
| if self._client is None: | |
| try: | |
| import google.generativeai as legacy | |
| legacy.configure(api_key=self.api_key) | |
| self._legacy = legacy | |
| except Exception: | |
| self._legacy = None | |
| if (self._client is None) and (self._legacy is None): | |
| raise RuntimeError( | |
| "No Gemini client available. Install 'google-genai' or 'google-generativeai'." | |
| ) | |
| def _embed_one(self, text: str) -> list[float]: | |
| # Try new client first | |
| if self._client is not None: | |
| try: | |
| out = self._client.models.embed_content( | |
| model="text-embedding-004", | |
| content=text, | |
| ) | |
| # Normalize response shape | |
| emb = getattr(out, "embedding", None) | |
| if emb is not None: | |
| vals = getattr(emb, "values", None) | |
| if vals is not None: | |
| return list(vals) | |
| if isinstance(out, dict): | |
| emb = out.get("embedding", out) | |
| vals = emb.get("values") if isinstance(emb, dict) else None | |
| if vals is not None: | |
| return list(vals) | |
| except Exception: | |
| pass # fall through to legacy | |
| if self._legacy is not None: | |
| out = self._legacy.embed_content(model="text-embedding-004", content=text) | |
| if isinstance(out, dict): | |
| data = out.get("embedding") or out | |
| vals = data.get("values") | |
| if vals is not None: | |
| return list(vals) | |
| emb = getattr(out, "embedding", None) | |
| if emb is not None: | |
| vals = getattr(emb, "values", None) | |
| if vals is not None: | |
| return list(vals) | |
| raise RuntimeError("Unexpected legacy embed_content response") | |
| raise RuntimeError("No embedding backend available") | |
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
| return [self._embed_one(t) for t in texts] | |
| def embed_query(self, text: str) -> list[float]: | |
| return self._embed_one(text) | |
| class GeminiGenerator: | |
| """Minimal text generation wrapper supporting both clients.""" | |
| def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash"): | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self._client = None | |
| self._legacy = None | |
| self._init_clients() | |
| def _init_clients(self): | |
| try: | |
| from google import genai | |
| self._client = genai.Client(api_key=self.api_key) | |
| except Exception: | |
| self._client = None | |
| if self._client is None: | |
| try: | |
| import google.generativeai as legacy | |
| legacy.configure(api_key=self.api_key) | |
| self._legacy = legacy | |
| except Exception: | |
| self._legacy = None | |
| if (self._client is None) and (self._legacy is None): | |
| raise RuntimeError( | |
| "No Gemini client available. Install 'google-genai' or 'google-generativeai'." | |
| ) | |
| def generate(self, prompt: str) -> str: | |
| if self._client is not None: | |
| resp = self._client.models.generate_content( | |
| model=self.model_name, | |
| contents=prompt, | |
| ) | |
| # Try common shapes | |
| text = getattr(resp, "text", None) | |
| if text: | |
| return text | |
| if isinstance(resp, dict) and resp.get("text"): | |
| return resp["text"] | |
| cand = getattr(resp, "candidates", None) | |
| if cand and getattr(cand[0], "content", None): | |
| parts = getattr(cand[0].content, "parts", []) | |
| if parts and getattr(parts[0], "text", None): | |
| return parts[0].text | |
| return "" | |
| # Legacy path | |
| resp = self._legacy.generate_content(prompt, model=self.model_name) | |
| text = getattr(resp, "text", None) | |
| if text: | |
| return text | |
| if isinstance(resp, dict) and resp.get("text"): | |
| return resp["text"] | |
| try: | |
| return resp.candidates[0].content.parts[0].text | |
| except Exception: | |
| return "" | |
| # ---------------------------- | |
| # RAG helpers | |
| # ---------------------------- | |
| def extract_text_from_pdfs(files: list[tuple[str, bytes]]) -> str: | |
| """Concatenate text from uploaded PDFs.""" | |
| texts = [] | |
| for name, data in files: | |
| reader = PdfReader(io.BytesIO(data)) | |
| pages_txt = [] | |
| for p in reader.pages: | |
| try: | |
| pages_txt.append(p.extract_text() or "") | |
| except Exception: | |
| pages_txt.append("") | |
| texts.append("\n\n".join(pages_txt)) | |
| return "\n\n".join(texts) | |
| def choose_splitter(text: str): | |
| """Heuristic splitter choice to mirror your reference code behavior.""" | |
| # Markdown? (headings / code fences) | |
| if any(h in text for h in ("\n# ", "\n## ", "\n```")): | |
| return MarkdownTextSplitter(chunk_size=1200, chunk_overlap=100) | |
| # Python-ish? | |
| if any(k in text for k in ("def ", "class ", "import ")): | |
| return PythonCodeTextSplitter(chunk_size=1200, chunk_overlap=100) | |
| # JavaScript-ish? | |
| if any(k in text for k in ("function ", "const ", "let ", "=>")): | |
| return RecursiveCharacterTextSplitter.from_language( | |
| language=Language.JS, chunk_size=1200, chunk_overlap=100 | |
| ) | |
| # Fallback: generic recursive | |
| return RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=100) | |
| def build_vectorstore(all_text: str, embeddings: GeminiEmbeddings): | |
| splitter = choose_splitter(all_text) | |
| docs = splitter.create_documents([all_text]) | |
| vs = FAISS.from_documents(docs, embedding=embeddings) | |
| return vs, len(docs) | |
| def make_rag_prompt(question: str, context_chunks: list[str]) -> str: | |
| instruction = ( | |
| "You are a helpful assistant. Answer the user's question using only the provided CONTEXT. " | |
| "If the answer cannot be found in the context, say you don't know. Keep the answer concise.\n\n" | |
| ) | |
| context = "\n\n".join([f"[Chunk {i+1}]\n{c}" for i, c in enumerate(context_chunks)]) | |
| return f"{instruction}CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:" | |
| def rag_answer(state, files, question, k): | |
| api_key = os.environ.get("GEMINI_API_KEY", "").strip() | |
| if not api_key: | |
| return state, "❌ Missing GEMINI_API_KEY. Add it in the Space settings and restart.", [] | |
| # Init tools | |
| embeds = GeminiEmbeddings(api_key=api_key) | |
| llm = GeminiGenerator(api_key=api_key, model_name="gemini-2.5-flash") | |
| # Build / reuse index | |
| vs = None | |
| n_chunks = 0 | |
| if state and isinstance(state, dict) and state.get("vs") is not None: | |
| vs = state["vs"] | |
| n_chunks = state.get("n_chunks", 0) | |
| else: | |
| if not files: | |
| return state, "Please upload at least one PDF first.", [] | |
| text = extract_text_from_pdfs(files) | |
| if not text.strip(): | |
| return state, "No extractable text found in the uploaded PDFs.", [] | |
| vs, n_chunks = build_vectorstore(text, embeds) | |
| state = {"vs": vs, "n_chunks": n_chunks} | |
| # Retrieve | |
| retriever = vs.as_retriever(search_kwargs={"k": int(k)}) | |
| docs = retriever.get_relevant_documents(question) | |
| context_chunks = [d.page_content for d in docs] | |
| # Generate | |
| prompt = make_rag_prompt(question, context_chunks) | |
| answer = llm.generate(prompt) | |
| return state, answer, context_chunks | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks(title="PDF Q&A (Gemini RAG)") as demo: | |
| gr.Markdown("# PDF Q&A (RAG) with Gemini 2.5 Flash") | |
| gr.Markdown( | |
| "Upload PDF(s), then ask questions. Uses **document-specific splitting** with LangChain splitters, " | |
| "FAISS for vector search, and Gemini for embeddings + generation.\n\n" | |
| "**Setup:** In this Space, go to **Settings → Variables and secrets** and add `GEMINI_API_KEY`." | |
| ) | |
| state = gr.State(value=None) | |
| with gr.Row(): | |
| file_uploader = gr.File( | |
| label="Upload PDFs", | |
| file_count="multiple", | |
| file_types=[".pdf"], | |
| ) | |
| top_k = gr.Slider(1, 10, value=4, step=1, label="Top-k context chunks") | |
| question = gr.Textbox(label="Your question", placeholder="Ask about the uploaded PDFs…") | |
| ask_btn = gr.Button("Ask") | |
| answer = gr.Markdown("") | |
| with gr.Accordion("Retrieved context (debug)", open=False): | |
| ctx = gr.Markdown("") | |
| def _convert_files(files): | |
| """Convert Gradio temp files to (name, bytes) pairs.""" | |
| if not files: | |
| return [] | |
| pairs = [] | |
| for f in files: | |
| try: | |
| # Gradio File returns an object with a temp path in .name | |
| with open(f.name, "rb") as fh: | |
| pairs.append((os.path.basename(getattr(f, "orig_name", f.name)), fh.read())) | |
| except Exception: | |
| try: | |
| # Some builds expose a file-like object with .read() | |
| pairs.append((os.path.basename(getattr(f, "orig_name", "file.pdf")), f.read())) | |
| except Exception: | |
| pass | |
| return pairs | |
| def on_ask(state_val, files_val, q_val, k_val): | |
| files_pairs = _convert_files(files_val) | |
| new_state, ans, chunks = rag_answer(state_val, files_pairs, q_val, k_val) | |
| ctx_text = "----\n\n".join(chunks) if chunks else "" | |
| return new_state, ans, ctx_text | |
| ask_btn.click( | |
| fn=on_ask, | |
| inputs=[state, file_uploader, question, top_k], | |
| outputs=[state, answer, ctx], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |