EngrMuhammadBilal commited on
Commit
b1748d2
·
verified ·
1 Parent(s): 070bfe5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, json, math, pickle, textwrap, shutil, re
2
+ from typing import List, Dict, Any, Tuple
3
+ import numpy as np, faiss, fitz # pymupdf
4
+ from tqdm import tqdm
5
+ import torch
6
+ from sentence_transformers import SentenceTransformer
7
+ import gradio as gr
8
+ from groq import Groq
9
+
10
+ # ---------- Config ----------
11
+ EMBED_MODEL_NAME = "intfloat/multilingual-e5-small"
12
+ CHUNK_SIZE = 1200
13
+ CHUNK_OVERLAP = 200
14
+ TOP_K_DEFAULT = 5
15
+ MAX_CONTEXT_CHARS = 12000
16
+
17
+ INDEX_PATH = "rag_index.faiss"
18
+ STORE_PATH = "rag_store.pkl"
19
+
20
+ MODEL_CHOICES = [
21
+ "llama-3.1-70b-versatile",
22
+ "llama-3.1-8b-instant",
23
+ "mixtral-8x7b-32768",
24
+ ]
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ embedder = None
28
+ faiss_index = None
29
+ docstore: List[Dict[str, Any]] = []
30
+
31
+ # ---------- PDF utils ----------
32
+ def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]:
33
+ pages = []
34
+ with fitz.open(pdf_path) as doc:
35
+ for i, page in enumerate(doc, start=1):
36
+ txt = page.get_text("text") or ""
37
+ if not txt.strip():
38
+ blocks = page.get_text("blocks")
39
+ if isinstance(blocks, list):
40
+ txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4)
41
+ pages.append((i, txt or ""))
42
+ return pages
43
+
44
+ def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]:
45
+ text = text.replace("\x00", " ").strip()
46
+ if len(text) <= chunk_size:
47
+ return [text] if text else []
48
+ out, start = [], 0
49
+ while start < len(text):
50
+ end = start + chunk_size
51
+ out.append(text[start:end])
52
+ start = max(end - overlap, start + 1)
53
+ return out
54
+
55
+ # ---------- Embeddings / FAISS ----------
56
+ def load_embedder():
57
+ global embedder
58
+ if embedder is None:
59
+ embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device)
60
+ return embedder
61
+
62
+ def _normalize(vecs: np.ndarray) -> np.ndarray:
63
+ norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
64
+ return (vecs / norms).astype("float32")
65
+
66
+ def embed_passages(texts: List[str]) -> np.ndarray:
67
+ model = load_embedder()
68
+ inputs = [f"passage: {t}" for t in texts]
69
+ embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
70
+ return _normalize(embs)
71
+
72
+ def embed_query(q: str) -> np.ndarray:
73
+ model = load_embedder()
74
+ embs = model.encode([f"query: {q}"], convert_to_numpy=True)
75
+ return _normalize(embs)
76
+
77
+ def build_faiss(embs: np.ndarray):
78
+ index = faiss.IndexFlatIP(embs.shape[1])
79
+ index.add(embs)
80
+ return index
81
+
82
+ def save_index(index, store_list: List[Dict[str, Any]]):
83
+ faiss.write_index(index, INDEX_PATH)
84
+ with open(STORE_PATH, "wb") as f:
85
+ pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f)
86
+
87
+ def load_index() -> bool:
88
+ global faiss_index, docstore
89
+ if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH):
90
+ faiss_index = faiss.read_index(INDEX_PATH)
91
+ with open(STORE_PATH, "rb") as f:
92
+ data = pickle.load(f)
93
+ docstore = data["docstore"]
94
+ load_embedder()
95
+ return True
96
+ return False
97
+
98
+ # ---------- Ingest ----------
99
+ def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]:
100
+ entries: List[Dict[str, Any]] = []
101
+ for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"):
102
+ try:
103
+ pages = extract_text_from_pdf(pdf)
104
+ base = os.path.basename(pdf)
105
+ for pno, ptxt in pages:
106
+ if not ptxt.strip():
107
+ continue
108
+ for ci, ch in enumerate(chunk_text(ptxt)):
109
+ entries.append({
110
+ "text": ch,
111
+ "source": base,
112
+ "page_start": pno,
113
+ "page_end": pno,
114
+ "chunk_id": f"{base}::p{pno}::c{ci}",
115
+ })
116
+ except Exception as e:
117
+ print(f"[WARN] Failed to parse {pdf}: {e}")
118
+ if not entries:
119
+ raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.")
120
+ texts = [e["text"] for e in entries]
121
+ embs = embed_passages(texts)
122
+ index = build_faiss(embs)
123
+ return index, entries
124
+
125
+ # ---------- Retrieval (supports required keywords) ----------
126
+ def retrieve(query: str, top_k=5, must_contain: str = ""):
127
+ global faiss_index, docstore
128
+ if faiss_index is None or not docstore:
129
+ raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.")
130
+ k = int(top_k) if top_k else TOP_K_DEFAULT
131
+
132
+ pool = min(max(10 * k, 200), len(docstore))
133
+ qemb = embed_query(query)
134
+ D, I = faiss_index.search(qemb, pool)
135
+ pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0]
136
+
137
+ must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()]
138
+ if must_words:
139
+ filtered = []
140
+ for idx, score in pairs:
141
+ t = docstore[idx]["text"].lower()
142
+ if all(w in t for w in must_words):
143
+ filtered.append((idx, score))
144
+ if filtered:
145
+ pairs = filtered
146
+
147
+ pairs = pairs[:k]
148
+ hits = []
149
+ for idx, score in pairs:
150
+ item = docstore[idx].copy()
151
+ item["score"] = float(score)
152
+ hits.append(item)
153
+ return hits
154
+
155
+ # ---------- Groq LLM ----------
156
+ def groq_answer(query: str, contexts, model_name="llama-3.1-70b-versatile", temperature=0.2, max_tokens=1000):
157
+ try:
158
+ if not os.environ.get("GROQ_API_KEY"):
159
+ return "GROQ_API_KEY is not set. Add it in your host's environment/secrets."
160
+ client = Groq(api_key=os.environ["GROQ_API_KEY"])
161
+
162
+ packed, used = [], 0
163
+ for c in contexts:
164
+ tag = f"[{c['source']} p.{c['page_start']}]"
165
+ piece = f"{tag}\n{c['text'].strip()}\n"
166
+ if used + len(piece) > MAX_CONTEXT_CHARS:
167
+ break
168
+ packed.append(piece); used += len(piece)
169
+ context_str = "\n---\n".join(packed)
170
+
171
+ system_prompt = (
172
+ "You are a scholarly assistant. Answer using ONLY the provided context. "
173
+ "If the answer is not present, say so. Always include a 'References' section with sources and page numbers."
174
+ )
175
+ user_prompt = (
176
+ f"Question:\n{query}\n\n"
177
+ f"Context snippets (use these only):\n{context_str}\n\n"
178
+ "Write a precise answer. Keep claims traceable to the snippets."
179
+ )
180
+
181
+ resp = client.chat.completions.create(
182
+ model=model_name,
183
+ temperature=float(temperature),
184
+ max_tokens=int(max_tokens),
185
+ messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}],
186
+ )
187
+ return resp.choices[0].message.content.strip()
188
+ except Exception as e:
189
+ import traceback
190
+ return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```"
191
+
192
+ # ---------- Helpers for UI ----------
193
+ def build_index_from_uploads(paths: List[str]) -> str:
194
+ global faiss_index, docstore
195
+ if not paths: return "Please upload at least one PDF."
196
+ if len(paths) > 120: return "Please limit to ~100 PDFs per build."
197
+
198
+ faiss_index, entries = ingest_pdfs(paths)
199
+ save_index(faiss_index, entries)
200
+ docstore = entries
201
+ return f"Index built with {len(entries)} chunks from {len(paths)} PDFs. Saved to disk."
202
+
203
+ def reload_index() -> str:
204
+ ok = load_index()
205
+ return f"Index reloaded. Chunks: {len(docstore)}" if ok else "No saved index found."
206
+
207
+ def ask_rag(query: str, top_k, model_name: str, temperature: float, must_contain: str):
208
+ try:
209
+ if not query.strip():
210
+ return "Please enter a question.", []
211
+ ctx = retrieve(query, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain)
212
+ ans = groq_answer(query, ctx, model_name=model_name, temperature=temperature)
213
+ rows = []
214
+ for c in ctx:
215
+ preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "")
216
+ rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview])
217
+ return ans, rows
218
+ except Exception as e:
219
+ import traceback
220
+ return f"**Error:** {e}\n```\n{traceback.format_exc()}\n```", []
221
+
222
+ def set_api_key(k: str):
223
+ if k and k.strip():
224
+ os.environ["GROQ_API_KEY"] = k.strip()
225
+ return "API key set in runtime."
226
+ return "No key provided."
227
+
228
+ def download_index_zip():
229
+ if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)):
230
+ return None
231
+ base = "rag_index_bundle"
232
+ zip_path = shutil.make_archive(base, "zip", ".", ".")
233
+ # workaround for shutil: package explicit files
234
+ with shutil.make_archive("rag_index", "zip"):
235
+ pass
236
+ # build our own zip containing only index files
237
+ import zipfile
238
+ zp = "rag_index_bundle.zip"
239
+ with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z:
240
+ z.write(INDEX_PATH)
241
+ z.write(STORE_PATH)
242
+ return zp
243
+
244
+ # ---------- Gradio UI ----------
245
+ with gr.Blocks(title="RAG over PDFs (Groq)") as demo:
246
+ gr.Markdown("## RAG over your PDFs using Groq\nUpload PDFs, build an index, then ask questions with cited answers.")
247
+ with gr.Row():
248
+ api_box = gr.Textbox(label="(Optional) Set GROQ_API_KEY for this session", type="password", placeholder="sk_...")
249
+ set_btn = gr.Button("Set Key")
250
+ set_out = gr.Markdown()
251
+ set_btn.click(set_api_key, inputs=[api_box], outputs=[set_out])
252
+
253
+ with gr.Tab("1) Build or Load Index"):
254
+ file_u = gr.Files(label="Upload PDFs", file_types=[".pdf"], type="filepath")
255
+ with gr.Row():
256
+ build_btn = gr.Button("Build Index")
257
+ reload_btn = gr.Button("Reload Saved Index")
258
+ download_btn = gr.Button("Download Index (.zip)")
259
+ build_out = gr.Markdown()
260
+
261
+ def on_build(paths, progress=gr.Progress(track_tqdm=True)):
262
+ try:
263
+ return build_index_from_uploads(paths)
264
+ except Exception as e:
265
+ import traceback
266
+ return f"**Error while building index:** {e}\n\n```\n{traceback.format_exc()}\n```"
267
+
268
+ build_btn.click(on_build, inputs=[file_u], outputs=[build_out])
269
+ reload_btn.click(fn=reload_index, outputs=[build_out])
270
+ zpath = gr.File(label="Index zip", interactive=False)
271
+ download_btn.click(fn=download_index_zip, outputs=[zpath])
272
+
273
+ with gr.Tab("2) Ask Questions"):
274
+ q = gr.Textbox(label="Your question", lines=2, placeholder="Ask something present in the uploaded papers…")
275
+ with gr.Row():
276
+ topk = gr.Slider(1, 15, value=TOP_K_DEFAULT, step=1, label="Top-K passages")
277
+ model_dd = gr.Dropdown(MODEL_CHOICES, value=MODEL_CHOICES[0], label="Groq model")
278
+ temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
279
+ must = gr.Textbox(label="Must contain (comma-separated keywords)", placeholder="camera, CMOS, frame rate")
280
+ ask_btn = gr.Button("Answer")
281
+ ans = gr.Markdown()
282
+ src = gr.Dataframe(headers=["Source","Page","Score","Snippet"], wrap=True)
283
+ ask_btn.click(ask_rag, inputs=[q, topk, model_dd, temp, must], outputs=[ans, src])
284
+
285
+ demo.queue() # keep it simple for broad Gradio versions
286
+ if __name__ == "__main__":
287
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))