import os, io, json, math, pickle, textwrap, shutil, re from typing import List, Dict, Any, Tuple import numpy as np, faiss, fitz # pymupdf from tqdm import tqdm import torch from sentence_transformers import SentenceTransformer import gradio as gr from groq import Groq # ---------- Config ---------- EMBED_MODEL_NAME = "intfloat/multilingual-e5-small" CHUNK_SIZE = 1200 CHUNK_OVERLAP = 200 TOP_K_DEFAULT = 5 MAX_CONTEXT_CHARS = 12000 INDEX_PATH = "rag_index.faiss" STORE_PATH = "rag_store.pkl" MODEL_CHOICES = [ "llama-3.3-70b-versatile", "llama-3.1-8b-instant", "mixtral-8x7b-32768", ] device = "cuda" if torch.cuda.is_available() else "cpu" embedder = None faiss_index = None docstore: List[Dict[str, Any]] = [] # ---------- PDF utils ---------- def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]: pages = [] with fitz.open(pdf_path) as doc: for i, page in enumerate(doc, start=1): txt = page.get_text("text") or "" if not txt.strip(): blocks = page.get_text("blocks") if isinstance(blocks, list): txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4) pages.append((i, txt or "")) return pages def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]: text = text.replace("\x00", " ").strip() if len(text) <= chunk_size: return [text] if text else [] out, start = [], 0 while start < len(text): end = start + chunk_size out.append(text[start:end]) start = max(end - overlap, start + 1) return out # ---------- Embeddings / FAISS ---------- def load_embedder(): global embedder if embedder is None: embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device) return embedder def _normalize(vecs: np.ndarray) -> np.ndarray: norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 return (vecs / norms).astype("float32") def embed_passages(texts: List[str]) -> np.ndarray: model = load_embedder() inputs = [f"passage: {t}" for t in texts] embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True) return _normalize(embs) def embed_query(q: str) -> np.ndarray: model = load_embedder() embs = model.encode([f"query: {q}"], convert_to_numpy=True) return _normalize(embs) def build_faiss(embs: np.ndarray): index = faiss.IndexFlatIP(embs.shape[1]) index.add(embs) return index def save_index(index, store_list: List[Dict[str, Any]]): faiss.write_index(index, INDEX_PATH) with open(STORE_PATH, "wb") as f: pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f) def load_index() -> bool: global faiss_index, docstore if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH): faiss_index = faiss.read_index(INDEX_PATH) with open(STORE_PATH, "rb") as f: data = pickle.load(f) docstore = data["docstore"] load_embedder() return True return False # ---------- Ingest ---------- def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]: entries: List[Dict[str, Any]] = [] for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"): try: pages = extract_text_from_pdf(pdf) base = os.path.basename(pdf) for pno, ptxt in pages: if not ptxt.strip(): continue for ci, ch in enumerate(chunk_text(ptxt)): entries.append({ "text": ch, "source": base, "page_start": pno, "page_end": pno, "chunk_id": f"{base}::p{pno}::c{ci}", }) except Exception as e: print(f"[WARN] Failed to parse {pdf}: {e}") if not entries: raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.") texts = [e["text"] for e in entries] embs = embed_passages(texts) index = build_faiss(embs) return index, entries # ---------- Retrieval (supports required keywords) ---------- def retrieve(query: str, top_k=5, must_contain: str = ""): global faiss_index, docstore if faiss_index is None or not docstore: raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.") k = int(top_k) if top_k else TOP_K_DEFAULT pool = min(max(10 * k, 200), len(docstore)) qemb = embed_query(query) D, I = faiss_index.search(qemb, pool) pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0] must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()] if must_words: filtered = [] for idx, score in pairs: t = docstore[idx]["text"].lower() if all(w in t for w in must_words): filtered.append((idx, score)) if filtered: pairs = filtered pairs = pairs[:k] hits = [] for idx, score in pairs: item = docstore[idx].copy() item["score"] = float(score) hits.append(item) return hits # ---------- Groq LLM ---------- def groq_answer(query: str, contexts, model_name="llama-3.1-70b-versatile", temperature=0.2, max_tokens=1000): try: if not os.environ.get("GROQ_API_KEY"): return "GROQ_API_KEY is not set. Add it in your host's environment/secrets." client = Groq(api_key=os.environ["GROQ_API_KEY"]) packed, used = [], 0 for c in contexts: tag = f"[{c['source']} p.{c['page_start']}]" piece = f"{tag}\n{c['text'].strip()}\n" if used + len(piece) > MAX_CONTEXT_CHARS: break packed.append(piece); used += len(piece) context_str = "\n---\n".join(packed) system_prompt = ( "You are a scholarly assistant. Answer using ONLY the provided context. " "If the answer is not present, say so. Always include a 'References' section with sources and page numbers." ) user_prompt = ( f"Question:\n{query}\n\n" f"Context snippets (use these only):\n{context_str}\n\n" "Write a precise answer. Keep claims traceable to the snippets." ) resp = client.chat.completions.create( model=model_name, temperature=float(temperature), max_tokens=int(max_tokens), messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}], ) return resp.choices[0].message.content.strip() except Exception as e: import traceback return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```" # ---------- Helpers for UI ---------- def build_index_from_uploads(paths: List[str]) -> str: global faiss_index, docstore if not paths: return "Please upload at least one PDF." if len(paths) > 120: return "Please limit to ~100 PDFs per build." faiss_index, entries = ingest_pdfs(paths) save_index(faiss_index, entries) docstore = entries return f"Index built with {len(entries)} chunks from {len(paths)} PDFs. Saved to disk." def reload_index() -> str: ok = load_index() return f"Index reloaded. Chunks: {len(docstore)}" if ok else "No saved index found." def ask_rag(query: str, top_k, model_name: str, temperature: float, must_contain: str): try: if not query.strip(): return "Please enter a question.", [] ctx = retrieve(query, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain) ans = groq_answer(query, ctx, model_name=model_name, temperature=temperature) rows = [] for c in ctx: preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "") rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview]) return ans, rows except Exception as e: import traceback return f"**Error:** {e}\n```\n{traceback.format_exc()}\n```", [] def set_api_key(k: str): if k and k.strip(): os.environ["GROQ_API_KEY"] = k.strip() return "API key set in runtime." return "No key provided." def download_index_zip(): if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)): return None base = "rag_index_bundle" zip_path = shutil.make_archive(base, "zip", ".", ".") # workaround for shutil: package explicit files with shutil.make_archive("rag_index", "zip"): pass # build our own zip containing only index files import zipfile zp = "rag_index_bundle.zip" with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z: z.write(INDEX_PATH) z.write(STORE_PATH) return zp # ---------- Gradio UI ---------- with gr.Blocks(title="RAG over PDFs (Groq)") as demo: gr.Markdown("## RAG over your PDFs using Groq\nUpload PDFs, build an index, then ask questions with cited answers.") with gr.Row(): api_box = gr.Textbox(label="(Optional) Set GROQ_API_KEY for this session", type="password", placeholder="sk_...") set_btn = gr.Button("Set Key") set_out = gr.Markdown() set_btn.click(set_api_key, inputs=[api_box], outputs=[set_out]) with gr.Tab("1) Build or Load Index"): file_u = gr.Files(label="Upload PDFs", file_types=[".pdf"], type="filepath") with gr.Row(): build_btn = gr.Button("Build Index") reload_btn = gr.Button("Reload Saved Index") download_btn = gr.Button("Download Index (.zip)") build_out = gr.Markdown() def on_build(paths, progress=gr.Progress(track_tqdm=True)): try: return build_index_from_uploads(paths) except Exception as e: import traceback return f"**Error while building index:** {e}\n\n```\n{traceback.format_exc()}\n```" build_btn.click(on_build, inputs=[file_u], outputs=[build_out]) reload_btn.click(fn=reload_index, outputs=[build_out]) zpath = gr.File(label="Index zip", interactive=False) download_btn.click(fn=download_index_zip, outputs=[zpath]) with gr.Tab("2) Ask Questions"): q = gr.Textbox(label="Your question", lines=2, placeholder="Ask something present in the uploaded papers…") with gr.Row(): topk = gr.Slider(1, 15, value=TOP_K_DEFAULT, step=1, label="Top-K passages") model_dd = gr.Dropdown(MODEL_CHOICES, value=MODEL_CHOICES[0], label="Groq model") temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") must = gr.Textbox(label="Must contain (comma-separated keywords)", placeholder="camera, CMOS, frame rate") ask_btn = gr.Button("Answer") ans = gr.Markdown() src = gr.Dataframe(headers=["Source","Page","Score","Snippet"], wrap=True) ask_btn.click(ask_rag, inputs=[q, topk, model_dd, temp, must], outputs=[ans, src]) demo.queue() # keep it simple for broad Gradio versions if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))