File size: 11,305 Bytes
ca9bab5
 
aaae6cb
 
ca9bab5
aaae6cb
 
ca9bab5
 
 
 
 
 
 
 
 
aaae6cb
 
ca9bab5
 
aaae6cb
 
ca9bab5
 
aaae6cb
ca9bab5
 
aaae6cb
 
 
 
ca9bab5
aaae6cb
 
ca9bab5
 
 
 
 
 
 
aaae6cb
ca9bab5
 
 
aaae6cb
ca9bab5
aaae6cb
ca9bab5
 
 
 
 
 
 
 
aaae6cb
 
 
ca9bab5
 
 
 
 
 
 
aaae6cb
ca9bab5
aaae6cb
 
 
 
 
 
 
 
 
 
 
 
 
ca9bab5
 
 
 
 
 
aaae6cb
 
ca9bab5
 
aaae6cb
 
 
ca9bab5
 
aaae6cb
ca9bab5
 
 
 
 
 
 
 
 
aaae6cb
 
ca9bab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaae6cb
 
 
ca9bab5
 
 
 
 
aaae6cb
ca9bab5
aaae6cb
ca9bab5
aaae6cb
 
 
 
 
 
 
 
 
 
 
ca9bab5
 
aaae6cb
 
 
 
 
 
 
 
ca9bab5
 
aaae6cb
 
 
ca9bab5
 
 
 
 
aaae6cb
ca9bab5
 
aaae6cb
ca9bab5
aaae6cb
 
ca9bab5
 
 
 
aaae6cb
 
 
ca9bab5
 
aaae6cb
 
ca9bab5
 
aaae6cb
 
ca9bab5
 
 
 
aaae6cb
ca9bab5
 
 
 
 
 
aaae6cb
 
ca9bab5
 
 
 
 
 
 
 
 
 
 
 
 
 
aaae6cb
ca9bab5
aaae6cb
ca9bab5
 
 
aaae6cb
ca9bab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaae6cb
 
 
ca9bab5
 
 
 
 
 
 
 
 
 
 
 
 
 
aaae6cb
ca9bab5
 
 
aaae6cb
ca9bab5
 
 
 
 
 
aaae6cb
ca9bab5
 
 
 
 
aaae6cb
ca9bab5
aaae6cb
ca9bab5
 
aaae6cb
 
ca9bab5
 
 
 
 
 
 
 
 
 
 
 
 
aaae6cb
ca9bab5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# app.py
# Hugging Face Space: PDF Q&A (RAG) with Gemini 2.5 Flash
# - Upload PDFs, index with FAISS, ask questions answered by Gemini.
# - Uses document-specific splitters (Markdown/Python/JS) + generic fallback.
#
# IMPORTANT: In your Space, set Settings → Variables and secrets:
#   Name: GEMINI_API_KEY   Value: <your real key>

import os
import io
import numpy as np
import gradio as gr

# PDF parsing
from pypdf import PdfReader

# ✅ LangChain 1.x splitters live in a separate package now
from langchain_text_splitters import (
    RecursiveCharacterTextSplitter,
    MarkdownTextSplitter,
    PythonCodeTextSplitter,
    Language,
)

# FAISS vector store (community package in LC 1.x)
from langchain_community.vectorstores import FAISS


# ----------------------------
# Gemini wrappers
# ----------------------------
class GeminiEmbeddings:
    """Minimal embedding wrapper that works with either google-genai (new) or google-generativeai (legacy)."""

    def __init__(self, api_key: str):
        self.api_key = api_key
        self._client = None
        self._legacy = None
        self._init_clients()

    def _init_clients(self):
        # Preferred: new "from google import genai"
        try:
            from google import genai
            self._client = genai.Client(api_key=self.api_key)
        except Exception:
            self._client = None
        # Fallback: legacy
        if self._client is None:
            try:
                import google.generativeai as legacy
                legacy.configure(api_key=self.api_key)
                self._legacy = legacy
            except Exception:
                self._legacy = None
        if (self._client is None) and (self._legacy is None):
            raise RuntimeError(
                "No Gemini client available. Install 'google-genai' or 'google-generativeai'."
            )

    def _embed_one(self, text: str) -> list[float]:
        # Try new client first
        if self._client is not None:
            try:
                out = self._client.models.embed_content(
                    model="text-embedding-004",
                    content=text,
                )
                # Normalize response shape
                emb = getattr(out, "embedding", None)
                if emb is not None:
                    vals = getattr(emb, "values", None)
                    if vals is not None:
                        return list(vals)
                if isinstance(out, dict):
                    emb = out.get("embedding", out)
                    vals = emb.get("values") if isinstance(emb, dict) else None
                    if vals is not None:
                        return list(vals)
            except Exception:
                pass  # fall through to legacy

        if self._legacy is not None:
            out = self._legacy.embed_content(model="text-embedding-004", content=text)
            if isinstance(out, dict):
                data = out.get("embedding") or out
                vals = data.get("values")
                if vals is not None:
                    return list(vals)
            emb = getattr(out, "embedding", None)
            if emb is not None:
                vals = getattr(emb, "values", None)
                if vals is not None:
                    return list(vals)
            raise RuntimeError("Unexpected legacy embed_content response")

        raise RuntimeError("No embedding backend available")

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        return [self._embed_one(t) for t in texts]

    def embed_query(self, text: str) -> list[float]:
        return self._embed_one(text)


class GeminiGenerator:
    """Minimal text generation wrapper supporting both clients."""

    def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash"):
        self.api_key = api_key
        self.model_name = model_name
        self._client = None
        self._legacy = None
        self._init_clients()

    def _init_clients(self):
        try:
            from google import genai
            self._client = genai.Client(api_key=self.api_key)
        except Exception:
            self._client = None
        if self._client is None:
            try:
                import google.generativeai as legacy
                legacy.configure(api_key=self.api_key)
                self._legacy = legacy
            except Exception:
                self._legacy = None
        if (self._client is None) and (self._legacy is None):
            raise RuntimeError(
                "No Gemini client available. Install 'google-genai' or 'google-generativeai'."
            )

    def generate(self, prompt: str) -> str:
        if self._client is not None:
            resp = self._client.models.generate_content(
                model=self.model_name,
                contents=prompt,
            )
            # Try common shapes
            text = getattr(resp, "text", None)
            if text:
                return text
            if isinstance(resp, dict) and resp.get("text"):
                return resp["text"]
            cand = getattr(resp, "candidates", None)
            if cand and getattr(cand[0], "content", None):
                parts = getattr(cand[0].content, "parts", [])
                if parts and getattr(parts[0], "text", None):
                    return parts[0].text
            return ""
        # Legacy path
        resp = self._legacy.generate_content(prompt, model=self.model_name)
        text = getattr(resp, "text", None)
        if text:
            return text
        if isinstance(resp, dict) and resp.get("text"):
            return resp["text"]
        try:
            return resp.candidates[0].content.parts[0].text
        except Exception:
            return ""


# ----------------------------
# RAG helpers
# ----------------------------
def extract_text_from_pdfs(files: list[tuple[str, bytes]]) -> str:
    """Concatenate text from uploaded PDFs."""
    texts = []
    for name, data in files:
        reader = PdfReader(io.BytesIO(data))
        pages_txt = []
        for p in reader.pages:
            try:
                pages_txt.append(p.extract_text() or "")
            except Exception:
                pages_txt.append("")
        texts.append("\n\n".join(pages_txt))
    return "\n\n".join(texts)


def choose_splitter(text: str):
    """Heuristic splitter choice to mirror your reference code behavior."""
    # Markdown? (headings / code fences)
    if any(h in text for h in ("\n# ", "\n## ", "\n```")):
        return MarkdownTextSplitter(chunk_size=1200, chunk_overlap=100)

    # Python-ish?
    if any(k in text for k in ("def ", "class ", "import ")):
        return PythonCodeTextSplitter(chunk_size=1200, chunk_overlap=100)

    # JavaScript-ish?
    if any(k in text for k in ("function ", "const ", "let ", "=>")):
        return RecursiveCharacterTextSplitter.from_language(
            language=Language.JS, chunk_size=1200, chunk_overlap=100
        )

    # Fallback: generic recursive
    return RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=100)


def build_vectorstore(all_text: str, embeddings: GeminiEmbeddings):
    splitter = choose_splitter(all_text)
    docs = splitter.create_documents([all_text])
    vs = FAISS.from_documents(docs, embedding=embeddings)
    return vs, len(docs)


def make_rag_prompt(question: str, context_chunks: list[str]) -> str:
    instruction = (
        "You are a helpful assistant. Answer the user's question using only the provided CONTEXT. "
        "If the answer cannot be found in the context, say you don't know. Keep the answer concise.\n\n"
    )
    context = "\n\n".join([f"[Chunk {i+1}]\n{c}" for i, c in enumerate(context_chunks)])
    return f"{instruction}CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:"


def rag_answer(state, files, question, k):
    api_key = os.environ.get("GEMINI_API_KEY", "").strip()
    if not api_key:
        return state, "❌ Missing GEMINI_API_KEY. Add it in the Space settings and restart.", []

    # Init tools
    embeds = GeminiEmbeddings(api_key=api_key)
    llm = GeminiGenerator(api_key=api_key, model_name="gemini-2.5-flash")

    # Build / reuse index
    vs = None
    n_chunks = 0
    if state and isinstance(state, dict) and state.get("vs") is not None:
        vs = state["vs"]
        n_chunks = state.get("n_chunks", 0)
    else:
        if not files:
            return state, "Please upload at least one PDF first.", []
        text = extract_text_from_pdfs(files)
        if not text.strip():
            return state, "No extractable text found in the uploaded PDFs.", []
        vs, n_chunks = build_vectorstore(text, embeds)
        state = {"vs": vs, "n_chunks": n_chunks}

    # Retrieve
    retriever = vs.as_retriever(search_kwargs={"k": int(k)})
    docs = retriever.get_relevant_documents(question)
    context_chunks = [d.page_content for d in docs]

    # Generate
    prompt = make_rag_prompt(question, context_chunks)
    answer = llm.generate(prompt)

    return state, answer, context_chunks


# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="PDF Q&A (Gemini RAG)") as demo:
    gr.Markdown("# PDF Q&A (RAG) with Gemini 2.5 Flash")
    gr.Markdown(
        "Upload PDF(s), then ask questions. Uses **document-specific splitting** with LangChain splitters, "
        "FAISS for vector search, and Gemini for embeddings + generation.\n\n"
        "**Setup:** In this Space, go to **Settings → Variables and secrets** and add `GEMINI_API_KEY`."
    )

    state = gr.State(value=None)

    with gr.Row():
        file_uploader = gr.File(
            label="Upload PDFs",
            file_count="multiple",
            file_types=[".pdf"],
        )
        top_k = gr.Slider(1, 10, value=4, step=1, label="Top-k context chunks")

    question = gr.Textbox(label="Your question", placeholder="Ask about the uploaded PDFs…")
    ask_btn = gr.Button("Ask")
    answer = gr.Markdown("")
    with gr.Accordion("Retrieved context (debug)", open=False):
        ctx = gr.Markdown("")

    def _convert_files(files):
        """Convert Gradio temp files to (name, bytes) pairs."""
        if not files:
            return []
        pairs = []
        for f in files:
            try:
                # Gradio File returns an object with a temp path in .name
                with open(f.name, "rb") as fh:
                    pairs.append((os.path.basename(getattr(f, "orig_name", f.name)), fh.read()))
            except Exception:
                try:
                    # Some builds expose a file-like object with .read()
                    pairs.append((os.path.basename(getattr(f, "orig_name", "file.pdf")), f.read()))
                except Exception:
                    pass
        return pairs

    def on_ask(state_val, files_val, q_val, k_val):
        files_pairs = _convert_files(files_val)
        new_state, ans, chunks = rag_answer(state_val, files_pairs, q_val, k_val)
        ctx_text = "----\n\n".join(chunks) if chunks else ""
        return new_state, ans, ctx_text

    ask_btn.click(
        fn=on_ask,
        inputs=[state, file_uploader, question, top_k],
        outputs=[state, answer, ctx],
    )

if __name__ == "__main__":
    demo.launch()