historical-doc-extractor / src /postcorrect.py
narayananv10
HF Space deploy snapshot
5e4028d
"""Claude vision post-correction.
One API call per document: the full scan image plus all TrOCR-transcribed lines
(prefixed with [LINE_N]) are sent together so the model has cross-line context.
The model returns per-line corrections and a self-reported confidence score.
Pass no_api=True to skip the API call and return raw TrOCR text unchanged
(useful for the --no-api CLI flag and offline testing).
"""
from __future__ import annotations
import base64
import os
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Sequence
import anthropic
from dotenv import load_dotenv
from PIL import Image, ImageOps
from src.ocr_trocr import Line
MODEL_ID = "claude-haiku-4-5-20251001"
_CORRECTION_TOOL: dict = {
"name": "submit_corrections",
"description": "Submit per-line corrections for the transcribed document.",
"input_schema": {
"type": "object",
"properties": {
"lines": {
"type": "array",
"items": {
"type": "object",
"properties": {
"line_id": {"type": "integer"},
"original": {"type": "string"},
"corrected": {"type": "string"},
"changed": {"type": "boolean"},
"llm_confidence": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
},
},
"required": ["line_id", "original", "corrected", "changed", "llm_confidence"],
},
}
},
"required": ["lines"],
},
}
@dataclass
class CorrectedLine:
"""A post-corrected line with provenance from both TrOCR and Claude."""
line_id: int
original: str
corrected: str
changed: bool
llm_confidence: float | None
bbox: tuple[int, int, int, int] | None = None
@lru_cache(maxsize=1)
def _get_client() -> anthropic.Anthropic:
load_dotenv()
key = os.environ.get("ANTHROPIC_API_KEY")
if not key:
raise EnvironmentError(
"ANTHROPIC_API_KEY not set. Add it to .env or export it as an environment variable. "
"See .env.example for the required format."
)
return anthropic.Anthropic(api_key=key)
@lru_cache(maxsize=1)
def _load_system_prompt() -> str:
prompt_path = Path(__file__).parent.parent / "prompts" / "v1" / "postcorrect.md"
return prompt_path.read_text(encoding="utf-8")
def _encode_image(image_path: str | Path) -> tuple[str, str]:
"""Return (base64_data, media_type) for the given image file.
Always re-encodes through PIL as JPEG. Two reasons:
1. Anthropic accepts JPEG/PNG/GIF/WebP only — HEIC, TIFF, and other
formats common on iPhone / scanner exports must be converted first.
Reading the file as raw bytes and labelling it `image/jpeg` (the
previous behaviour) trips a 400 "Could not process image" when the
contents don't actually match the MIME type.
2. PIL's HEIF plugin (registered in src.preprocess) handles HEIC files
that are mislabelled with a `.jpg` / `.jpeg` extension by macOS
Photos / iPhone exports.
"""
from io import BytesIO
# Importing src.preprocess registers the HEIF opener as a side effect, so
# PIL.Image.open() handles HEIC files even when their extension lies.
import src.preprocess # noqa: F401
pil = Image.open(image_path)
pil = ImageOps.exif_transpose(pil).convert("RGB")
buf = BytesIO()
pil.save(buf, format="JPEG", quality=92)
return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/jpeg"
def _build_transcription_block(trocr_lines: Sequence[Line]) -> str:
return "\n".join(f"[LINE_{i}] {line.text}" for i, line in enumerate(trocr_lines))
def post_correct(
image_path: str | Path,
trocr_lines: Sequence[Line],
*,
no_api: bool = False,
model: str = MODEL_ID,
) -> list[CorrectedLine]:
"""Post-correct TrOCR output with Claude vision (one call per document).
Args:
image_path: Path to the original scan.
trocr_lines: Transcribed lines from ocr_trocr.transcribe().
no_api: If True, skip the API call and return raw TrOCR text unchanged.
model: Claude model ID to use.
Returns:
List of CorrectedLine, one per input line, in order.
"""
if no_api or not trocr_lines:
return [
CorrectedLine(
line_id=i,
original=line.text,
corrected=line.text,
changed=False,
llm_confidence=None,
bbox=line.bbox,
)
for i, line in enumerate(trocr_lines)
]
client = _get_client()
system_prompt = _load_system_prompt()
image_data, media_type = _encode_image(image_path)
transcription_block = _build_transcription_block(trocr_lines)
user_text = (
f"Please correct the following OCR transcription. "
f"There are {len(trocr_lines)} lines.\n\n"
f"{transcription_block}"
)
response = client.messages.create(
model=model,
max_tokens=4096,
system=[
{
"type": "text",
"text": system_prompt,
# Cache the system prompt across calls in the same batch run.
# Saves ~90% of input token cost on the system prompt for docs 2+.
"cache_control": {"type": "ephemeral"},
}
],
tools=[_CORRECTION_TOOL],
tool_choice={"type": "tool", "name": "submit_corrections"},
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": image_data,
},
},
{"type": "text", "text": user_text},
],
}
],
)
tool_use_block = next(
(b for b in response.content if b.type == "tool_use"),
None,
)
if tool_use_block is None:
print(
"[postcorrect] no tool_use block in response; returning raw TrOCR lines",
file=sys.stderr,
)
return [
CorrectedLine(
line_id=i,
original=line.text,
corrected=line.text,
changed=False,
llm_confidence=None,
bbox=line.bbox,
)
for i, line in enumerate(trocr_lines)
]
corrections: dict[int, dict] = {
item["line_id"]: item for item in tool_use_block.input["lines"]
}
result: list[CorrectedLine] = []
for i, line in enumerate(trocr_lines):
if i in corrections:
c = corrections[i]
result.append(
CorrectedLine(
line_id=i,
original=line.text,
corrected=c["corrected"],
changed=c["changed"],
llm_confidence=c["llm_confidence"],
bbox=line.bbox,
)
)
else:
# Model didn't return this line; keep original
print(
f"[postcorrect] LINE_{i} missing from model response; keeping original",
file=sys.stderr,
)
result.append(
CorrectedLine(
line_id=i,
original=line.text,
corrected=line.text,
changed=False,
llm_confidence=None,
bbox=line.bbox,
)
)
return result