ScholarLens / app.py
EngrMuhammadBilal's picture
Update app.py
5273e75 verified
import os, io, json, math, pickle, textwrap, shutil, re, zipfile, tempfile
from typing import List, Dict, Any, Tuple
import numpy as np, faiss, fitz # pymupdf
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
import gradio as gr
from groq import Groq
from docx import Document
from docx.shared import Pt
from string import Template
# =========================
# Branding
# =========================
APP_NAME = "ScholarLens"
TAGLINE = "Query your literature, get page-level proof"
# =========================
# Color System (accessible dark theme)
# =========================
# Primary palette chosen for high contrast and clear role separation.
PALETTE = {
"bg": "#0D1224", # deep slate/navy background
"panel": "#121936", # panel background
"panel_alt": "#0F1530", # secondary panel
"text_light":"#EAF0FF", # default light text on dark
"text_dark": "#0B111C", # text on light surfaces
# Accents (all readable on dark):
"primary": "#22D3EE", # cyan (primary actions)
"secondary": "#A78BFA", # purple (secondary actions)
"accent": "#FBBF24", # amber (highlights/links)
"success": "#34D399", # green (success state)
"danger": "#FB7185", # rose (errors)
# Borders & subtle strokes
"stroke": "rgba(255,255,255,0.14)",
"stroke_alt":"rgba(255,255,255,0.10)",
}
from string import Template
def build_custom_css():
"""
Strong-contrast dark UI, light text everywhere (incl. Dataframe & Examples).
"""
tmpl = Template(r"""
:root{
--bg: $bg; --panel: $panel; --panel-alt: $panel_alt;
--text-light: $text_light; --text-dark: $text_dark;
--primary: $primary; --secondary: $secondary; --accent: $accent;
--success: $success; --danger: $danger;
--stroke: $stroke; --stroke-alt: $stroke_alt;
/* Gradio tokens */
--body-background-fill: var(--bg);
--body-text-color: var(--text-light);
--block-background-fill: var(--panel);
--block-title-text-color: var(--text-light);
--border-color-primary: var(--stroke);
--button-primary-background-fill: var(--primary);
--button-primary-text-color: var(--text-dark);
--button-primary-border-color: color-mix(in srgb, var(--primary) 75%, black 25%);
--button-secondary-background-fill: var(--secondary);
--button-secondary-text-color: var(--text-dark);
--button-secondary-border-color: color-mix(in srgb, var(--secondary) 70%, black 30%);
--link-text-color: var(--accent);
}
/* Global */
html, body, .gradio-container{
background: var(--bg) !important;
color: var(--text-light) !important;
font-size: 16px; line-height: 1.5;
}
/* Panels / Tabs */
.gradio-container .block,
.gradio-container .tabs,
.gradio-container .tabs > .tabitem{
background: var(--panel) !important;
color: var(--text-light) !important;
border: 1px solid var(--stroke);
border-radius: 12px;
}
/* Hero */
#hero{
background:
radial-gradient(900px 350px at 20% -20%, color-mix(in srgb, var(--secondary) 25%, transparent) 0%, transparent 100%),
radial-gradient(900px 350px at 120% 10%, color-mix(in srgb, var(--primary) 25%, transparent) 0%, transparent 100%),
var(--panel-alt);
border: 1px solid var(--stroke);
border-radius: 14px; padding: 16px 18px;
color: var(--text-light);
}
/* KPI */
.kpi{ text-align:center; padding:12px; border-radius:10px; border:1px solid var(--stroke);
background: var(--panel-alt); color: var(--text-light); }
/* Buttons */
.gradio-container .gr-button, .gradio-container button{
border-radius: 10px !important; font-weight: 650 !important; letter-spacing: .2px;
}
.gradio-container .gr-button-primary, .gradio-container button.primary{
background: var(--primary) !important; color: var(--text-dark) !important;
border: 1px solid var(--button-primary-border-color) !important;
box-shadow: 0 8px 20px -8px color-mix(in srgb, var(--primary) 50%, transparent);
}
.gradio-container .gr-button-secondary, .gradio-container button.secondary{
background: var(--secondary) !important; color: var(--text-dark) !important;
border: 1px solid var(--button-secondary-border-color) !important;
}
/* Inputs */
input, textarea, select, .gr-textbox, .gr-text-area, .gr-dropdown, .gr-file, .gr-slider{
background: var(--panel-alt) !important; color: var(--text-light) !important;
border: 1px solid var(--stroke-alt) !important; border-radius: 10px !important;
}
input::placeholder, textarea::placeholder{ color: color-mix(in srgb, var(--text-light) 60%, transparent) !important; }
/* Markdown / labels / links */
label, .label, .prose h1, .prose h2, .prose h3, .prose p, .markdown-body{ color: var(--text-light) !important; }
a, .prose a{ color: var(--accent) !important; text-decoration:none; } a:hover{ text-decoration: underline; }
/* --- CRITICAL FIXES (visibility) --- */
/* Pandas DataFrame table (Top passages) */
.gradio-container table.dataframe,
.gradio-container .dataframe,
.gradio-container .gr-dataframe{
background: var(--panel-alt) !important;
color: var(--text-light) !important;
border: 1px solid var(--stroke) !important;
border-radius: 10px !important;
}
.gradio-container table.dataframe th,
.gradio-container table.dataframe td,
.gradio-container .gr-dataframe th,
.gradio-container .gr-dataframe td{
background: var(--panel-alt) !important;
color: var(--text-light) !important;
border-color: var(--stroke-alt) !important;
}
/* Examples grid (Quick examples) */
.gradio-container .examples,
.gradio-container .examples *{
color: var(--text-light) !important;
}
.gradio-container .examples,
.gradio-container .examples .grid,
.gradio-container .examples .grid .item{
background: var(--panel-alt) !important;
border: 1px solid var(--stroke-alt) !important;
border-radius: 10px !important;
}
/* Code blocks in Markdown (error traces, etc.) */
.markdown-body pre, .markdown-body code{
background: #0B1D3A !important; color: var(--text-light) !important;
border: 1px solid var(--stroke-alt) !important; border-radius: 8px;
}
/* Accordion */
.accordion, .gr-accordion{
background: var(--panel-alt) !important; border: 1px solid var(--stroke) !important; border-radius: 10px !important;
}
/* Tabs active underline */
.gradio-container .tabs .tab-nav button.selected{
box-shadow: inset 0 -3px 0 0 var(--primary) !important; color: var(--text-light) !important;
}
/* Focus outlines for a11y */
:focus-visible{ outline: 3px solid var(--accent) !important; outline-offset: 2px !important; }
/* Page width */
.gradio-container{ max-width: 1120px; margin: 0 auto; }
""")
return tmpl.substitute(
bg=PALETTE["bg"], panel=PALETTE["panel"], panel_alt=PALETTE["panel_alt"],
text_light=PALETTE["text_light"], text_dark=PALETTE["text_dark"],
primary=PALETTE["primary"], secondary=PALETTE["secondary"],
accent=PALETTE["accent"], success=PALETTE["success"], danger=PALETTE["danger"],
stroke=PALETTE["stroke"], stroke_alt=PALETTE["stroke_alt"],
)
# =========================
# Engine config
# =========================
EMBED_MODEL_NAME = "intfloat/multilingual-e5-small"
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 200
TOP_K_DEFAULT = 7
MAX_CONTEXT_CHARS = 16000
INDEX_PATH = "rag_index.faiss"
STORE_PATH = "rag_store.pkl"
MODEL_CHOICES = [
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
embedder = None
faiss_index = None
docstore: List[Dict[str, Any]] = []
# =========================
# PDF utils
# =========================
def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]:
pages = []
with fitz.open(pdf_path) as doc:
for i, page in enumerate(doc, start=1):
txt = page.get_text("text") or ""
if not txt.strip():
blocks = page.get_text("blocks")
if isinstance(blocks, list):
txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4)
pages.append((i, txt or ""))
return pages
def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]:
text = text.replace("\x00", " ").strip()
if len(text) <= chunk_size:
return [text] if text else []
out, start = [], 0
while start < len(text):
end = start + chunk_size
out.append(text[start:end])
start = max(end - overlap, start + 1)
return out
# =========================
# Embeddings / FAISS
# =========================
def load_embedder():
global embedder
if embedder is None:
embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device)
return embedder
def _normalize(vecs: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
return (vecs / norms).astype("float32")
def embed_passages(texts: List[str]) -> np.ndarray:
model = load_embedder()
inputs = [f"passage: {t}" for t in texts]
embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
return _normalize(embs)
def embed_query(q: str) -> np.ndarray:
model = load_embedder()
embs = model.encode([f"query: {q}"], convert_to_numpy=True)
return _normalize(embs)
def build_faiss(embs: np.ndarray):
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
return index
def save_index(index, store_list: List[Dict[str, Any]]):
faiss.write_index(index, INDEX_PATH)
with open(STORE_PATH, "wb") as f:
pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f)
def load_index() -> bool:
global faiss_index, docstore
if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH):
faiss_index = faiss.read_index(INDEX_PATH)
with open(STORE_PATH, "rb") as f:
data = pickle.load(f)
docstore = data["docstore"]
load_embedder()
return True
return False
# =========================
# Ingest
# =========================
def _collect_pdf_paths(upload_paths: List[str]) -> List[str]:
"""Accept PDFs and ZIPs of PDFs."""
if not upload_paths:
return []
out = []
for p in upload_paths:
p = str(p)
if p.lower().endswith(".pdf"):
out.append(p)
elif p.lower().endswith(".zip"):
tmpdir = tempfile.mkdtemp(prefix="pdfs_")
with zipfile.ZipFile(p, "r") as z:
for name in z.namelist():
if name.lower().endswith(".pdf"):
z.extract(name, tmpdir)
for root, _, files in os.walk(tmpdir):
for f in files:
if f.lower().endswith(".pdf"):
out.append(os.path.join(root, f))
return out
def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]:
entries: List[Dict[str, Any]] = []
for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"):
try:
pages = extract_text_from_pdf(pdf)
base = os.path.basename(pdf)
for pno, ptxt in pages:
if not ptxt.strip():
continue
for ci, ch in enumerate(chunk_text(ptxt)):
entries.append({
"text": ch, "source": base,
"page_start": pno, "page_end": pno,
"chunk_id": f"{base}::p{pno}::c{ci}",
})
except Exception as e:
print(f"[WARN] Failed to parse {pdf}: {e}")
if not entries:
raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.")
texts = [e["text"] for e in entries]
embs = embed_passages(texts)
index = build_faiss(embs)
return index, entries
# =========================
# Retrieval
# =========================
def retrieve(query: str, top_k=5, must_contain: str = ""):
global faiss_index, docstore
if faiss_index is None or not docstore:
raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.")
k = int(top_k) if top_k else TOP_K_DEFAULT
pool = min(max(10 * k, 200), len(docstore))
qemb = embed_query(query)
D, I = faiss_index.search(qemb, pool)
pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0]
must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()]
if must_words:
filtered = []
for idx, score in pairs:
t = docstore[idx]["text"].lower()
if all(w in t for w in must_words):
filtered.append((idx, score))
if filtered:
pairs = filtered
pairs = pairs[:k]
hits = []
for idx, score in pairs:
item = docstore[idx].copy()
item["score"] = float(score)
hits.append(item)
return hits
# =========================
# Groq LLM
# =========================
def groq_answer(query: str, contexts, model_name="llama-3.3-70b-versatile", temperature=0.2, max_tokens=1000):
try:
if not os.environ.get("GROQ_API_KEY"):
return "GROQ_API_KEY is not set. Add it in your Space secrets or the key box."
client = Groq(api_key=os.environ["GROQ_API_KEY"])
packed, used = [], 0
for c in contexts:
tag = f"[{c['source']} p.{c['page_start']}]"
piece = f"{tag}\n{c['text'].strip()}\n"
if used + len(piece) > MAX_CONTEXT_CHARS:
break
packed.append(piece); used += len(piece)
context_str = "\n---\n".join(packed)
system_prompt = (
"You are a scholarly assistant. Answer using ONLY the provided context. "
"If the answer is not present, say so. Always include a 'References' section with sources and page numbers."
)
user_prompt = (
f"Question:\n{query}\n\n"
f"Context snippets (use these only):\n{context_str}\n\n"
"Write a precise answer. Keep claims traceable to the snippets."
)
resp = client.chat.completions.create(
model=model_name, temperature=float(temperature), max_tokens=int(max_tokens),
messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}],
)
return resp.choices[0].message.content.strip()
except Exception as e:
import traceback
return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```"
# =========================
# Export helpers
# =========================
def export_answer_to_docx(question: str, answer_md: str, rows: List[List[str]]) -> str:
doc = Document()
try:
styles = doc.styles
styles['Normal'].font.name = 'Calibri'
styles['Normal'].font.size = Pt(11)
except Exception:
pass
doc.add_heading(f"{APP_NAME} - Answer", level=1)
doc.add_paragraph(f"Question: {question}")
doc.add_heading("Answer", level=2)
for line in answer_md.splitlines():
doc.add_paragraph(line)
doc.add_heading("References (Top Passages)", level=2)
table = doc.add_table(rows=1, cols=4)
hdr = table.rows[0].cells
hdr[0].text = "Source"; hdr[1].text = "Page"; hdr[2].text = "Score"; hdr[3].text = "Snippet"
for r in rows:
row = table.add_row().cells
for i, val in enumerate(r):
row[i].text = str(val)
path = "scholarlens_answer.docx"
doc.save(path)
return path
# =========================
# UI helpers
# =========================
def build_index_from_uploads(paths: List[str]) -> str:
global faiss_index, docstore
pdfs = _collect_pdf_paths(paths)
if not pdfs:
return "Please upload at least one PDF or ZIP of PDFs."
faiss_index, entries = ingest_pdfs(pdfs)
save_index(faiss_index, entries)
docstore = entries
return f"โœ… Index built with {len(entries)} chunks from {len(pdfs)} files. You can start asking questions."
def reload_index() -> str:
ok = load_index()
return f"๐Ÿ” Index reloaded. Chunks ready: {len(docstore)}" if ok else "No saved index found yet."
def ask_rag(question: str, top_k, model_name: str, temperature: float, must_contain: str):
try:
if not question.strip():
return "Please enter a question.", [], "", gr.update(visible=False)
ctx = retrieve(question, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain)
ans = groq_answer(question, ctx, model_name=model_name, temperature=temperature)
rows = []
for c in ctx:
preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "")
rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview])
details = []
for c in ctx:
details.append(f"**{c['source']} p.{c['page_start']}**\n> {c['text'].strip()[:1000]}")
snippets_md = "\n\n---\n\n".join(details)
download_btn = gr.update(visible=True)
return ans, rows, snippets_md, download_btn
except Exception as e:
import traceback
err = f"**Error:** {e}\n```\n{traceback.format_exc()}\n```"
return err, [], "", gr.update(visible=False)
def set_api_key(k: str):
if k and k.strip():
os.environ["GROQ_API_KEY"] = k.strip()
return "๐Ÿ”‘ API key set for this session."
return "No key provided."
def download_index_zip():
if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)):
return None
zp = "rag_index_bundle.zip"
with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z:
z.write(INDEX_PATH)
z.write(STORE_PATH)
return zp
def do_export_docx(question, answer_md, sources_rows):
if not answer_md or not sources_rows:
return None
try:
return export_answer_to_docx(question, answer_md, sources_rows)
except Exception:
return None
# =========================
# UI
# =========================
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="slate")
with gr.Blocks(title=f"{APP_NAME} | RAG over PDFs", theme=theme, css=build_custom_css()) as demo:
# Hero
with gr.Group(elem_id="hero"):
gr.Markdown(f"""
<div style="display:flex;align-items:center;gap:16px;">
<div style="font-size:36px">๐Ÿ“š๐Ÿ”Ž <b>{APP_NAME}</b></div>
<div style="opacity:.9;">{TAGLINE}</div>
</div>
<p style="opacity:.85;margin-top:6px;">
Upload your papers, build an index, and ask research questions with verifiable, page-level citations.
</p>""")
# KPIs
with gr.Row():
gr.Markdown("**Meaning-aware retrieval**<br><span class='kpi'>E5 + FAISS</span>", elem_classes=["kpi"])
gr.Markdown("**Cited answers**<br><span class='kpi'>Page-level proof</span>", elem_classes=["kpi"])
gr.Markdown("**Runs anywhere**<br><span class='kpi'>HF Spaces or Colab</span>", elem_classes=["kpi"])
# Key
with gr.Row():
api_box = gr.Textbox(label="(Optional) Set GROQ_API_KEY", type="password", placeholder="sk_...")
set_btn = gr.Button("Set Key")
set_out = gr.Markdown()
set_btn.click(set_api_key, inputs=[api_box], outputs=[set_out])
with gr.Tabs():
# Build / Load
with gr.Tab("1) Build or Load Index"):
gr.Markdown("Upload PDFs or a ZIP of PDFs, then click **Build Index**.")
file_u = gr.Files(label="Upload PDFs or ZIP", file_types=[".pdf", ".zip"], type="filepath")
with gr.Row():
build_btn = gr.Button("Build Index", variant="primary")
reload_btn = gr.Button("Reload Saved Index", variant="secondary")
download_btn = gr.Button("Download Index (.zip)")
build_out = gr.Markdown()
def on_build(paths, progress=gr.Progress(track_tqdm=True)):
try:
return build_index_from_uploads(paths)
except Exception as e:
import traceback
return f"**Error while building index:** {e}\n\n```\n{traceback.format_exc()}\n```"
build_btn.click(on_build, inputs=[file_u], outputs=[build_out])
reload_btn.click(fn=reload_index, outputs=[build_out])
zpath = gr.File(label="Index bundle", interactive=False)
download_btn.click(fn=download_index_zip, outputs=[zpath])
# Ask
with gr.Tab("2) Ask Questions"):
with gr.Row():
with gr.Column(scale=1):
q = gr.Textbox(label="Your question", lines=3, placeholder="e.g., Compare GTAW parameters with citations")
must = gr.Textbox(label="Must contain (comma-separated keywords)", placeholder="camera, CMOS, frame rate")
with gr.Accordion("Advanced settings", open=False):
topk = gr.Slider(1, 20, value=TOP_K_DEFAULT, step=1, label="Top-K passages")
model_dd = gr.Dropdown(MODEL_CHOICES, value=MODEL_CHOICES[0], label="Groq model")
temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
with gr.Row():
ask_btn = gr.Button("Answer", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
gr.Examples(
examples=[
["List camera model, sensor type, resolution, and FPS across studies. Cite pages.", "camera, fps, resolution"],
["Extract limitations and future work across the corpus, with page references.", ""],
["Compare GTAW setups: current range, travel speed, torch standoff, sensors.", "GTAW, current, speed, torch"],
["Summarize results tables with metrics and page citations.", "table, accuracy, mAP, F1"]
],
inputs=[q, must],
label="Quick examples",
)
with gr.Column(scale=1.4):
ans = gr.Markdown(label="Answer", show_label=False)
src = gr.Dataframe(headers=["Source","Page","Score","Snippet"], wrap=True, label="Top passages")
with gr.Accordion("Show retrieved snippets", open=False):
snippets_md = gr.Markdown("")
with gr.Row():
export_btn = gr.Button("Export Answer to DOCX", visible=False)
exported = gr.File(label="Download answer", visible=True)
ask_btn.click(fn=ask_rag, inputs=[q, topk, model_dd, temp, must], outputs=[ans, src, snippets_md, export_btn])
export_btn.click(fn=do_export_docx, inputs=[q, ans, src], outputs=[exported])
clear_btn.click(lambda: ("", [], "", gr.update(visible=False)), outputs=[ans, src, snippets_md, export_btn])
# About
with gr.Tab("About"):
gr.Markdown("""
**ScholarLens** helps researchers move from reading to results with answers grounded in the papers you upload.
- Meaning-aware retrieval (E5 + FAISS)
- Answers limited to your corpus, with page-level citations
- Optional keyword filter to stay on topic
- Runs on Hugging Face Spaces or Google Colab
- Powered by Groq models
*Privacy note:* your files stay on this Space. Only the Groq call is external.
""")
# Run
demo.queue()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))