Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import hashlib | |
| import threading | |
| from typing import List, Dict, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import chromadb | |
| from pypdf import PdfReader | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| DB_DIR = os.environ.get("CHROMA_DB_DIR", "./chroma_db") | |
| COLLECTION_NAME = os.environ.get("CHROMA_COLLECTION", "pdf_docs") | |
| EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| DEFAULT_CHUNK_SIZE = 1200 # characters | |
| DEFAULT_CHUNK_OVERLAP = 200 # characters | |
| MAX_CHARS_PER_PDF = 1_500_000 # safety cap for huge PDFs | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def sha1_file(path: str) -> str: | |
| h = hashlib.sha1() | |
| with open(path, "rb") as f: | |
| for block in iter(lambda: f.read(1024 * 1024), b""): | |
| h.update(block) | |
| return h.hexdigest() | |
| def clean_text(t: str) -> str: | |
| t = t.replace("\x00", " ") | |
| t = re.sub(r"\s+", " ", t) | |
| return t.strip() | |
| def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: | |
| if chunk_size <= 0: | |
| return [text] | |
| if overlap >= chunk_size: | |
| overlap = max(0, chunk_size // 4) | |
| chunks = [] | |
| start = 0 | |
| n = len(text) | |
| while start < n: | |
| end = min(n, start + chunk_size) | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| if end == n: | |
| break | |
| start = max(0, end - overlap) | |
| return chunks | |
| def extract_pdf_text_by_page(pdf_path: str) -> List[Tuple[int, str]]: | |
| """Returns [(page_index_1based, text), ...]""" | |
| reader = PdfReader(pdf_path) | |
| out = [] | |
| for i, page in enumerate(reader.pages, start=1): | |
| try: | |
| txt = page.extract_text() or "" | |
| except Exception: | |
| txt = "" | |
| txt = clean_text(txt) | |
| if txt: | |
| out.append((i, txt)) | |
| return out | |
| # ----------------------------- | |
| # Vector DB + Embeddings (PyTorch) | |
| # ----------------------------- | |
| _lock = threading.Lock() | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _model = SentenceTransformer(EMBED_MODEL_NAME, device=_device) | |
| _model.eval() | |
| _client = chromadb.PersistentClient(path=DB_DIR) | |
| # Use cosine space for more intuitive similarity | |
| _collection = _client.get_or_create_collection( | |
| name=COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| def embed_texts(texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """ | |
| Returns embeddings as float32 numpy array of shape (N, D). | |
| SentenceTransformer runs on PyTorch under the hood. | |
| """ | |
| with torch.inference_mode(): | |
| emb = _model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, # good for cosine | |
| ) | |
| return emb.astype(np.float32) | |
| def add_pdf_to_db( | |
| pdf_path: str, | |
| chunk_size: int, | |
| chunk_overlap: int, | |
| ) -> Dict[str, int]: | |
| """ | |
| Extracts text, chunks it, embeds chunks, and adds to Chroma. | |
| Returns stats dict. | |
| """ | |
| file_hash = sha1_file(pdf_path) | |
| file_name = os.path.basename(pdf_path) | |
| pages = extract_pdf_text_by_page(pdf_path) | |
| if not pages: | |
| return {"added": 0, "skipped_pages": 0, "pages": 0} | |
| docs = [] | |
| metadatas = [] | |
| ids = [] | |
| total_chars = 0 | |
| for page_num, page_text in pages: | |
| total_chars += len(page_text) | |
| if total_chars > MAX_CHARS_PER_PDF: | |
| break | |
| chunks = chunk_text(page_text, chunk_size, chunk_overlap) | |
| for j, ch in enumerate(chunks): | |
| # Stable chunk id | |
| chunk_id = f"{file_hash}_p{page_num}_c{j}" | |
| ids.append(chunk_id) | |
| docs.append(ch) | |
| metadatas.append( | |
| { | |
| "source_file": file_name, | |
| "source_sha1": file_hash, | |
| "page": page_num, | |
| "chunk": j, | |
| } | |
| ) | |
| if not docs: | |
| return {"added": 0, "skipped_pages": len(pages), "pages": len(pages)} | |
| embs = embed_texts(docs) | |
| with _lock: | |
| # Upsert behavior: Chroma doesn't have true upsert everywhere; | |
| # we add and ignore duplicates by pre-checking existing ids. | |
| # For simplicity: try add; if fails, delete and re-add. | |
| try: | |
| _collection.add( | |
| ids=ids, | |
| documents=docs, | |
| metadatas=metadatas, | |
| embeddings=embs.tolist(), | |
| ) | |
| except Exception: | |
| # If duplicates exist, delete them and retry. | |
| try: | |
| _collection.delete(ids=ids) | |
| except Exception: | |
| pass | |
| _collection.add( | |
| ids=ids, | |
| documents=docs, | |
| metadatas=metadatas, | |
| embeddings=embs.tolist(), | |
| ) | |
| return {"added": len(docs), "pages": len(pages), "skipped_pages": 0} | |
| def db_stats() -> str: | |
| try: | |
| count = _collection.count() | |
| except Exception: | |
| count = 0 | |
| return f"**Collection:** `{COLLECTION_NAME}` \n**Stored chunks:** `{count}` \n**DB dir:** `{os.path.abspath(DB_DIR)}` \n**Embed model:** `{EMBED_MODEL_NAME}` \n**Device:** `{_device}`" | |
| def clear_db() -> str: | |
| with _lock: | |
| _client.delete_collection(COLLECTION_NAME) | |
| global _collection | |
| _collection = _client.get_or_create_collection( | |
| name=COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| return "✅ Cleared the vector database." | |
| def search_db(query: str, top_k: int = 5) -> Tuple[str, str]: | |
| query = (query or "").strip() | |
| if not query: | |
| return "Please enter a query.", "" | |
| with _lock: | |
| n = _collection.count() | |
| if n == 0: | |
| return "Your database is empty. Upload and index PDFs first.", "" | |
| q_emb = embed_texts([query])[0].tolist() | |
| with _lock: | |
| res = _collection.query( | |
| query_embeddings=[q_emb], | |
| n_results=int(top_k), | |
| include=["documents", "metadatas", "distances"], | |
| ) | |
| docs = res.get("documents", [[]])[0] | |
| metas = res.get("metadatas", [[]])[0] | |
| dists = res.get("distances", [[]])[0] | |
| if not docs: | |
| return "No results found.", "" | |
| # Build a “response” plus a detailed results view | |
| # For cosine: distance ~ (1 - cosine_similarity) | |
| blocks = [] | |
| for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists), start=1): | |
| sim = 1.0 - float(dist) if dist is not None else None | |
| src = meta.get("source_file", "unknown") | |
| page = meta.get("page", "?") | |
| chunk = meta.get("chunk", "?") | |
| sim_str = f"{sim:.3f}" if sim is not None else "?" | |
| blocks.append( | |
| f"### Result {i} (similarity: **{sim_str}**)\n" | |
| f"- **Source:** `{src}` (page {page}, chunk {chunk})\n\n" | |
| f"{doc}\n" | |
| ) | |
| results_md = "\n---\n".join(blocks) | |
| # “Response” field: concise summary of what was found | |
| response = ( | |
| f"Found **{len(docs)}** matching passages. The most relevant content appears to be from " | |
| f"`{metas[0].get('source_file','unknown')}` page {metas[0].get('page','?')}. " | |
| f"See the results below for the exact extracted passages." | |
| ) | |
| return response, results_md | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| def index_pdfs(files: Optional[List[gr.File]], chunk_size: int, chunk_overlap: int) -> Tuple[str, str]: | |
| if not files: | |
| return "Please upload one or more PDFs.", db_stats() | |
| added_total = 0 | |
| msgs = [] | |
| for f in files: | |
| path = f.name if hasattr(f, "name") else str(f) | |
| if not path.lower().endswith(".pdf"): | |
| msgs.append(f"⚠️ Skipped non-PDF: {os.path.basename(path)}") | |
| continue | |
| try: | |
| stats = add_pdf_to_db(path, int(chunk_size), int(chunk_overlap)) | |
| added_total += stats["added"] | |
| if stats["added"] == 0: | |
| msgs.append(f"⚠️ No extractable text in: {os.path.basename(path)} (may be scanned/image-only).") | |
| else: | |
| msgs.append(f"✅ Indexed {os.path.basename(path)}: added {stats['added']} chunks.") | |
| except Exception as e: | |
| msgs.append(f"❌ Failed {os.path.basename(path)}: {e}") | |
| msgs.append(f"\n**Total chunks added:** `{added_total}`") | |
| return "\n".join(msgs), db_stats() | |
| with gr.Blocks(title="PDF Vector Search (ChromaDB + PyTorch)") as demo: | |
| gr.Markdown("# 📄🔎 PDF Vector Search (ChromaDB + PyTorch Embeddings)") | |
| gr.Markdown( | |
| "Drag PDFs into the uploader, click **Index PDFs**, then ask questions in the **Query** box.\n\n" | |
| "**Note:** If a PDF is scanned (images only), text extraction may return nothing." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| uploader = gr.Files(label="Upload PDFs (drag & drop)", file_types=[".pdf"]) | |
| chunk_size = gr.Slider(300, 2500, value=DEFAULT_CHUNK_SIZE, step=50, label="Chunk size (characters)") | |
| chunk_overlap = gr.Slider(0, 800, value=DEFAULT_CHUNK_OVERLAP, step=25, label="Chunk overlap (characters)") | |
| with gr.Row(): | |
| btn_index = gr.Button("Index PDFs", variant="primary") | |
| btn_clear = gr.Button("Clear DB", variant="stop") | |
| index_status = gr.Markdown() | |
| with gr.Column(scale=1): | |
| stats_box = gr.Markdown(db_stats()) | |
| gr.Markdown("## Ask a question") | |
| with gr.Row(): | |
| query_in = gr.Textbox(label="Query", placeholder="Type your question (e.g., 'What is the main conclusion?')") | |
| top_k = gr.Slider(1, 12, value=5, step=1, label="Top K results") | |
| btn_search = gr.Button("Search", variant="primary") | |
| response_out = gr.Textbox(label="Response", lines=2) | |
| results_out = gr.Markdown(label="Results") | |
| btn_index.click( | |
| fn=index_pdfs, | |
| inputs=[uploader, chunk_size, chunk_overlap], | |
| outputs=[index_status, stats_box], | |
| ) | |
| btn_clear.click( | |
| fn=lambda: (clear_db(), db_stats()), | |
| inputs=[], | |
| outputs=[index_status, stats_box], | |
| ) | |
| btn_search.click( | |
| fn=search_db, | |
| inputs=[query_in, top_k], | |
| outputs=[response_out, results_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |