File size: 4,398 Bytes
7f9dfed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from __future__ import annotations

import csv
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any

from datasets.field_notes import FieldNote, FieldNoteStore


@dataclass(frozen=True)
class OCRPrediction:
    """One local OCR prediction that may need human correction."""

    source_path: str
    text: str
    confidence: float
    page: str = ""

    @classmethod
    def from_row(cls, row: dict[str, Any]) -> OCRPrediction:
        source_path = _first_present(row, ["source_path", "image_path", "file_path", "path"])
        text = _first_present(row, ["text", "prediction", "ocr_text", "response"])
        confidence_raw = _first_present(row, ["confidence", "score", "probability"], "0")
        return cls(
            source_path=source_path,
            text=text,
            confidence=_parse_confidence(confidence_raw),
            page=str(row.get("page", "")),
        )

    def to_dict(self) -> dict[str, object]:
        return asdict(self)


def load_ocr_predictions(path: str | Path) -> list[OCRPrediction]:
    """Load OCR predictions from local CSV, JSONL, or NDJSON files."""

    source = Path(path)
    if not source.exists():
        raise FileNotFoundError(f"OCR prediction file not found: {source}")

    suffix = source.suffix.lower()
    if suffix == ".csv":
        with source.open(newline="", encoding="utf-8") as f:
            return [OCRPrediction.from_row(row) for row in csv.DictReader(f)]
    if suffix in {".jsonl", ".ndjson"}:
        return [
            OCRPrediction.from_row(json.loads(line))
            for line in source.read_text(encoding="utf-8").splitlines()
            if line.strip()
        ]
    raise ValueError("OCR predictions must be a .csv, .jsonl, or .ndjson file.")


def uncertain_predictions(
    predictions: list[OCRPrediction],
    confidence_threshold: float = 0.8,
) -> list[OCRPrediction]:
    return [
        prediction
        for prediction in predictions
        if prediction.confidence <= confidence_threshold or not prediction.text.strip()
    ]


def import_uncertain_predictions(
    store: FieldNoteStore,
    predictions: list[OCRPrediction],
    model_id: str,
    confidence_threshold: float = 0.8,
    tags: str = "ocr,uncertain",
) -> int:
    imported = 0
    for prediction in uncertain_predictions(predictions, confidence_threshold):
        page_note = f" page {prediction.page}" if prediction.page else ""
        note = FieldNote.create(
            model_id=model_id,
            prompt=f"Review OCR text for {prediction.source_path}{page_note}.",
            response=prediction.text,
            correction="",
            tags=tags,
            image_path=prediction.source_path,
            use_for_training=False,
        )
        store.save(note)
        imported += 1
    return imported


def export_corrected_ocr_notes(
    store: FieldNoteStore,
    output_path: str | Path = "data/ocr_corrections.jsonl",
) -> Path:
    output = Path(output_path)
    output.parent.mkdir(parents=True, exist_ok=True)
    notes = store.list_notes(corrected_only=True, tag="ocr")
    with output.open("w", encoding="utf-8") as f:
        for note in notes:
            row = {
                "source_path": note.image_path,
                "predicted_text": note.response,
                "corrected_text": note.correction,
                "model_id": note.model_id,
                "created_at": note.created_at,
                "tags": note.tags,
            }
            f.write(json.dumps(row, ensure_ascii=False) + "\n")
    return output


def ocr_import_summary(path: str | Path, confidence_threshold: float = 0.8) -> dict[str, object]:
    predictions = load_ocr_predictions(path)
    uncertain = uncertain_predictions(predictions, confidence_threshold)
    return {
        "source": str(path),
        "rows": len(predictions),
        "uncertain_rows": len(uncertain),
        "confidence_threshold": confidence_threshold,
        "sample": [prediction.to_dict() for prediction in uncertain[:5]],
    }


def _first_present(row: dict[str, Any], names: list[str], default: str = "") -> str:
    for name in names:
        value = row.get(name)
        if value is not None:
            return str(value)
    return default


def _parse_confidence(value: str) -> float:
    try:
        return float(value)
    except ValueError:
        return 0.0