import os import json import hashlib import shutil from typing import List, Tuple import gradio as gr import numpy as np import faiss import requests from sentence_transformers import SentenceTransformer import fitz # PyMuPDF # ---------------- Config ---------------- OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" EMBEDDING_MODEL_NAME = "paraphrase-MiniLM-L3-v2" CACHE_DIR = "./cache" CHUNK_SIZE = 300 # words per chunk CHUNK_OVERLAP = 50 # overlapping words between chunks TOP_K = 4 # number of chunks to retrieve SYSTEM_PROMPT = ( "You are an expert document assistant. " "Answer questions using ONLY the provided context from the uploaded PDFs. " "Be concise, accurate, and cite which document your answer comes from. " "Always respond in plain text. Avoid markdown formatting." ) os.makedirs(CACHE_DIR, exist_ok=True) # Lazy loaded to avoid OOM on HF Spaces embedder = None def get_embedder(): global embedder if embedder is None: print("Loading embedder model...") embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) print("Embedder loaded.") return embedder # Global state CHUNKS: List[str] = [] CHUNK_SOURCES: List[str] = [] CHUNK_PAGES: List[int] = [] EMBEDDINGS: np.ndarray = None FAISS_INDEX = None INDEXED_FILES: List[dict] = [] # ---------------- Cache cleanup ---------------- def clear_old_cache(): try: if os.path.exists(CACHE_DIR): shutil.rmtree(CACHE_DIR) os.makedirs(CACHE_DIR, exist_ok=True) except Exception as e: print(f"[Cache cleanup error] {e}") # ---------------- PDF extraction with page tracking ---------------- def extract_pages_from_pdf(file_bytes: bytes) -> List[Tuple[int, str]]: """Returns list of (page_number, page_text)""" try: doc = fitz.open(stream=file_bytes, filetype="pdf") pages = [] for i, page in enumerate(doc): text = page.get_text().strip() if text: pages.append((i + 1, text)) return pages except Exception as e: return [(0, f"[PDF extraction error] {e}")] # ---------------- Chunking strategy ---------------- def chunk_text(text: str, source: str, page: int, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[Tuple[str, str, int]]: """ Splits text into overlapping word-level chunks. Returns list of (chunk_text, source, page) """ words = text.split() chunks = [] step = chunk_size - overlap for i in range(0, len(words), step): chunk = " ".join(words[i: i + chunk_size]) if len(chunk.strip()) > 50: chunks.append((chunk, source, page)) if i + chunk_size >= len(words): break return chunks # ---------------- Cache helpers ---------------- def make_cache_key(files: List[Tuple[str, bytes]]) -> str: h = hashlib.sha256() for name, b in sorted(files, key=lambda x: x[0]): h.update(name.encode()) h.update(hashlib.sha256(b).digest()) return h.hexdigest() def cache_save(cache_key: str, embeddings: np.ndarray, chunks: List[str], sources: List[str], pages: List[int]): np.savez_compressed( os.path.join(CACHE_DIR, f"{cache_key}.npz"), embeddings=embeddings, chunks=np.array(chunks), sources=np.array(sources), pages=np.array(pages), ) def cache_load(cache_key: str): path = os.path.join(CACHE_DIR, f"{cache_key}.npz") if not os.path.exists(path): return None try: data = np.load(path, allow_pickle=True) return ( data["embeddings"], data["chunks"].tolist(), data["sources"].tolist(), data["pages"].tolist(), ) except: return None # ---------------- FAISS ---------------- def build_faiss(emb: np.ndarray): global FAISS_INDEX if emb is None or len(emb) == 0: FAISS_INDEX = None return emb = emb.astype("float32") index = faiss.IndexFlatL2(emb.shape[1]) index.add(emb) FAISS_INDEX = index def search(query: str, k: int = TOP_K): if FAISS_INDEX is None or not CHUNKS: return [] q_emb = get_embedder().encode([query], convert_to_numpy=True).astype("float32") D, I = FAISS_INDEX.search(q_emb, k) results = [] for d, i in zip(D[0], I[0]): if i >= 0 and i < len(CHUNKS): results.append({ "text": CHUNKS[i], "source": CHUNK_SOURCES[i], "page": CHUNK_PAGES[i], "distance": float(d), }) return results # ---------------- OpenRouter API ---------------- def call_openrouter(messages: list) -> str: if not OPENROUTER_API_KEY: return "Error: OPENROUTER_API_KEY is not set. Please add it in HF Space secrets." url = "https://openrouter.ai/api/v1/chat/completions" headers = { "Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json", } payload = { "model": OPENROUTER_MODEL, "messages": [{"role": "system", "content": SYSTEM_PROMPT}] + messages, } try: r = requests.post(url, headers=headers, json=payload, timeout=60) r.raise_for_status() obj = r.json() if "choices" in obj and obj["choices"]: return obj["choices"][0]["message"]["content"].strip().replace("```", "") return "[Unexpected response from API]" except Exception as e: return f"[OpenRouter error] {e}" # ---------------- File bytes reader ---------------- def read_file_bytes(f) -> Tuple[str, bytes]: if isinstance(f, tuple) and len(f) == 2 and isinstance(f[1], (bytes, bytearray)): return f[0], bytes(f[1]) if isinstance(f, dict): name = f.get("name") or f.get("filename") or "uploaded" data = f.get("data") or f.get("content") or f.get("value") or f.get("file") if isinstance(data, (bytes, bytearray)): return name, bytes(data) if isinstance(data, str): try: return name, data.encode("utf-8") except Exception: pass tmp_path = f.get("tmp_path") or f.get("path") or f.get("file") if tmp_path and isinstance(tmp_path, str) and os.path.exists(tmp_path): with open(tmp_path, "rb") as fh: return os.path.basename(tmp_path), fh.read() if hasattr(f, "name") and hasattr(f, "read"): try: name = os.path.basename(f.name) if getattr(f, "name", None) else "uploaded" return name, f.read() except Exception: pass if hasattr(f, "name") and hasattr(f, "value"): name = os.path.basename(getattr(f, "name") or "uploaded") v = getattr(f, "value") if isinstance(v, (bytes, bytearray)): return name, bytes(v) if isinstance(v, str): return name, v.encode("utf-8") if isinstance(f, str) and os.path.exists(f): with open(f, "rb") as fh: return os.path.basename(f), fh.read() raise ValueError(f"Unsupported file object type: {type(f)}") # ---------------- Upload & Index ---------------- def upload_and_index(files): global CHUNKS, CHUNK_SOURCES, CHUNK_PAGES, EMBEDDINGS, INDEXED_FILES if not files: return "No files uploaded.", "No files indexed yet." clear_old_cache() processed = [] if not isinstance(files, (list, tuple)): files = [files] try: for f in files: name, b = read_file_bytes(f) processed.append((name, b)) except ValueError as e: return f"Upload error: {e}", "No files indexed yet." cache_key = make_cache_key(processed) cached = cache_load(cache_key) if cached: EMBEDDINGS, CHUNKS, CHUNK_SOURCES, CHUNK_PAGES = cached EMBEDDINGS = np.array(EMBEDDINGS) build_faiss(EMBEDDINGS) INDEXED_FILES = [{"name": n, "size_kb": round(len(b)/1024, 1)} for n, b in processed] return ( f"Loaded from cache โ {len(CHUNKS)} chunks across {len(processed)} PDF(s).", _render_file_list(INDEXED_FILES) ) all_chunks, all_sources, all_pages = [], [], [] INDEXED_FILES = [] for name, b in processed: pages = extract_pages_from_pdf(b) file_chunks = 0 for page_num, page_text in pages: for chunk, src, pg in chunk_text(page_text, name, page_num): all_chunks.append(chunk) all_sources.append(src) all_pages.append(pg) file_chunks += 1 INDEXED_FILES.append({ "name": name, "size_kb": round(len(b) / 1024, 1), "pages": len(pages), "chunks": file_chunks, }) CHUNKS = all_chunks CHUNK_SOURCES = all_sources CHUNK_PAGES = all_pages if not CHUNKS: return "Could not extract any text from the PDFs.", "No files indexed." EMBEDDINGS = get_embedder().encode(CHUNKS, convert_to_numpy=True).astype("float32") cache_save(cache_key, EMBEDDINGS, CHUNKS, CHUNK_SOURCES, CHUNK_PAGES) build_faiss(EMBEDDINGS) return ( f"Indexed {len(processed)} PDF(s) โ {len(CHUNKS)} chunks ready.", _render_file_list(INDEXED_FILES) ) def _render_file_list(files: List[dict]) -> str: if not files: return "No files indexed yet." lines = [] for f in files: parts = [f"๐ {f['name']} ({f['size_kb']} KB)"] if "pages" in f: parts.append(f"{f['pages']} pages") if "chunks" in f: parts.append(f"{f['chunks']} chunks") lines.append(" | ".join(parts)) return "\n".join(lines) # ---------------- Chat ---------------- def chat(message: str, history: list): if not message.strip(): return "", history if not CHUNKS: history.append((message, "No PDFs indexed yet. Please upload a PDF first.")) return "", history results = search(message) if not results: history.append((message, "No relevant content found in the uploaded PDFs.")) return "", history context_parts = [] sources_used = [] for r in results: context_parts.append(f"[From: {r['source']}, Page {r['page']}]\n{r['text']}") source_ref = f"{r['source']} (p.{r['page']})" if source_ref not in sources_used: sources_used.append(source_ref) context = "\n\n---\n\n".join(context_parts) # Multi-turn: include last 4 exchanges messages = [] for user_msg, bot_msg in history[-4:]: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) messages.append({ "role": "user", "content": f"Context from PDFs:\n\n{context}\n\nQuestion: {message}" }) answer = call_openrouter(messages) if sources_used: answer += f"\n\nSources: {', '.join(sources_used)}" history.append((message, answer)) return "", history def clear_chat(): return [] # ---------------- Custom CSS ---------------- custom_css = """ @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Mono:wght@300;400;500&display=swap'); :root { --bg: #0d0f12; --surface: #13161b; --surface2: #1a1e26; --border: #252a35; --accent: #4fffb0; --accent2: #00c2ff; --text: #e8eaf0; --muted: #6b7280; } body, .gradio-container { background: var(--bg) !important; font-family: 'DM Mono', monospace !important; color: var(--text) !important; } .gradio-container { max-width: 1100px !important; margin: 0 auto !important; } .app-header { text-align: center; padding: 36px 0 28px; border-bottom: 1px solid var(--border); margin-bottom: 28px; } .app-header h1 { font-family: 'Syne', sans-serif; font-size: 2.4rem; font-weight: 800; background: linear-gradient(135deg, var(--accent), var(--accent2)); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; margin: 0 0 6px; letter-spacing: -1px; } .app-header p { color: var(--muted); font-size: 0.85rem; margin: 0; font-family: 'DM Mono', monospace; } .section-label { font-family: 'Syne', sans-serif; font-size: 0.7rem; font-weight: 700; letter-spacing: 2.5px; text-transform: uppercase; color: var(--accent); margin-bottom: 10px; } textarea, input[type="text"] { background: var(--surface2) !important; border: 1px solid var(--border) !important; border-radius: 8px !important; color: var(--text) !important; font-family: 'DM Mono', monospace !important; font-size: 0.87rem !important; } textarea:focus, input[type="text"]:focus { border-color: var(--accent) !important; box-shadow: 0 0 0 2px rgba(79,255,176,0.08) !important; } .footer-note { text-align: center; margin-top: 28px; color: #2d3340; font-size: 0.72rem; font-family: 'DM Mono', monospace; letter-spacing: 0.5px; } """ # ---------------- Gradio UI ---------------- with gr.Blocks( title="PDF RAG Bot", css=custom_css, theme=gr.themes.Base( primary_hue="emerald", neutral_hue="slate", ) ) as demo: gr.HTML("""
Upload PDFs ยท Semantic chunking ยท Ask anything ยท AI answers with page sources