Spaces:
Sleeping
Sleeping
| import os, io, json, math, pickle, textwrap, shutil, re | |
| from typing import List, Dict, Any, Tuple | |
| import numpy as np, faiss, fitz # pymupdf | |
| from tqdm import tqdm | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from groq import Groq | |
| # ---------- Config ---------- | |
| EMBED_MODEL_NAME = "intfloat/multilingual-e5-small" | |
| CHUNK_SIZE = 1200 | |
| CHUNK_OVERLAP = 200 | |
| TOP_K_DEFAULT = 5 | |
| MAX_CONTEXT_CHARS = 12000 | |
| INDEX_PATH = "rag_index.faiss" | |
| STORE_PATH = "rag_store.pkl" | |
| MODEL_CHOICES = [ | |
| "llama-3.3-70b-versatile", | |
| "llama-3.1-8b-instant", | |
| "mixtral-8x7b-32768", | |
| ] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| embedder = None | |
| faiss_index = None | |
| docstore: List[Dict[str, Any]] = [] | |
| # ---------- PDF utils ---------- | |
| def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]: | |
| pages = [] | |
| with fitz.open(pdf_path) as doc: | |
| for i, page in enumerate(doc, start=1): | |
| txt = page.get_text("text") or "" | |
| if not txt.strip(): | |
| blocks = page.get_text("blocks") | |
| if isinstance(blocks, list): | |
| txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4) | |
| pages.append((i, txt or "")) | |
| return pages | |
| def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]: | |
| text = text.replace("\x00", " ").strip() | |
| if len(text) <= chunk_size: | |
| return [text] if text else [] | |
| out, start = [], 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| out.append(text[start:end]) | |
| start = max(end - overlap, start + 1) | |
| return out | |
| # ---------- Embeddings / FAISS ---------- | |
| def load_embedder(): | |
| global embedder | |
| if embedder is None: | |
| embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device) | |
| return embedder | |
| def _normalize(vecs: np.ndarray) -> np.ndarray: | |
| norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 | |
| return (vecs / norms).astype("float32") | |
| def embed_passages(texts: List[str]) -> np.ndarray: | |
| model = load_embedder() | |
| inputs = [f"passage: {t}" for t in texts] | |
| embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True) | |
| return _normalize(embs) | |
| def embed_query(q: str) -> np.ndarray: | |
| model = load_embedder() | |
| embs = model.encode([f"query: {q}"], convert_to_numpy=True) | |
| return _normalize(embs) | |
| def build_faiss(embs: np.ndarray): | |
| index = faiss.IndexFlatIP(embs.shape[1]) | |
| index.add(embs) | |
| return index | |
| def save_index(index, store_list: List[Dict[str, Any]]): | |
| faiss.write_index(index, INDEX_PATH) | |
| with open(STORE_PATH, "wb") as f: | |
| pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f) | |
| def load_index() -> bool: | |
| global faiss_index, docstore | |
| if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH): | |
| faiss_index = faiss.read_index(INDEX_PATH) | |
| with open(STORE_PATH, "rb") as f: | |
| data = pickle.load(f) | |
| docstore = data["docstore"] | |
| load_embedder() | |
| return True | |
| return False | |
| # ---------- Ingest ---------- | |
| def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]: | |
| entries: List[Dict[str, Any]] = [] | |
| for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"): | |
| try: | |
| pages = extract_text_from_pdf(pdf) | |
| base = os.path.basename(pdf) | |
| for pno, ptxt in pages: | |
| if not ptxt.strip(): | |
| continue | |
| for ci, ch in enumerate(chunk_text(ptxt)): | |
| entries.append({ | |
| "text": ch, | |
| "source": base, | |
| "page_start": pno, | |
| "page_end": pno, | |
| "chunk_id": f"{base}::p{pno}::c{ci}", | |
| }) | |
| except Exception as e: | |
| print(f"[WARN] Failed to parse {pdf}: {e}") | |
| if not entries: | |
| raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.") | |
| texts = [e["text"] for e in entries] | |
| embs = embed_passages(texts) | |
| index = build_faiss(embs) | |
| return index, entries | |
| # ---------- Retrieval (supports required keywords) ---------- | |
| def retrieve(query: str, top_k=5, must_contain: str = ""): | |
| global faiss_index, docstore | |
| if faiss_index is None or not docstore: | |
| raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.") | |
| k = int(top_k) if top_k else TOP_K_DEFAULT | |
| pool = min(max(10 * k, 200), len(docstore)) | |
| qemb = embed_query(query) | |
| D, I = faiss_index.search(qemb, pool) | |
| pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0] | |
| must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()] | |
| if must_words: | |
| filtered = [] | |
| for idx, score in pairs: | |
| t = docstore[idx]["text"].lower() | |
| if all(w in t for w in must_words): | |
| filtered.append((idx, score)) | |
| if filtered: | |
| pairs = filtered | |
| pairs = pairs[:k] | |
| hits = [] | |
| for idx, score in pairs: | |
| item = docstore[idx].copy() | |
| item["score"] = float(score) | |
| hits.append(item) | |
| return hits | |
| # ---------- Groq LLM ---------- | |
| def groq_answer(query: str, contexts, model_name="llama-3.1-70b-versatile", temperature=0.2, max_tokens=1000): | |
| try: | |
| if not os.environ.get("GROQ_API_KEY"): | |
| return "GROQ_API_KEY is not set. Add it in your host's environment/secrets." | |
| client = Groq(api_key=os.environ["GROQ_API_KEY"]) | |
| packed, used = [], 0 | |
| for c in contexts: | |
| tag = f"[{c['source']} p.{c['page_start']}]" | |
| piece = f"{tag}\n{c['text'].strip()}\n" | |
| if used + len(piece) > MAX_CONTEXT_CHARS: | |
| break | |
| packed.append(piece); used += len(piece) | |
| context_str = "\n---\n".join(packed) | |
| system_prompt = ( | |
| "You are a scholarly assistant. Answer using ONLY the provided context. " | |
| "If the answer is not present, say so. Always include a 'References' section with sources and page numbers." | |
| ) | |
| user_prompt = ( | |
| f"Question:\n{query}\n\n" | |
| f"Context snippets (use these only):\n{context_str}\n\n" | |
| "Write a precise answer. Keep claims traceable to the snippets." | |
| ) | |
| resp = client.chat.completions.create( | |
| model=model_name, | |
| temperature=float(temperature), | |
| max_tokens=int(max_tokens), | |
| messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}], | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| import traceback | |
| return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```" | |
| # ---------- Helpers for UI ---------- | |
| def build_index_from_uploads(paths: List[str]) -> str: | |
| global faiss_index, docstore | |
| if not paths: return "Please upload at least one PDF." | |
| if len(paths) > 120: return "Please limit to ~100 PDFs per build." | |
| faiss_index, entries = ingest_pdfs(paths) | |
| save_index(faiss_index, entries) | |
| docstore = entries | |
| return f"Index built with {len(entries)} chunks from {len(paths)} PDFs. Saved to disk." | |
| def reload_index() -> str: | |
| ok = load_index() | |
| return f"Index reloaded. Chunks: {len(docstore)}" if ok else "No saved index found." | |
| def ask_rag(query: str, top_k, model_name: str, temperature: float, must_contain: str): | |
| try: | |
| if not query.strip(): | |
| return "Please enter a question.", [] | |
| ctx = retrieve(query, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain) | |
| ans = groq_answer(query, ctx, model_name=model_name, temperature=temperature) | |
| rows = [] | |
| for c in ctx: | |
| preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "") | |
| rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview]) | |
| return ans, rows | |
| except Exception as e: | |
| import traceback | |
| return f"**Error:** {e}\n```\n{traceback.format_exc()}\n```", [] | |
| def set_api_key(k: str): | |
| if k and k.strip(): | |
| os.environ["GROQ_API_KEY"] = k.strip() | |
| return "API key set in runtime." | |
| return "No key provided." | |
| def download_index_zip(): | |
| if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)): | |
| return None | |
| base = "rag_index_bundle" | |
| zip_path = shutil.make_archive(base, "zip", ".", ".") | |
| # workaround for shutil: package explicit files | |
| with shutil.make_archive("rag_index", "zip"): | |
| pass | |
| # build our own zip containing only index files | |
| import zipfile | |
| zp = "rag_index_bundle.zip" | |
| with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z: | |
| z.write(INDEX_PATH) | |
| z.write(STORE_PATH) | |
| return zp | |
| # ---------- Gradio UI ---------- | |
| with gr.Blocks(title="RAG over PDFs (Groq)") as demo: | |
| gr.Markdown("## RAG over your PDFs using Groq\nUpload PDFs, build an index, then ask questions with cited answers.") | |
| with gr.Row(): | |
| api_box = gr.Textbox(label="(Optional) Set GROQ_API_KEY for this session", type="password", placeholder="sk_...") | |
| set_btn = gr.Button("Set Key") | |
| set_out = gr.Markdown() | |
| set_btn.click(set_api_key, inputs=[api_box], outputs=[set_out]) | |
| with gr.Tab("1) Build or Load Index"): | |
| file_u = gr.Files(label="Upload PDFs", file_types=[".pdf"], type="filepath") | |
| with gr.Row(): | |
| build_btn = gr.Button("Build Index") | |
| reload_btn = gr.Button("Reload Saved Index") | |
| download_btn = gr.Button("Download Index (.zip)") | |
| build_out = gr.Markdown() | |
| def on_build(paths, progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| return build_index_from_uploads(paths) | |
| except Exception as e: | |
| import traceback | |
| return f"**Error while building index:** {e}\n\n```\n{traceback.format_exc()}\n```" | |
| build_btn.click(on_build, inputs=[file_u], outputs=[build_out]) | |
| reload_btn.click(fn=reload_index, outputs=[build_out]) | |
| zpath = gr.File(label="Index zip", interactive=False) | |
| download_btn.click(fn=download_index_zip, outputs=[zpath]) | |
| with gr.Tab("2) Ask Questions"): | |
| q = gr.Textbox(label="Your question", lines=2, placeholder="Ask something present in the uploaded papers…") | |
| with gr.Row(): | |
| topk = gr.Slider(1, 15, value=TOP_K_DEFAULT, step=1, label="Top-K passages") | |
| model_dd = gr.Dropdown(MODEL_CHOICES, value=MODEL_CHOICES[0], label="Groq model") | |
| temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") | |
| must = gr.Textbox(label="Must contain (comma-separated keywords)", placeholder="camera, CMOS, frame rate") | |
| ask_btn = gr.Button("Answer") | |
| ans = gr.Markdown() | |
| src = gr.Dataframe(headers=["Source","Page","Score","Snippet"], wrap=True) | |
| ask_btn.click(ask_rag, inputs=[q, topk, model_dd, temp, must], outputs=[ans, src]) | |
| demo.queue() # keep it simple for broad Gradio versions | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |