Rakib023 commited on
Commit
ca9bab5
·
verified ·
1 Parent(s): 9108c7c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Hugging Face Space: PDF Q&A (RAG) with Gemini 2.5 Flash
3
+ # - Upload one or more PDFs, index them with vector search, and ask questions.
4
+ # - Uses Gemini for both embeddings (text-embedding-004) and generation ("gemini-2.5-flash").
5
+ # - Demonstrates document-specific splitting à la LangChain (Markdown/Python/JS) + generic recursive splitting.
6
+ #
7
+ # IMPORTANT: Set your Gemini API key as an environment variable GEMINI_API_KEY
8
+ # in the Space's "Settings" ➜ "Variables and secrets" ➜ Add "GEMINI_API_KEY".
9
+
10
+ import os
11
+ import io
12
+ import numpy as np
13
+ import gradio as gr
14
+
15
+ # PDF parsing
16
+ from pypdf import PdfReader
17
+
18
+ # Text splitters inspired by your reference
19
+ from langchain.text_splitter import (
20
+ RecursiveCharacterTextSplitter,
21
+ MarkdownTextSplitter,
22
+ Language
23
+ )
24
+ from langchain.text_splitter import PythonCodeTextSplitter
25
+
26
+ # Simple FAISS vector store
27
+ from langchain_community.vectorstores import FAISS
28
+
29
+ # We'll create a minimal Embeddings interface wrapper for Gemini
30
+ class GeminiEmbeddings:
31
+ def __init__(self, api_key: str):
32
+ self.api_key = api_key
33
+ self._client = None
34
+ self._legacy = None
35
+ self._init_clients()
36
+
37
+ def _init_clients(self):
38
+ # Preferred: new "from google import genai" client
39
+ try:
40
+ from google import genai
41
+ self._client = genai.Client(api_key=self.api_key)
42
+ except Exception as e:
43
+ self._client = None
44
+
45
+ # Fallback: legacy google-generativeai
46
+ if self._client is None:
47
+ try:
48
+ import google.generativeai as legacy
49
+ legacy.configure(api_key=self.api_key)
50
+ self._legacy = legacy
51
+ except Exception:
52
+ self._legacy = None
53
+
54
+ if (self._client is None) and (self._legacy is None):
55
+ raise RuntimeError("No Gemini client available. Install either 'google-genai' or 'google-generativeai'.")
56
+
57
+ def _embed_one(self, text: str) -> list[float]:
58
+ # Try new client first
59
+ if self._client is not None:
60
+ try:
61
+ # New client style
62
+ out = self._client.models.embed_content(
63
+ model="text-embedding-004",
64
+ content=text
65
+ )
66
+ # new client returns {"embedding": {"values": [...]}} or obj with .embedding.values
67
+ emb = getattr(out, "embedding", None) or (out.get("embedding") if isinstance(out, dict) else None)
68
+ vals = getattr(emb, "values", None) or (emb.get("values") if isinstance(emb, dict) else None)
69
+ if vals is None:
70
+ # Some versions return directly list under "values"
71
+ vals = out.get("values") if isinstance(out, dict) else None
72
+ if vals is None:
73
+ raise RuntimeError("Unexpected embed_content response")
74
+ return list(vals)
75
+ except Exception as e:
76
+ # Fall back to legacy
77
+ pass
78
+
79
+ if self._legacy is not None:
80
+ out = self._legacy.embed_content(model="text-embedding-004", content=text)
81
+ if isinstance(out, dict):
82
+ data = out.get("embedding") or out
83
+ vals = data.get("values")
84
+ return list(vals)
85
+ # Some versions return an object with .embedding
86
+ emb = getattr(out, "embedding", None)
87
+ if emb is not None:
88
+ return list(getattr(emb, "values", []))
89
+ raise RuntimeError("Unexpected legacy embed_content response")
90
+
91
+ raise RuntimeError("No embedding backend available.")
92
+
93
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
94
+ return [self._embed_one(t) for t in texts]
95
+
96
+ def embed_query(self, text: str) -> list[float]:
97
+ return self._embed_one(text)
98
+
99
+
100
+ class GeminiGenerator:
101
+ def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash"):
102
+ self.api_key = api_key
103
+ self.model_name = model_name
104
+ self._client = None
105
+ self._legacy = None
106
+ self._init_clients()
107
+
108
+ def _init_clients(self):
109
+ try:
110
+ from google import genai
111
+ self._client = genai.Client(api_key=self.api_key)
112
+ except Exception:
113
+ self._client = None
114
+ if self._client is None:
115
+ try:
116
+ import google.generativeai as legacy
117
+ legacy.configure(api_key=self.api_key)
118
+ self._legacy = legacy
119
+ except Exception:
120
+ self._legacy = None
121
+ if (self._client is None) and (self._legacy is None):
122
+ raise RuntimeError("No Gemini client available. Install either 'google-genai' or 'google-generativeai'.")
123
+
124
+ def generate(self, prompt: str) -> str:
125
+ if self._client is not None:
126
+ resp = self._client.models.generate_content(
127
+ model=self.model_name,
128
+ contents=prompt
129
+ )
130
+ # New client usually returns object with .text
131
+ text = getattr(resp, "text", None)
132
+ if text is None and isinstance(resp, dict):
133
+ text = resp.get("text")
134
+ if text is None:
135
+ # Some versions have candidates[0].content.parts[0].text
136
+ cand = getattr(resp, "candidates", None)
137
+ if cand and getattr(cand[0], "content", None):
138
+ parts = getattr(cand[0].content, "parts", [])
139
+ if parts and getattr(parts[0], "text", None):
140
+ text = parts[0].text
141
+ return text or ""
142
+ # Fallback legacy
143
+ resp = self._legacy.generate_content(prompt, model=self.model_name)
144
+ # unify
145
+ text = getattr(resp, "text", None)
146
+ if text is None and isinstance(resp, dict):
147
+ text = resp.get("text")
148
+ if text is None:
149
+ try:
150
+ text = resp.candidates[0].content.parts[0].text
151
+ except Exception:
152
+ text = ""
153
+ return text
154
+
155
+
156
+ def extract_text_from_pdfs(files: list[tuple[str, bytes]]) -> str:
157
+ """Concatenate text from uploaded PDFs."""
158
+ texts = []
159
+ for name, data in files:
160
+ reader = PdfReader(io.BytesIO(data))
161
+ pages = []
162
+ for p in reader.pages:
163
+ try:
164
+ pages.append(p.extract_text() or "")
165
+ except Exception:
166
+ pages.append("")
167
+ texts.append("\n\n".join(pages))
168
+ return "\n\n".join(texts)
169
+
170
+
171
+ def choose_splitter(text: str):
172
+ """Demonstrate document-specific splitting based on content heuristics."""
173
+ # If it looks like Markdown (headings, code fences), use markdown splitter
174
+ if any(h in text for h in ["\n# ", "\n## ", "\n```"]):
175
+ return MarkdownTextSplitter(chunk_size=1200, chunk_overlap=100)
176
+
177
+ # If it looks like Python code
178
+ if any(k in text for k in ["def ", "class ", "import "]):
179
+ return PythonCodeTextSplitter(chunk_size=1200, chunk_overlap=100)
180
+
181
+ # If it looks like Javascript
182
+ if any(k in text for k in ["function ", "const ", "let ", "=>"]):
183
+ return RecursiveCharacterTextSplitter.from_language(
184
+ language=Language.JS, chunk_size=1200, chunk_overlap=100
185
+ )
186
+
187
+ # Generic fallback
188
+ return RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=100)
189
+
190
+
191
+ def build_vectorstore(all_text: str, embeddings: GeminiEmbeddings):
192
+ splitter = choose_splitter(all_text)
193
+ docs = splitter.create_documents([all_text])
194
+ # Create FAISS index
195
+ return FAISS.from_documents(docs, embedding=embeddings), len(docs)
196
+
197
+
198
+ def make_rag_prompt(question: str, context_chunks: list[str]) -> str:
199
+ instruction = (
200
+ "You are a helpful assistant. Answer the user's question using only the provided CONTEXT. "
201
+ "If the answer cannot be found in the context, say you don't know. Keep the answer concise.\n\n"
202
+ )
203
+ context = "\n\n".join([f"[Chunk {i+1}]\n{c}" for i, c in enumerate(context_chunks)])
204
+ return f"{instruction}CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:"
205
+
206
+
207
+ def rag_answer(state, files, question, k):
208
+ api_key = os.environ.get("GEMINI_API_KEY", "").strip()
209
+ if not api_key:
210
+ return state, "❌ Missing GEMINI_API_KEY. Please add it in the Space settings.", []
211
+
212
+ # Initialize tools
213
+ embeds = GeminiEmbeddings(api_key=api_key)
214
+ llm = GeminiGenerator(api_key=api_key, model_name="gemini-2.5-flash")
215
+
216
+ # Build or reuse vector store
217
+ vs = None
218
+ n_chunks = 0
219
+ if state and isinstance(state, dict) and state.get("vs") is not None:
220
+ vs = state["vs"]
221
+ n_chunks = state.get("n_chunks", 0)
222
+ else:
223
+ if not files:
224
+ return state, "Please upload at least one PDF first.", []
225
+ text = extract_text_from_pdfs(files)
226
+ if not text.strip():
227
+ return state, "No extractable text found in the uploaded PDFs.", []
228
+ vs, n_chunks = build_vectorstore(text, embeds)
229
+ state = {"vs": vs, "n_chunks": n_chunks}
230
+
231
+ # Retrieve
232
+ retriever = vs.as_retriever(search_kwargs={"k": int(k)})
233
+ docs = retriever.get_relevant_documents(question)
234
+ context_chunks = [d.page_content for d in docs]
235
+
236
+ # Generate
237
+ prompt = make_rag_prompt(question, context_chunks)
238
+ answer = llm.generate(prompt)
239
+
240
+ return state, answer, context_chunks
241
+
242
+
243
+ with gr.Blocks(title="PDF Q&A (Gemini RAG)") as demo:
244
+ gr.Markdown("# PDF Q&A (RAG) with Gemini 2.5 Flash")
245
+ gr.Markdown(
246
+ "Upload PDF(s), then ask questions. Uses **document-specific splitting** with LangChain splitters, "
247
+ "FAISS for vector search, and Gemini for embeddings + generation.\n\n"
248
+ "**Setup:** In this Space, go to **Settings → Variables and secrets** and add `GEMINI_API_KEY`."
249
+ )
250
+
251
+ state = gr.State(value=None)
252
+
253
+ with gr.Row():
254
+ file_uploader = gr.File(
255
+ label="Upload PDFs",
256
+ file_count="multiple",
257
+ file_types=[".pdf"]
258
+ )
259
+ top_k = gr.Slider(1, 10, value=4, step=1, label="Top-k context chunks")
260
+
261
+ question = gr.Textbox(label="Your question", placeholder="Ask about the uploaded PDFs...")
262
+ ask_btn = gr.Button("Ask")
263
+ answer = gr.Markdown("")
264
+ with gr.Accordion("Retrieved context (debug)", open=False):
265
+ ctx = gr.Markdown("")
266
+
267
+ def _convert_files(files):
268
+ # Gradio provides file objects; read into (name, bytes)
269
+ if not files:
270
+ return []
271
+ pairs = []
272
+ for f in files:
273
+ try:
274
+ with open(f.name, "rb") as fh:
275
+ pairs.append((os.path.basename(f.name), fh.read()))
276
+ except Exception:
277
+ # In some environments .name might already be a temp path ready to read
278
+ try:
279
+ pairs.append((os.path.basename(getattr(f, 'orig_name', 'file.pdf')), f.read()))
280
+ except Exception:
281
+ pass
282
+ return pairs
283
+
284
+ def on_ask(state_val, files_val, q_val, k_val):
285
+ files_pairs = _convert_files(files_val)
286
+ new_state, ans, chunks = rag_answer(state_val, files_pairs, q_val, k_val)
287
+ ctx_text = "----\n\n".join(chunks) if chunks else ""
288
+ return new_state, ans, ctx_text
289
+
290
+ ask_btn.click(
291
+ fn=on_ask,
292
+ inputs=[state, file_uploader, question, top_k],
293
+ outputs=[state, answer, ctx]
294
+ )
295
+
296
+ if __name__ == "__main__":
297
+ demo.launch()