eaglelandsonce's picture
Create app.py
e03905a verified
import os
import re
import hashlib
import threading
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import gradio as gr
import chromadb
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
# -----------------------------
# Config
# -----------------------------
DB_DIR = os.environ.get("CHROMA_DB_DIR", "./chroma_db")
COLLECTION_NAME = os.environ.get("CHROMA_COLLECTION", "pdf_docs")
EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
DEFAULT_CHUNK_SIZE = 1200 # characters
DEFAULT_CHUNK_OVERLAP = 200 # characters
MAX_CHARS_PER_PDF = 1_500_000 # safety cap for huge PDFs
# -----------------------------
# Utilities
# -----------------------------
def sha1_file(path: str) -> str:
h = hashlib.sha1()
with open(path, "rb") as f:
for block in iter(lambda: f.read(1024 * 1024), b""):
h.update(block)
return h.hexdigest()
def clean_text(t: str) -> str:
t = t.replace("\x00", " ")
t = re.sub(r"\s+", " ", t)
return t.strip()
def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
if chunk_size <= 0:
return [text]
if overlap >= chunk_size:
overlap = max(0, chunk_size // 4)
chunks = []
start = 0
n = len(text)
while start < n:
end = min(n, start + chunk_size)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
if end == n:
break
start = max(0, end - overlap)
return chunks
def extract_pdf_text_by_page(pdf_path: str) -> List[Tuple[int, str]]:
"""Returns [(page_index_1based, text), ...]"""
reader = PdfReader(pdf_path)
out = []
for i, page in enumerate(reader.pages, start=1):
try:
txt = page.extract_text() or ""
except Exception:
txt = ""
txt = clean_text(txt)
if txt:
out.append((i, txt))
return out
# -----------------------------
# Vector DB + Embeddings (PyTorch)
# -----------------------------
_lock = threading.Lock()
_device = "cuda" if torch.cuda.is_available() else "cpu"
_model = SentenceTransformer(EMBED_MODEL_NAME, device=_device)
_model.eval()
_client = chromadb.PersistentClient(path=DB_DIR)
# Use cosine space for more intuitive similarity
_collection = _client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def embed_texts(texts: List[str], batch_size: int = 32) -> np.ndarray:
"""
Returns embeddings as float32 numpy array of shape (N, D).
SentenceTransformer runs on PyTorch under the hood.
"""
with torch.inference_mode():
emb = _model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=True, # good for cosine
)
return emb.astype(np.float32)
def add_pdf_to_db(
pdf_path: str,
chunk_size: int,
chunk_overlap: int,
) -> Dict[str, int]:
"""
Extracts text, chunks it, embeds chunks, and adds to Chroma.
Returns stats dict.
"""
file_hash = sha1_file(pdf_path)
file_name = os.path.basename(pdf_path)
pages = extract_pdf_text_by_page(pdf_path)
if not pages:
return {"added": 0, "skipped_pages": 0, "pages": 0}
docs = []
metadatas = []
ids = []
total_chars = 0
for page_num, page_text in pages:
total_chars += len(page_text)
if total_chars > MAX_CHARS_PER_PDF:
break
chunks = chunk_text(page_text, chunk_size, chunk_overlap)
for j, ch in enumerate(chunks):
# Stable chunk id
chunk_id = f"{file_hash}_p{page_num}_c{j}"
ids.append(chunk_id)
docs.append(ch)
metadatas.append(
{
"source_file": file_name,
"source_sha1": file_hash,
"page": page_num,
"chunk": j,
}
)
if not docs:
return {"added": 0, "skipped_pages": len(pages), "pages": len(pages)}
embs = embed_texts(docs)
with _lock:
# Upsert behavior: Chroma doesn't have true upsert everywhere;
# we add and ignore duplicates by pre-checking existing ids.
# For simplicity: try add; if fails, delete and re-add.
try:
_collection.add(
ids=ids,
documents=docs,
metadatas=metadatas,
embeddings=embs.tolist(),
)
except Exception:
# If duplicates exist, delete them and retry.
try:
_collection.delete(ids=ids)
except Exception:
pass
_collection.add(
ids=ids,
documents=docs,
metadatas=metadatas,
embeddings=embs.tolist(),
)
return {"added": len(docs), "pages": len(pages), "skipped_pages": 0}
def db_stats() -> str:
try:
count = _collection.count()
except Exception:
count = 0
return f"**Collection:** `{COLLECTION_NAME}` \n**Stored chunks:** `{count}` \n**DB dir:** `{os.path.abspath(DB_DIR)}` \n**Embed model:** `{EMBED_MODEL_NAME}` \n**Device:** `{_device}`"
def clear_db() -> str:
with _lock:
_client.delete_collection(COLLECTION_NAME)
global _collection
_collection = _client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
return "✅ Cleared the vector database."
def search_db(query: str, top_k: int = 5) -> Tuple[str, str]:
query = (query or "").strip()
if not query:
return "Please enter a query.", ""
with _lock:
n = _collection.count()
if n == 0:
return "Your database is empty. Upload and index PDFs first.", ""
q_emb = embed_texts([query])[0].tolist()
with _lock:
res = _collection.query(
query_embeddings=[q_emb],
n_results=int(top_k),
include=["documents", "metadatas", "distances"],
)
docs = res.get("documents", [[]])[0]
metas = res.get("metadatas", [[]])[0]
dists = res.get("distances", [[]])[0]
if not docs:
return "No results found.", ""
# Build a “response” plus a detailed results view
# For cosine: distance ~ (1 - cosine_similarity)
blocks = []
for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists), start=1):
sim = 1.0 - float(dist) if dist is not None else None
src = meta.get("source_file", "unknown")
page = meta.get("page", "?")
chunk = meta.get("chunk", "?")
sim_str = f"{sim:.3f}" if sim is not None else "?"
blocks.append(
f"### Result {i} (similarity: **{sim_str}**)\n"
f"- **Source:** `{src}` (page {page}, chunk {chunk})\n\n"
f"{doc}\n"
)
results_md = "\n---\n".join(blocks)
# “Response” field: concise summary of what was found
response = (
f"Found **{len(docs)}** matching passages. The most relevant content appears to be from "
f"`{metas[0].get('source_file','unknown')}` page {metas[0].get('page','?')}. "
f"See the results below for the exact extracted passages."
)
return response, results_md
# -----------------------------
# Gradio UI
# -----------------------------
def index_pdfs(files: Optional[List[gr.File]], chunk_size: int, chunk_overlap: int) -> Tuple[str, str]:
if not files:
return "Please upload one or more PDFs.", db_stats()
added_total = 0
msgs = []
for f in files:
path = f.name if hasattr(f, "name") else str(f)
if not path.lower().endswith(".pdf"):
msgs.append(f"⚠️ Skipped non-PDF: {os.path.basename(path)}")
continue
try:
stats = add_pdf_to_db(path, int(chunk_size), int(chunk_overlap))
added_total += stats["added"]
if stats["added"] == 0:
msgs.append(f"⚠️ No extractable text in: {os.path.basename(path)} (may be scanned/image-only).")
else:
msgs.append(f"✅ Indexed {os.path.basename(path)}: added {stats['added']} chunks.")
except Exception as e:
msgs.append(f"❌ Failed {os.path.basename(path)}: {e}")
msgs.append(f"\n**Total chunks added:** `{added_total}`")
return "\n".join(msgs), db_stats()
with gr.Blocks(title="PDF Vector Search (ChromaDB + PyTorch)") as demo:
gr.Markdown("# 📄🔎 PDF Vector Search (ChromaDB + PyTorch Embeddings)")
gr.Markdown(
"Drag PDFs into the uploader, click **Index PDFs**, then ask questions in the **Query** box.\n\n"
"**Note:** If a PDF is scanned (images only), text extraction may return nothing."
)
with gr.Row():
with gr.Column(scale=2):
uploader = gr.Files(label="Upload PDFs (drag & drop)", file_types=[".pdf"])
chunk_size = gr.Slider(300, 2500, value=DEFAULT_CHUNK_SIZE, step=50, label="Chunk size (characters)")
chunk_overlap = gr.Slider(0, 800, value=DEFAULT_CHUNK_OVERLAP, step=25, label="Chunk overlap (characters)")
with gr.Row():
btn_index = gr.Button("Index PDFs", variant="primary")
btn_clear = gr.Button("Clear DB", variant="stop")
index_status = gr.Markdown()
with gr.Column(scale=1):
stats_box = gr.Markdown(db_stats())
gr.Markdown("## Ask a question")
with gr.Row():
query_in = gr.Textbox(label="Query", placeholder="Type your question (e.g., 'What is the main conclusion?')")
top_k = gr.Slider(1, 12, value=5, step=1, label="Top K results")
btn_search = gr.Button("Search", variant="primary")
response_out = gr.Textbox(label="Response", lines=2)
results_out = gr.Markdown(label="Results")
btn_index.click(
fn=index_pdfs,
inputs=[uploader, chunk_size, chunk_overlap],
outputs=[index_status, stats_box],
)
btn_clear.click(
fn=lambda: (clear_db(), db_stats()),
inputs=[],
outputs=[index_status, stats_box],
)
btn_search.click(
fn=search_db,
inputs=[query_in, top_k],
outputs=[response_out, results_out],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))