Spaces:
No application file
No application file
Fix tokenization: add simple_tok parameter (default=True) to match original errant script
a2eb45d
verified
| """ERRANT metric for Grammatical Error Correction evaluation. | |
| This metric uses the ERRANT (ERRor ANnotation Toolkit) to evaluate | |
| grammatical error correction systems by comparing edit operations | |
| between source, prediction, and reference sentences. | |
| """ | |
| import datasets | |
| import evaluate | |
| _CITATION = """\ | |
| @inproceedings{bryant-etal-2017-automatic, | |
| title = "Automatic Annotation and Evaluation of Error Types for Grammatical Error Correction", | |
| author = "Bryant, Christopher and | |
| Felice, Mariano and | |
| Briscoe, Ted", | |
| booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", | |
| month = jul, | |
| year = "2017", | |
| address = "Vancouver, Canada", | |
| publisher = "Association for Computational Linguistics", | |
| url = "https://aclanthology.org/P17-1074", | |
| doi = "10.18653/v1/P17-1074", | |
| pages = "793--805", | |
| } | |
| """ | |
| _DESCRIPTION = """\ | |
| ERRANT (ERRor ANnotation Toolkit) is a metric for evaluating grammatical error | |
| correction (GEC) systems. It computes precision, recall, and F-score by comparing | |
| the edit operations needed to transform source sentences into predictions versus | |
| the edit operations needed to transform source sentences into references. | |
| This metric requires three inputs: | |
| - sources: The original (uncorrected) sentences | |
| - predictions: The model's corrected sentences | |
| - references: The gold standard corrected sentences | |
| The metric extracts edits using the ERRANT library and computes: | |
| - Precision: What fraction of predicted edits are correct | |
| - Recall: What fraction of gold edits were predicted | |
| - F0.5: F-score with beta=0.5 (weighing precision twice as much as recall) | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Args: | |
| sources: list of source (original/uncorrected) sentences | |
| predictions: list of predicted (corrected) sentences | |
| references: list of reference (gold corrected) sentences | |
| lang: language code for spaCy model (default: "en") | |
| - "en": English (requires en_core_web_sm) | |
| - "nb": Norwegian Bokmål (requires nb_core_news_sm) | |
| - "de": German (requires de_core_news_sm) | |
| - etc. (any language with a spaCy model) | |
| beta: beta value for F-score calculation (default: 0.5) | |
| simple_tok: use simple whitespace tokenization instead of spaCy (default: True) | |
| This matches the behavior of errant_parallel's -tok flag. | |
| Returns: | |
| precision: fraction of predicted edits that are correct | |
| recall: fraction of gold edits that were predicted | |
| f0.5: F-score with the specified beta value | |
| Examples: | |
| >>> import evaluate | |
| >>> errant_gec = evaluate.load("marksverdhei/errant_gec") | |
| >>> results = errant_gec.compute( | |
| ... sources=["This are a sentence ."], | |
| ... predictions=["This is a sentence ."], | |
| ... references=["This is a sentence ."], | |
| ... lang="en" | |
| ... ) | |
| >>> print(results) | |
| {'precision': 1.0, 'recall': 1.0, 'f0.5': 1.0} | |
| """ | |
| # Map language codes to spaCy model names | |
| SPACY_MODELS = { | |
| "en": "en_core_web_sm", | |
| "nb": "nb_core_news_sm", | |
| "nn": "nb_core_news_sm", # Use Bokmål model for Nynorsk as fallback | |
| "de": "de_core_news_sm", | |
| "es": "es_core_news_sm", | |
| "fr": "fr_core_news_sm", | |
| "it": "it_core_news_sm", | |
| "nl": "nl_core_news_sm", | |
| "pt": "pt_core_news_sm", | |
| "ru": "ru_core_news_sm", | |
| "zh": "zh_core_web_sm", | |
| "ja": "ja_core_news_sm", | |
| } | |
| class Errant(evaluate.Metric): | |
| """ERRANT metric for grammatical error correction evaluation.""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._annotators = {} # Cache annotators per language | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "sources": datasets.Value("string"), | |
| "predictions": datasets.Value("string"), | |
| "references": datasets.Value("string"), | |
| } | |
| ), | |
| reference_urls=["https://github.com/chrisjbryant/errant"], | |
| ) | |
| def _get_annotator(self, lang: str): | |
| """Get or create an ERRANT annotator for the specified language.""" | |
| if lang in self._annotators: | |
| return self._annotators[lang] | |
| import errant | |
| import spacy | |
| model_name = SPACY_MODELS.get(lang, f"{lang}_core_news_sm") | |
| try: | |
| nlp = spacy.load(model_name) | |
| except OSError: | |
| raise OSError( | |
| f"spaCy model '{model_name}' not found. " | |
| f"Please install it with: python -m spacy download {model_name}" | |
| ) | |
| # ERRANT uses 'en' as base but we provide the spaCy model | |
| # The language code is mainly used for tokenization rules | |
| annotator = errant.load(lang if lang == "en" else "en", nlp) | |
| self._annotators[lang] = annotator | |
| return annotator | |
| def _get_edits(self, annotator, orig_doc, cor_doc, lang: str = "en"): | |
| """Extract edits between original and corrected documents. | |
| Returns a set of (o_start, o_end, o_str, c_str) tuples. | |
| For non-English languages, we skip classification since ERRANT's | |
| classifier uses English-specific POS tag mappings. | |
| """ | |
| # Use align and merge without classification for non-English | |
| # This matches the behavior of errant_parallel with -lev flag | |
| alignment = annotator.align(orig_doc, cor_doc, lev=True) | |
| edits = annotator.merge(alignment) | |
| # Only classify for English (classifier uses English POS tags) | |
| if lang == "en": | |
| edits = [annotator.classify(edit) for edit in edits] | |
| edit_set = set() | |
| for edit in edits: | |
| # Skip noop edits (no actual change) | |
| if edit.o_str == edit.c_str: | |
| continue | |
| # Use span positions and strings as edit identifier | |
| edit_set.add((edit.o_start, edit.o_end, edit.o_str, edit.c_str)) | |
| return edit_set | |
| def _compute_fscore(self, tp: int, fp: int, fn: int, beta: float = 0.5) -> dict: | |
| """Compute precision, recall, and F-score.""" | |
| precision = float(tp) / (tp + fp) if (tp + fp) > 0 else 1.0 | |
| recall = float(tp) / (tp + fn) if (tp + fn) > 0 else 1.0 | |
| if precision + recall > 0: | |
| f_score = float((1 + beta**2) * precision * recall) / ( | |
| (beta**2 * precision) + recall | |
| ) | |
| else: | |
| f_score = 0.0 | |
| return { | |
| "precision": precision, | |
| "recall": recall, | |
| f"f{beta}": f_score, | |
| } | |
| def _compute( | |
| self, | |
| sources: list[str], | |
| predictions: list[str], | |
| references: list[str], | |
| lang: str = "en", | |
| beta: float = 0.5, | |
| simple_tok: bool = True, | |
| ) -> dict: | |
| """Compute ERRANT scores for the given inputs. | |
| Args: | |
| sources: Original (uncorrected) sentences | |
| predictions: Model's corrected sentences | |
| references: Gold standard corrected sentences | |
| lang: Language code for spaCy model | |
| beta: Beta value for F-score (default 0.5) | |
| simple_tok: Use simple whitespace tokenization (default True) | |
| This matches the behavior of errant_parallel's -tok flag. | |
| Returns: | |
| Dictionary with precision, recall, and f{beta} scores | |
| """ | |
| if not (len(sources) == len(predictions) == len(references)): | |
| raise ValueError( | |
| f"Inputs must have the same length. Got sources={len(sources)}, " | |
| f"predictions={len(predictions)}, references={len(references)}" | |
| ) | |
| annotator = self._get_annotator(lang) | |
| total_tp = 0 | |
| total_fp = 0 | |
| total_fn = 0 | |
| for source, prediction, reference in zip(sources, predictions, references): | |
| # Parse sentences (tokenise=True uses simple whitespace tokenization) | |
| orig_doc = annotator.parse(source, tokenise=simple_tok) | |
| hyp_doc = annotator.parse(prediction, tokenise=simple_tok) | |
| ref_doc = annotator.parse(reference, tokenise=simple_tok) | |
| # Get edit sets | |
| hyp_edits = self._get_edits(annotator, orig_doc, hyp_doc, lang) | |
| ref_edits = self._get_edits(annotator, orig_doc, ref_doc, lang) | |
| # Compute TP, FP, FN for this sample | |
| tp = len(ref_edits & hyp_edits) | |
| fp = len(hyp_edits - ref_edits) | |
| fn = len(ref_edits - hyp_edits) | |
| total_tp += tp | |
| total_fp += fp | |
| total_fn += fn | |
| return self._compute_fscore(total_tp, total_fp, total_fn, beta=beta) | |