pdf / app.py
Resham2987's picture
Update app.py
0c1607e verified
import os
import json
import hashlib
import shutil
from typing import List, Tuple
import gradio as gr
import numpy as np
import faiss
import requests
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF
# ---------------- Config ----------------
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free"
EMBEDDING_MODEL_NAME = "paraphrase-MiniLM-L3-v2"
CACHE_DIR = "./cache"
CHUNK_SIZE = 300 # words per chunk
CHUNK_OVERLAP = 50 # overlapping words between chunks
TOP_K = 4 # number of chunks to retrieve
SYSTEM_PROMPT = (
"You are an expert document assistant. "
"Answer questions using ONLY the provided context from the uploaded PDFs. "
"Be concise, accurate, and cite which document your answer comes from. "
"Always respond in plain text. Avoid markdown formatting."
)
os.makedirs(CACHE_DIR, exist_ok=True)
# Lazy loaded to avoid OOM on HF Spaces
embedder = None
def get_embedder():
global embedder
if embedder is None:
print("Loading embedder model...")
embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
print("Embedder loaded.")
return embedder
# Global state
CHUNKS: List[str] = []
CHUNK_SOURCES: List[str] = []
CHUNK_PAGES: List[int] = []
EMBEDDINGS: np.ndarray = None
FAISS_INDEX = None
INDEXED_FILES: List[dict] = []
# ---------------- Cache cleanup ----------------
def clear_old_cache():
try:
if os.path.exists(CACHE_DIR):
shutil.rmtree(CACHE_DIR)
os.makedirs(CACHE_DIR, exist_ok=True)
except Exception as e:
print(f"[Cache cleanup error] {e}")
# ---------------- PDF extraction with page tracking ----------------
def extract_pages_from_pdf(file_bytes: bytes) -> List[Tuple[int, str]]:
"""Returns list of (page_number, page_text)"""
try:
doc = fitz.open(stream=file_bytes, filetype="pdf")
pages = []
for i, page in enumerate(doc):
text = page.get_text().strip()
if text:
pages.append((i + 1, text))
return pages
except Exception as e:
return [(0, f"[PDF extraction error] {e}")]
# ---------------- Chunking strategy ----------------
def chunk_text(text: str, source: str, page: int,
chunk_size: int = CHUNK_SIZE,
overlap: int = CHUNK_OVERLAP) -> List[Tuple[str, str, int]]:
"""
Splits text into overlapping word-level chunks.
Returns list of (chunk_text, source, page)
"""
words = text.split()
chunks = []
step = chunk_size - overlap
for i in range(0, len(words), step):
chunk = " ".join(words[i: i + chunk_size])
if len(chunk.strip()) > 50:
chunks.append((chunk, source, page))
if i + chunk_size >= len(words):
break
return chunks
# ---------------- Cache helpers ----------------
def make_cache_key(files: List[Tuple[str, bytes]]) -> str:
h = hashlib.sha256()
for name, b in sorted(files, key=lambda x: x[0]):
h.update(name.encode())
h.update(hashlib.sha256(b).digest())
return h.hexdigest()
def cache_save(cache_key: str, embeddings: np.ndarray,
chunks: List[str], sources: List[str], pages: List[int]):
np.savez_compressed(
os.path.join(CACHE_DIR, f"{cache_key}.npz"),
embeddings=embeddings,
chunks=np.array(chunks),
sources=np.array(sources),
pages=np.array(pages),
)
def cache_load(cache_key: str):
path = os.path.join(CACHE_DIR, f"{cache_key}.npz")
if not os.path.exists(path):
return None
try:
data = np.load(path, allow_pickle=True)
return (
data["embeddings"],
data["chunks"].tolist(),
data["sources"].tolist(),
data["pages"].tolist(),
)
except:
return None
# ---------------- FAISS ----------------
def build_faiss(emb: np.ndarray):
global FAISS_INDEX
if emb is None or len(emb) == 0:
FAISS_INDEX = None
return
emb = emb.astype("float32")
index = faiss.IndexFlatL2(emb.shape[1])
index.add(emb)
FAISS_INDEX = index
def search(query: str, k: int = TOP_K):
if FAISS_INDEX is None or not CHUNKS:
return []
q_emb = get_embedder().encode([query], convert_to_numpy=True).astype("float32")
D, I = FAISS_INDEX.search(q_emb, k)
results = []
for d, i in zip(D[0], I[0]):
if i >= 0 and i < len(CHUNKS):
results.append({
"text": CHUNKS[i],
"source": CHUNK_SOURCES[i],
"page": CHUNK_PAGES[i],
"distance": float(d),
})
return results
# ---------------- OpenRouter API ----------------
def call_openrouter(messages: list) -> str:
if not OPENROUTER_API_KEY:
return "Error: OPENROUTER_API_KEY is not set. Please add it in HF Space secrets."
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
}
payload = {
"model": OPENROUTER_MODEL,
"messages": [{"role": "system", "content": SYSTEM_PROMPT}] + messages,
}
try:
r = requests.post(url, headers=headers, json=payload, timeout=60)
r.raise_for_status()
obj = r.json()
if "choices" in obj and obj["choices"]:
return obj["choices"][0]["message"]["content"].strip().replace("```", "")
return "[Unexpected response from API]"
except Exception as e:
return f"[OpenRouter error] {e}"
# ---------------- File bytes reader ----------------
def read_file_bytes(f) -> Tuple[str, bytes]:
if isinstance(f, tuple) and len(f) == 2 and isinstance(f[1], (bytes, bytearray)):
return f[0], bytes(f[1])
if isinstance(f, dict):
name = f.get("name") or f.get("filename") or "uploaded"
data = f.get("data") or f.get("content") or f.get("value") or f.get("file")
if isinstance(data, (bytes, bytearray)):
return name, bytes(data)
if isinstance(data, str):
try:
return name, data.encode("utf-8")
except Exception:
pass
tmp_path = f.get("tmp_path") or f.get("path") or f.get("file")
if tmp_path and isinstance(tmp_path, str) and os.path.exists(tmp_path):
with open(tmp_path, "rb") as fh:
return os.path.basename(tmp_path), fh.read()
if hasattr(f, "name") and hasattr(f, "read"):
try:
name = os.path.basename(f.name) if getattr(f, "name", None) else "uploaded"
return name, f.read()
except Exception:
pass
if hasattr(f, "name") and hasattr(f, "value"):
name = os.path.basename(getattr(f, "name") or "uploaded")
v = getattr(f, "value")
if isinstance(v, (bytes, bytearray)):
return name, bytes(v)
if isinstance(v, str):
return name, v.encode("utf-8")
if isinstance(f, str) and os.path.exists(f):
with open(f, "rb") as fh:
return os.path.basename(f), fh.read()
raise ValueError(f"Unsupported file object type: {type(f)}")
# ---------------- Upload & Index ----------------
def upload_and_index(files):
global CHUNKS, CHUNK_SOURCES, CHUNK_PAGES, EMBEDDINGS, INDEXED_FILES
if not files:
return "No files uploaded.", "No files indexed yet."
clear_old_cache()
processed = []
if not isinstance(files, (list, tuple)):
files = [files]
try:
for f in files:
name, b = read_file_bytes(f)
processed.append((name, b))
except ValueError as e:
return f"Upload error: {e}", "No files indexed yet."
cache_key = make_cache_key(processed)
cached = cache_load(cache_key)
if cached:
EMBEDDINGS, CHUNKS, CHUNK_SOURCES, CHUNK_PAGES = cached
EMBEDDINGS = np.array(EMBEDDINGS)
build_faiss(EMBEDDINGS)
INDEXED_FILES = [{"name": n, "size_kb": round(len(b)/1024, 1)} for n, b in processed]
return (
f"Loaded from cache β€” {len(CHUNKS)} chunks across {len(processed)} PDF(s).",
_render_file_list(INDEXED_FILES)
)
all_chunks, all_sources, all_pages = [], [], []
INDEXED_FILES = []
for name, b in processed:
pages = extract_pages_from_pdf(b)
file_chunks = 0
for page_num, page_text in pages:
for chunk, src, pg in chunk_text(page_text, name, page_num):
all_chunks.append(chunk)
all_sources.append(src)
all_pages.append(pg)
file_chunks += 1
INDEXED_FILES.append({
"name": name,
"size_kb": round(len(b) / 1024, 1),
"pages": len(pages),
"chunks": file_chunks,
})
CHUNKS = all_chunks
CHUNK_SOURCES = all_sources
CHUNK_PAGES = all_pages
if not CHUNKS:
return "Could not extract any text from the PDFs.", "No files indexed."
EMBEDDINGS = get_embedder().encode(CHUNKS, convert_to_numpy=True).astype("float32")
cache_save(cache_key, EMBEDDINGS, CHUNKS, CHUNK_SOURCES, CHUNK_PAGES)
build_faiss(EMBEDDINGS)
return (
f"Indexed {len(processed)} PDF(s) β€” {len(CHUNKS)} chunks ready.",
_render_file_list(INDEXED_FILES)
)
def _render_file_list(files: List[dict]) -> str:
if not files:
return "No files indexed yet."
lines = []
for f in files:
parts = [f"πŸ“„ {f['name']} ({f['size_kb']} KB)"]
if "pages" in f:
parts.append(f"{f['pages']} pages")
if "chunks" in f:
parts.append(f"{f['chunks']} chunks")
lines.append(" | ".join(parts))
return "\n".join(lines)
# ---------------- Chat ----------------
def chat(message: str, history: list):
if not message.strip():
return "", history
if not CHUNKS:
history.append((message, "No PDFs indexed yet. Please upload a PDF first."))
return "", history
results = search(message)
if not results:
history.append((message, "No relevant content found in the uploaded PDFs."))
return "", history
context_parts = []
sources_used = []
for r in results:
context_parts.append(f"[From: {r['source']}, Page {r['page']}]\n{r['text']}")
source_ref = f"{r['source']} (p.{r['page']})"
if source_ref not in sources_used:
sources_used.append(source_ref)
context = "\n\n---\n\n".join(context_parts)
# Multi-turn: include last 4 exchanges
messages = []
for user_msg, bot_msg in history[-4:]:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({
"role": "user",
"content": f"Context from PDFs:\n\n{context}\n\nQuestion: {message}"
})
answer = call_openrouter(messages)
if sources_used:
answer += f"\n\nSources: {', '.join(sources_used)}"
history.append((message, answer))
return "", history
def clear_chat():
return []
# ---------------- Custom CSS ----------------
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Mono:wght@300;400;500&display=swap');
:root {
--bg: #0d0f12;
--surface: #13161b;
--surface2: #1a1e26;
--border: #252a35;
--accent: #4fffb0;
--accent2: #00c2ff;
--text: #e8eaf0;
--muted: #6b7280;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'DM Mono', monospace !important;
color: var(--text) !important;
}
.gradio-container {
max-width: 1100px !important;
margin: 0 auto !important;
}
.app-header {
text-align: center;
padding: 36px 0 28px;
border-bottom: 1px solid var(--border);
margin-bottom: 28px;
}
.app-header h1 {
font-family: 'Syne', sans-serif;
font-size: 2.4rem;
font-weight: 800;
background: linear-gradient(135deg, var(--accent), var(--accent2));
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
margin: 0 0 6px;
letter-spacing: -1px;
}
.app-header p {
color: var(--muted);
font-size: 0.85rem;
margin: 0;
font-family: 'DM Mono', monospace;
}
.section-label {
font-family: 'Syne', sans-serif;
font-size: 0.7rem;
font-weight: 700;
letter-spacing: 2.5px;
text-transform: uppercase;
color: var(--accent);
margin-bottom: 10px;
}
textarea, input[type="text"] {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 8px !important;
color: var(--text) !important;
font-family: 'DM Mono', monospace !important;
font-size: 0.87rem !important;
}
textarea:focus, input[type="text"]:focus {
border-color: var(--accent) !important;
box-shadow: 0 0 0 2px rgba(79,255,176,0.08) !important;
}
.footer-note {
text-align: center;
margin-top: 28px;
color: #2d3340;
font-size: 0.72rem;
font-family: 'DM Mono', monospace;
letter-spacing: 0.5px;
}
"""
# ---------------- Gradio UI ----------------
with gr.Blocks(
title="PDF RAG Bot",
css=custom_css,
theme=gr.themes.Base(
primary_hue="emerald",
neutral_hue="slate",
)
) as demo:
gr.HTML("""
<div class="app-header">
<h1>⚑ PDF RAG Bot</h1>
<p>Upload PDFs &nbsp;Β·&nbsp; Semantic chunking &nbsp;Β·&nbsp; Ask anything &nbsp;Β·&nbsp; AI answers with page sources</p>
</div>
""")
with gr.Row(equal_height=False):
# ── Left: Upload panel ──
with gr.Column(scale=1, min_width=280):
gr.HTML('<div class="section-label">πŸ“‚ Document Upload</div>')
file_input = gr.File(
label="Drop PDF files here",
file_count="multiple",
file_types=[".pdf"],
)
upload_btn = gr.Button("⚑ Upload & Index", variant="primary", size="lg")
status = gr.Textbox(
label="Status",
interactive=False,
lines=2,
)
file_list = gr.Textbox(
label="Indexed Files",
interactive=False,
lines=6,
placeholder="No files indexed yet...",
)
# ── Right: Chat panel ──
with gr.Column(scale=2):
gr.HTML('<div class="section-label">πŸ’¬ Chat with your PDFs</div>')
chatbot = gr.Chatbot(
label="",
height=430,
bubble_full_width=False,
show_label=False,
placeholder="Upload a PDF and start asking questions...",
)
with gr.Row():
question = gr.Textbox(
label="",
placeholder="Ask something about your documents...",
lines=2,
scale=5,
show_label=False,
)
with gr.Column(scale=1, min_width=90):
send_btn = gr.Button("Send ➀", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
gr.HTML("""
<div class="footer-note">
Powered by OpenRouter &nbsp;Β·&nbsp; nvidia/nemotron-nano-12b &nbsp;Β·&nbsp;
sentence-transformers &nbsp;Β·&nbsp; FAISS vector search
</div>
""")
# Events
upload_btn.click(
upload_and_index,
inputs=[file_input],
outputs=[status, file_list],
)
send_btn.click(
chat,
inputs=[question, chatbot],
outputs=[question, chatbot],
)
question.submit(
chat,
inputs=[question, chatbot],
outputs=[question, chatbot],
)
clear_btn.click(clear_chat, outputs=[chatbot])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)