"""Deterministic dictation grading: align a blind transcription against a known reference and emit a strict JSON report. See specs/ocr.md. The VLM transcription step lives elsewhere; this module never sees the image and is fully testable with a fixed transcription string. """ import re import unicodedata from difflib import SequenceMatcher from typing import Literal, NamedTuple from pydantic import BaseModel Status = Literal["correct", "misspelled", "missing", "extra"] # A word is a maximal run of word characters (Unicode letters/digits/_); # any other non-space character is a standalone punctuation token. _TOKEN_RE = re.compile(r"\w+|[^\w\s]", re.UNICODE) class Token(NamedTuple): text: str is_word: bool class GradeOptions(NamedTuple): """Pedagogy knobs for grading (see specs/ocr.md §8). Defaults match the spec: case matters, punctuation is reported but not counted, diacritics are graded.""" case_sensitive: bool = True grade_punctuation: bool = False grade_diacritics: bool = True def comparison_key(word: str, options: GradeOptions) -> str: """Map a surface word to the key used for alignment/equality, applying the active grading options. The surface form is preserved for display; only this key decides whether two words count as 'the same'. Never maps ß→ss or ö→oe — diacritic *grading* only strips combining marks (accents), and ß is not a combining mark, so it always stays distinct.""" key = unicodedata.normalize("NFC", word) if not options.case_sensitive: key = key.casefold() if not options.grade_diacritics: decomposed = unicodedata.normalize("NFD", key) key = "".join(ch for ch in decomposed if not unicodedata.combining(ch)) key = unicodedata.normalize("NFC", key) return key def tokenize(text: str) -> list[Token]: """NFC-normalize then split into word and punctuation tokens.""" text = unicodedata.normalize("NFC", text) tokens: list[Token] = [] for match in _TOKEN_RE.finditer(text): piece = match.group() tokens.append(Token(piece, is_word=bool(re.match(r"\w", piece)))) return tokens def graphemes(text: str) -> list[str]: """Split a string into grapheme clusters (a base char plus any trailing combining marks), so combining diacritics don't desync a character diff.""" clusters: list[str] = [] for ch in text: if clusters and unicodedata.combining(ch): clusters[-1] += ch else: clusters.append(ch) return clusters def char_diff(expected: str, read: str) -> str: """Human-readable grapheme-level diff describing how ``read`` deviates from ``expected`` (e.g. ``ß→ss``, ``-t``, ``+e``). Empty string if identical.""" a, b = graphemes(expected), graphemes(read) parts: list[str] = [] for op, i1, i2, j1, j2 in SequenceMatcher(a=a, b=b).get_opcodes(): if op == "equal": continue exp_chunk, read_chunk = "".join(a[i1:i2]), "".join(b[j1:j2]) if op == "replace": parts.append(f"{exp_chunk}→{read_chunk}") elif op == "delete": parts.append(f"-{exp_chunk}") elif op == "insert": parts.append(f"+{read_chunk}") return ", ".join(parts) class Word(BaseModel): """One graded word in the report (specs/ocr.md §7). ``expected`` is None for ``extra``; ``read`` is None for ``missing``; ``diff`` is present only for ``misspelled``.""" index: int expected: str | None read: str | None status: Status diff: str | None = None def _word_texts(text: str, options: GradeOptions) -> list[str]: """Token surface forms to align: words always, punctuation only when graded.""" return [ tok.text for tok in tokenize(text) if tok.is_word or options.grade_punctuation ] def align_words( reference: str, transcription: str, options: GradeOptions | None = None ) -> list[Word]: """Align the blind transcription against the reference at word level and classify each token (specs/ocr.md stages 3-4). Deterministic, no model.""" options = options or GradeOptions() ref = _word_texts(reference, options) read = _word_texts(transcription, options) ref_keys = [comparison_key(w, options) for w in ref] read_keys = [comparison_key(w, options) for w in read] words: list[Word] = [] def emit(expected: str | None, got: str | None, status: Status) -> None: diff = char_diff(expected, got) if status == "misspelled" else None words.append( Word( index=len(words), expected=expected, read=got, status=status, diff=diff or None, ) ) for op, i1, i2, j1, j2 in SequenceMatcher(a=ref_keys, b=read_keys).get_opcodes(): if op == "equal": for i, j in zip(range(i1, i2), range(j1, j2)): emit(ref[i], read[j], "correct") elif op == "replace": # Pair up as misspellings; leftovers are missing/extra. paired = min(i2 - i1, j2 - j1) for k in range(paired): emit(ref[i1 + k], read[j1 + k], "misspelled") for i in range(i1 + paired, i2): emit(ref[i], None, "missing") for j in range(j1 + paired, j2): emit(None, read[j], "extra") elif op == "delete": for i in range(i1, i2): emit(ref[i], None, "missing") elif op == "insert": for j in range(j1, j2): emit(None, read[j], "extra") return words class Summary(BaseModel): """Tally over the graded words. ``total`` counts reference words only (correct + misspelled + missing); extras don't inflate it.""" total: int correct: int misspelled: int missing: int extra: int accuracy: float class GradeReport(BaseModel): """The strict JSON grading report (specs/ocr.md §7).""" lang: str reference: str transcription: str words: list[Word] summary: Summary def _summarize(words: list[Word]) -> Summary: counts = {"correct": 0, "misspelled": 0, "missing": 0, "extra": 0} for w in words: counts[w.status] += 1 total = counts["correct"] + counts["misspelled"] + counts["missing"] accuracy = round(counts["correct"] / total, 4) if total else 0.0 return Summary(total=total, accuracy=accuracy, **counts) def grade( reference: str, transcription: str, lang: str, options: GradeOptions | None = None, ) -> GradeReport: """Grade a blind transcription against the reference and return a validated report (specs/ocr.md stages 3-5). The image/VLM never enters here.""" options = options or GradeOptions() words = align_words(reference, transcription, options) return GradeReport( lang=lang, reference=unicodedata.normalize("NFC", reference), transcription=unicodedata.normalize("NFC", transcription), words=words, summary=_summarize(words), ) _MARKS = {"correct": "✓", "misspelled": "✗", "missing": "·", "extra": "+"} def format_text_report(report: GradeReport) -> str: """Render a GradeReport as a human-readable plain-text report (derivable purely from the JSON, specs/ocr.md §2).""" s = report.summary lines = [ f"[{report.lang}] {s.correct}/{s.total} correct " f"({s.accuracy * 100:.0f}%) " f"misspelled={s.misspelled} missing={s.missing} extra={s.extra}", ] for w in report.words: mark = _MARKS[w.status] if w.status == "correct": lines.append(f" {mark} {w.read}") elif w.status == "misspelled": lines.append(f" {mark} {w.expected} → {w.read} [{w.diff}]") elif w.status == "missing": lines.append(f" {mark} {w.expected} (missing)") else: # extra lines.append(f" {mark} {w.read} (extra)") return "\n".join(lines)