File size: 7,100 Bytes
5e4028d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Flagger: predicts which post-corrected lines still need human review.

Two implementations behind a single flag(features) -> FlaggerOutput interface:
  - learned (Phase 6+): sklearn classifier loaded from models/flagger_v1.pkl
  - rule-based: hand-tuned thresholds on the same feature set

Phase 5 ships rule-based only; the learned path is wired in Phase 6 with the
same FlaggerOutput contract so the pipeline / Streamlit UI don't need to
change. A failing or missing model file always degrades to rule-based,
logged loudly to stderr, so the review queue surface never disappears.
"""

from __future__ import annotations

import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable

import numpy as np
from rapidfuzz.distance import Levenshtein

from src.ocr_trocr import Line
from src.postcorrect import CorrectedLine

MODEL_PATH = Path(__file__).parent.parent / "models" / "flagger_v1.pkl"
DEFAULT_THRESHOLD = 0.5

# Lazy-loaded model bundle. Sentinel-vs-None lets us distinguish "not yet
# attempted" from "attempted and failed", so we don't retry on every flag()
# call when the file is genuinely missing.
_SENTINEL = object()
_MODEL_BUNDLE_CACHE = _SENTINEL


@dataclass
class FlaggerOutput:
    prob_wrong: float
    flagged: bool
    reasons: list[str] = field(default_factory=list)


REASON_DESCRIPTIONS: dict[str, str] = {
    "LOW_MIN_LOGPROB_TOKEN": "TrOCR was very uncertain about at least one token in this line.",
    "LOW_MEAN_LOGPROB": "TrOCR was uncertain across the whole line.",
    "HIGH_CORRECTION_DELTA": "Claude vision changed many characters relative to the TrOCR output.",
    "LOW_LLM_CONFIDENCE": "Claude reported low confidence in the corrected text.",
    "VERY_SHORT_LINE": "Very short line; the OCR had little context to work with.",
    "NO_API_VERIFICATION": "Claude post-correction was skipped (--no-api mode); the transcription is raw TrOCR output and has not been verified by a second model.",
}


def describe(reason_code: str) -> str:
    return REASON_DESCRIPTIONS.get(reason_code, reason_code)


def compute_features(trocr_line: Line, corrected_line: CorrectedLine) -> dict:
    """Per-line features shared by the rule-based and learned flaggers."""
    logprobs = np.asarray(trocr_line.token_logprobs, dtype=np.float64)
    n_tokens = int(logprobs.size)
    if n_tokens > 0:
        mean_lp = float(logprobs.mean())
        min_lp = float(logprobs.min())
        std_lp = float(logprobs.std())
        length_norm = float(logprobs.sum() / max(n_tokens, 1))
    else:
        mean_lp = min_lp = std_lp = length_norm = 0.0

    distance = Levenshtein.distance(trocr_line.text, corrected_line.corrected)
    base_len = max(len(trocr_line.text), 1)

    return {
        "n_tokens": n_tokens,
        "mean_logprob": mean_lp,
        "min_logprob": min_lp,
        "std_logprob": std_lp,
        "length_normalized_logprob": length_norm,
        "edit_distance_trocr_vs_corrected": int(distance),
        "n_chars_changed": int(distance),
        "frac_chars_changed": float(distance / base_len),
        "llm_confidence": corrected_line.llm_confidence,
        "line_height_px": int(trocr_line.bbox[3]),
        "line_width_px": int(trocr_line.bbox[2]),
    }


def _flag_rule_based(features: dict, threshold: float) -> FlaggerOutput:
    """Hand-tuned thresholds. Each rule contributes a candidate score; we
    take the max. Crude on purpose — Phase 6's learned model replaces this."""
    reasons: list[str] = []
    score = 0.0

    n_tokens = features["n_tokens"]

    if n_tokens > 0 and features["min_logprob"] < -3.0:
        reasons.append("LOW_MIN_LOGPROB_TOKEN")
        score = max(score, 0.6)

    if n_tokens > 0 and features["mean_logprob"] < -1.0:
        reasons.append("LOW_MEAN_LOGPROB")
        score = max(score, 0.5)

    if features["frac_chars_changed"] > 0.3:
        reasons.append("HIGH_CORRECTION_DELTA")
        score = max(score, 0.7)

    llm_conf = features["llm_confidence"]
    if llm_conf is not None and llm_conf < 0.6:
        reasons.append("LOW_LLM_CONFIDENCE")
        score = max(score, 1.0 - float(llm_conf))

    if 0 < n_tokens <= 2:
        reasons.append("VERY_SHORT_LINE")
        score = max(score, 0.4)

    return FlaggerOutput(
        prob_wrong=float(min(score, 1.0)),
        flagged=bool(score >= threshold),
        reasons=reasons,
    )


def _try_load_model():
    """Returns the loaded model bundle dict, or None to fall back. Cached
    after first load — joblib.load is expensive enough to matter when batch
    runs call flag() per line."""
    global _MODEL_BUNDLE_CACHE
    if _MODEL_BUNDLE_CACHE is _SENTINEL:
        if not MODEL_PATH.exists():
            _MODEL_BUNDLE_CACHE = None
        else:
            try:
                import joblib

                _MODEL_BUNDLE_CACHE = joblib.load(MODEL_PATH)
            except Exception as exc:
                print(
                    f"[flagger] failed to load {MODEL_PATH}: {exc!r}; "
                    f"falling back to rule-based",
                    file=sys.stderr,
                )
                _MODEL_BUNDLE_CACHE = None
    return _MODEL_BUNDLE_CACHE


def _flag_learned(features: dict, bundle: dict, threshold: float) -> FlaggerOutput:
    """Learned-model path. Probability comes from the trained classifier;
    reason codes still come from the rule-based path so flagged lines have
    a human-readable explanation alongside the probability."""
    feature_names = bundle["feature_names"]
    model = bundle["model"]
    scaler = bundle["scaler"]

    x = np.array(
        [[float(features.get(name, 0.0) or 0.0) for name in feature_names]],
        dtype=np.float64,
    )
    x_scaled = scaler.transform(x)
    prob = float(model.predict_proba(x_scaled)[0, 1])

    # Reuse the rule-based reasons as descriptive attribution. They explain
    # which features fired strongly, not why the model predicted what it did.
    rule_out = _flag_rule_based(features, threshold)
    return FlaggerOutput(prob_wrong=prob, flagged=bool(prob >= threshold), reasons=rule_out.reasons)


def flag(features: dict, *, threshold: float | None = None) -> FlaggerOutput:
    """Flag a single line. Uses the learned model when (a) it's loaded and
    (b) we have post-correction context (llm_confidence is not None). The
    learned model was trained on lines that went through Claude vision
    post-correction; calling it on --no-api features (llm_confidence=None,
    frac_chars_changed=0) is out of distribution and produces nonsense, so
    we fall back to rule-based for that case. Reason codes are identical
    across both paths."""
    bundle = _try_load_model()
    if bundle is None or features.get("llm_confidence") is None:
        return _flag_rule_based(features, threshold or DEFAULT_THRESHOLD)
    return _flag_learned(features, bundle, threshold or bundle.get("threshold", DEFAULT_THRESHOLD))


def flag_many(
    features_iter: Iterable[dict], *, threshold: float | None = None
) -> list[FlaggerOutput]:
    return [flag(f, threshold=threshold) for f in features_iter]