"""Cache TrOCR + Claude vision post-correction over the IAM-HistDB GW subset. Writes data/parquet_cache/iam_gw_pipeline.parquet — the training set for the flagger. One row per ground-truth line. Columns include TrOCR confidence features, TrOCR/post-correction agreement features, the ground truth, and the binary label `is_still_wrong = (CER(corrected, gt) > 0)`. Strategy: - Use the FKI dataset's pre-segmented line images for TrOCR (one call/line). This guarantees 1:1 alignment with the dataset's line-level ground truth. - Use the original LoC page scan for the Claude vision post-correction call (one call/page). FKI's release ships line images only; the original page scans live at the Library of Congress (download via scripts/download_gw_pages.py). Sending the real scan keeps the training-time and production-time inputs in the same distribution. - Idempotent: skips pages whose rows are already in the output parquet. Dataset layout assumed: {root}/data/line_images_normalized/{line_id}.png (FKI release) {root}/ground_truth/transcription.txt (FKI release) {loc_pages}/{page_id}.jpg (LoC, downloaded separately) Run: # First, get the LoC page scans (one-shot, ~5 MB): python scripts/download_gw_pages.py # Then cache the pipeline run (~$0.10, ~5 min): python scripts/cache_iam_gw.py python scripts/cache_iam_gw.py --dry-run # just discover & print python scripts/cache_iam_gw.py --limit 2 # process first 2 pages """ from __future__ import annotations import argparse import sys from collections import defaultdict from dataclasses import dataclass from pathlib import Path # Allow running as `python scripts/cache_iam_gw.py` from the repo root. sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import numpy as np import pandas as pd import pyarrow.parquet as pq from jiwer import cer from PIL import Image from rapidfuzz.distance import Levenshtein from src.ocr_trocr import Line, transcribe_one from src.postcorrect import post_correct DEFAULT_GW_ROOT = Path("data/raw/iam_gw/washingtondb-v1.0") DEFAULT_LOC_PAGES = Path("data/raw/iam_gw/loc_pages") DEFAULT_OUTPUT = Path("data/parquet_cache/iam_gw_pipeline.parquet") # GW transcriptions encode special characters as "s_xx" tokens. This map covers # the common ones; extend if you hit unmapped tokens (you'll see them verbatim # in the gt column with an `s_` prefix). GW_SPECIAL_CHARS = { "s_pt": ".", "s_cm": ",", "s_mi": "-", "s_qo": "'", "s_qc": "'", "s_qt": '"', "s_sq": ";", "s_cl": ":", "s_GW": "G.W.", "s_lb": "£", "s_bsl": "/", "s_et": "&", "s_br": ")", "s_bl": "(", "s_pl": "+", "s_eq": "=", "s_no": "No.", "s_pa": ",", "s_5s": "5s", "s_qu": "?", "s_excl": "!", "s_s": "s", } # Candidate line-image directories (different versions of the GW release name # this differently). First hit wins. LINE_DIR_CANDIDATES = ["line_images_normalized", "lines", "line_images"] PAGE_EXT_CANDIDATES = [".jpg", ".jpeg", ".png", ".tif", ".tiff"] @dataclass class GWDoc: page_id: str page_image: Path lines: list[tuple[str, Path, str]] # (line_id, line_image_path, gt_text) def _decode_gw_word(word: str) -> str: chars = word.split("-") return "".join(GW_SPECIAL_CHARS.get(c, c) for c in chars) def _decode_gw_transcription(text_part: str) -> str: return " ".join(_decode_gw_word(w) for w in text_part.split("|")) def _parse_transcription(path: Path) -> dict[str, str]: out: dict[str, str] = {} for raw in path.read_text(encoding="utf-8").splitlines(): raw = raw.strip() if not raw or raw.startswith("#"): continue parts = raw.split(" ", 1) if len(parts) != 2: continue line_id, text_part = parts out[line_id] = _decode_gw_transcription(text_part) return out def _resolve_dir(root: Path, candidates: list[str]) -> Path | None: for name in candidates: cand = root / "data" / name if cand.is_dir(): return cand return None def _resolve_image(directory: Path, stem: str) -> Path | None: for ext in PAGE_EXT_CANDIDATES: cand = directory / f"{stem}{ext}" if cand.exists(): return cand return None def _discover_gw(root: Path, loc_pages: Path) -> list[GWDoc]: transcription = root / "ground_truth" / "transcription.txt" if not transcription.exists(): sys.exit( f"transcription file not found: {transcription}\n" f"Pass --root pointing at your extracted GW dataset, or adjust " f"DEFAULT_GW_ROOT in this file." ) line_dir = _resolve_dir(root, LINE_DIR_CANDIDATES) if line_dir is None: sys.exit(f"No line-image directory found under {root / 'data'} " f"(tried: {LINE_DIR_CANDIDATES})") if not loc_pages.is_dir(): sys.exit( f"LoC page directory not found: {loc_pages}\n" f"Run `python scripts/download_gw_pages.py` first to fetch the " f"original page scans from loc.gov." ) print(f"[discover] line images: {line_dir}", file=sys.stderr) print(f"[discover] page scans: {loc_pages}", file=sys.stderr) gt = _parse_transcription(transcription) print(f"[discover] {len(gt)} ground-truth lines", file=sys.stderr) pages: dict[str, list[tuple[str, str]]] = defaultdict(list) for line_id, text in gt.items(): page_id = line_id.split("-")[0] pages[page_id].append((line_id, text)) docs: list[GWDoc] = [] for page_id in sorted(pages.keys()): page_image = _resolve_image(loc_pages, page_id) if page_image is None: print(f"[skip page {page_id}] no LoC scan in {loc_pages} " f"(run download_gw_pages.py)", file=sys.stderr) continue lines = sorted(pages[page_id], key=lambda t: t[0]) line_paths: list[tuple[str, Path, str]] = [] for line_id, gt_text in lines: line_img = _resolve_image(line_dir, line_id) if line_img is None: print(f"[skip line {line_id}] no image", file=sys.stderr) continue line_paths.append((line_id, line_img, gt_text)) if line_paths: docs.append(GWDoc(page_id=page_id, page_image=page_image, lines=line_paths)) print(f"[discover] {len(docs)} pages with at least one usable line", file=sys.stderr) return docs def _logprob_features(token_logprobs: list[float]) -> dict: if not token_logprobs: return { "n_tokens": 0, "mean_logprob": 0.0, "min_logprob": 0.0, "std_logprob": 0.0, "length_normalized_logprob": 0.0, } arr = np.asarray(token_logprobs, dtype=np.float64) return { "n_tokens": int(arr.size), "mean_logprob": float(arr.mean()), "min_logprob": float(arr.min()), "std_logprob": float(arr.std()), "length_normalized_logprob": float(arr.sum() / max(arr.size, 1)), } def _agreement_features(trocr_text: str, corrected_text: str) -> dict: distance = Levenshtein.distance(trocr_text, corrected_text) base_len = max(len(trocr_text), 1) return { "edit_distance_trocr_vs_corrected": int(distance), "n_chars_changed": int(distance), "frac_chars_changed": float(distance / base_len), } def _process_doc(doc: GWDoc, *, no_api: bool) -> list[dict]: print(f"[page {doc.page_id}] {len(doc.lines)} lines", file=sys.stderr) trocr_lines: list[Line] = [] for line_id, line_image_path, _gt in doc.lines: img = np.array(Image.open(line_image_path).convert("RGB")) text, logprobs, token_ids = transcribe_one(img) trocr_lines.append( Line( text=text, bbox=(0, 0, img.shape[1], img.shape[0]), token_logprobs=logprobs, token_ids=token_ids, ) ) corrected = post_correct(doc.page_image, trocr_lines, no_api=no_api) rows: list[dict] = [] for (line_id, _path, gt_text), trocr_line, corrected_line in zip( doc.lines, trocr_lines, corrected ): gt_cer = cer(reference=gt_text, hypothesis=corrected_line.corrected) row = { "doc_id": doc.page_id, "line_id": line_id, "trocr_text": trocr_line.text, "trocr_token_logprobs": list(trocr_line.token_logprobs), **_logprob_features(trocr_line.token_logprobs), "corrected_text": corrected_line.corrected, "llm_confidence": corrected_line.llm_confidence, "changed": corrected_line.changed, **_agreement_features(trocr_line.text, corrected_line.corrected), "line_height_px": int(trocr_line.bbox[3]), "line_width_px": int(trocr_line.bbox[2]), "gt": gt_text, "gt_cer": float(gt_cer), "is_still_wrong": bool(gt_cer > 0), } rows.append(row) return rows def _existing_doc_ids(parquet_path: Path) -> set[str]: if not parquet_path.exists(): return set() table = pq.read_table(parquet_path, columns=["doc_id"]) return set(table["doc_id"].to_pylist()) def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--root", type=Path, default=DEFAULT_GW_ROOT, help="Path to extracted FKI GW dataset") parser.add_argument("--loc-pages", type=Path, default=DEFAULT_LOC_PAGES, help="Directory of LoC page scans (one JPG per page id; " "produced by scripts/download_gw_pages.py)") parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, help="Parquet path to write") parser.add_argument("--no-api", action="store_true", help="Skip Claude post-correction (dev mode; the resulting " "parquet is useless for training the flagger)") parser.add_argument("--limit", type=int, default=None, help="Stop after this many pages (testing only)") parser.add_argument("--dry-run", action="store_true", help="Discover the dataset and exit without running the pipeline") args = parser.parse_args() docs = _discover_gw(args.root, args.loc_pages) if args.dry_run: for doc in docs[: (args.limit or len(docs))]: print(f" {doc.page_id}: {len(doc.lines)} lines, " f"page={doc.page_image.name}") return 0 args.output.parent.mkdir(parents=True, exist_ok=True) already_done = _existing_doc_ids(args.output) if already_done: print(f"[resume] skipping {len(already_done)} pages already in {args.output}", file=sys.stderr) pending = [d for d in docs if d.page_id not in already_done] if args.limit is not None: pending = pending[: args.limit] if not pending: print("[done] nothing to do", file=sys.stderr) return 0 existing_rows: list[dict] = [] if args.output.exists(): existing_rows = pq.read_table(args.output).to_pylist() new_rows: list[dict] = [] for i, doc in enumerate(pending, start=1): try: rows = _process_doc(doc, no_api=args.no_api) except Exception as exc: print(f"[error page {doc.page_id}] {exc!r}", file=sys.stderr) continue new_rows.extend(rows) df = pd.DataFrame(existing_rows + new_rows) df.to_parquet(args.output, index=False) print(f"[saved] {i}/{len(pending)} rows total: {len(df)}", file=sys.stderr) print(f"\n[done] wrote {args.output} ({len(existing_rows) + len(new_rows)} rows)") return 0 if __name__ == "__main__": sys.exit(main())