Spaces:
Running
Running
| import os | |
| import json | |
| import hashlib | |
| import shutil | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| import requests | |
| from sentence_transformers import SentenceTransformer | |
| import fitz # PyMuPDF | |
| # ---------------- Config ---------------- | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" | |
| EMBEDDING_MODEL_NAME = "paraphrase-MiniLM-L3-v2" | |
| CACHE_DIR = "./cache" | |
| CHUNK_SIZE = 300 # words per chunk | |
| CHUNK_OVERLAP = 50 # overlapping words between chunks | |
| TOP_K = 4 # number of chunks to retrieve | |
| SYSTEM_PROMPT = ( | |
| "You are an expert document assistant. " | |
| "Answer questions using ONLY the provided context from the uploaded PDFs. " | |
| "Be concise, accurate, and cite which document your answer comes from. " | |
| "Always respond in plain text. Avoid markdown formatting." | |
| ) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # Lazy loaded to avoid OOM on HF Spaces | |
| embedder = None | |
| def get_embedder(): | |
| global embedder | |
| if embedder is None: | |
| print("Loading embedder model...") | |
| embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| print("Embedder loaded.") | |
| return embedder | |
| # Global state | |
| CHUNKS: List[str] = [] | |
| CHUNK_SOURCES: List[str] = [] | |
| CHUNK_PAGES: List[int] = [] | |
| EMBEDDINGS: np.ndarray = None | |
| FAISS_INDEX = None | |
| INDEXED_FILES: List[dict] = [] | |
| # ---------------- Cache cleanup ---------------- | |
| def clear_old_cache(): | |
| try: | |
| if os.path.exists(CACHE_DIR): | |
| shutil.rmtree(CACHE_DIR) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| except Exception as e: | |
| print(f"[Cache cleanup error] {e}") | |
| # ---------------- PDF extraction with page tracking ---------------- | |
| def extract_pages_from_pdf(file_bytes: bytes) -> List[Tuple[int, str]]: | |
| """Returns list of (page_number, page_text)""" | |
| try: | |
| doc = fitz.open(stream=file_bytes, filetype="pdf") | |
| pages = [] | |
| for i, page in enumerate(doc): | |
| text = page.get_text().strip() | |
| if text: | |
| pages.append((i + 1, text)) | |
| return pages | |
| except Exception as e: | |
| return [(0, f"[PDF extraction error] {e}")] | |
| # ---------------- Chunking strategy ---------------- | |
| def chunk_text(text: str, source: str, page: int, | |
| chunk_size: int = CHUNK_SIZE, | |
| overlap: int = CHUNK_OVERLAP) -> List[Tuple[str, str, int]]: | |
| """ | |
| Splits text into overlapping word-level chunks. | |
| Returns list of (chunk_text, source, page) | |
| """ | |
| words = text.split() | |
| chunks = [] | |
| step = chunk_size - overlap | |
| for i in range(0, len(words), step): | |
| chunk = " ".join(words[i: i + chunk_size]) | |
| if len(chunk.strip()) > 50: | |
| chunks.append((chunk, source, page)) | |
| if i + chunk_size >= len(words): | |
| break | |
| return chunks | |
| # ---------------- Cache helpers ---------------- | |
| def make_cache_key(files: List[Tuple[str, bytes]]) -> str: | |
| h = hashlib.sha256() | |
| for name, b in sorted(files, key=lambda x: x[0]): | |
| h.update(name.encode()) | |
| h.update(hashlib.sha256(b).digest()) | |
| return h.hexdigest() | |
| def cache_save(cache_key: str, embeddings: np.ndarray, | |
| chunks: List[str], sources: List[str], pages: List[int]): | |
| np.savez_compressed( | |
| os.path.join(CACHE_DIR, f"{cache_key}.npz"), | |
| embeddings=embeddings, | |
| chunks=np.array(chunks), | |
| sources=np.array(sources), | |
| pages=np.array(pages), | |
| ) | |
| def cache_load(cache_key: str): | |
| path = os.path.join(CACHE_DIR, f"{cache_key}.npz") | |
| if not os.path.exists(path): | |
| return None | |
| try: | |
| data = np.load(path, allow_pickle=True) | |
| return ( | |
| data["embeddings"], | |
| data["chunks"].tolist(), | |
| data["sources"].tolist(), | |
| data["pages"].tolist(), | |
| ) | |
| except: | |
| return None | |
| # ---------------- FAISS ---------------- | |
| def build_faiss(emb: np.ndarray): | |
| global FAISS_INDEX | |
| if emb is None or len(emb) == 0: | |
| FAISS_INDEX = None | |
| return | |
| emb = emb.astype("float32") | |
| index = faiss.IndexFlatL2(emb.shape[1]) | |
| index.add(emb) | |
| FAISS_INDEX = index | |
| def search(query: str, k: int = TOP_K): | |
| if FAISS_INDEX is None or not CHUNKS: | |
| return [] | |
| q_emb = get_embedder().encode([query], convert_to_numpy=True).astype("float32") | |
| D, I = FAISS_INDEX.search(q_emb, k) | |
| results = [] | |
| for d, i in zip(D[0], I[0]): | |
| if i >= 0 and i < len(CHUNKS): | |
| results.append({ | |
| "text": CHUNKS[i], | |
| "source": CHUNK_SOURCES[i], | |
| "page": CHUNK_PAGES[i], | |
| "distance": float(d), | |
| }) | |
| return results | |
| # ---------------- OpenRouter API ---------------- | |
| def call_openrouter(messages: list) -> str: | |
| if not OPENROUTER_API_KEY: | |
| return "Error: OPENROUTER_API_KEY is not set. Please add it in HF Space secrets." | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": OPENROUTER_MODEL, | |
| "messages": [{"role": "system", "content": SYSTEM_PROMPT}] + messages, | |
| } | |
| try: | |
| r = requests.post(url, headers=headers, json=payload, timeout=60) | |
| r.raise_for_status() | |
| obj = r.json() | |
| if "choices" in obj and obj["choices"]: | |
| return obj["choices"][0]["message"]["content"].strip().replace("```", "") | |
| return "[Unexpected response from API]" | |
| except Exception as e: | |
| return f"[OpenRouter error] {e}" | |
| # ---------------- File bytes reader ---------------- | |
| def read_file_bytes(f) -> Tuple[str, bytes]: | |
| if isinstance(f, tuple) and len(f) == 2 and isinstance(f[1], (bytes, bytearray)): | |
| return f[0], bytes(f[1]) | |
| if isinstance(f, dict): | |
| name = f.get("name") or f.get("filename") or "uploaded" | |
| data = f.get("data") or f.get("content") or f.get("value") or f.get("file") | |
| if isinstance(data, (bytes, bytearray)): | |
| return name, bytes(data) | |
| if isinstance(data, str): | |
| try: | |
| return name, data.encode("utf-8") | |
| except Exception: | |
| pass | |
| tmp_path = f.get("tmp_path") or f.get("path") or f.get("file") | |
| if tmp_path and isinstance(tmp_path, str) and os.path.exists(tmp_path): | |
| with open(tmp_path, "rb") as fh: | |
| return os.path.basename(tmp_path), fh.read() | |
| if hasattr(f, "name") and hasattr(f, "read"): | |
| try: | |
| name = os.path.basename(f.name) if getattr(f, "name", None) else "uploaded" | |
| return name, f.read() | |
| except Exception: | |
| pass | |
| if hasattr(f, "name") and hasattr(f, "value"): | |
| name = os.path.basename(getattr(f, "name") or "uploaded") | |
| v = getattr(f, "value") | |
| if isinstance(v, (bytes, bytearray)): | |
| return name, bytes(v) | |
| if isinstance(v, str): | |
| return name, v.encode("utf-8") | |
| if isinstance(f, str) and os.path.exists(f): | |
| with open(f, "rb") as fh: | |
| return os.path.basename(f), fh.read() | |
| raise ValueError(f"Unsupported file object type: {type(f)}") | |
| # ---------------- Upload & Index ---------------- | |
| def upload_and_index(files): | |
| global CHUNKS, CHUNK_SOURCES, CHUNK_PAGES, EMBEDDINGS, INDEXED_FILES | |
| if not files: | |
| return "No files uploaded.", "No files indexed yet." | |
| clear_old_cache() | |
| processed = [] | |
| if not isinstance(files, (list, tuple)): | |
| files = [files] | |
| try: | |
| for f in files: | |
| name, b = read_file_bytes(f) | |
| processed.append((name, b)) | |
| except ValueError as e: | |
| return f"Upload error: {e}", "No files indexed yet." | |
| cache_key = make_cache_key(processed) | |
| cached = cache_load(cache_key) | |
| if cached: | |
| EMBEDDINGS, CHUNKS, CHUNK_SOURCES, CHUNK_PAGES = cached | |
| EMBEDDINGS = np.array(EMBEDDINGS) | |
| build_faiss(EMBEDDINGS) | |
| INDEXED_FILES = [{"name": n, "size_kb": round(len(b)/1024, 1)} for n, b in processed] | |
| return ( | |
| f"Loaded from cache β {len(CHUNKS)} chunks across {len(processed)} PDF(s).", | |
| _render_file_list(INDEXED_FILES) | |
| ) | |
| all_chunks, all_sources, all_pages = [], [], [] | |
| INDEXED_FILES = [] | |
| for name, b in processed: | |
| pages = extract_pages_from_pdf(b) | |
| file_chunks = 0 | |
| for page_num, page_text in pages: | |
| for chunk, src, pg in chunk_text(page_text, name, page_num): | |
| all_chunks.append(chunk) | |
| all_sources.append(src) | |
| all_pages.append(pg) | |
| file_chunks += 1 | |
| INDEXED_FILES.append({ | |
| "name": name, | |
| "size_kb": round(len(b) / 1024, 1), | |
| "pages": len(pages), | |
| "chunks": file_chunks, | |
| }) | |
| CHUNKS = all_chunks | |
| CHUNK_SOURCES = all_sources | |
| CHUNK_PAGES = all_pages | |
| if not CHUNKS: | |
| return "Could not extract any text from the PDFs.", "No files indexed." | |
| EMBEDDINGS = get_embedder().encode(CHUNKS, convert_to_numpy=True).astype("float32") | |
| cache_save(cache_key, EMBEDDINGS, CHUNKS, CHUNK_SOURCES, CHUNK_PAGES) | |
| build_faiss(EMBEDDINGS) | |
| return ( | |
| f"Indexed {len(processed)} PDF(s) β {len(CHUNKS)} chunks ready.", | |
| _render_file_list(INDEXED_FILES) | |
| ) | |
| def _render_file_list(files: List[dict]) -> str: | |
| if not files: | |
| return "No files indexed yet." | |
| lines = [] | |
| for f in files: | |
| parts = [f"π {f['name']} ({f['size_kb']} KB)"] | |
| if "pages" in f: | |
| parts.append(f"{f['pages']} pages") | |
| if "chunks" in f: | |
| parts.append(f"{f['chunks']} chunks") | |
| lines.append(" | ".join(parts)) | |
| return "\n".join(lines) | |
| # ---------------- Chat ---------------- | |
| def chat(message: str, history: list): | |
| if not message.strip(): | |
| return "", history | |
| if not CHUNKS: | |
| history.append((message, "No PDFs indexed yet. Please upload a PDF first.")) | |
| return "", history | |
| results = search(message) | |
| if not results: | |
| history.append((message, "No relevant content found in the uploaded PDFs.")) | |
| return "", history | |
| context_parts = [] | |
| sources_used = [] | |
| for r in results: | |
| context_parts.append(f"[From: {r['source']}, Page {r['page']}]\n{r['text']}") | |
| source_ref = f"{r['source']} (p.{r['page']})" | |
| if source_ref not in sources_used: | |
| sources_used.append(source_ref) | |
| context = "\n\n---\n\n".join(context_parts) | |
| # Multi-turn: include last 4 exchanges | |
| messages = [] | |
| for user_msg, bot_msg in history[-4:]: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| messages.append({ | |
| "role": "user", | |
| "content": f"Context from PDFs:\n\n{context}\n\nQuestion: {message}" | |
| }) | |
| answer = call_openrouter(messages) | |
| if sources_used: | |
| answer += f"\n\nSources: {', '.join(sources_used)}" | |
| history.append((message, answer)) | |
| return "", history | |
| def clear_chat(): | |
| return [] | |
| # ---------------- Custom CSS ---------------- | |
| custom_css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Mono:wght@300;400;500&display=swap'); | |
| :root { | |
| --bg: #0d0f12; | |
| --surface: #13161b; | |
| --surface2: #1a1e26; | |
| --border: #252a35; | |
| --accent: #4fffb0; | |
| --accent2: #00c2ff; | |
| --text: #e8eaf0; | |
| --muted: #6b7280; | |
| } | |
| body, .gradio-container { | |
| background: var(--bg) !important; | |
| font-family: 'DM Mono', monospace !important; | |
| color: var(--text) !important; | |
| } | |
| .gradio-container { | |
| max-width: 1100px !important; | |
| margin: 0 auto !important; | |
| } | |
| .app-header { | |
| text-align: center; | |
| padding: 36px 0 28px; | |
| border-bottom: 1px solid var(--border); | |
| margin-bottom: 28px; | |
| } | |
| .app-header h1 { | |
| font-family: 'Syne', sans-serif; | |
| font-size: 2.4rem; | |
| font-weight: 800; | |
| background: linear-gradient(135deg, var(--accent), var(--accent2)); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| margin: 0 0 6px; | |
| letter-spacing: -1px; | |
| } | |
| .app-header p { | |
| color: var(--muted); | |
| font-size: 0.85rem; | |
| margin: 0; | |
| font-family: 'DM Mono', monospace; | |
| } | |
| .section-label { | |
| font-family: 'Syne', sans-serif; | |
| font-size: 0.7rem; | |
| font-weight: 700; | |
| letter-spacing: 2.5px; | |
| text-transform: uppercase; | |
| color: var(--accent); | |
| margin-bottom: 10px; | |
| } | |
| textarea, input[type="text"] { | |
| background: var(--surface2) !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 8px !important; | |
| color: var(--text) !important; | |
| font-family: 'DM Mono', monospace !important; | |
| font-size: 0.87rem !important; | |
| } | |
| textarea:focus, input[type="text"]:focus { | |
| border-color: var(--accent) !important; | |
| box-shadow: 0 0 0 2px rgba(79,255,176,0.08) !important; | |
| } | |
| .footer-note { | |
| text-align: center; | |
| margin-top: 28px; | |
| color: #2d3340; | |
| font-size: 0.72rem; | |
| font-family: 'DM Mono', monospace; | |
| letter-spacing: 0.5px; | |
| } | |
| """ | |
| # ---------------- Gradio UI ---------------- | |
| with gr.Blocks( | |
| title="PDF RAG Bot", | |
| css=custom_css, | |
| theme=gr.themes.Base( | |
| primary_hue="emerald", | |
| neutral_hue="slate", | |
| ) | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="app-header"> | |
| <h1>β‘ PDF RAG Bot</h1> | |
| <p>Upload PDFs Β· Semantic chunking Β· Ask anything Β· AI answers with page sources</p> | |
| </div> | |
| """) | |
| with gr.Row(equal_height=False): | |
| # ββ Left: Upload panel ββ | |
| with gr.Column(scale=1, min_width=280): | |
| gr.HTML('<div class="section-label">π Document Upload</div>') | |
| file_input = gr.File( | |
| label="Drop PDF files here", | |
| file_count="multiple", | |
| file_types=[".pdf"], | |
| ) | |
| upload_btn = gr.Button("β‘ Upload & Index", variant="primary", size="lg") | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=2, | |
| ) | |
| file_list = gr.Textbox( | |
| label="Indexed Files", | |
| interactive=False, | |
| lines=6, | |
| placeholder="No files indexed yet...", | |
| ) | |
| # ββ Right: Chat panel ββ | |
| with gr.Column(scale=2): | |
| gr.HTML('<div class="section-label">π¬ Chat with your PDFs</div>') | |
| chatbot = gr.Chatbot( | |
| label="", | |
| height=430, | |
| bubble_full_width=False, | |
| show_label=False, | |
| placeholder="Upload a PDF and start asking questions...", | |
| ) | |
| with gr.Row(): | |
| question = gr.Textbox( | |
| label="", | |
| placeholder="Ask something about your documents...", | |
| lines=2, | |
| scale=5, | |
| show_label=False, | |
| ) | |
| with gr.Column(scale=1, min_width=90): | |
| send_btn = gr.Button("Send β€", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| gr.HTML(""" | |
| <div class="footer-note"> | |
| Powered by OpenRouter Β· nvidia/nemotron-nano-12b Β· | |
| sentence-transformers Β· FAISS vector search | |
| </div> | |
| """) | |
| # Events | |
| upload_btn.click( | |
| upload_and_index, | |
| inputs=[file_input], | |
| outputs=[status, file_list], | |
| ) | |
| send_btn.click( | |
| chat, | |
| inputs=[question, chatbot], | |
| outputs=[question, chatbot], | |
| ) | |
| question.submit( | |
| chat, | |
| inputs=[question, chatbot], | |
| outputs=[question, chatbot], | |
| ) | |
| clear_btn.click(clear_chat, outputs=[chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |