apparatus-ocr / src /evaluation /build_result.py
al1808th's picture
first commit
69dc570
import argparse
import json
import os
from pathlib import Path
from src.evaluation.metrics import bleu_score, flatten_ocr_json, levenshtein_similarity
REPO_ROOT = Path(__file__).resolve().parents[2]
GOLD_ROOT = REPO_ROOT / "data" / "lloyd-jones-soph-170" / "ocr"
TEXT_GOLD_PATH = GOLD_ROOT / "lloyd-jones-text.json"
APPARATUS_GOLD_PATH = GOLD_ROOT / "lloyd-jones-apparatus.json"
def build_result_payload(
model_name: str,
revision: str,
precision: str,
easy_text_path: str,
easy_apparatus_path: str,
hard_text_path: str,
hard_apparatus_path: str,
) -> dict:
text_gold = _load_json(TEXT_GOLD_PATH)
apparatus_gold = _load_json(APPARATUS_GOLD_PATH)
easy_reference = _join_sections(text_gold, apparatus_gold)
easy_prediction = _join_sections(_load_json(easy_text_path), _load_json(easy_apparatus_path))
hard_reference = _join_sections(text_gold, apparatus_gold)
hard_prediction = _join_sections(_load_json(hard_text_path), _load_json(hard_apparatus_path))
return {
"config": {
"model_dtype": _normalize_precision(precision),
"model_name": model_name,
"model_sha": revision,
},
"results": {
"easy_levenshtein": {"score": levenshtein_similarity(easy_reference, easy_prediction) / 100.0},
"easy_bleu": {"score": bleu_score(easy_reference, easy_prediction) / 100.0},
"hard_levenshtein": {"score": levenshtein_similarity(hard_reference, hard_prediction) / 100.0},
"hard_bleu": {"score": bleu_score(hard_reference, hard_prediction) / 100.0},
},
}
def main():
parser = argparse.ArgumentParser(description="Build a leaderboard-compatible result JSON for the OCR benchmark.")
parser.add_argument("--model-name", required=True)
parser.add_argument("--revision", default="main")
parser.add_argument("--precision", default="float16")
parser.add_argument("--easy-text", required=True)
parser.add_argument("--easy-apparatus", required=True)
parser.add_argument("--hard-text", required=True)
parser.add_argument("--hard-apparatus", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()
payload = build_result_payload(
model_name=args.model_name,
revision=args.revision,
precision=args.precision,
easy_text_path=args.easy_text,
easy_apparatus_path=args.easy_apparatus,
hard_text_path=args.hard_text,
hard_apparatus_path=args.hard_apparatus,
)
os.makedirs(os.path.dirname(args.output), exist_ok=True)
with open(args.output, "w") as handle:
json.dump(payload, handle, ensure_ascii=False, indent=2)
def _load_json(path: str | Path) -> dict[str, str]:
with open(path) as handle:
return json.load(handle)
def _join_sections(text_json: dict[str, str], apparatus_json: dict[str, str]) -> str:
return "\n".join(
[
"[TEXT]",
flatten_ocr_json(text_json),
"[APPARATUS]",
flatten_ocr_json(apparatus_json),
]
)
def _normalize_precision(precision: str) -> str:
if precision.startswith("torch."):
return precision
return f"torch.{precision}"
if __name__ == "__main__":
main()