| """Flagger: predicts which post-corrected lines still need human review. |
| |
| Two implementations behind a single flag(features) -> FlaggerOutput interface: |
| - learned (Phase 6+): sklearn classifier loaded from models/flagger_v1.pkl |
| - rule-based: hand-tuned thresholds on the same feature set |
| |
| Phase 5 ships rule-based only; the learned path is wired in Phase 6 with the |
| same FlaggerOutput contract so the pipeline / Streamlit UI don't need to |
| change. A failing or missing model file always degrades to rule-based, |
| logged loudly to stderr, so the review queue surface never disappears. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Iterable |
|
|
| import numpy as np |
| from rapidfuzz.distance import Levenshtein |
|
|
| from src.ocr_trocr import Line |
| from src.postcorrect import CorrectedLine |
|
|
| MODEL_PATH = Path(__file__).parent.parent / "models" / "flagger_v1.pkl" |
| DEFAULT_THRESHOLD = 0.5 |
|
|
| |
| |
| |
| _SENTINEL = object() |
| _MODEL_BUNDLE_CACHE = _SENTINEL |
|
|
|
|
| @dataclass |
| class FlaggerOutput: |
| prob_wrong: float |
| flagged: bool |
| reasons: list[str] = field(default_factory=list) |
|
|
|
|
| REASON_DESCRIPTIONS: dict[str, str] = { |
| "LOW_MIN_LOGPROB_TOKEN": "TrOCR was very uncertain about at least one token in this line.", |
| "LOW_MEAN_LOGPROB": "TrOCR was uncertain across the whole line.", |
| "HIGH_CORRECTION_DELTA": "Claude vision changed many characters relative to the TrOCR output.", |
| "LOW_LLM_CONFIDENCE": "Claude reported low confidence in the corrected text.", |
| "VERY_SHORT_LINE": "Very short line; the OCR had little context to work with.", |
| "NO_API_VERIFICATION": "Claude post-correction was skipped (--no-api mode); the transcription is raw TrOCR output and has not been verified by a second model.", |
| } |
|
|
|
|
| def describe(reason_code: str) -> str: |
| return REASON_DESCRIPTIONS.get(reason_code, reason_code) |
|
|
|
|
| def compute_features(trocr_line: Line, corrected_line: CorrectedLine) -> dict: |
| """Per-line features shared by the rule-based and learned flaggers.""" |
| logprobs = np.asarray(trocr_line.token_logprobs, dtype=np.float64) |
| n_tokens = int(logprobs.size) |
| if n_tokens > 0: |
| mean_lp = float(logprobs.mean()) |
| min_lp = float(logprobs.min()) |
| std_lp = float(logprobs.std()) |
| length_norm = float(logprobs.sum() / max(n_tokens, 1)) |
| else: |
| mean_lp = min_lp = std_lp = length_norm = 0.0 |
|
|
| distance = Levenshtein.distance(trocr_line.text, corrected_line.corrected) |
| base_len = max(len(trocr_line.text), 1) |
|
|
| return { |
| "n_tokens": n_tokens, |
| "mean_logprob": mean_lp, |
| "min_logprob": min_lp, |
| "std_logprob": std_lp, |
| "length_normalized_logprob": length_norm, |
| "edit_distance_trocr_vs_corrected": int(distance), |
| "n_chars_changed": int(distance), |
| "frac_chars_changed": float(distance / base_len), |
| "llm_confidence": corrected_line.llm_confidence, |
| "line_height_px": int(trocr_line.bbox[3]), |
| "line_width_px": int(trocr_line.bbox[2]), |
| } |
|
|
|
|
| def _flag_rule_based(features: dict, threshold: float) -> FlaggerOutput: |
| """Hand-tuned thresholds. Each rule contributes a candidate score; we |
| take the max. Crude on purpose — Phase 6's learned model replaces this.""" |
| reasons: list[str] = [] |
| score = 0.0 |
|
|
| n_tokens = features["n_tokens"] |
|
|
| if n_tokens > 0 and features["min_logprob"] < -3.0: |
| reasons.append("LOW_MIN_LOGPROB_TOKEN") |
| score = max(score, 0.6) |
|
|
| if n_tokens > 0 and features["mean_logprob"] < -1.0: |
| reasons.append("LOW_MEAN_LOGPROB") |
| score = max(score, 0.5) |
|
|
| if features["frac_chars_changed"] > 0.3: |
| reasons.append("HIGH_CORRECTION_DELTA") |
| score = max(score, 0.7) |
|
|
| llm_conf = features["llm_confidence"] |
| if llm_conf is not None and llm_conf < 0.6: |
| reasons.append("LOW_LLM_CONFIDENCE") |
| score = max(score, 1.0 - float(llm_conf)) |
|
|
| if 0 < n_tokens <= 2: |
| reasons.append("VERY_SHORT_LINE") |
| score = max(score, 0.4) |
|
|
| return FlaggerOutput( |
| prob_wrong=float(min(score, 1.0)), |
| flagged=bool(score >= threshold), |
| reasons=reasons, |
| ) |
|
|
|
|
| def _try_load_model(): |
| """Returns the loaded model bundle dict, or None to fall back. Cached |
| after first load — joblib.load is expensive enough to matter when batch |
| runs call flag() per line.""" |
| global _MODEL_BUNDLE_CACHE |
| if _MODEL_BUNDLE_CACHE is _SENTINEL: |
| if not MODEL_PATH.exists(): |
| _MODEL_BUNDLE_CACHE = None |
| else: |
| try: |
| import joblib |
|
|
| _MODEL_BUNDLE_CACHE = joblib.load(MODEL_PATH) |
| except Exception as exc: |
| print( |
| f"[flagger] failed to load {MODEL_PATH}: {exc!r}; " |
| f"falling back to rule-based", |
| file=sys.stderr, |
| ) |
| _MODEL_BUNDLE_CACHE = None |
| return _MODEL_BUNDLE_CACHE |
|
|
|
|
| def _flag_learned(features: dict, bundle: dict, threshold: float) -> FlaggerOutput: |
| """Learned-model path. Probability comes from the trained classifier; |
| reason codes still come from the rule-based path so flagged lines have |
| a human-readable explanation alongside the probability.""" |
| feature_names = bundle["feature_names"] |
| model = bundle["model"] |
| scaler = bundle["scaler"] |
|
|
| x = np.array( |
| [[float(features.get(name, 0.0) or 0.0) for name in feature_names]], |
| dtype=np.float64, |
| ) |
| x_scaled = scaler.transform(x) |
| prob = float(model.predict_proba(x_scaled)[0, 1]) |
|
|
| |
| |
| rule_out = _flag_rule_based(features, threshold) |
| return FlaggerOutput(prob_wrong=prob, flagged=bool(prob >= threshold), reasons=rule_out.reasons) |
|
|
|
|
| def flag(features: dict, *, threshold: float | None = None) -> FlaggerOutput: |
| """Flag a single line. Uses the learned model when (a) it's loaded and |
| (b) we have post-correction context (llm_confidence is not None). The |
| learned model was trained on lines that went through Claude vision |
| post-correction; calling it on --no-api features (llm_confidence=None, |
| frac_chars_changed=0) is out of distribution and produces nonsense, so |
| we fall back to rule-based for that case. Reason codes are identical |
| across both paths.""" |
| bundle = _try_load_model() |
| if bundle is None or features.get("llm_confidence") is None: |
| return _flag_rule_based(features, threshold or DEFAULT_THRESHOLD) |
| return _flag_learned(features, bundle, threshold or bundle.get("threshold", DEFAULT_THRESHOLD)) |
|
|
|
|
| def flag_many( |
| features_iter: Iterable[dict], *, threshold: float | None = None |
| ) -> list[FlaggerOutput]: |
| return [flag(f, threshold=threshold) for f in features_iter] |
|
|