ankira / ocr /grading.py
nofater's picture
mvp
2a0825b
Raw
History Blame Contribute Delete
8.08 kB
"""Deterministic dictation grading: align a blind transcription against a
known reference and emit a strict JSON report.
See specs/ocr.md. The VLM transcription step lives elsewhere; this module
never sees the image and is fully testable with a fixed transcription string.
"""
import re
import unicodedata
from difflib import SequenceMatcher
from typing import Literal, NamedTuple
from pydantic import BaseModel
Status = Literal["correct", "misspelled", "missing", "extra"]
# A word is a maximal run of word characters (Unicode letters/digits/_);
# any other non-space character is a standalone punctuation token.
_TOKEN_RE = re.compile(r"\w+|[^\w\s]", re.UNICODE)
class Token(NamedTuple):
text: str
is_word: bool
class GradeOptions(NamedTuple):
"""Pedagogy knobs for grading (see specs/ocr.md §8). Defaults match the
spec: case matters, punctuation is reported but not counted, diacritics
are graded."""
case_sensitive: bool = True
grade_punctuation: bool = False
grade_diacritics: bool = True
def comparison_key(word: str, options: GradeOptions) -> str:
"""Map a surface word to the key used for alignment/equality, applying the
active grading options. The surface form is preserved for display; only
this key decides whether two words count as 'the same'.
Never maps ß→ss or ö→oe — diacritic *grading* only strips combining marks
(accents), and ß is not a combining mark, so it always stays distinct."""
key = unicodedata.normalize("NFC", word)
if not options.case_sensitive:
key = key.casefold()
if not options.grade_diacritics:
decomposed = unicodedata.normalize("NFD", key)
key = "".join(ch for ch in decomposed if not unicodedata.combining(ch))
key = unicodedata.normalize("NFC", key)
return key
def tokenize(text: str) -> list[Token]:
"""NFC-normalize then split into word and punctuation tokens."""
text = unicodedata.normalize("NFC", text)
tokens: list[Token] = []
for match in _TOKEN_RE.finditer(text):
piece = match.group()
tokens.append(Token(piece, is_word=bool(re.match(r"\w", piece))))
return tokens
def graphemes(text: str) -> list[str]:
"""Split a string into grapheme clusters (a base char plus any trailing
combining marks), so combining diacritics don't desync a character diff."""
clusters: list[str] = []
for ch in text:
if clusters and unicodedata.combining(ch):
clusters[-1] += ch
else:
clusters.append(ch)
return clusters
def char_diff(expected: str, read: str) -> str:
"""Human-readable grapheme-level diff describing how ``read`` deviates from
``expected`` (e.g. ``ß→ss``, ``-t``, ``+e``). Empty string if identical."""
a, b = graphemes(expected), graphemes(read)
parts: list[str] = []
for op, i1, i2, j1, j2 in SequenceMatcher(a=a, b=b).get_opcodes():
if op == "equal":
continue
exp_chunk, read_chunk = "".join(a[i1:i2]), "".join(b[j1:j2])
if op == "replace":
parts.append(f"{exp_chunk}{read_chunk}")
elif op == "delete":
parts.append(f"-{exp_chunk}")
elif op == "insert":
parts.append(f"+{read_chunk}")
return ", ".join(parts)
class Word(BaseModel):
"""One graded word in the report (specs/ocr.md §7).
``expected`` is None for ``extra``; ``read`` is None for ``missing``;
``diff`` is present only for ``misspelled``."""
index: int
expected: str | None
read: str | None
status: Status
diff: str | None = None
def _word_texts(text: str, options: GradeOptions) -> list[str]:
"""Token surface forms to align: words always, punctuation only when graded."""
return [
tok.text
for tok in tokenize(text)
if tok.is_word or options.grade_punctuation
]
def align_words(
reference: str, transcription: str, options: GradeOptions | None = None
) -> list[Word]:
"""Align the blind transcription against the reference at word level and
classify each token (specs/ocr.md stages 3-4). Deterministic, no model."""
options = options or GradeOptions()
ref = _word_texts(reference, options)
read = _word_texts(transcription, options)
ref_keys = [comparison_key(w, options) for w in ref]
read_keys = [comparison_key(w, options) for w in read]
words: list[Word] = []
def emit(expected: str | None, got: str | None, status: Status) -> None:
diff = char_diff(expected, got) if status == "misspelled" else None
words.append(
Word(
index=len(words),
expected=expected,
read=got,
status=status,
diff=diff or None,
)
)
for op, i1, i2, j1, j2 in SequenceMatcher(a=ref_keys, b=read_keys).get_opcodes():
if op == "equal":
for i, j in zip(range(i1, i2), range(j1, j2)):
emit(ref[i], read[j], "correct")
elif op == "replace":
# Pair up as misspellings; leftovers are missing/extra.
paired = min(i2 - i1, j2 - j1)
for k in range(paired):
emit(ref[i1 + k], read[j1 + k], "misspelled")
for i in range(i1 + paired, i2):
emit(ref[i], None, "missing")
for j in range(j1 + paired, j2):
emit(None, read[j], "extra")
elif op == "delete":
for i in range(i1, i2):
emit(ref[i], None, "missing")
elif op == "insert":
for j in range(j1, j2):
emit(None, read[j], "extra")
return words
class Summary(BaseModel):
"""Tally over the graded words. ``total`` counts reference words only
(correct + misspelled + missing); extras don't inflate it."""
total: int
correct: int
misspelled: int
missing: int
extra: int
accuracy: float
class GradeReport(BaseModel):
"""The strict JSON grading report (specs/ocr.md §7)."""
lang: str
reference: str
transcription: str
words: list[Word]
summary: Summary
def _summarize(words: list[Word]) -> Summary:
counts = {"correct": 0, "misspelled": 0, "missing": 0, "extra": 0}
for w in words:
counts[w.status] += 1
total = counts["correct"] + counts["misspelled"] + counts["missing"]
accuracy = round(counts["correct"] / total, 4) if total else 0.0
return Summary(total=total, accuracy=accuracy, **counts)
def grade(
reference: str,
transcription: str,
lang: str,
options: GradeOptions | None = None,
) -> GradeReport:
"""Grade a blind transcription against the reference and return a validated
report (specs/ocr.md stages 3-5). The image/VLM never enters here."""
options = options or GradeOptions()
words = align_words(reference, transcription, options)
return GradeReport(
lang=lang,
reference=unicodedata.normalize("NFC", reference),
transcription=unicodedata.normalize("NFC", transcription),
words=words,
summary=_summarize(words),
)
_MARKS = {"correct": "✓", "misspelled": "✗", "missing": "·", "extra": "+"}
def format_text_report(report: GradeReport) -> str:
"""Render a GradeReport as a human-readable plain-text report (derivable
purely from the JSON, specs/ocr.md §2)."""
s = report.summary
lines = [
f"[{report.lang}] {s.correct}/{s.total} correct "
f"({s.accuracy * 100:.0f}%) "
f"misspelled={s.misspelled} missing={s.missing} extra={s.extra}",
]
for w in report.words:
mark = _MARKS[w.status]
if w.status == "correct":
lines.append(f" {mark} {w.read}")
elif w.status == "misspelled":
lines.append(f" {mark} {w.expected}{w.read} [{w.diff}]")
elif w.status == "missing":
lines.append(f" {mark} {w.expected} (missing)")
else: # extra
lines.append(f" {mark} {w.read} (extra)")
return "\n".join(lines)