| """Cache TrOCR + Claude vision post-correction over the IAM-HistDB GW subset. |
| |
| Writes data/parquet_cache/iam_gw_pipeline.parquet — the training set for the |
| flagger. One row per ground-truth line. Columns include TrOCR confidence |
| features, TrOCR/post-correction agreement features, the ground truth, and the |
| binary label `is_still_wrong = (CER(corrected, gt) > 0)`. |
| |
| Strategy: |
| - Use the FKI dataset's pre-segmented line images for TrOCR (one call/line). |
| This guarantees 1:1 alignment with the dataset's line-level ground truth. |
| - Use the original LoC page scan for the Claude vision post-correction call |
| (one call/page). FKI's release ships line images only; the original page |
| scans live at the Library of Congress (download via scripts/download_gw_pages.py). |
| Sending the real scan keeps the training-time and production-time inputs |
| in the same distribution. |
| - Idempotent: skips pages whose rows are already in the output parquet. |
| |
| Dataset layout assumed: |
| {root}/data/line_images_normalized/{line_id}.png (FKI release) |
| {root}/ground_truth/transcription.txt (FKI release) |
| {loc_pages}/{page_id}.jpg (LoC, downloaded separately) |
| |
| Run: |
| # First, get the LoC page scans (one-shot, ~5 MB): |
| python scripts/download_gw_pages.py |
| # Then cache the pipeline run (~$0.10, ~5 min): |
| python scripts/cache_iam_gw.py |
| python scripts/cache_iam_gw.py --dry-run # just discover & print |
| python scripts/cache_iam_gw.py --limit 2 # process first 2 pages |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import numpy as np |
| import pandas as pd |
| import pyarrow.parquet as pq |
| from jiwer import cer |
| from PIL import Image |
| from rapidfuzz.distance import Levenshtein |
|
|
| from src.ocr_trocr import Line, transcribe_one |
| from src.postcorrect import post_correct |
|
|
| DEFAULT_GW_ROOT = Path("data/raw/iam_gw/washingtondb-v1.0") |
| DEFAULT_LOC_PAGES = Path("data/raw/iam_gw/loc_pages") |
| DEFAULT_OUTPUT = Path("data/parquet_cache/iam_gw_pipeline.parquet") |
|
|
| |
| |
| |
| GW_SPECIAL_CHARS = { |
| "s_pt": ".", "s_cm": ",", "s_mi": "-", "s_qo": "'", "s_qc": "'", |
| "s_qt": '"', "s_sq": ";", "s_cl": ":", "s_GW": "G.W.", "s_lb": "£", |
| "s_bsl": "/", "s_et": "&", "s_br": ")", "s_bl": "(", "s_pl": "+", |
| "s_eq": "=", "s_no": "No.", "s_pa": ",", "s_5s": "5s", "s_qu": "?", |
| "s_excl": "!", "s_s": "s", |
| } |
|
|
| |
| |
| LINE_DIR_CANDIDATES = ["line_images_normalized", "lines", "line_images"] |
| PAGE_EXT_CANDIDATES = [".jpg", ".jpeg", ".png", ".tif", ".tiff"] |
|
|
|
|
| @dataclass |
| class GWDoc: |
| page_id: str |
| page_image: Path |
| lines: list[tuple[str, Path, str]] |
|
|
|
|
| def _decode_gw_word(word: str) -> str: |
| chars = word.split("-") |
| return "".join(GW_SPECIAL_CHARS.get(c, c) for c in chars) |
|
|
|
|
| def _decode_gw_transcription(text_part: str) -> str: |
| return " ".join(_decode_gw_word(w) for w in text_part.split("|")) |
|
|
|
|
| def _parse_transcription(path: Path) -> dict[str, str]: |
| out: dict[str, str] = {} |
| for raw in path.read_text(encoding="utf-8").splitlines(): |
| raw = raw.strip() |
| if not raw or raw.startswith("#"): |
| continue |
| parts = raw.split(" ", 1) |
| if len(parts) != 2: |
| continue |
| line_id, text_part = parts |
| out[line_id] = _decode_gw_transcription(text_part) |
| return out |
|
|
|
|
| def _resolve_dir(root: Path, candidates: list[str]) -> Path | None: |
| for name in candidates: |
| cand = root / "data" / name |
| if cand.is_dir(): |
| return cand |
| return None |
|
|
|
|
| def _resolve_image(directory: Path, stem: str) -> Path | None: |
| for ext in PAGE_EXT_CANDIDATES: |
| cand = directory / f"{stem}{ext}" |
| if cand.exists(): |
| return cand |
| return None |
|
|
|
|
| def _discover_gw(root: Path, loc_pages: Path) -> list[GWDoc]: |
| transcription = root / "ground_truth" / "transcription.txt" |
| if not transcription.exists(): |
| sys.exit( |
| f"transcription file not found: {transcription}\n" |
| f"Pass --root pointing at your extracted GW dataset, or adjust " |
| f"DEFAULT_GW_ROOT in this file." |
| ) |
|
|
| line_dir = _resolve_dir(root, LINE_DIR_CANDIDATES) |
| if line_dir is None: |
| sys.exit(f"No line-image directory found under {root / 'data'} " |
| f"(tried: {LINE_DIR_CANDIDATES})") |
|
|
| if not loc_pages.is_dir(): |
| sys.exit( |
| f"LoC page directory not found: {loc_pages}\n" |
| f"Run `python scripts/download_gw_pages.py` first to fetch the " |
| f"original page scans from loc.gov." |
| ) |
|
|
| print(f"[discover] line images: {line_dir}", file=sys.stderr) |
| print(f"[discover] page scans: {loc_pages}", file=sys.stderr) |
|
|
| gt = _parse_transcription(transcription) |
| print(f"[discover] {len(gt)} ground-truth lines", file=sys.stderr) |
|
|
| pages: dict[str, list[tuple[str, str]]] = defaultdict(list) |
| for line_id, text in gt.items(): |
| page_id = line_id.split("-")[0] |
| pages[page_id].append((line_id, text)) |
|
|
| docs: list[GWDoc] = [] |
| for page_id in sorted(pages.keys()): |
| page_image = _resolve_image(loc_pages, page_id) |
| if page_image is None: |
| print(f"[skip page {page_id}] no LoC scan in {loc_pages} " |
| f"(run download_gw_pages.py)", file=sys.stderr) |
| continue |
| lines = sorted(pages[page_id], key=lambda t: t[0]) |
| line_paths: list[tuple[str, Path, str]] = [] |
| for line_id, gt_text in lines: |
| line_img = _resolve_image(line_dir, line_id) |
| if line_img is None: |
| print(f"[skip line {line_id}] no image", file=sys.stderr) |
| continue |
| line_paths.append((line_id, line_img, gt_text)) |
| if line_paths: |
| docs.append(GWDoc(page_id=page_id, page_image=page_image, lines=line_paths)) |
|
|
| print(f"[discover] {len(docs)} pages with at least one usable line", |
| file=sys.stderr) |
| return docs |
|
|
|
|
| def _logprob_features(token_logprobs: list[float]) -> dict: |
| if not token_logprobs: |
| return { |
| "n_tokens": 0, |
| "mean_logprob": 0.0, |
| "min_logprob": 0.0, |
| "std_logprob": 0.0, |
| "length_normalized_logprob": 0.0, |
| } |
| arr = np.asarray(token_logprobs, dtype=np.float64) |
| return { |
| "n_tokens": int(arr.size), |
| "mean_logprob": float(arr.mean()), |
| "min_logprob": float(arr.min()), |
| "std_logprob": float(arr.std()), |
| "length_normalized_logprob": float(arr.sum() / max(arr.size, 1)), |
| } |
|
|
|
|
| def _agreement_features(trocr_text: str, corrected_text: str) -> dict: |
| distance = Levenshtein.distance(trocr_text, corrected_text) |
| base_len = max(len(trocr_text), 1) |
| return { |
| "edit_distance_trocr_vs_corrected": int(distance), |
| "n_chars_changed": int(distance), |
| "frac_chars_changed": float(distance / base_len), |
| } |
|
|
|
|
| def _process_doc(doc: GWDoc, *, no_api: bool) -> list[dict]: |
| print(f"[page {doc.page_id}] {len(doc.lines)} lines", file=sys.stderr) |
|
|
| trocr_lines: list[Line] = [] |
| for line_id, line_image_path, _gt in doc.lines: |
| img = np.array(Image.open(line_image_path).convert("RGB")) |
| text, logprobs, token_ids = transcribe_one(img) |
| trocr_lines.append( |
| Line( |
| text=text, |
| bbox=(0, 0, img.shape[1], img.shape[0]), |
| token_logprobs=logprobs, |
| token_ids=token_ids, |
| ) |
| ) |
|
|
| corrected = post_correct(doc.page_image, trocr_lines, no_api=no_api) |
|
|
| rows: list[dict] = [] |
| for (line_id, _path, gt_text), trocr_line, corrected_line in zip( |
| doc.lines, trocr_lines, corrected |
| ): |
| gt_cer = cer(reference=gt_text, hypothesis=corrected_line.corrected) |
| row = { |
| "doc_id": doc.page_id, |
| "line_id": line_id, |
| "trocr_text": trocr_line.text, |
| "trocr_token_logprobs": list(trocr_line.token_logprobs), |
| **_logprob_features(trocr_line.token_logprobs), |
| "corrected_text": corrected_line.corrected, |
| "llm_confidence": corrected_line.llm_confidence, |
| "changed": corrected_line.changed, |
| **_agreement_features(trocr_line.text, corrected_line.corrected), |
| "line_height_px": int(trocr_line.bbox[3]), |
| "line_width_px": int(trocr_line.bbox[2]), |
| "gt": gt_text, |
| "gt_cer": float(gt_cer), |
| "is_still_wrong": bool(gt_cer > 0), |
| } |
| rows.append(row) |
|
|
| return rows |
|
|
|
|
| def _existing_doc_ids(parquet_path: Path) -> set[str]: |
| if not parquet_path.exists(): |
| return set() |
| table = pq.read_table(parquet_path, columns=["doc_id"]) |
| return set(table["doc_id"].to_pylist()) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--root", type=Path, default=DEFAULT_GW_ROOT, |
| help="Path to extracted FKI GW dataset") |
| parser.add_argument("--loc-pages", type=Path, default=DEFAULT_LOC_PAGES, |
| help="Directory of LoC page scans (one JPG per page id; " |
| "produced by scripts/download_gw_pages.py)") |
| parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, |
| help="Parquet path to write") |
| parser.add_argument("--no-api", action="store_true", |
| help="Skip Claude post-correction (dev mode; the resulting " |
| "parquet is useless for training the flagger)") |
| parser.add_argument("--limit", type=int, default=None, |
| help="Stop after this many pages (testing only)") |
| parser.add_argument("--dry-run", action="store_true", |
| help="Discover the dataset and exit without running the pipeline") |
| args = parser.parse_args() |
|
|
| docs = _discover_gw(args.root, args.loc_pages) |
| if args.dry_run: |
| for doc in docs[: (args.limit or len(docs))]: |
| print(f" {doc.page_id}: {len(doc.lines)} lines, " |
| f"page={doc.page_image.name}") |
| return 0 |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| already_done = _existing_doc_ids(args.output) |
| if already_done: |
| print(f"[resume] skipping {len(already_done)} pages already in {args.output}", |
| file=sys.stderr) |
|
|
| pending = [d for d in docs if d.page_id not in already_done] |
| if args.limit is not None: |
| pending = pending[: args.limit] |
|
|
| if not pending: |
| print("[done] nothing to do", file=sys.stderr) |
| return 0 |
|
|
| existing_rows: list[dict] = [] |
| if args.output.exists(): |
| existing_rows = pq.read_table(args.output).to_pylist() |
|
|
| new_rows: list[dict] = [] |
| for i, doc in enumerate(pending, start=1): |
| try: |
| rows = _process_doc(doc, no_api=args.no_api) |
| except Exception as exc: |
| print(f"[error page {doc.page_id}] {exc!r}", file=sys.stderr) |
| continue |
| new_rows.extend(rows) |
| df = pd.DataFrame(existing_rows + new_rows) |
| df.to_parquet(args.output, index=False) |
| print(f"[saved] {i}/{len(pending)} rows total: {len(df)}", |
| file=sys.stderr) |
|
|
| print(f"\n[done] wrote {args.output} ({len(existing_rows) + len(new_rows)} rows)") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|