Spaces:
Sleeping
Sleeping
| import os | |
| import asyncio | |
| import json | |
| import hashlib | |
| import shutil | |
| from io import BytesIO | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| import requests | |
| from sentence_transformers import SentenceTransformer | |
| import fitz # PyMuPDF | |
| # ---------------- Config ---------------- | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" | |
| EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" | |
| CACHE_DIR = "./cache" | |
| SYSTEM_PROMPT = "You are a helpful assistant." | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| DOCS: List[str] = [] | |
| FILENAMES: List[str] = [] | |
| EMBEDDINGS: np.ndarray = None | |
| FAISS_INDEX = None | |
| CURRENT_CACHE_KEY: str = "" | |
| # ---------------- Periodic cache cleanup ---------------- | |
| async def clear_cache_every_5min(): | |
| while True: | |
| await asyncio.sleep(300) | |
| try: | |
| if os.path.exists(CACHE_DIR): | |
| shutil.rmtree(CACHE_DIR) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| print("🧹 Cache cleared.") | |
| except Exception as e: | |
| print(f"[Cache cleanup error] {e}") | |
| asyncio.get_event_loop().create_task(clear_cache_every_5min()) | |
| # ---------------- PDF extraction ---------------- | |
| def extract_text_from_pdf(file_bytes: bytes) -> str: | |
| try: | |
| doc = fitz.open(stream=file_bytes, filetype="pdf") | |
| return "\n".join(page.get_text() for page in doc) | |
| except Exception as e: | |
| return f"[PDF extraction error] {e}" | |
| # ---------------- Cache + FAISS helpers ---------------- | |
| def make_cache_key(files: List[Tuple[str, bytes]]) -> str: | |
| h = hashlib.sha256() | |
| for name, b in sorted(files, key=lambda x: x[0]): | |
| h.update(name.encode()) | |
| h.update(str(len(b)).encode()) | |
| h.update(hashlib.sha256(b).digest()) | |
| return h.hexdigest() | |
| def cache_save(cache_key: str, embeddings: np.ndarray, filenames: List[str]): | |
| np.savez_compressed(os.path.join(CACHE_DIR, f"{cache_key}.npz"), | |
| embeddings=embeddings, filenames=np.array(filenames)) | |
| def cache_load(cache_key: str): | |
| path = os.path.join(CACHE_DIR, f"{cache_key}.npz") | |
| if not os.path.exists(path): return None | |
| try: | |
| data = np.load(path, allow_pickle=True) | |
| return data["embeddings"], data["filenames"].tolist() | |
| except: | |
| return None | |
| def build_faiss(emb: np.ndarray): | |
| global FAISS_INDEX | |
| if emb is None or len(emb) == 0: | |
| FAISS_INDEX = None | |
| return None | |
| emb = emb.astype("float32") | |
| index = faiss.IndexFlatL2(emb.shape[1]) | |
| index.add(emb) | |
| FAISS_INDEX = index | |
| return index | |
| def search(query: str, k: int = 3): | |
| if FAISS_INDEX is None: | |
| return [] | |
| q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32") | |
| D, I = FAISS_INDEX.search(q_emb, k) | |
| return [ | |
| {"index": int(i), "distance": float(d), "text": DOCS[i], "source": FILENAMES[i]} | |
| for d, i in zip(D[0], I[0]) if i >= 0 | |
| ] | |
| # ---------------- OpenRouter API ---------------- | |
| def call_openrouter(prompt: str): | |
| if not OPENROUTER_API_KEY: | |
| return "[OpenRouter error] Missing OPENROUTER_API_KEY." | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": OPENROUTER_MODEL, | |
| "messages": [ | |
| {"role": "system", | |
| "content": SYSTEM_PROMPT + " Always respond in plain text. Avoid markdown."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| } | |
| try: | |
| r = requests.post(url, headers=headers, json=payload, timeout=60) | |
| r.raise_for_status() | |
| obj = r.json() | |
| if "choices" in obj and obj["choices"]: | |
| text = obj["choices"][0]["message"]["content"] | |
| return text.strip().replace("```", "") | |
| return "[Unexpected OpenRouter response]" | |
| except Exception as e: | |
| return f"[OpenRouter request error] {e}" | |
| # ---------- Helper to read bytes from various Gradio file shapes ---------- | |
| def read_file_bytes(f) -> Tuple[str, bytes]: | |
| """ | |
| Accepts the variety of file objects Gradio may pass: | |
| - file-like objects with .name and .read() | |
| - objects with .name and .value (NamedString) | |
| - tuples like (name, bytes) | |
| - dicts that may contain 'name' and 'data' or temporary path keys | |
| - string filesystem paths | |
| Returns (filename, bytes) | |
| Raises ValueError for unsupported shapes. | |
| """ | |
| # tuple (name, bytes) | |
| if isinstance(f, tuple) and len(f) == 2 and isinstance(f[1], (bytes, bytearray)): | |
| return f[0], bytes(f[1]) | |
| # dict-like (from some frontends) | |
| if isinstance(f, dict): | |
| name = f.get("name") or f.get("filename") or "uploaded" | |
| # raw bytes/content | |
| data = f.get("data") or f.get("content") or f.get("value") or f.get("file") | |
| if isinstance(data, (bytes, bytearray)): | |
| return name, bytes(data) | |
| if isinstance(data, str): | |
| # data could be text content | |
| try: | |
| return name, data.encode("utf-8") | |
| except Exception: | |
| pass | |
| # maybe a temp file path | |
| tmp_path = f.get("tmp_path") or f.get("path") or f.get("file") | |
| if tmp_path and isinstance(tmp_path, str) and os.path.exists(tmp_path): | |
| with open(tmp_path, "rb") as fh: | |
| return os.path.basename(tmp_path), fh.read() | |
| # file-like object with read() | |
| if hasattr(f, "name") and hasattr(f, "read"): | |
| try: | |
| name = os.path.basename(f.name) if getattr(f, "name", None) else "uploaded" | |
| return name, f.read() | |
| except Exception: | |
| pass | |
| # NamedString-like: has .name and .value | |
| if hasattr(f, "name") and hasattr(f, "value"): | |
| name = os.path.basename(getattr(f, "name") or "uploaded") | |
| v = getattr(f, "value") | |
| if isinstance(v, (bytes, bytearray)): | |
| return name, bytes(v) | |
| if isinstance(v, str): | |
| return name, v.encode("utf-8") | |
| # string path | |
| if isinstance(f, str) and os.path.exists(f): | |
| with open(f, "rb") as fh: | |
| return os.path.basename(f), fh.read() | |
| raise ValueError(f"Unsupported file object type: {type(f)}") | |
| # ---------------- PDF Upload & Index (fixed) ---------------- | |
| def upload_and_index(files): | |
| global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY | |
| if not files: | |
| return "No PDF uploaded.", "" | |
| processed = [] | |
| # files may be a single object or a list; normalize | |
| if not isinstance(files, (list, tuple)): | |
| files = [files] | |
| try: | |
| for f in files: | |
| name, b = read_file_bytes(f) | |
| processed.append((name, b)) | |
| except ValueError as e: | |
| # return a clear message to the UI so user can debug what Gradio passed | |
| return f"Upload error: {e}", "" | |
| # preview for UI | |
| preview = [{"name": n, "size": len(b)} for n, b in processed] | |
| # cache key | |
| cache_key = make_cache_key(processed) | |
| CURRENT_CACHE_KEY = cache_key | |
| cached = cache_load(cache_key) | |
| if cached: | |
| EMBEDDINGS, FILENAMES = cached | |
| EMBEDDINGS = np.array(EMBEDDINGS) | |
| DOCS = [extract_text_from_pdf(b) for _, b in processed] | |
| build_faiss(EMBEDDINGS) | |
| return f"Loaded cached embeddings ({len(FILENAMES)} PDFs).", json.dumps(preview) | |
| # extract text and index | |
| DOCS = [extract_text_from_pdf(b) for _, b in processed] | |
| FILENAMES = [n for n, _ in processed] | |
| EMBEDDINGS = embedder.encode(DOCS, convert_to_numpy=True).astype("float32") | |
| cache_save(cache_key, EMBEDDINGS, FILENAMES) | |
| build_faiss(EMBEDDINGS) | |
| return f"Uploaded + indexed {len(DOCS)} PDFs.", json.dumps(preview) | |
| # ---------------- Question Answering ---------------- | |
| def ask(question: str): | |
| if not question: | |
| return "Please enter a question." | |
| if not DOCS: | |
| return "No PDFs indexed." | |
| results = search(question) | |
| if not results: | |
| return "No relevant text found." | |
| context = "\n".join( | |
| f"Source: {r['source']}\n\n{r['text'][:15000]}\n---\n" | |
| for r in results | |
| ) | |
| prompt = f"Use this context to answer briefly:\n\n{context}\nQuestion: {question}\nAnswer:" | |
| return call_openrouter(prompt) | |
| # ---------------- Gradio UI ---------------- | |
| with gr.Blocks(title="PDF RAG Bot") as demo: | |
| gr.Markdown("# 📄 PDF-Only RAG Bot\nUpload PDFs → Ask Questions → AI Answers from PDF content.") | |
| file_input = gr.File(label="Upload PDF files", file_count="multiple", file_types=[".pdf"]) | |
| upload_btn = gr.Button("Upload & Index") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| preview = gr.Textbox(label="Upload preview (JSON)", interactive=False) | |
| upload_btn.click(upload_and_index, inputs=[file_input], outputs=[status, preview]) | |
| gr.Markdown("### Ask a Question") | |
| q = gr.Textbox(label="Your question", lines=3) | |
| ask_btn = gr.Button("Ask PDF Bot") | |
| answer = gr.Textbox(label="Answer", lines=15) | |
| ask_btn.click(ask, inputs=[q], outputs=[answer]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |