EngrMuhammadBilal's picture
Update app.py
6506759 verified
import os, io, json, math, pickle, textwrap, shutil, re
from typing import List, Dict, Any, Tuple
import numpy as np, faiss, fitz # pymupdf
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
import gradio as gr
from groq import Groq
# ---------- Config ----------
EMBED_MODEL_NAME = "intfloat/multilingual-e5-small"
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 200
TOP_K_DEFAULT = 5
MAX_CONTEXT_CHARS = 12000
INDEX_PATH = "rag_index.faiss"
STORE_PATH = "rag_store.pkl"
MODEL_CHOICES = [
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
embedder = None
faiss_index = None
docstore: List[Dict[str, Any]] = []
# ---------- PDF utils ----------
def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]:
pages = []
with fitz.open(pdf_path) as doc:
for i, page in enumerate(doc, start=1):
txt = page.get_text("text") or ""
if not txt.strip():
blocks = page.get_text("blocks")
if isinstance(blocks, list):
txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4)
pages.append((i, txt or ""))
return pages
def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]:
text = text.replace("\x00", " ").strip()
if len(text) <= chunk_size:
return [text] if text else []
out, start = [], 0
while start < len(text):
end = start + chunk_size
out.append(text[start:end])
start = max(end - overlap, start + 1)
return out
# ---------- Embeddings / FAISS ----------
def load_embedder():
global embedder
if embedder is None:
embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device)
return embedder
def _normalize(vecs: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
return (vecs / norms).astype("float32")
def embed_passages(texts: List[str]) -> np.ndarray:
model = load_embedder()
inputs = [f"passage: {t}" for t in texts]
embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
return _normalize(embs)
def embed_query(q: str) -> np.ndarray:
model = load_embedder()
embs = model.encode([f"query: {q}"], convert_to_numpy=True)
return _normalize(embs)
def build_faiss(embs: np.ndarray):
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
return index
def save_index(index, store_list: List[Dict[str, Any]]):
faiss.write_index(index, INDEX_PATH)
with open(STORE_PATH, "wb") as f:
pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f)
def load_index() -> bool:
global faiss_index, docstore
if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH):
faiss_index = faiss.read_index(INDEX_PATH)
with open(STORE_PATH, "rb") as f:
data = pickle.load(f)
docstore = data["docstore"]
load_embedder()
return True
return False
# ---------- Ingest ----------
def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]:
entries: List[Dict[str, Any]] = []
for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"):
try:
pages = extract_text_from_pdf(pdf)
base = os.path.basename(pdf)
for pno, ptxt in pages:
if not ptxt.strip():
continue
for ci, ch in enumerate(chunk_text(ptxt)):
entries.append({
"text": ch,
"source": base,
"page_start": pno,
"page_end": pno,
"chunk_id": f"{base}::p{pno}::c{ci}",
})
except Exception as e:
print(f"[WARN] Failed to parse {pdf}: {e}")
if not entries:
raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.")
texts = [e["text"] for e in entries]
embs = embed_passages(texts)
index = build_faiss(embs)
return index, entries
# ---------- Retrieval (supports required keywords) ----------
def retrieve(query: str, top_k=5, must_contain: str = ""):
global faiss_index, docstore
if faiss_index is None or not docstore:
raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.")
k = int(top_k) if top_k else TOP_K_DEFAULT
pool = min(max(10 * k, 200), len(docstore))
qemb = embed_query(query)
D, I = faiss_index.search(qemb, pool)
pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0]
must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()]
if must_words:
filtered = []
for idx, score in pairs:
t = docstore[idx]["text"].lower()
if all(w in t for w in must_words):
filtered.append((idx, score))
if filtered:
pairs = filtered
pairs = pairs[:k]
hits = []
for idx, score in pairs:
item = docstore[idx].copy()
item["score"] = float(score)
hits.append(item)
return hits
# ---------- Groq LLM ----------
def groq_answer(query: str, contexts, model_name="llama-3.1-70b-versatile", temperature=0.2, max_tokens=1000):
try:
if not os.environ.get("GROQ_API_KEY"):
return "GROQ_API_KEY is not set. Add it in your host's environment/secrets."
client = Groq(api_key=os.environ["GROQ_API_KEY"])
packed, used = [], 0
for c in contexts:
tag = f"[{c['source']} p.{c['page_start']}]"
piece = f"{tag}\n{c['text'].strip()}\n"
if used + len(piece) > MAX_CONTEXT_CHARS:
break
packed.append(piece); used += len(piece)
context_str = "\n---\n".join(packed)
system_prompt = (
"You are a scholarly assistant. Answer using ONLY the provided context. "
"If the answer is not present, say so. Always include a 'References' section with sources and page numbers."
)
user_prompt = (
f"Question:\n{query}\n\n"
f"Context snippets (use these only):\n{context_str}\n\n"
"Write a precise answer. Keep claims traceable to the snippets."
)
resp = client.chat.completions.create(
model=model_name,
temperature=float(temperature),
max_tokens=int(max_tokens),
messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}],
)
return resp.choices[0].message.content.strip()
except Exception as e:
import traceback
return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```"
# ---------- Helpers for UI ----------
def build_index_from_uploads(paths: List[str]) -> str:
global faiss_index, docstore
if not paths: return "Please upload at least one PDF."
if len(paths) > 120: return "Please limit to ~100 PDFs per build."
faiss_index, entries = ingest_pdfs(paths)
save_index(faiss_index, entries)
docstore = entries
return f"Index built with {len(entries)} chunks from {len(paths)} PDFs. Saved to disk."
def reload_index() -> str:
ok = load_index()
return f"Index reloaded. Chunks: {len(docstore)}" if ok else "No saved index found."
def ask_rag(query: str, top_k, model_name: str, temperature: float, must_contain: str):
try:
if not query.strip():
return "Please enter a question.", []
ctx = retrieve(query, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain)
ans = groq_answer(query, ctx, model_name=model_name, temperature=temperature)
rows = []
for c in ctx:
preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "")
rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview])
return ans, rows
except Exception as e:
import traceback
return f"**Error:** {e}\n```\n{traceback.format_exc()}\n```", []
def set_api_key(k: str):
if k and k.strip():
os.environ["GROQ_API_KEY"] = k.strip()
return "API key set in runtime."
return "No key provided."
def download_index_zip():
if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)):
return None
base = "rag_index_bundle"
zip_path = shutil.make_archive(base, "zip", ".", ".")
# workaround for shutil: package explicit files
with shutil.make_archive("rag_index", "zip"):
pass
# build our own zip containing only index files
import zipfile
zp = "rag_index_bundle.zip"
with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z:
z.write(INDEX_PATH)
z.write(STORE_PATH)
return zp
# ---------- Gradio UI ----------
with gr.Blocks(title="RAG over PDFs (Groq)") as demo:
gr.Markdown("## RAG over your PDFs using Groq\nUpload PDFs, build an index, then ask questions with cited answers.")
with gr.Row():
api_box = gr.Textbox(label="(Optional) Set GROQ_API_KEY for this session", type="password", placeholder="sk_...")
set_btn = gr.Button("Set Key")
set_out = gr.Markdown()
set_btn.click(set_api_key, inputs=[api_box], outputs=[set_out])
with gr.Tab("1) Build or Load Index"):
file_u = gr.Files(label="Upload PDFs", file_types=[".pdf"], type="filepath")
with gr.Row():
build_btn = gr.Button("Build Index")
reload_btn = gr.Button("Reload Saved Index")
download_btn = gr.Button("Download Index (.zip)")
build_out = gr.Markdown()
def on_build(paths, progress=gr.Progress(track_tqdm=True)):
try:
return build_index_from_uploads(paths)
except Exception as e:
import traceback
return f"**Error while building index:** {e}\n\n```\n{traceback.format_exc()}\n```"
build_btn.click(on_build, inputs=[file_u], outputs=[build_out])
reload_btn.click(fn=reload_index, outputs=[build_out])
zpath = gr.File(label="Index zip", interactive=False)
download_btn.click(fn=download_index_zip, outputs=[zpath])
with gr.Tab("2) Ask Questions"):
q = gr.Textbox(label="Your question", lines=2, placeholder="Ask something present in the uploaded papers…")
with gr.Row():
topk = gr.Slider(1, 15, value=TOP_K_DEFAULT, step=1, label="Top-K passages")
model_dd = gr.Dropdown(MODEL_CHOICES, value=MODEL_CHOICES[0], label="Groq model")
temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
must = gr.Textbox(label="Must contain (comma-separated keywords)", placeholder="camera, CMOS, frame rate")
ask_btn = gr.Button("Answer")
ans = gr.Markdown()
src = gr.Dataframe(headers=["Source","Page","Score","Snippet"], wrap=True)
ask_btn.click(ask_rag, inputs=[q, topk, model_dd, temp, must], outputs=[ans, src])
demo.queue() # keep it simple for broad Gradio versions
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))