File size: 1,958 Bytes
7cc1e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Run the corrector against the bundled dataset and report per-category accuracy."""

import json
import time
from collections import defaultdict
from pathlib import Path

import pytest

from spelling.corrector import SpellingCorrector

DATASET = Path(__file__).parent.parent.parent / "dataset" / "samples.jsonl"


@pytest.fixture(scope="module")
def samples():
    return [json.loads(line) for line in DATASET.read_text().splitlines() if line.strip()]


def test_dataset_accuracy(samples, capsys):
    c = SpellingCorrector()
    results = defaultdict(lambda: {"hits": 0, "total": 0, "ms": 0.0})

    for ex in samples:
        cat = ex["category"]
        if ex.get("domain_terms"):
            c.add_domain_terms(ex["domain_terms"])
        start = time.perf_counter()
        if cat == "spell":
            out = c.correct(ex["original"])
        elif cat == "grammar":
            out = c.correct_compound(ex["original"])
        elif cat == "phonetic":
            out = c.correct_phonetic(ex["original"])
        elif cat == "domain":
            out = c.correct(ex["original"])
        else:
            continue
        elapsed = (time.perf_counter() - start) * 1000
        results[cat]["total"] += 1
        results[cat]["ms"] += elapsed
        if out.lower().strip() == ex["corrected"].lower().strip():
            results[cat]["hits"] += 1

    with capsys.disabled():
        print()
        print(f"{'category':<10} {'hits':>5} {'total':>6} {'acc':>7} {'mean ms':>8}")
        print("-" * 42)
        for cat, r in sorted(results.items()):
            acc = r["hits"] / r["total"] if r["total"] else 0
            mean_ms = r["ms"] / r["total"] if r["total"] else 0
            print(f"{cat:<10} {r['hits']:>5} {r['total']:>6} {acc:>6.0%} {mean_ms:>7.2f}ms")

    total_acc = sum(r["hits"] for r in results.values()) / sum(r["total"] for r in results.values())
    assert total_acc >= 0.70, f"Total accuracy {total_acc:.0%} below 70% floor"