import os, io, json, math, pickle, textwrap, shutil, re, zipfile, tempfile from typing import List, Dict, Any, Tuple import numpy as np, faiss, fitz # pymupdf from tqdm import tqdm import torch from sentence_transformers import SentenceTransformer import gradio as gr from groq import Groq from docx import Document from docx.shared import Pt from string import Template # ========================= # Branding # ========================= APP_NAME = "ScholarLens" TAGLINE = "Query your literature, get page-level proof" # ========================= # Color System (accessible dark theme) # ========================= # Primary palette chosen for high contrast and clear role separation. PALETTE = { "bg": "#0D1224", # deep slate/navy background "panel": "#121936", # panel background "panel_alt": "#0F1530", # secondary panel "text_light":"#EAF0FF", # default light text on dark "text_dark": "#0B111C", # text on light surfaces # Accents (all readable on dark): "primary": "#22D3EE", # cyan (primary actions) "secondary": "#A78BFA", # purple (secondary actions) "accent": "#FBBF24", # amber (highlights/links) "success": "#34D399", # green (success state) "danger": "#FB7185", # rose (errors) # Borders & subtle strokes "stroke": "rgba(255,255,255,0.14)", "stroke_alt":"rgba(255,255,255,0.10)", } from string import Template def build_custom_css(): """ Strong-contrast dark UI, light text everywhere (incl. Dataframe & Examples). """ tmpl = Template(r""" :root{ --bg: $bg; --panel: $panel; --panel-alt: $panel_alt; --text-light: $text_light; --text-dark: $text_dark; --primary: $primary; --secondary: $secondary; --accent: $accent; --success: $success; --danger: $danger; --stroke: $stroke; --stroke-alt: $stroke_alt; /* Gradio tokens */ --body-background-fill: var(--bg); --body-text-color: var(--text-light); --block-background-fill: var(--panel); --block-title-text-color: var(--text-light); --border-color-primary: var(--stroke); --button-primary-background-fill: var(--primary); --button-primary-text-color: var(--text-dark); --button-primary-border-color: color-mix(in srgb, var(--primary) 75%, black 25%); --button-secondary-background-fill: var(--secondary); --button-secondary-text-color: var(--text-dark); --button-secondary-border-color: color-mix(in srgb, var(--secondary) 70%, black 30%); --link-text-color: var(--accent); } /* Global */ html, body, .gradio-container{ background: var(--bg) !important; color: var(--text-light) !important; font-size: 16px; line-height: 1.5; } /* Panels / Tabs */ .gradio-container .block, .gradio-container .tabs, .gradio-container .tabs > .tabitem{ background: var(--panel) !important; color: var(--text-light) !important; border: 1px solid var(--stroke); border-radius: 12px; } /* Hero */ #hero{ background: radial-gradient(900px 350px at 20% -20%, color-mix(in srgb, var(--secondary) 25%, transparent) 0%, transparent 100%), radial-gradient(900px 350px at 120% 10%, color-mix(in srgb, var(--primary) 25%, transparent) 0%, transparent 100%), var(--panel-alt); border: 1px solid var(--stroke); border-radius: 14px; padding: 16px 18px; color: var(--text-light); } /* KPI */ .kpi{ text-align:center; padding:12px; border-radius:10px; border:1px solid var(--stroke); background: var(--panel-alt); color: var(--text-light); } /* Buttons */ .gradio-container .gr-button, .gradio-container button{ border-radius: 10px !important; font-weight: 650 !important; letter-spacing: .2px; } .gradio-container .gr-button-primary, .gradio-container button.primary{ background: var(--primary) !important; color: var(--text-dark) !important; border: 1px solid var(--button-primary-border-color) !important; box-shadow: 0 8px 20px -8px color-mix(in srgb, var(--primary) 50%, transparent); } .gradio-container .gr-button-secondary, .gradio-container button.secondary{ background: var(--secondary) !important; color: var(--text-dark) !important; border: 1px solid var(--button-secondary-border-color) !important; } /* Inputs */ input, textarea, select, .gr-textbox, .gr-text-area, .gr-dropdown, .gr-file, .gr-slider{ background: var(--panel-alt) !important; color: var(--text-light) !important; border: 1px solid var(--stroke-alt) !important; border-radius: 10px !important; } input::placeholder, textarea::placeholder{ color: color-mix(in srgb, var(--text-light) 60%, transparent) !important; } /* Markdown / labels / links */ label, .label, .prose h1, .prose h2, .prose h3, .prose p, .markdown-body{ color: var(--text-light) !important; } a, .prose a{ color: var(--accent) !important; text-decoration:none; } a:hover{ text-decoration: underline; } /* --- CRITICAL FIXES (visibility) --- */ /* Pandas DataFrame table (Top passages) */ .gradio-container table.dataframe, .gradio-container .dataframe, .gradio-container .gr-dataframe{ background: var(--panel-alt) !important; color: var(--text-light) !important; border: 1px solid var(--stroke) !important; border-radius: 10px !important; } .gradio-container table.dataframe th, .gradio-container table.dataframe td, .gradio-container .gr-dataframe th, .gradio-container .gr-dataframe td{ background: var(--panel-alt) !important; color: var(--text-light) !important; border-color: var(--stroke-alt) !important; } /* Examples grid (Quick examples) */ .gradio-container .examples, .gradio-container .examples *{ color: var(--text-light) !important; } .gradio-container .examples, .gradio-container .examples .grid, .gradio-container .examples .grid .item{ background: var(--panel-alt) !important; border: 1px solid var(--stroke-alt) !important; border-radius: 10px !important; } /* Code blocks in Markdown (error traces, etc.) */ .markdown-body pre, .markdown-body code{ background: #0B1D3A !important; color: var(--text-light) !important; border: 1px solid var(--stroke-alt) !important; border-radius: 8px; } /* Accordion */ .accordion, .gr-accordion{ background: var(--panel-alt) !important; border: 1px solid var(--stroke) !important; border-radius: 10px !important; } /* Tabs active underline */ .gradio-container .tabs .tab-nav button.selected{ box-shadow: inset 0 -3px 0 0 var(--primary) !important; color: var(--text-light) !important; } /* Focus outlines for a11y */ :focus-visible{ outline: 3px solid var(--accent) !important; outline-offset: 2px !important; } /* Page width */ .gradio-container{ max-width: 1120px; margin: 0 auto; } """) return tmpl.substitute( bg=PALETTE["bg"], panel=PALETTE["panel"], panel_alt=PALETTE["panel_alt"], text_light=PALETTE["text_light"], text_dark=PALETTE["text_dark"], primary=PALETTE["primary"], secondary=PALETTE["secondary"], accent=PALETTE["accent"], success=PALETTE["success"], danger=PALETTE["danger"], stroke=PALETTE["stroke"], stroke_alt=PALETTE["stroke_alt"], ) # ========================= # Engine config # ========================= EMBED_MODEL_NAME = "intfloat/multilingual-e5-small" CHUNK_SIZE = 1200 CHUNK_OVERLAP = 200 TOP_K_DEFAULT = 7 MAX_CONTEXT_CHARS = 16000 INDEX_PATH = "rag_index.faiss" STORE_PATH = "rag_store.pkl" MODEL_CHOICES = [ "llama-3.3-70b-versatile", "llama-3.1-8b-instant", "mixtral-8x7b-32768", ] device = "cuda" if torch.cuda.is_available() else "cpu" embedder = None faiss_index = None docstore: List[Dict[str, Any]] = [] # ========================= # PDF utils # ========================= def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]: pages = [] with fitz.open(pdf_path) as doc: for i, page in enumerate(doc, start=1): txt = page.get_text("text") or "" if not txt.strip(): blocks = page.get_text("blocks") if isinstance(blocks, list): txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4) pages.append((i, txt or "")) return pages def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]: text = text.replace("\x00", " ").strip() if len(text) <= chunk_size: return [text] if text else [] out, start = [], 0 while start < len(text): end = start + chunk_size out.append(text[start:end]) start = max(end - overlap, start + 1) return out # ========================= # Embeddings / FAISS # ========================= def load_embedder(): global embedder if embedder is None: embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device) return embedder def _normalize(vecs: np.ndarray) -> np.ndarray: norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 return (vecs / norms).astype("float32") def embed_passages(texts: List[str]) -> np.ndarray: model = load_embedder() inputs = [f"passage: {t}" for t in texts] embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True) return _normalize(embs) def embed_query(q: str) -> np.ndarray: model = load_embedder() embs = model.encode([f"query: {q}"], convert_to_numpy=True) return _normalize(embs) def build_faiss(embs: np.ndarray): index = faiss.IndexFlatIP(embs.shape[1]) index.add(embs) return index def save_index(index, store_list: List[Dict[str, Any]]): faiss.write_index(index, INDEX_PATH) with open(STORE_PATH, "wb") as f: pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f) def load_index() -> bool: global faiss_index, docstore if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH): faiss_index = faiss.read_index(INDEX_PATH) with open(STORE_PATH, "rb") as f: data = pickle.load(f) docstore = data["docstore"] load_embedder() return True return False # ========================= # Ingest # ========================= def _collect_pdf_paths(upload_paths: List[str]) -> List[str]: """Accept PDFs and ZIPs of PDFs.""" if not upload_paths: return [] out = [] for p in upload_paths: p = str(p) if p.lower().endswith(".pdf"): out.append(p) elif p.lower().endswith(".zip"): tmpdir = tempfile.mkdtemp(prefix="pdfs_") with zipfile.ZipFile(p, "r") as z: for name in z.namelist(): if name.lower().endswith(".pdf"): z.extract(name, tmpdir) for root, _, files in os.walk(tmpdir): for f in files: if f.lower().endswith(".pdf"): out.append(os.path.join(root, f)) return out def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]: entries: List[Dict[str, Any]] = [] for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"): try: pages = extract_text_from_pdf(pdf) base = os.path.basename(pdf) for pno, ptxt in pages: if not ptxt.strip(): continue for ci, ch in enumerate(chunk_text(ptxt)): entries.append({ "text": ch, "source": base, "page_start": pno, "page_end": pno, "chunk_id": f"{base}::p{pno}::c{ci}", }) except Exception as e: print(f"[WARN] Failed to parse {pdf}: {e}") if not entries: raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.") texts = [e["text"] for e in entries] embs = embed_passages(texts) index = build_faiss(embs) return index, entries # ========================= # Retrieval # ========================= def retrieve(query: str, top_k=5, must_contain: str = ""): global faiss_index, docstore if faiss_index is None or not docstore: raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.") k = int(top_k) if top_k else TOP_K_DEFAULT pool = min(max(10 * k, 200), len(docstore)) qemb = embed_query(query) D, I = faiss_index.search(qemb, pool) pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0] must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()] if must_words: filtered = [] for idx, score in pairs: t = docstore[idx]["text"].lower() if all(w in t for w in must_words): filtered.append((idx, score)) if filtered: pairs = filtered pairs = pairs[:k] hits = [] for idx, score in pairs: item = docstore[idx].copy() item["score"] = float(score) hits.append(item) return hits # ========================= # Groq LLM # ========================= def groq_answer(query: str, contexts, model_name="llama-3.3-70b-versatile", temperature=0.2, max_tokens=1000): try: if not os.environ.get("GROQ_API_KEY"): return "GROQ_API_KEY is not set. Add it in your Space secrets or the key box." client = Groq(api_key=os.environ["GROQ_API_KEY"]) packed, used = [], 0 for c in contexts: tag = f"[{c['source']} p.{c['page_start']}]" piece = f"{tag}\n{c['text'].strip()}\n" if used + len(piece) > MAX_CONTEXT_CHARS: break packed.append(piece); used += len(piece) context_str = "\n---\n".join(packed) system_prompt = ( "You are a scholarly assistant. Answer using ONLY the provided context. " "If the answer is not present, say so. Always include a 'References' section with sources and page numbers." ) user_prompt = ( f"Question:\n{query}\n\n" f"Context snippets (use these only):\n{context_str}\n\n" "Write a precise answer. Keep claims traceable to the snippets." ) resp = client.chat.completions.create( model=model_name, temperature=float(temperature), max_tokens=int(max_tokens), messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}], ) return resp.choices[0].message.content.strip() except Exception as e: import traceback return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```" # ========================= # Export helpers # ========================= def export_answer_to_docx(question: str, answer_md: str, rows: List[List[str]]) -> str: doc = Document() try: styles = doc.styles styles['Normal'].font.name = 'Calibri' styles['Normal'].font.size = Pt(11) except Exception: pass doc.add_heading(f"{APP_NAME} - Answer", level=1) doc.add_paragraph(f"Question: {question}") doc.add_heading("Answer", level=2) for line in answer_md.splitlines(): doc.add_paragraph(line) doc.add_heading("References (Top Passages)", level=2) table = doc.add_table(rows=1, cols=4) hdr = table.rows[0].cells hdr[0].text = "Source"; hdr[1].text = "Page"; hdr[2].text = "Score"; hdr[3].text = "Snippet" for r in rows: row = table.add_row().cells for i, val in enumerate(r): row[i].text = str(val) path = "scholarlens_answer.docx" doc.save(path) return path # ========================= # UI helpers # ========================= def build_index_from_uploads(paths: List[str]) -> str: global faiss_index, docstore pdfs = _collect_pdf_paths(paths) if not pdfs: return "Please upload at least one PDF or ZIP of PDFs." faiss_index, entries = ingest_pdfs(pdfs) save_index(faiss_index, entries) docstore = entries return f"✅ Index built with {len(entries)} chunks from {len(pdfs)} files. You can start asking questions." def reload_index() -> str: ok = load_index() return f"🔁 Index reloaded. Chunks ready: {len(docstore)}" if ok else "No saved index found yet." def ask_rag(question: str, top_k, model_name: str, temperature: float, must_contain: str): try: if not question.strip(): return "Please enter a question.", [], "", gr.update(visible=False) ctx = retrieve(question, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain) ans = groq_answer(question, ctx, model_name=model_name, temperature=temperature) rows = [] for c in ctx: preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "") rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview]) details = [] for c in ctx: details.append(f"**{c['source']} p.{c['page_start']}**\n> {c['text'].strip()[:1000]}") snippets_md = "\n\n---\n\n".join(details) download_btn = gr.update(visible=True) return ans, rows, snippets_md, download_btn except Exception as e: import traceback err = f"**Error:** {e}\n```\n{traceback.format_exc()}\n```" return err, [], "", gr.update(visible=False) def set_api_key(k: str): if k and k.strip(): os.environ["GROQ_API_KEY"] = k.strip() return "🔑 API key set for this session." return "No key provided." def download_index_zip(): if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)): return None zp = "rag_index_bundle.zip" with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z: z.write(INDEX_PATH) z.write(STORE_PATH) return zp def do_export_docx(question, answer_md, sources_rows): if not answer_md or not sources_rows: return None try: return export_answer_to_docx(question, answer_md, sources_rows) except Exception: return None # ========================= # UI # ========================= theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="slate") with gr.Blocks(title=f"{APP_NAME} | RAG over PDFs", theme=theme, css=build_custom_css()) as demo: # Hero with gr.Group(elem_id="hero"): gr.Markdown(f"""
Upload your papers, build an index, and ask research questions with verifiable, page-level citations.
""") # KPIs with gr.Row(): gr.Markdown("**Meaning-aware retrieval**