davanstrien's picture
davanstrien HF Staff
Upload app.py with huggingface_hub
a1a412f verified
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "lancedb>=0.17",
# "fastapi>=0.115",
# "uvicorn[standard]>=0.32",
# "jinja2>=3.1",
# "sentence-transformers",
# "pillow",
# "huggingface-hub",
# ]
# ///
"""
BPL Card Catalog — OCR Search Comparison
FastAPI + HTMX app comparing old (Tesseract) and new (GLM-OCR) search
results side by side. Clean, Tufte-inspired design.
Usage (local dev):
uv run app.py --db-path ../bpl-lance-db
Usage (HF Spaces / Hub dataset):
uv run app.py --from-hub davanstrien/bpl-card-catalog-lance
"""
from __future__ import annotations
import argparse
import io
import os
import random
from pathlib import Path
import lancedb
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import snapshot_download
from PIL import Image
from sentence_transformers import SentenceTransformer
DEFAULT_HUB_REPO = "davanstrien/bpl-card-catalog-lance-full"
DB_PATH = "../bpl-lance-db"
TABLE_NAME = "cards"
EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5"
SPOTLIGHT_COUNT = 3
OLD_OCR_LABEL = "Tesseract"
NEW_OCR_LABEL = "VLM OCR"
APP_DIR = Path(__file__).parent
TEMPLATES_DIR = APP_DIR / "templates"
STATIC_DIR = APP_DIR / "static"
EXAMPLE_QUERIES = [
"abolitionism",
"Civil War letters",
"Shakespeare plays",
"Boston history",
"illuminated manuscripts",
"African American history",
"French literature",
"music composition",
"botanical illustrations",
"theater history",
]
SELECT_COLS = [
"drawer_id",
"card_number",
"text",
"markdown",
"source_url",
"image",
]
def truncate(text: str, n: int = 800) -> str:
if not text:
return "(empty)"
return text[:n] + ("\u2026" if len(text) > n else "")
def parse_drawer_id(drawer_id: str) -> tuple[str, str]:
"""'145-great-britain-acts' -> ('145', 'Great Britain Acts')"""
parts = drawer_id.split("-", 1)
num = parts[0]
label = (
parts[1].replace("-", " ").replace(".", " ").strip().title()
if len(parts) > 1
else ""
)
return num, label
def create_app(db_path: str = DB_PATH) -> FastAPI:
app = FastAPI(title="BPL Card Catalog — OCR Search Comparison")
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
db = lancedb.connect(db_path)
table = db.open_table(TABLE_NAME)
model = SentenceTransformer(EMBEDDING_MODEL)
total_cards = table.count_rows()
# -- Build drawer index at startup --
all_rows = table.search().select(["drawer_id"]).limit(total_cards).to_list()
drawer_counts: dict[str, int] = {}
for row in all_rows:
did = row["drawer_id"]
drawer_counts[did] = drawer_counts.get(did, 0) + 1
drawer_list: list[dict] = []
for did, count in drawer_counts.items():
num, label = parse_drawer_id(did)
drawer_list.append(
{"drawer_id": did, "drawer_num": num, "drawer_label": label, "count": count}
)
drawer_list.sort(
key=lambda d: (
int(d["drawer_num"]) if d["drawer_num"].isdigit() else 9999,
d["drawer_id"],
)
)
known_drawer_ids = {d["drawer_id"] for d in drawer_list}
# -- Image cache (row_idx -> JPEG bytes) --
image_cache: dict[int, bytes] = {}
def _get_image_bytes(row_idx: int) -> bytes | None:
if row_idx in image_cache:
return image_cache[row_idx]
rows = (
table.search()
.where(f"_rowid = {row_idx}")
.select(["image"])
.limit(1)
.to_list()
)
if not rows:
return None
image_cache[row_idx] = rows[0]["image"]
return rows[0]["image"]
# -- Search functions --
def search_old_vector(query: str, limit: int) -> list[dict]:
q_vec = model.encode(query, normalize_embeddings=True).tolist()
return (
table.search(q_vec, vector_column_name="old_ocr_embedding")
.select(SELECT_COLS)
.limit(limit)
.to_list()
)
def search_new_vector(query: str, limit: int) -> list[dict]:
q_vec = model.encode(query, normalize_embeddings=True).tolist()
return (
table.search(q_vec, vector_column_name="new_ocr_embedding")
.select(SELECT_COLS)
.limit(limit)
.to_list()
)
def search_old_fts(query: str, limit: int) -> list[dict]:
return (
table.search(query, query_type="fts", fts_columns="text")
.select(SELECT_COLS)
.limit(limit)
.to_list()
)
def search_new_fts(query: str, limit: int) -> list[dict]:
return (
table.search(query, query_type="fts", fts_columns="markdown")
.select(SELECT_COLS)
.limit(limit)
.to_list()
)
def format_results(
results: list[dict],
ocr_field: str,
other_field: str,
ocr_label: str,
compare_label: str,
) -> list[dict]:
formatted = []
for i, row in enumerate(results):
score = (
row.get("_distance") or row.get("_score") or row.get("_relevance_score")
)
# Cache image bytes using a simple incrementing key
row_idx = len(image_cache)
if isinstance(row.get("image"), bytes):
image_cache[row_idx] = row["image"]
drawer_id = row.get("drawer_id", "?")
drawer_num, drawer_label = parse_drawer_id(drawer_id)
formatted.append(
{
"rank": i + 1,
"row_idx": row_idx,
"drawer_id": drawer_id,
"drawer_num": drawer_num,
"drawer_label": drawer_label,
"card_number": row.get("card_number", "?"),
"ocr_text": truncate(row.get(ocr_field, ""), 800),
"other_ocr": truncate(row.get(other_field, ""), 800),
"ocr_label": ocr_label,
"compare_label": compare_label,
"score": f"{score:.4f}" if score is not None else "",
"source_url": row.get("source_url", ""),
}
)
return formatted
# -- Routes --
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse(
request,
"index.html",
{
"total_cards": f"{total_cards:,}",
"total_drawers": len(drawer_list),
"examples": EXAMPLE_QUERIES,
"query": "",
"mode": "fts",
"limit": 10,
"old_ocr_label": OLD_OCR_LABEL,
"new_ocr_label": NEW_OCR_LABEL,
},
)
@app.get("/search", response_class=HTMLResponse)
async def search(
request: Request,
query: str = "",
mode: str = "fts",
limit: int = 5,
):
limit = max(1, min(20, limit))
if not query.strip():
return templates.TemplateResponse(
request,
"results.html",
{"query": "", "total_cards": f"{total_cards:,}"},
)
if mode == "fts":
old_raw = search_old_fts(query, limit)
new_raw = search_new_fts(query, limit)
else:
old_raw = search_old_vector(query, limit)
new_raw = search_new_vector(query, limit)
old_label = f"Old OCR ({OLD_OCR_LABEL})"
new_label = f"New OCR ({NEW_OCR_LABEL})"
old_results = format_results(
old_raw,
"text",
"markdown",
old_label,
new_label,
)
new_results = format_results(
new_raw,
"markdown",
"text",
new_label,
old_label,
)
return templates.TemplateResponse(
request,
"results.html",
{
"query": query,
"mode": mode,
"old_results": old_results,
"new_results": new_results,
"total_cards": f"{total_cards:,}",
"old_ocr_label": OLD_OCR_LABEL,
"new_ocr_label": NEW_OCR_LABEL,
},
)
@app.get("/search-single", response_class=HTMLResponse)
async def search_single(
request: Request,
query: str = "",
mode: str = "fts",
limit: int = 10,
):
limit = max(1, min(20, limit))
if not query.strip():
return templates.TemplateResponse(
request,
"results-search.html",
{"query": "", "total_cards": f"{total_cards:,}"},
)
if mode == "fts":
raw = search_new_fts(query, limit)
else:
raw = search_new_vector(query, limit)
new_label = f"New OCR ({NEW_OCR_LABEL})"
old_label = f"Old OCR ({OLD_OCR_LABEL})"
results = format_results(raw, "markdown", "text", new_label, old_label)
return templates.TemplateResponse(
request,
"results-search.html",
{
"query": query,
"mode": mode,
"results": results,
"total_cards": f"{total_cards:,}",
},
)
@app.get("/random-cards", response_class=HTMLResponse)
async def random_cards(request: Request):
indices = random.sample(
range(total_cards), min(SPOTLIGHT_COUNT, total_cards)
)
cards = []
for idx in indices:
rows = (
table.search().select(SELECT_COLS).limit(1).offset(idx).to_list()
)
if not rows:
continue
row = rows[0]
row_idx = len(image_cache)
if isinstance(row.get("image"), bytes):
image_cache[row_idx] = row["image"]
drawer_id = row.get("drawer_id", "?")
drawer_num, drawer_label = parse_drawer_id(drawer_id)
cards.append(
{
"row_idx": row_idx,
"drawer_id": drawer_id,
"drawer_num": drawer_num,
"drawer_label": drawer_label,
"card_number": row.get("card_number", "?"),
"ocr_text": truncate(row.get("markdown", ""), 200),
"source_url": row.get("source_url", ""),
}
)
return templates.TemplateResponse(
request, "spotlight.html", {"cards": cards}
)
@app.get("/drawers", response_class=HTMLResponse)
async def drawers_index(request: Request):
return templates.TemplateResponse(
request,
"drawers.html",
{
"drawers": drawer_list,
"total_drawers": len(drawer_list),
"total_cards": f"{total_cards:,}",
},
)
@app.get("/drawer/{drawer_id}", response_class=HTMLResponse)
async def drawer_detail(request: Request, drawer_id: str):
# Validate against known drawer IDs to prevent injection
if drawer_id not in known_drawer_ids:
return HTMLResponse("Drawer not found", status_code=404)
rows = (
table.search()
.where(f"drawer_id = '{drawer_id}'", prefilter=True)
.select(SELECT_COLS)
.limit(2000)
.to_list()
)
rows.sort(key=lambda r: r.get("card_number", 0))
cards = []
for i, row in enumerate(rows):
row_idx = len(image_cache)
if isinstance(row.get("image"), bytes):
image_cache[row_idx] = row["image"]
cards.append(
{
"card_number": row.get("card_number", i),
"row_idx": row_idx,
"ocr_text": truncate(row.get("markdown", ""), 800),
"source_url": row.get("source_url", ""),
}
)
# Prev/next drawer navigation
idx = next(
(i for i, d in enumerate(drawer_list) if d["drawer_id"] == drawer_id), -1
)
prev_drawer = drawer_list[idx - 1] if idx > 0 else None
next_drawer = drawer_list[idx + 1] if idx < len(drawer_list) - 1 else None
drawer_num, drawer_label = parse_drawer_id(drawer_id)
return templates.TemplateResponse(
request,
"drawer.html",
{
"drawer_id": drawer_id,
"drawer_num": drawer_num,
"drawer_label": drawer_label,
"cards": cards,
"card_count": len(cards),
"prev_drawer": prev_drawer,
"next_drawer": next_drawer,
},
)
@app.get("/image/{row_idx}")
async def image(row_idx: int):
img_bytes = image_cache.get(row_idx)
if img_bytes is None:
return HTMLResponse("Image not found", status_code=404)
buf = io.BytesIO()
img = Image.open(io.BytesIO(img_bytes))
img.save(buf, format="JPEG", quality=85)
buf.seek(0)
return StreamingResponse(buf, media_type="image/jpeg")
return app
def resolve_db_path(args) -> str:
"""Resolve database path from CLI args, env var, or HF Hub download."""
# Explicit local path takes priority
if args.db_path:
db_path = Path(args.db_path)
if not db_path.exists():
print(f"Database not found at {db_path}")
print("Run 'uv run bpl-lance-poc.py build' first.")
raise SystemExit(1)
return str(db_path)
# Download from HF Hub
repo_id = args.from_hub or os.environ.get("BPL_HUB_REPO", DEFAULT_HUB_REPO)
cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache")
local_dir = f"{cache_base}/bpl-lance"
print(f"Downloading dataset from {repo_id} to {local_dir}...")
snapshot_download(repo_id, repo_type="dataset", local_dir=local_dir)
print("Download complete.")
return local_dir
def main():
parser = argparse.ArgumentParser(
description="BPL OCR search comparison (FastAPI + HTMX)"
)
source = parser.add_mutually_exclusive_group()
source.add_argument(
"--db-path",
default=None,
help="Path to local LanceDB directory (for local dev)",
)
source.add_argument(
"--from-hub",
nargs="?",
const=DEFAULT_HUB_REPO,
default=None,
help=f"Download Lance DB from HF Hub (default: {DEFAULT_HUB_REPO})",
)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
# If neither --db-path nor --from-hub, try default local path
if args.db_path is None and args.from_hub is None:
default = Path(DB_PATH)
if default.exists():
args.db_path = DB_PATH
else:
args.from_hub = DEFAULT_HUB_REPO
db_path = resolve_db_path(args)
app = create_app(db_path)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()