errant_gec / errant_gec.py
marksverdhei's picture
Fix tokenization: add simple_tok parameter (default=True) to match original errant script
a2eb45d verified
"""ERRANT metric for Grammatical Error Correction evaluation.
This metric uses the ERRANT (ERRor ANnotation Toolkit) to evaluate
grammatical error correction systems by comparing edit operations
between source, prediction, and reference sentences.
"""
import datasets
import evaluate
_CITATION = """\
@inproceedings{bryant-etal-2017-automatic,
title = "Automatic Annotation and Evaluation of Error Types for Grammatical Error Correction",
author = "Bryant, Christopher and
Felice, Mariano and
Briscoe, Ted",
booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = jul,
year = "2017",
address = "Vancouver, Canada",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/P17-1074",
doi = "10.18653/v1/P17-1074",
pages = "793--805",
}
"""
_DESCRIPTION = """\
ERRANT (ERRor ANnotation Toolkit) is a metric for evaluating grammatical error
correction (GEC) systems. It computes precision, recall, and F-score by comparing
the edit operations needed to transform source sentences into predictions versus
the edit operations needed to transform source sentences into references.
This metric requires three inputs:
- sources: The original (uncorrected) sentences
- predictions: The model's corrected sentences
- references: The gold standard corrected sentences
The metric extracts edits using the ERRANT library and computes:
- Precision: What fraction of predicted edits are correct
- Recall: What fraction of gold edits were predicted
- F0.5: F-score with beta=0.5 (weighing precision twice as much as recall)
"""
_KWARGS_DESCRIPTION = """
Args:
sources: list of source (original/uncorrected) sentences
predictions: list of predicted (corrected) sentences
references: list of reference (gold corrected) sentences
lang: language code for spaCy model (default: "en")
- "en": English (requires en_core_web_sm)
- "nb": Norwegian Bokmål (requires nb_core_news_sm)
- "de": German (requires de_core_news_sm)
- etc. (any language with a spaCy model)
beta: beta value for F-score calculation (default: 0.5)
simple_tok: use simple whitespace tokenization instead of spaCy (default: True)
This matches the behavior of errant_parallel's -tok flag.
Returns:
precision: fraction of predicted edits that are correct
recall: fraction of gold edits that were predicted
f0.5: F-score with the specified beta value
Examples:
>>> import evaluate
>>> errant_gec = evaluate.load("marksverdhei/errant_gec")
>>> results = errant_gec.compute(
... sources=["This are a sentence ."],
... predictions=["This is a sentence ."],
... references=["This is a sentence ."],
... lang="en"
... )
>>> print(results)
{'precision': 1.0, 'recall': 1.0, 'f0.5': 1.0}
"""
# Map language codes to spaCy model names
SPACY_MODELS = {
"en": "en_core_web_sm",
"nb": "nb_core_news_sm",
"nn": "nb_core_news_sm", # Use Bokmål model for Nynorsk as fallback
"de": "de_core_news_sm",
"es": "es_core_news_sm",
"fr": "fr_core_news_sm",
"it": "it_core_news_sm",
"nl": "nl_core_news_sm",
"pt": "pt_core_news_sm",
"ru": "ru_core_news_sm",
"zh": "zh_core_web_sm",
"ja": "ja_core_news_sm",
}
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Errant(evaluate.Metric):
"""ERRANT metric for grammatical error correction evaluation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._annotators = {} # Cache annotators per language
def _info(self):
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"sources": datasets.Value("string"),
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://github.com/chrisjbryant/errant"],
)
def _get_annotator(self, lang: str):
"""Get or create an ERRANT annotator for the specified language."""
if lang in self._annotators:
return self._annotators[lang]
import errant
import spacy
model_name = SPACY_MODELS.get(lang, f"{lang}_core_news_sm")
try:
nlp = spacy.load(model_name)
except OSError:
raise OSError(
f"spaCy model '{model_name}' not found. "
f"Please install it with: python -m spacy download {model_name}"
)
# ERRANT uses 'en' as base but we provide the spaCy model
# The language code is mainly used for tokenization rules
annotator = errant.load(lang if lang == "en" else "en", nlp)
self._annotators[lang] = annotator
return annotator
def _get_edits(self, annotator, orig_doc, cor_doc, lang: str = "en"):
"""Extract edits between original and corrected documents.
Returns a set of (o_start, o_end, o_str, c_str) tuples.
For non-English languages, we skip classification since ERRANT's
classifier uses English-specific POS tag mappings.
"""
# Use align and merge without classification for non-English
# This matches the behavior of errant_parallel with -lev flag
alignment = annotator.align(orig_doc, cor_doc, lev=True)
edits = annotator.merge(alignment)
# Only classify for English (classifier uses English POS tags)
if lang == "en":
edits = [annotator.classify(edit) for edit in edits]
edit_set = set()
for edit in edits:
# Skip noop edits (no actual change)
if edit.o_str == edit.c_str:
continue
# Use span positions and strings as edit identifier
edit_set.add((edit.o_start, edit.o_end, edit.o_str, edit.c_str))
return edit_set
def _compute_fscore(self, tp: int, fp: int, fn: int, beta: float = 0.5) -> dict:
"""Compute precision, recall, and F-score."""
precision = float(tp) / (tp + fp) if (tp + fp) > 0 else 1.0
recall = float(tp) / (tp + fn) if (tp + fn) > 0 else 1.0
if precision + recall > 0:
f_score = float((1 + beta**2) * precision * recall) / (
(beta**2 * precision) + recall
)
else:
f_score = 0.0
return {
"precision": precision,
"recall": recall,
f"f{beta}": f_score,
}
def _compute(
self,
sources: list[str],
predictions: list[str],
references: list[str],
lang: str = "en",
beta: float = 0.5,
simple_tok: bool = True,
) -> dict:
"""Compute ERRANT scores for the given inputs.
Args:
sources: Original (uncorrected) sentences
predictions: Model's corrected sentences
references: Gold standard corrected sentences
lang: Language code for spaCy model
beta: Beta value for F-score (default 0.5)
simple_tok: Use simple whitespace tokenization (default True)
This matches the behavior of errant_parallel's -tok flag.
Returns:
Dictionary with precision, recall, and f{beta} scores
"""
if not (len(sources) == len(predictions) == len(references)):
raise ValueError(
f"Inputs must have the same length. Got sources={len(sources)}, "
f"predictions={len(predictions)}, references={len(references)}"
)
annotator = self._get_annotator(lang)
total_tp = 0
total_fp = 0
total_fn = 0
for source, prediction, reference in zip(sources, predictions, references):
# Parse sentences (tokenise=True uses simple whitespace tokenization)
orig_doc = annotator.parse(source, tokenise=simple_tok)
hyp_doc = annotator.parse(prediction, tokenise=simple_tok)
ref_doc = annotator.parse(reference, tokenise=simple_tok)
# Get edit sets
hyp_edits = self._get_edits(annotator, orig_doc, hyp_doc, lang)
ref_edits = self._get_edits(annotator, orig_doc, ref_doc, lang)
# Compute TP, FP, FN for this sample
tp = len(ref_edits & hyp_edits)
fp = len(hyp_edits - ref_edits)
fn = len(ref_edits - hyp_edits)
total_tp += tp
total_fp += fp
total_fn += fn
return self._compute_fscore(total_tp, total_fp, total_fn, beta=beta)