apparatus-ocr / src /evaluation /metrics.py
al1808th's picture
first commit
69dc570
import math
from collections import Counter
def flatten_ocr_json(payload: dict[str, str]) -> str:
parts = [f"{key}\t{payload[key]}" for key in sorted(payload, key=_sort_key)]
return "\n".join(parts)
def _sort_key(value: str):
head, _, tail = value.partition("-")
try:
return (int(head), int(tail) if tail else -1, value)
except ValueError:
return (math.inf, math.inf, value)
def levenshtein_similarity(reference: str, prediction: str) -> float:
if reference == prediction:
return 100.0
if not reference and not prediction:
return 100.0
if not reference or not prediction:
return 0.0
prev = list(range(len(prediction) + 1))
for i, ref_char in enumerate(reference, start=1):
current = [i]
for j, pred_char in enumerate(prediction, start=1):
substitution_cost = 0 if ref_char == pred_char else 1
current.append(
min(
prev[j] + 1,
current[j - 1] + 1,
prev[j - 1] + substitution_cost,
)
)
prev = current
distance = prev[-1]
return max(0.0, (1 - (distance / max(len(reference), len(prediction)))) * 100.0)
def bleu_score(reference: str, prediction: str, max_order: int = 4) -> float:
ref_tokens = reference.split()
pred_tokens = prediction.split()
if not ref_tokens and not pred_tokens:
return 100.0
if not ref_tokens or not pred_tokens:
return 0.0
precisions = []
for order in range(1, max_order + 1):
ref_counts = _ngram_counts(ref_tokens, order)
pred_counts = _ngram_counts(pred_tokens, order)
overlap = sum(min(count, ref_counts[ngram]) for ngram, count in pred_counts.items())
total = max(sum(pred_counts.values()), 1)
precisions.append((overlap + 1) / (total + 1))
geo_mean = math.exp(sum(math.log(p) for p in precisions) / max_order)
bp = 1.0 if len(pred_tokens) > len(ref_tokens) else math.exp(1 - (len(ref_tokens) / len(pred_tokens)))
return geo_mean * bp * 100.0
def paired_ocr_metrics(reference: dict[str, str], prediction: dict[str, str]) -> dict[str, float]:
reference_text = flatten_ocr_json(reference)
prediction_text = flatten_ocr_json(prediction)
return {
"levenshtein": levenshtein_similarity(reference_text, prediction_text),
"bleu": bleu_score(reference_text, prediction_text),
}
def _ngram_counts(tokens: list[str], order: int) -> Counter:
if len(tokens) < order:
return Counter()
return Counter(tuple(tokens[i : i + order]) for i in range(len(tokens) - order + 1))