File size: 8,916 Bytes
9b8727e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2eb45d
 
9b8727e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e6f86e
9b8727e
 
 
6e6f86e
 
 
9b8727e
6e6f86e
 
 
 
 
 
 
 
 
9b8727e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2eb45d
9b8727e
 
 
 
 
 
 
 
 
a2eb45d
 
9b8727e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2eb45d
 
 
 
9b8727e
 
6e6f86e
 
9b8727e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""ERRANT metric for Grammatical Error Correction evaluation.

This metric uses the ERRANT (ERRor ANnotation Toolkit) to evaluate
grammatical error correction systems by comparing edit operations
between source, prediction, and reference sentences.
"""

import datasets
import evaluate


_CITATION = """\
@inproceedings{bryant-etal-2017-automatic,
    title = "Automatic Annotation and Evaluation of Error Types for Grammatical Error Correction",
    author = "Bryant, Christopher  and
      Felice, Mariano  and
      Briscoe, Ted",
    booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = jul,
    year = "2017",
    address = "Vancouver, Canada",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/P17-1074",
    doi = "10.18653/v1/P17-1074",
    pages = "793--805",
}
"""

_DESCRIPTION = """\
ERRANT (ERRor ANnotation Toolkit) is a metric for evaluating grammatical error
correction (GEC) systems. It computes precision, recall, and F-score by comparing
the edit operations needed to transform source sentences into predictions versus
the edit operations needed to transform source sentences into references.

This metric requires three inputs:
- sources: The original (uncorrected) sentences
- predictions: The model's corrected sentences
- references: The gold standard corrected sentences

The metric extracts edits using the ERRANT library and computes:
- Precision: What fraction of predicted edits are correct
- Recall: What fraction of gold edits were predicted
- F0.5: F-score with beta=0.5 (weighing precision twice as much as recall)
"""

_KWARGS_DESCRIPTION = """
Args:
    sources: list of source (original/uncorrected) sentences
    predictions: list of predicted (corrected) sentences
    references: list of reference (gold corrected) sentences
    lang: language code for spaCy model (default: "en")
        - "en": English (requires en_core_web_sm)
        - "nb": Norwegian Bokmål (requires nb_core_news_sm)
        - "de": German (requires de_core_news_sm)
        - etc. (any language with a spaCy model)
    beta: beta value for F-score calculation (default: 0.5)
    simple_tok: use simple whitespace tokenization instead of spaCy (default: True)
        This matches the behavior of errant_parallel's -tok flag.

Returns:
    precision: fraction of predicted edits that are correct
    recall: fraction of gold edits that were predicted
    f0.5: F-score with the specified beta value

Examples:
    >>> import evaluate
    >>> errant_gec = evaluate.load("marksverdhei/errant_gec")
    >>> results = errant_gec.compute(
    ...     sources=["This are a sentence ."],
    ...     predictions=["This is a sentence ."],
    ...     references=["This is a sentence ."],
    ...     lang="en"
    ... )
    >>> print(results)
    {'precision': 1.0, 'recall': 1.0, 'f0.5': 1.0}
"""

# Map language codes to spaCy model names
SPACY_MODELS = {
    "en": "en_core_web_sm",
    "nb": "nb_core_news_sm",
    "nn": "nb_core_news_sm",  # Use Bokmål model for Nynorsk as fallback
    "de": "de_core_news_sm",
    "es": "es_core_news_sm",
    "fr": "fr_core_news_sm",
    "it": "it_core_news_sm",
    "nl": "nl_core_news_sm",
    "pt": "pt_core_news_sm",
    "ru": "ru_core_news_sm",
    "zh": "zh_core_web_sm",
    "ja": "ja_core_news_sm",
}


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Errant(evaluate.Metric):
    """ERRANT metric for grammatical error correction evaluation."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._annotators = {}  # Cache annotators per language

    def _info(self):
        return evaluate.MetricInfo(
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "sources": datasets.Value("string"),
                    "predictions": datasets.Value("string"),
                    "references": datasets.Value("string"),
                }
            ),
            reference_urls=["https://github.com/chrisjbryant/errant"],
        )

    def _get_annotator(self, lang: str):
        """Get or create an ERRANT annotator for the specified language."""
        if lang in self._annotators:
            return self._annotators[lang]

        import errant
        import spacy

        model_name = SPACY_MODELS.get(lang, f"{lang}_core_news_sm")

        try:
            nlp = spacy.load(model_name)
        except OSError:
            raise OSError(
                f"spaCy model '{model_name}' not found. "
                f"Please install it with: python -m spacy download {model_name}"
            )

        # ERRANT uses 'en' as base but we provide the spaCy model
        # The language code is mainly used for tokenization rules
        annotator = errant.load(lang if lang == "en" else "en", nlp)
        self._annotators[lang] = annotator
        return annotator

    def _get_edits(self, annotator, orig_doc, cor_doc, lang: str = "en"):
        """Extract edits between original and corrected documents.

        Returns a set of (o_start, o_end, o_str, c_str) tuples.

        For non-English languages, we skip classification since ERRANT's
        classifier uses English-specific POS tag mappings.
        """
        # Use align and merge without classification for non-English
        # This matches the behavior of errant_parallel with -lev flag
        alignment = annotator.align(orig_doc, cor_doc, lev=True)
        edits = annotator.merge(alignment)

        # Only classify for English (classifier uses English POS tags)
        if lang == "en":
            edits = [annotator.classify(edit) for edit in edits]

        edit_set = set()
        for edit in edits:
            # Skip noop edits (no actual change)
            if edit.o_str == edit.c_str:
                continue
            # Use span positions and strings as edit identifier
            edit_set.add((edit.o_start, edit.o_end, edit.o_str, edit.c_str))
        return edit_set

    def _compute_fscore(self, tp: int, fp: int, fn: int, beta: float = 0.5) -> dict:
        """Compute precision, recall, and F-score."""
        precision = float(tp) / (tp + fp) if (tp + fp) > 0 else 1.0
        recall = float(tp) / (tp + fn) if (tp + fn) > 0 else 1.0

        if precision + recall > 0:
            f_score = float((1 + beta**2) * precision * recall) / (
                (beta**2 * precision) + recall
            )
        else:
            f_score = 0.0

        return {
            "precision": precision,
            "recall": recall,
            f"f{beta}": f_score,
        }

    def _compute(
        self,
        sources: list[str],
        predictions: list[str],
        references: list[str],
        lang: str = "en",
        beta: float = 0.5,
        simple_tok: bool = True,
    ) -> dict:
        """Compute ERRANT scores for the given inputs.

        Args:
            sources: Original (uncorrected) sentences
            predictions: Model's corrected sentences
            references: Gold standard corrected sentences
            lang: Language code for spaCy model
            beta: Beta value for F-score (default 0.5)
            simple_tok: Use simple whitespace tokenization (default True)
                This matches the behavior of errant_parallel's -tok flag.

        Returns:
            Dictionary with precision, recall, and f{beta} scores
        """
        if not (len(sources) == len(predictions) == len(references)):
            raise ValueError(
                f"Inputs must have the same length. Got sources={len(sources)}, "
                f"predictions={len(predictions)}, references={len(references)}"
            )

        annotator = self._get_annotator(lang)

        total_tp = 0
        total_fp = 0
        total_fn = 0

        for source, prediction, reference in zip(sources, predictions, references):
            # Parse sentences (tokenise=True uses simple whitespace tokenization)
            orig_doc = annotator.parse(source, tokenise=simple_tok)
            hyp_doc = annotator.parse(prediction, tokenise=simple_tok)
            ref_doc = annotator.parse(reference, tokenise=simple_tok)

            # Get edit sets
            hyp_edits = self._get_edits(annotator, orig_doc, hyp_doc, lang)
            ref_edits = self._get_edits(annotator, orig_doc, ref_doc, lang)

            # Compute TP, FP, FN for this sample
            tp = len(ref_edits & hyp_edits)
            fp = len(hyp_edits - ref_edits)
            fn = len(ref_edits - hyp_edits)

            total_tp += tp
            total_fp += fp
            total_fn += fn

        return self._compute_fscore(total_tp, total_fp, total_fn, beta=beta)