File size: 7,100 Bytes
5e4028d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """Flagger: predicts which post-corrected lines still need human review.
Two implementations behind a single flag(features) -> FlaggerOutput interface:
- learned (Phase 6+): sklearn classifier loaded from models/flagger_v1.pkl
- rule-based: hand-tuned thresholds on the same feature set
Phase 5 ships rule-based only; the learned path is wired in Phase 6 with the
same FlaggerOutput contract so the pipeline / Streamlit UI don't need to
change. A failing or missing model file always degrades to rule-based,
logged loudly to stderr, so the review queue surface never disappears.
"""
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable
import numpy as np
from rapidfuzz.distance import Levenshtein
from src.ocr_trocr import Line
from src.postcorrect import CorrectedLine
MODEL_PATH = Path(__file__).parent.parent / "models" / "flagger_v1.pkl"
DEFAULT_THRESHOLD = 0.5
# Lazy-loaded model bundle. Sentinel-vs-None lets us distinguish "not yet
# attempted" from "attempted and failed", so we don't retry on every flag()
# call when the file is genuinely missing.
_SENTINEL = object()
_MODEL_BUNDLE_CACHE = _SENTINEL
@dataclass
class FlaggerOutput:
prob_wrong: float
flagged: bool
reasons: list[str] = field(default_factory=list)
REASON_DESCRIPTIONS: dict[str, str] = {
"LOW_MIN_LOGPROB_TOKEN": "TrOCR was very uncertain about at least one token in this line.",
"LOW_MEAN_LOGPROB": "TrOCR was uncertain across the whole line.",
"HIGH_CORRECTION_DELTA": "Claude vision changed many characters relative to the TrOCR output.",
"LOW_LLM_CONFIDENCE": "Claude reported low confidence in the corrected text.",
"VERY_SHORT_LINE": "Very short line; the OCR had little context to work with.",
"NO_API_VERIFICATION": "Claude post-correction was skipped (--no-api mode); the transcription is raw TrOCR output and has not been verified by a second model.",
}
def describe(reason_code: str) -> str:
return REASON_DESCRIPTIONS.get(reason_code, reason_code)
def compute_features(trocr_line: Line, corrected_line: CorrectedLine) -> dict:
"""Per-line features shared by the rule-based and learned flaggers."""
logprobs = np.asarray(trocr_line.token_logprobs, dtype=np.float64)
n_tokens = int(logprobs.size)
if n_tokens > 0:
mean_lp = float(logprobs.mean())
min_lp = float(logprobs.min())
std_lp = float(logprobs.std())
length_norm = float(logprobs.sum() / max(n_tokens, 1))
else:
mean_lp = min_lp = std_lp = length_norm = 0.0
distance = Levenshtein.distance(trocr_line.text, corrected_line.corrected)
base_len = max(len(trocr_line.text), 1)
return {
"n_tokens": n_tokens,
"mean_logprob": mean_lp,
"min_logprob": min_lp,
"std_logprob": std_lp,
"length_normalized_logprob": length_norm,
"edit_distance_trocr_vs_corrected": int(distance),
"n_chars_changed": int(distance),
"frac_chars_changed": float(distance / base_len),
"llm_confidence": corrected_line.llm_confidence,
"line_height_px": int(trocr_line.bbox[3]),
"line_width_px": int(trocr_line.bbox[2]),
}
def _flag_rule_based(features: dict, threshold: float) -> FlaggerOutput:
"""Hand-tuned thresholds. Each rule contributes a candidate score; we
take the max. Crude on purpose — Phase 6's learned model replaces this."""
reasons: list[str] = []
score = 0.0
n_tokens = features["n_tokens"]
if n_tokens > 0 and features["min_logprob"] < -3.0:
reasons.append("LOW_MIN_LOGPROB_TOKEN")
score = max(score, 0.6)
if n_tokens > 0 and features["mean_logprob"] < -1.0:
reasons.append("LOW_MEAN_LOGPROB")
score = max(score, 0.5)
if features["frac_chars_changed"] > 0.3:
reasons.append("HIGH_CORRECTION_DELTA")
score = max(score, 0.7)
llm_conf = features["llm_confidence"]
if llm_conf is not None and llm_conf < 0.6:
reasons.append("LOW_LLM_CONFIDENCE")
score = max(score, 1.0 - float(llm_conf))
if 0 < n_tokens <= 2:
reasons.append("VERY_SHORT_LINE")
score = max(score, 0.4)
return FlaggerOutput(
prob_wrong=float(min(score, 1.0)),
flagged=bool(score >= threshold),
reasons=reasons,
)
def _try_load_model():
"""Returns the loaded model bundle dict, or None to fall back. Cached
after first load — joblib.load is expensive enough to matter when batch
runs call flag() per line."""
global _MODEL_BUNDLE_CACHE
if _MODEL_BUNDLE_CACHE is _SENTINEL:
if not MODEL_PATH.exists():
_MODEL_BUNDLE_CACHE = None
else:
try:
import joblib
_MODEL_BUNDLE_CACHE = joblib.load(MODEL_PATH)
except Exception as exc:
print(
f"[flagger] failed to load {MODEL_PATH}: {exc!r}; "
f"falling back to rule-based",
file=sys.stderr,
)
_MODEL_BUNDLE_CACHE = None
return _MODEL_BUNDLE_CACHE
def _flag_learned(features: dict, bundle: dict, threshold: float) -> FlaggerOutput:
"""Learned-model path. Probability comes from the trained classifier;
reason codes still come from the rule-based path so flagged lines have
a human-readable explanation alongside the probability."""
feature_names = bundle["feature_names"]
model = bundle["model"]
scaler = bundle["scaler"]
x = np.array(
[[float(features.get(name, 0.0) or 0.0) for name in feature_names]],
dtype=np.float64,
)
x_scaled = scaler.transform(x)
prob = float(model.predict_proba(x_scaled)[0, 1])
# Reuse the rule-based reasons as descriptive attribution. They explain
# which features fired strongly, not why the model predicted what it did.
rule_out = _flag_rule_based(features, threshold)
return FlaggerOutput(prob_wrong=prob, flagged=bool(prob >= threshold), reasons=rule_out.reasons)
def flag(features: dict, *, threshold: float | None = None) -> FlaggerOutput:
"""Flag a single line. Uses the learned model when (a) it's loaded and
(b) we have post-correction context (llm_confidence is not None). The
learned model was trained on lines that went through Claude vision
post-correction; calling it on --no-api features (llm_confidence=None,
frac_chars_changed=0) is out of distribution and produces nonsense, so
we fall back to rule-based for that case. Reason codes are identical
across both paths."""
bundle = _try_load_model()
if bundle is None or features.get("llm_confidence") is None:
return _flag_rule_based(features, threshold or DEFAULT_THRESHOLD)
return _flag_learned(features, bundle, threshold or bundle.get("threshold", DEFAULT_THRESHOLD))
def flag_many(
features_iter: Iterable[dict], *, threshold: float | None = None
) -> list[FlaggerOutput]:
return [flag(f, threshold=threshold) for f in features_iter]
|