File size: 8,025 Bytes
5e4028d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""Claude vision post-correction.

One API call per document: the full scan image plus all TrOCR-transcribed lines
(prefixed with [LINE_N]) are sent together so the model has cross-line context.
The model returns per-line corrections and a self-reported confidence score.

Pass no_api=True to skip the API call and return raw TrOCR text unchanged
(useful for the --no-api CLI flag and offline testing).
"""

from __future__ import annotations

import base64
import os
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Sequence

import anthropic
from dotenv import load_dotenv
from PIL import Image, ImageOps

from src.ocr_trocr import Line

MODEL_ID = "claude-haiku-4-5-20251001"

_CORRECTION_TOOL: dict = {
    "name": "submit_corrections",
    "description": "Submit per-line corrections for the transcribed document.",
    "input_schema": {
        "type": "object",
        "properties": {
            "lines": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "line_id": {"type": "integer"},
                        "original": {"type": "string"},
                        "corrected": {"type": "string"},
                        "changed": {"type": "boolean"},
                        "llm_confidence": {
                            "type": "number",
                            "minimum": 0.0,
                            "maximum": 1.0,
                        },
                    },
                    "required": ["line_id", "original", "corrected", "changed", "llm_confidence"],
                },
            }
        },
        "required": ["lines"],
    },
}


@dataclass
class CorrectedLine:
    """A post-corrected line with provenance from both TrOCR and Claude."""

    line_id: int
    original: str
    corrected: str
    changed: bool
    llm_confidence: float | None
    bbox: tuple[int, int, int, int] | None = None


@lru_cache(maxsize=1)
def _get_client() -> anthropic.Anthropic:
    load_dotenv()
    key = os.environ.get("ANTHROPIC_API_KEY")
    if not key:
        raise EnvironmentError(
            "ANTHROPIC_API_KEY not set. Add it to .env or export it as an environment variable. "
            "See .env.example for the required format."
        )
    return anthropic.Anthropic(api_key=key)


@lru_cache(maxsize=1)
def _load_system_prompt() -> str:
    prompt_path = Path(__file__).parent.parent / "prompts" / "v1" / "postcorrect.md"
    return prompt_path.read_text(encoding="utf-8")


def _encode_image(image_path: str | Path) -> tuple[str, str]:
    """Return (base64_data, media_type) for the given image file.

    Always re-encodes through PIL as JPEG. Two reasons:
      1. Anthropic accepts JPEG/PNG/GIF/WebP only — HEIC, TIFF, and other
         formats common on iPhone / scanner exports must be converted first.
         Reading the file as raw bytes and labelling it `image/jpeg` (the
         previous behaviour) trips a 400 "Could not process image" when the
         contents don't actually match the MIME type.
      2. PIL's HEIF plugin (registered in src.preprocess) handles HEIC files
         that are mislabelled with a `.jpg` / `.jpeg` extension by macOS
         Photos / iPhone exports.
    """
    from io import BytesIO

    # Importing src.preprocess registers the HEIF opener as a side effect, so
    # PIL.Image.open() handles HEIC files even when their extension lies.
    import src.preprocess  # noqa: F401

    pil = Image.open(image_path)
    pil = ImageOps.exif_transpose(pil).convert("RGB")
    buf = BytesIO()
    pil.save(buf, format="JPEG", quality=92)
    return base64.standard_b64encode(buf.getvalue()).decode("utf-8"), "image/jpeg"


def _build_transcription_block(trocr_lines: Sequence[Line]) -> str:
    return "\n".join(f"[LINE_{i}] {line.text}" for i, line in enumerate(trocr_lines))


def post_correct(
    image_path: str | Path,
    trocr_lines: Sequence[Line],
    *,
    no_api: bool = False,
    model: str = MODEL_ID,
) -> list[CorrectedLine]:
    """Post-correct TrOCR output with Claude vision (one call per document).

    Args:
        image_path: Path to the original scan.
        trocr_lines: Transcribed lines from ocr_trocr.transcribe().
        no_api: If True, skip the API call and return raw TrOCR text unchanged.
        model: Claude model ID to use.

    Returns:
        List of CorrectedLine, one per input line, in order.
    """
    if no_api or not trocr_lines:
        return [
            CorrectedLine(
                line_id=i,
                original=line.text,
                corrected=line.text,
                changed=False,
                llm_confidence=None,
                bbox=line.bbox,
            )
            for i, line in enumerate(trocr_lines)
        ]

    client = _get_client()
    system_prompt = _load_system_prompt()
    image_data, media_type = _encode_image(image_path)
    transcription_block = _build_transcription_block(trocr_lines)

    user_text = (
        f"Please correct the following OCR transcription. "
        f"There are {len(trocr_lines)} lines.\n\n"
        f"{transcription_block}"
    )

    response = client.messages.create(
        model=model,
        max_tokens=4096,
        system=[
            {
                "type": "text",
                "text": system_prompt,
                # Cache the system prompt across calls in the same batch run.
                # Saves ~90% of input token cost on the system prompt for docs 2+.
                "cache_control": {"type": "ephemeral"},
            }
        ],
        tools=[_CORRECTION_TOOL],
        tool_choice={"type": "tool", "name": "submit_corrections"},
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": media_type,
                            "data": image_data,
                        },
                    },
                    {"type": "text", "text": user_text},
                ],
            }
        ],
    )

    tool_use_block = next(
        (b for b in response.content if b.type == "tool_use"),
        None,
    )
    if tool_use_block is None:
        print(
            "[postcorrect] no tool_use block in response; returning raw TrOCR lines",
            file=sys.stderr,
        )
        return [
            CorrectedLine(
                line_id=i,
                original=line.text,
                corrected=line.text,
                changed=False,
                llm_confidence=None,
                bbox=line.bbox,
            )
            for i, line in enumerate(trocr_lines)
        ]

    corrections: dict[int, dict] = {
        item["line_id"]: item for item in tool_use_block.input["lines"]
    }

    result: list[CorrectedLine] = []
    for i, line in enumerate(trocr_lines):
        if i in corrections:
            c = corrections[i]
            result.append(
                CorrectedLine(
                    line_id=i,
                    original=line.text,
                    corrected=c["corrected"],
                    changed=c["changed"],
                    llm_confidence=c["llm_confidence"],
                    bbox=line.bbox,
                )
            )
        else:
            # Model didn't return this line; keep original
            print(
                f"[postcorrect] LINE_{i} missing from model response; keeping original",
                file=sys.stderr,
            )
            result.append(
                CorrectedLine(
                    line_id=i,
                    original=line.text,
                    corrected=line.text,
                    changed=False,
                    llm_confidence=None,
                    bbox=line.bbox,
                )
            )

    return result