# /// script # requires-python = ">=3.11" # dependencies = [ # "lancedb>=0.17", # "fastapi>=0.115", # "uvicorn[standard]>=0.32", # "jinja2>=3.1", # "sentence-transformers", # "pillow", # "huggingface-hub", # ] # /// """ BPL Card Catalog — OCR Search Comparison FastAPI + HTMX app comparing old (Tesseract) and new (GLM-OCR) search results side by side. Clean, Tufte-inspired design. Usage (local dev): uv run app.py --db-path ../bpl-lance-db Usage (HF Spaces / Hub dataset): uv run app.py --from-hub davanstrien/bpl-card-catalog-lance """ from __future__ import annotations import argparse import io import os import random from pathlib import Path import lancedb import uvicorn from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from huggingface_hub import snapshot_download from PIL import Image from sentence_transformers import SentenceTransformer DEFAULT_HUB_REPO = "davanstrien/bpl-card-catalog-lance-full" DB_PATH = "../bpl-lance-db" TABLE_NAME = "cards" EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5" SPOTLIGHT_COUNT = 3 OLD_OCR_LABEL = "Tesseract" NEW_OCR_LABEL = "VLM OCR" APP_DIR = Path(__file__).parent TEMPLATES_DIR = APP_DIR / "templates" STATIC_DIR = APP_DIR / "static" EXAMPLE_QUERIES = [ "abolitionism", "Civil War letters", "Shakespeare plays", "Boston history", "illuminated manuscripts", "African American history", "French literature", "music composition", "botanical illustrations", "theater history", ] SELECT_COLS = [ "drawer_id", "card_number", "text", "markdown", "source_url", "image", ] def truncate(text: str, n: int = 800) -> str: if not text: return "(empty)" return text[:n] + ("\u2026" if len(text) > n else "") def parse_drawer_id(drawer_id: str) -> tuple[str, str]: """'145-great-britain-acts' -> ('145', 'Great Britain Acts')""" parts = drawer_id.split("-", 1) num = parts[0] label = ( parts[1].replace("-", " ").replace(".", " ").strip().title() if len(parts) > 1 else "" ) return num, label def create_app(db_path: str = DB_PATH) -> FastAPI: app = FastAPI(title="BPL Card Catalog — OCR Search Comparison") app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) db = lancedb.connect(db_path) table = db.open_table(TABLE_NAME) model = SentenceTransformer(EMBEDDING_MODEL) total_cards = table.count_rows() # -- Build drawer index at startup -- all_rows = table.search().select(["drawer_id"]).limit(total_cards).to_list() drawer_counts: dict[str, int] = {} for row in all_rows: did = row["drawer_id"] drawer_counts[did] = drawer_counts.get(did, 0) + 1 drawer_list: list[dict] = [] for did, count in drawer_counts.items(): num, label = parse_drawer_id(did) drawer_list.append( {"drawer_id": did, "drawer_num": num, "drawer_label": label, "count": count} ) drawer_list.sort( key=lambda d: ( int(d["drawer_num"]) if d["drawer_num"].isdigit() else 9999, d["drawer_id"], ) ) known_drawer_ids = {d["drawer_id"] for d in drawer_list} # -- Image cache (row_idx -> JPEG bytes) -- image_cache: dict[int, bytes] = {} def _get_image_bytes(row_idx: int) -> bytes | None: if row_idx in image_cache: return image_cache[row_idx] rows = ( table.search() .where(f"_rowid = {row_idx}") .select(["image"]) .limit(1) .to_list() ) if not rows: return None image_cache[row_idx] = rows[0]["image"] return rows[0]["image"] # -- Search functions -- def search_old_vector(query: str, limit: int) -> list[dict]: q_vec = model.encode(query, normalize_embeddings=True).tolist() return ( table.search(q_vec, vector_column_name="old_ocr_embedding") .select(SELECT_COLS) .limit(limit) .to_list() ) def search_new_vector(query: str, limit: int) -> list[dict]: q_vec = model.encode(query, normalize_embeddings=True).tolist() return ( table.search(q_vec, vector_column_name="new_ocr_embedding") .select(SELECT_COLS) .limit(limit) .to_list() ) def search_old_fts(query: str, limit: int) -> list[dict]: return ( table.search(query, query_type="fts", fts_columns="text") .select(SELECT_COLS) .limit(limit) .to_list() ) def search_new_fts(query: str, limit: int) -> list[dict]: return ( table.search(query, query_type="fts", fts_columns="markdown") .select(SELECT_COLS) .limit(limit) .to_list() ) def format_results( results: list[dict], ocr_field: str, other_field: str, ocr_label: str, compare_label: str, ) -> list[dict]: formatted = [] for i, row in enumerate(results): score = ( row.get("_distance") or row.get("_score") or row.get("_relevance_score") ) # Cache image bytes using a simple incrementing key row_idx = len(image_cache) if isinstance(row.get("image"), bytes): image_cache[row_idx] = row["image"] drawer_id = row.get("drawer_id", "?") drawer_num, drawer_label = parse_drawer_id(drawer_id) formatted.append( { "rank": i + 1, "row_idx": row_idx, "drawer_id": drawer_id, "drawer_num": drawer_num, "drawer_label": drawer_label, "card_number": row.get("card_number", "?"), "ocr_text": truncate(row.get(ocr_field, ""), 800), "other_ocr": truncate(row.get(other_field, ""), 800), "ocr_label": ocr_label, "compare_label": compare_label, "score": f"{score:.4f}" if score is not None else "", "source_url": row.get("source_url", ""), } ) return formatted # -- Routes -- @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse( request, "index.html", { "total_cards": f"{total_cards:,}", "total_drawers": len(drawer_list), "examples": EXAMPLE_QUERIES, "query": "", "mode": "fts", "limit": 10, "old_ocr_label": OLD_OCR_LABEL, "new_ocr_label": NEW_OCR_LABEL, }, ) @app.get("/search", response_class=HTMLResponse) async def search( request: Request, query: str = "", mode: str = "fts", limit: int = 5, ): limit = max(1, min(20, limit)) if not query.strip(): return templates.TemplateResponse( request, "results.html", {"query": "", "total_cards": f"{total_cards:,}"}, ) if mode == "fts": old_raw = search_old_fts(query, limit) new_raw = search_new_fts(query, limit) else: old_raw = search_old_vector(query, limit) new_raw = search_new_vector(query, limit) old_label = f"Old OCR ({OLD_OCR_LABEL})" new_label = f"New OCR ({NEW_OCR_LABEL})" old_results = format_results( old_raw, "text", "markdown", old_label, new_label, ) new_results = format_results( new_raw, "markdown", "text", new_label, old_label, ) return templates.TemplateResponse( request, "results.html", { "query": query, "mode": mode, "old_results": old_results, "new_results": new_results, "total_cards": f"{total_cards:,}", "old_ocr_label": OLD_OCR_LABEL, "new_ocr_label": NEW_OCR_LABEL, }, ) @app.get("/search-single", response_class=HTMLResponse) async def search_single( request: Request, query: str = "", mode: str = "fts", limit: int = 10, ): limit = max(1, min(20, limit)) if not query.strip(): return templates.TemplateResponse( request, "results-search.html", {"query": "", "total_cards": f"{total_cards:,}"}, ) if mode == "fts": raw = search_new_fts(query, limit) else: raw = search_new_vector(query, limit) new_label = f"New OCR ({NEW_OCR_LABEL})" old_label = f"Old OCR ({OLD_OCR_LABEL})" results = format_results(raw, "markdown", "text", new_label, old_label) return templates.TemplateResponse( request, "results-search.html", { "query": query, "mode": mode, "results": results, "total_cards": f"{total_cards:,}", }, ) @app.get("/random-cards", response_class=HTMLResponse) async def random_cards(request: Request): indices = random.sample( range(total_cards), min(SPOTLIGHT_COUNT, total_cards) ) cards = [] for idx in indices: rows = ( table.search().select(SELECT_COLS).limit(1).offset(idx).to_list() ) if not rows: continue row = rows[0] row_idx = len(image_cache) if isinstance(row.get("image"), bytes): image_cache[row_idx] = row["image"] drawer_id = row.get("drawer_id", "?") drawer_num, drawer_label = parse_drawer_id(drawer_id) cards.append( { "row_idx": row_idx, "drawer_id": drawer_id, "drawer_num": drawer_num, "drawer_label": drawer_label, "card_number": row.get("card_number", "?"), "ocr_text": truncate(row.get("markdown", ""), 200), "source_url": row.get("source_url", ""), } ) return templates.TemplateResponse( request, "spotlight.html", {"cards": cards} ) @app.get("/drawers", response_class=HTMLResponse) async def drawers_index(request: Request): return templates.TemplateResponse( request, "drawers.html", { "drawers": drawer_list, "total_drawers": len(drawer_list), "total_cards": f"{total_cards:,}", }, ) @app.get("/drawer/{drawer_id}", response_class=HTMLResponse) async def drawer_detail(request: Request, drawer_id: str): # Validate against known drawer IDs to prevent injection if drawer_id not in known_drawer_ids: return HTMLResponse("Drawer not found", status_code=404) rows = ( table.search() .where(f"drawer_id = '{drawer_id}'", prefilter=True) .select(SELECT_COLS) .limit(2000) .to_list() ) rows.sort(key=lambda r: r.get("card_number", 0)) cards = [] for i, row in enumerate(rows): row_idx = len(image_cache) if isinstance(row.get("image"), bytes): image_cache[row_idx] = row["image"] cards.append( { "card_number": row.get("card_number", i), "row_idx": row_idx, "ocr_text": truncate(row.get("markdown", ""), 800), "source_url": row.get("source_url", ""), } ) # Prev/next drawer navigation idx = next( (i for i, d in enumerate(drawer_list) if d["drawer_id"] == drawer_id), -1 ) prev_drawer = drawer_list[idx - 1] if idx > 0 else None next_drawer = drawer_list[idx + 1] if idx < len(drawer_list) - 1 else None drawer_num, drawer_label = parse_drawer_id(drawer_id) return templates.TemplateResponse( request, "drawer.html", { "drawer_id": drawer_id, "drawer_num": drawer_num, "drawer_label": drawer_label, "cards": cards, "card_count": len(cards), "prev_drawer": prev_drawer, "next_drawer": next_drawer, }, ) @app.get("/image/{row_idx}") async def image(row_idx: int): img_bytes = image_cache.get(row_idx) if img_bytes is None: return HTMLResponse("Image not found", status_code=404) buf = io.BytesIO() img = Image.open(io.BytesIO(img_bytes)) img.save(buf, format="JPEG", quality=85) buf.seek(0) return StreamingResponse(buf, media_type="image/jpeg") return app def resolve_db_path(args) -> str: """Resolve database path from CLI args, env var, or HF Hub download.""" # Explicit local path takes priority if args.db_path: db_path = Path(args.db_path) if not db_path.exists(): print(f"Database not found at {db_path}") print("Run 'uv run bpl-lance-poc.py build' first.") raise SystemExit(1) return str(db_path) # Download from HF Hub repo_id = args.from_hub or os.environ.get("BPL_HUB_REPO", DEFAULT_HUB_REPO) cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache") local_dir = f"{cache_base}/bpl-lance" print(f"Downloading dataset from {repo_id} to {local_dir}...") snapshot_download(repo_id, repo_type="dataset", local_dir=local_dir) print("Download complete.") return local_dir def main(): parser = argparse.ArgumentParser( description="BPL OCR search comparison (FastAPI + HTMX)" ) source = parser.add_mutually_exclusive_group() source.add_argument( "--db-path", default=None, help="Path to local LanceDB directory (for local dev)", ) source.add_argument( "--from-hub", nargs="?", const=DEFAULT_HUB_REPO, default=None, help=f"Download Lance DB from HF Hub (default: {DEFAULT_HUB_REPO})", ) parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() # If neither --db-path nor --from-hub, try default local path if args.db_path is None and args.from_hub is None: default = Path(DB_PATH) if default.exists(): args.db_path = DB_PATH else: args.from_hub = DEFAULT_HUB_REPO db_path = resolve_db_path(args) app = create_app(db_path) uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": main()