File size: 2,718 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
ERRANT-based grammatical error evaluation.
Uses the ERRANT toolkit for standardised GEC evaluation with
precision, recall, and F0.5 scores.
"""

from typing import List, Dict
from loguru import logger


class ERRANTEvaluator:
    """Evaluates grammar correction quality using ERRANT annotations."""

    def __init__(self):
        try:
            import errant
            self.annotator = errant.load("en")
            logger.info("ERRANT annotator loaded")
        except Exception as e:
            logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.")
            self.annotator = None

    def evaluate(
        self,
        sources: List[str],
        predictions: List[str],
        references: List[str],
    ) -> Dict[str, float]:
        """Compute ERRANT precision, recall, F0.5."""
        if self.annotator is None:
            logger.warning("ERRANT not available, returning zero scores")
            return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0}

        tp = 0
        fp = 0
        fn = 0

        for src, pred, ref in zip(sources, predictions, references):
            try:
                # Parse source and annotate edits
                orig = self.annotator.parse(src)
                cor_pred = self.annotator.parse(pred)
                cor_ref = self.annotator.parse(ref)

                # Get edit annotations
                pred_edits = self.annotator.annotate(orig, cor_pred)
                ref_edits = self.annotator.annotate(orig, cor_ref)

                # Convert to comparable sets of (start, end, correction, type)
                pred_set = set()
                for e in pred_edits:
                    pred_set.add((e.o_start, e.o_end, e.c_str, e.type))

                ref_set = set()
                for e in ref_edits:
                    ref_set.add((e.o_start, e.o_end, e.c_str, e.type))

                # Count TP, FP, FN
                tp += len(pred_set & ref_set)
                fp += len(pred_set - ref_set)
                fn += len(ref_set - pred_set)

            except Exception as e:
                logger.debug(f"ERRANT annotation failed for a sample: {e}")
                continue

        # Compute precision, recall, F0.5
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

        # F0.5 weighs precision higher than recall (β=0.5)
        beta = 0.5
        if precision + recall > 0:
            f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
        else:
            f_score = 0.0

        return {
            "precision": precision,
            "recall": recall,
            "f0.5": f_score,
        }