import os import re import hashlib import threading from typing import List, Dict, Tuple, Optional import numpy as np import torch import gradio as gr import chromadb from pypdf import PdfReader from sentence_transformers import SentenceTransformer # ----------------------------- # Config # ----------------------------- DB_DIR = os.environ.get("CHROMA_DB_DIR", "./chroma_db") COLLECTION_NAME = os.environ.get("CHROMA_COLLECTION", "pdf_docs") EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") DEFAULT_CHUNK_SIZE = 1200 # characters DEFAULT_CHUNK_OVERLAP = 200 # characters MAX_CHARS_PER_PDF = 1_500_000 # safety cap for huge PDFs # ----------------------------- # Utilities # ----------------------------- def sha1_file(path: str) -> str: h = hashlib.sha1() with open(path, "rb") as f: for block in iter(lambda: f.read(1024 * 1024), b""): h.update(block) return h.hexdigest() def clean_text(t: str) -> str: t = t.replace("\x00", " ") t = re.sub(r"\s+", " ", t) return t.strip() def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: if chunk_size <= 0: return [text] if overlap >= chunk_size: overlap = max(0, chunk_size // 4) chunks = [] start = 0 n = len(text) while start < n: end = min(n, start + chunk_size) chunk = text[start:end].strip() if chunk: chunks.append(chunk) if end == n: break start = max(0, end - overlap) return chunks def extract_pdf_text_by_page(pdf_path: str) -> List[Tuple[int, str]]: """Returns [(page_index_1based, text), ...]""" reader = PdfReader(pdf_path) out = [] for i, page in enumerate(reader.pages, start=1): try: txt = page.extract_text() or "" except Exception: txt = "" txt = clean_text(txt) if txt: out.append((i, txt)) return out # ----------------------------- # Vector DB + Embeddings (PyTorch) # ----------------------------- _lock = threading.Lock() _device = "cuda" if torch.cuda.is_available() else "cpu" _model = SentenceTransformer(EMBED_MODEL_NAME, device=_device) _model.eval() _client = chromadb.PersistentClient(path=DB_DIR) # Use cosine space for more intuitive similarity _collection = _client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) def embed_texts(texts: List[str], batch_size: int = 32) -> np.ndarray: """ Returns embeddings as float32 numpy array of shape (N, D). SentenceTransformer runs on PyTorch under the hood. """ with torch.inference_mode(): emb = _model.encode( texts, batch_size=batch_size, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True, # good for cosine ) return emb.astype(np.float32) def add_pdf_to_db( pdf_path: str, chunk_size: int, chunk_overlap: int, ) -> Dict[str, int]: """ Extracts text, chunks it, embeds chunks, and adds to Chroma. Returns stats dict. """ file_hash = sha1_file(pdf_path) file_name = os.path.basename(pdf_path) pages = extract_pdf_text_by_page(pdf_path) if not pages: return {"added": 0, "skipped_pages": 0, "pages": 0} docs = [] metadatas = [] ids = [] total_chars = 0 for page_num, page_text in pages: total_chars += len(page_text) if total_chars > MAX_CHARS_PER_PDF: break chunks = chunk_text(page_text, chunk_size, chunk_overlap) for j, ch in enumerate(chunks): # Stable chunk id chunk_id = f"{file_hash}_p{page_num}_c{j}" ids.append(chunk_id) docs.append(ch) metadatas.append( { "source_file": file_name, "source_sha1": file_hash, "page": page_num, "chunk": j, } ) if not docs: return {"added": 0, "skipped_pages": len(pages), "pages": len(pages)} embs = embed_texts(docs) with _lock: # Upsert behavior: Chroma doesn't have true upsert everywhere; # we add and ignore duplicates by pre-checking existing ids. # For simplicity: try add; if fails, delete and re-add. try: _collection.add( ids=ids, documents=docs, metadatas=metadatas, embeddings=embs.tolist(), ) except Exception: # If duplicates exist, delete them and retry. try: _collection.delete(ids=ids) except Exception: pass _collection.add( ids=ids, documents=docs, metadatas=metadatas, embeddings=embs.tolist(), ) return {"added": len(docs), "pages": len(pages), "skipped_pages": 0} def db_stats() -> str: try: count = _collection.count() except Exception: count = 0 return f"**Collection:** `{COLLECTION_NAME}` \n**Stored chunks:** `{count}` \n**DB dir:** `{os.path.abspath(DB_DIR)}` \n**Embed model:** `{EMBED_MODEL_NAME}` \n**Device:** `{_device}`" def clear_db() -> str: with _lock: _client.delete_collection(COLLECTION_NAME) global _collection _collection = _client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) return "✅ Cleared the vector database." def search_db(query: str, top_k: int = 5) -> Tuple[str, str]: query = (query or "").strip() if not query: return "Please enter a query.", "" with _lock: n = _collection.count() if n == 0: return "Your database is empty. Upload and index PDFs first.", "" q_emb = embed_texts([query])[0].tolist() with _lock: res = _collection.query( query_embeddings=[q_emb], n_results=int(top_k), include=["documents", "metadatas", "distances"], ) docs = res.get("documents", [[]])[0] metas = res.get("metadatas", [[]])[0] dists = res.get("distances", [[]])[0] if not docs: return "No results found.", "" # Build a “response” plus a detailed results view # For cosine: distance ~ (1 - cosine_similarity) blocks = [] for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists), start=1): sim = 1.0 - float(dist) if dist is not None else None src = meta.get("source_file", "unknown") page = meta.get("page", "?") chunk = meta.get("chunk", "?") sim_str = f"{sim:.3f}" if sim is not None else "?" blocks.append( f"### Result {i} (similarity: **{sim_str}**)\n" f"- **Source:** `{src}` (page {page}, chunk {chunk})\n\n" f"{doc}\n" ) results_md = "\n---\n".join(blocks) # “Response” field: concise summary of what was found response = ( f"Found **{len(docs)}** matching passages. The most relevant content appears to be from " f"`{metas[0].get('source_file','unknown')}` page {metas[0].get('page','?')}. " f"See the results below for the exact extracted passages." ) return response, results_md # ----------------------------- # Gradio UI # ----------------------------- def index_pdfs(files: Optional[List[gr.File]], chunk_size: int, chunk_overlap: int) -> Tuple[str, str]: if not files: return "Please upload one or more PDFs.", db_stats() added_total = 0 msgs = [] for f in files: path = f.name if hasattr(f, "name") else str(f) if not path.lower().endswith(".pdf"): msgs.append(f"⚠️ Skipped non-PDF: {os.path.basename(path)}") continue try: stats = add_pdf_to_db(path, int(chunk_size), int(chunk_overlap)) added_total += stats["added"] if stats["added"] == 0: msgs.append(f"⚠️ No extractable text in: {os.path.basename(path)} (may be scanned/image-only).") else: msgs.append(f"✅ Indexed {os.path.basename(path)}: added {stats['added']} chunks.") except Exception as e: msgs.append(f"❌ Failed {os.path.basename(path)}: {e}") msgs.append(f"\n**Total chunks added:** `{added_total}`") return "\n".join(msgs), db_stats() with gr.Blocks(title="PDF Vector Search (ChromaDB + PyTorch)") as demo: gr.Markdown("# 📄🔎 PDF Vector Search (ChromaDB + PyTorch Embeddings)") gr.Markdown( "Drag PDFs into the uploader, click **Index PDFs**, then ask questions in the **Query** box.\n\n" "**Note:** If a PDF is scanned (images only), text extraction may return nothing." ) with gr.Row(): with gr.Column(scale=2): uploader = gr.Files(label="Upload PDFs (drag & drop)", file_types=[".pdf"]) chunk_size = gr.Slider(300, 2500, value=DEFAULT_CHUNK_SIZE, step=50, label="Chunk size (characters)") chunk_overlap = gr.Slider(0, 800, value=DEFAULT_CHUNK_OVERLAP, step=25, label="Chunk overlap (characters)") with gr.Row(): btn_index = gr.Button("Index PDFs", variant="primary") btn_clear = gr.Button("Clear DB", variant="stop") index_status = gr.Markdown() with gr.Column(scale=1): stats_box = gr.Markdown(db_stats()) gr.Markdown("## Ask a question") with gr.Row(): query_in = gr.Textbox(label="Query", placeholder="Type your question (e.g., 'What is the main conclusion?')") top_k = gr.Slider(1, 12, value=5, step=1, label="Top K results") btn_search = gr.Button("Search", variant="primary") response_out = gr.Textbox(label="Response", lines=2) results_out = gr.Markdown(label="Results") btn_index.click( fn=index_pdfs, inputs=[uploader, chunk_size, chunk_overlap], outputs=[index_status, stats_box], ) btn_clear.click( fn=lambda: (clear_db(), db_stats()), inputs=[], outputs=[index_status, stats_box], ) btn_search.click( fn=search_db, inputs=[query_in, top_k], outputs=[response_out, results_out], ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))