Spaces:
Running
Running
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "lancedb>=0.17", | |
| # "fastapi>=0.115", | |
| # "uvicorn[standard]>=0.32", | |
| # "jinja2>=3.1", | |
| # "sentence-transformers", | |
| # "pillow", | |
| # "huggingface-hub", | |
| # ] | |
| # /// | |
| """ | |
| BPL Card Catalog — OCR Search Comparison | |
| FastAPI + HTMX app comparing old (Tesseract) and new (GLM-OCR) search | |
| results side by side. Clean, Tufte-inspired design. | |
| Usage (local dev): | |
| uv run app.py --db-path ../bpl-lance-db | |
| Usage (HF Spaces / Hub dataset): | |
| uv run app.py --from-hub davanstrien/bpl-card-catalog-lance | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import io | |
| import os | |
| import random | |
| from pathlib import Path | |
| import lancedb | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from sentence_transformers import SentenceTransformer | |
| DEFAULT_HUB_REPO = "davanstrien/bpl-card-catalog-lance-full" | |
| DB_PATH = "../bpl-lance-db" | |
| TABLE_NAME = "cards" | |
| EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5" | |
| SPOTLIGHT_COUNT = 3 | |
| OLD_OCR_LABEL = "Tesseract" | |
| NEW_OCR_LABEL = "VLM OCR" | |
| APP_DIR = Path(__file__).parent | |
| TEMPLATES_DIR = APP_DIR / "templates" | |
| STATIC_DIR = APP_DIR / "static" | |
| EXAMPLE_QUERIES = [ | |
| "abolitionism", | |
| "Civil War letters", | |
| "Shakespeare plays", | |
| "Boston history", | |
| "illuminated manuscripts", | |
| "African American history", | |
| "French literature", | |
| "music composition", | |
| "botanical illustrations", | |
| "theater history", | |
| ] | |
| SELECT_COLS = [ | |
| "drawer_id", | |
| "card_number", | |
| "text", | |
| "markdown", | |
| "source_url", | |
| "image", | |
| ] | |
| def truncate(text: str, n: int = 800) -> str: | |
| if not text: | |
| return "(empty)" | |
| return text[:n] + ("\u2026" if len(text) > n else "") | |
| def parse_drawer_id(drawer_id: str) -> tuple[str, str]: | |
| """'145-great-britain-acts' -> ('145', 'Great Britain Acts')""" | |
| parts = drawer_id.split("-", 1) | |
| num = parts[0] | |
| label = ( | |
| parts[1].replace("-", " ").replace(".", " ").strip().title() | |
| if len(parts) > 1 | |
| else "" | |
| ) | |
| return num, label | |
| def create_app(db_path: str = DB_PATH) -> FastAPI: | |
| app = FastAPI(title="BPL Card Catalog — OCR Search Comparison") | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |
| templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) | |
| db = lancedb.connect(db_path) | |
| table = db.open_table(TABLE_NAME) | |
| model = SentenceTransformer(EMBEDDING_MODEL) | |
| total_cards = table.count_rows() | |
| # -- Build drawer index at startup -- | |
| all_rows = table.search().select(["drawer_id"]).limit(total_cards).to_list() | |
| drawer_counts: dict[str, int] = {} | |
| for row in all_rows: | |
| did = row["drawer_id"] | |
| drawer_counts[did] = drawer_counts.get(did, 0) + 1 | |
| drawer_list: list[dict] = [] | |
| for did, count in drawer_counts.items(): | |
| num, label = parse_drawer_id(did) | |
| drawer_list.append( | |
| {"drawer_id": did, "drawer_num": num, "drawer_label": label, "count": count} | |
| ) | |
| drawer_list.sort( | |
| key=lambda d: ( | |
| int(d["drawer_num"]) if d["drawer_num"].isdigit() else 9999, | |
| d["drawer_id"], | |
| ) | |
| ) | |
| known_drawer_ids = {d["drawer_id"] for d in drawer_list} | |
| # -- Image cache (row_idx -> JPEG bytes) -- | |
| image_cache: dict[int, bytes] = {} | |
| def _get_image_bytes(row_idx: int) -> bytes | None: | |
| if row_idx in image_cache: | |
| return image_cache[row_idx] | |
| rows = ( | |
| table.search() | |
| .where(f"_rowid = {row_idx}") | |
| .select(["image"]) | |
| .limit(1) | |
| .to_list() | |
| ) | |
| if not rows: | |
| return None | |
| image_cache[row_idx] = rows[0]["image"] | |
| return rows[0]["image"] | |
| # -- Search functions -- | |
| def search_old_vector(query: str, limit: int) -> list[dict]: | |
| q_vec = model.encode(query, normalize_embeddings=True).tolist() | |
| return ( | |
| table.search(q_vec, vector_column_name="old_ocr_embedding") | |
| .select(SELECT_COLS) | |
| .limit(limit) | |
| .to_list() | |
| ) | |
| def search_new_vector(query: str, limit: int) -> list[dict]: | |
| q_vec = model.encode(query, normalize_embeddings=True).tolist() | |
| return ( | |
| table.search(q_vec, vector_column_name="new_ocr_embedding") | |
| .select(SELECT_COLS) | |
| .limit(limit) | |
| .to_list() | |
| ) | |
| def search_old_fts(query: str, limit: int) -> list[dict]: | |
| return ( | |
| table.search(query, query_type="fts", fts_columns="text") | |
| .select(SELECT_COLS) | |
| .limit(limit) | |
| .to_list() | |
| ) | |
| def search_new_fts(query: str, limit: int) -> list[dict]: | |
| return ( | |
| table.search(query, query_type="fts", fts_columns="markdown") | |
| .select(SELECT_COLS) | |
| .limit(limit) | |
| .to_list() | |
| ) | |
| def format_results( | |
| results: list[dict], | |
| ocr_field: str, | |
| other_field: str, | |
| ocr_label: str, | |
| compare_label: str, | |
| ) -> list[dict]: | |
| formatted = [] | |
| for i, row in enumerate(results): | |
| score = ( | |
| row.get("_distance") or row.get("_score") or row.get("_relevance_score") | |
| ) | |
| # Cache image bytes using a simple incrementing key | |
| row_idx = len(image_cache) | |
| if isinstance(row.get("image"), bytes): | |
| image_cache[row_idx] = row["image"] | |
| drawer_id = row.get("drawer_id", "?") | |
| drawer_num, drawer_label = parse_drawer_id(drawer_id) | |
| formatted.append( | |
| { | |
| "rank": i + 1, | |
| "row_idx": row_idx, | |
| "drawer_id": drawer_id, | |
| "drawer_num": drawer_num, | |
| "drawer_label": drawer_label, | |
| "card_number": row.get("card_number", "?"), | |
| "ocr_text": truncate(row.get(ocr_field, ""), 800), | |
| "other_ocr": truncate(row.get(other_field, ""), 800), | |
| "ocr_label": ocr_label, | |
| "compare_label": compare_label, | |
| "score": f"{score:.4f}" if score is not None else "", | |
| "source_url": row.get("source_url", ""), | |
| } | |
| ) | |
| return formatted | |
| # -- Routes -- | |
| async def index(request: Request): | |
| return templates.TemplateResponse( | |
| request, | |
| "index.html", | |
| { | |
| "total_cards": f"{total_cards:,}", | |
| "total_drawers": len(drawer_list), | |
| "examples": EXAMPLE_QUERIES, | |
| "query": "", | |
| "mode": "fts", | |
| "limit": 10, | |
| "old_ocr_label": OLD_OCR_LABEL, | |
| "new_ocr_label": NEW_OCR_LABEL, | |
| }, | |
| ) | |
| async def search( | |
| request: Request, | |
| query: str = "", | |
| mode: str = "fts", | |
| limit: int = 5, | |
| ): | |
| limit = max(1, min(20, limit)) | |
| if not query.strip(): | |
| return templates.TemplateResponse( | |
| request, | |
| "results.html", | |
| {"query": "", "total_cards": f"{total_cards:,}"}, | |
| ) | |
| if mode == "fts": | |
| old_raw = search_old_fts(query, limit) | |
| new_raw = search_new_fts(query, limit) | |
| else: | |
| old_raw = search_old_vector(query, limit) | |
| new_raw = search_new_vector(query, limit) | |
| old_label = f"Old OCR ({OLD_OCR_LABEL})" | |
| new_label = f"New OCR ({NEW_OCR_LABEL})" | |
| old_results = format_results( | |
| old_raw, | |
| "text", | |
| "markdown", | |
| old_label, | |
| new_label, | |
| ) | |
| new_results = format_results( | |
| new_raw, | |
| "markdown", | |
| "text", | |
| new_label, | |
| old_label, | |
| ) | |
| return templates.TemplateResponse( | |
| request, | |
| "results.html", | |
| { | |
| "query": query, | |
| "mode": mode, | |
| "old_results": old_results, | |
| "new_results": new_results, | |
| "total_cards": f"{total_cards:,}", | |
| "old_ocr_label": OLD_OCR_LABEL, | |
| "new_ocr_label": NEW_OCR_LABEL, | |
| }, | |
| ) | |
| async def search_single( | |
| request: Request, | |
| query: str = "", | |
| mode: str = "fts", | |
| limit: int = 10, | |
| ): | |
| limit = max(1, min(20, limit)) | |
| if not query.strip(): | |
| return templates.TemplateResponse( | |
| request, | |
| "results-search.html", | |
| {"query": "", "total_cards": f"{total_cards:,}"}, | |
| ) | |
| if mode == "fts": | |
| raw = search_new_fts(query, limit) | |
| else: | |
| raw = search_new_vector(query, limit) | |
| new_label = f"New OCR ({NEW_OCR_LABEL})" | |
| old_label = f"Old OCR ({OLD_OCR_LABEL})" | |
| results = format_results(raw, "markdown", "text", new_label, old_label) | |
| return templates.TemplateResponse( | |
| request, | |
| "results-search.html", | |
| { | |
| "query": query, | |
| "mode": mode, | |
| "results": results, | |
| "total_cards": f"{total_cards:,}", | |
| }, | |
| ) | |
| async def random_cards(request: Request): | |
| indices = random.sample( | |
| range(total_cards), min(SPOTLIGHT_COUNT, total_cards) | |
| ) | |
| cards = [] | |
| for idx in indices: | |
| rows = ( | |
| table.search().select(SELECT_COLS).limit(1).offset(idx).to_list() | |
| ) | |
| if not rows: | |
| continue | |
| row = rows[0] | |
| row_idx = len(image_cache) | |
| if isinstance(row.get("image"), bytes): | |
| image_cache[row_idx] = row["image"] | |
| drawer_id = row.get("drawer_id", "?") | |
| drawer_num, drawer_label = parse_drawer_id(drawer_id) | |
| cards.append( | |
| { | |
| "row_idx": row_idx, | |
| "drawer_id": drawer_id, | |
| "drawer_num": drawer_num, | |
| "drawer_label": drawer_label, | |
| "card_number": row.get("card_number", "?"), | |
| "ocr_text": truncate(row.get("markdown", ""), 200), | |
| "source_url": row.get("source_url", ""), | |
| } | |
| ) | |
| return templates.TemplateResponse( | |
| request, "spotlight.html", {"cards": cards} | |
| ) | |
| async def drawers_index(request: Request): | |
| return templates.TemplateResponse( | |
| request, | |
| "drawers.html", | |
| { | |
| "drawers": drawer_list, | |
| "total_drawers": len(drawer_list), | |
| "total_cards": f"{total_cards:,}", | |
| }, | |
| ) | |
| async def drawer_detail(request: Request, drawer_id: str): | |
| # Validate against known drawer IDs to prevent injection | |
| if drawer_id not in known_drawer_ids: | |
| return HTMLResponse("Drawer not found", status_code=404) | |
| rows = ( | |
| table.search() | |
| .where(f"drawer_id = '{drawer_id}'", prefilter=True) | |
| .select(SELECT_COLS) | |
| .limit(2000) | |
| .to_list() | |
| ) | |
| rows.sort(key=lambda r: r.get("card_number", 0)) | |
| cards = [] | |
| for i, row in enumerate(rows): | |
| row_idx = len(image_cache) | |
| if isinstance(row.get("image"), bytes): | |
| image_cache[row_idx] = row["image"] | |
| cards.append( | |
| { | |
| "card_number": row.get("card_number", i), | |
| "row_idx": row_idx, | |
| "ocr_text": truncate(row.get("markdown", ""), 800), | |
| "source_url": row.get("source_url", ""), | |
| } | |
| ) | |
| # Prev/next drawer navigation | |
| idx = next( | |
| (i for i, d in enumerate(drawer_list) if d["drawer_id"] == drawer_id), -1 | |
| ) | |
| prev_drawer = drawer_list[idx - 1] if idx > 0 else None | |
| next_drawer = drawer_list[idx + 1] if idx < len(drawer_list) - 1 else None | |
| drawer_num, drawer_label = parse_drawer_id(drawer_id) | |
| return templates.TemplateResponse( | |
| request, | |
| "drawer.html", | |
| { | |
| "drawer_id": drawer_id, | |
| "drawer_num": drawer_num, | |
| "drawer_label": drawer_label, | |
| "cards": cards, | |
| "card_count": len(cards), | |
| "prev_drawer": prev_drawer, | |
| "next_drawer": next_drawer, | |
| }, | |
| ) | |
| async def image(row_idx: int): | |
| img_bytes = image_cache.get(row_idx) | |
| if img_bytes is None: | |
| return HTMLResponse("Image not found", status_code=404) | |
| buf = io.BytesIO() | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| img.save(buf, format="JPEG", quality=85) | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/jpeg") | |
| return app | |
| def resolve_db_path(args) -> str: | |
| """Resolve database path from CLI args, env var, or HF Hub download.""" | |
| # Explicit local path takes priority | |
| if args.db_path: | |
| db_path = Path(args.db_path) | |
| if not db_path.exists(): | |
| print(f"Database not found at {db_path}") | |
| print("Run 'uv run bpl-lance-poc.py build' first.") | |
| raise SystemExit(1) | |
| return str(db_path) | |
| # Download from HF Hub | |
| repo_id = args.from_hub or os.environ.get("BPL_HUB_REPO", DEFAULT_HUB_REPO) | |
| cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| local_dir = f"{cache_base}/bpl-lance" | |
| print(f"Downloading dataset from {repo_id} to {local_dir}...") | |
| snapshot_download(repo_id, repo_type="dataset", local_dir=local_dir) | |
| print("Download complete.") | |
| return local_dir | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="BPL OCR search comparison (FastAPI + HTMX)" | |
| ) | |
| source = parser.add_mutually_exclusive_group() | |
| source.add_argument( | |
| "--db-path", | |
| default=None, | |
| help="Path to local LanceDB directory (for local dev)", | |
| ) | |
| source.add_argument( | |
| "--from-hub", | |
| nargs="?", | |
| const=DEFAULT_HUB_REPO, | |
| default=None, | |
| help=f"Download Lance DB from HF Hub (default: {DEFAULT_HUB_REPO})", | |
| ) | |
| parser.add_argument("--host", default="127.0.0.1") | |
| parser.add_argument("--port", type=int, default=8000) | |
| args = parser.parse_args() | |
| # If neither --db-path nor --from-hub, try default local path | |
| if args.db_path is None and args.from_hub is None: | |
| default = Path(DB_PATH) | |
| if default.exists(): | |
| args.db_path = DB_PATH | |
| else: | |
| args.from_hub = DEFAULT_HUB_REPO | |
| db_path = resolve_db_path(args) | |
| app = create_app(db_path) | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| if __name__ == "__main__": | |
| main() | |