""" ERRANT-based grammatical error evaluation. Uses the ERRANT toolkit for standardised GEC evaluation with precision, recall, and F0.5 scores. """ from typing import List, Dict from loguru import logger class ERRANTEvaluator: """Evaluates grammar correction quality using ERRANT annotations.""" def __init__(self): try: import errant self.annotator = errant.load("en") logger.info("ERRANT annotator loaded") except Exception as e: logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.") self.annotator = None def evaluate( self, sources: List[str], predictions: List[str], references: List[str], ) -> Dict[str, float]: """Compute ERRANT precision, recall, F0.5.""" if self.annotator is None: logger.warning("ERRANT not available, returning zero scores") return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0} tp = 0 fp = 0 fn = 0 for src, pred, ref in zip(sources, predictions, references): try: # Parse source and annotate edits orig = self.annotator.parse(src) cor_pred = self.annotator.parse(pred) cor_ref = self.annotator.parse(ref) # Get edit annotations pred_edits = self.annotator.annotate(orig, cor_pred) ref_edits = self.annotator.annotate(orig, cor_ref) # Convert to comparable sets of (start, end, correction, type) pred_set = set() for e in pred_edits: pred_set.add((e.o_start, e.o_end, e.c_str, e.type)) ref_set = set() for e in ref_edits: ref_set.add((e.o_start, e.o_end, e.c_str, e.type)) # Count TP, FP, FN tp += len(pred_set & ref_set) fp += len(pred_set - ref_set) fn += len(ref_set - pred_set) except Exception as e: logger.debug(f"ERRANT annotation failed for a sample: {e}") continue # Compute precision, recall, F0.5 precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 # F0.5 weighs precision higher than recall (β=0.5) beta = 0.5 if precision + recall > 0: f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) else: f_score = 0.0 return { "precision": precision, "recall": recall, "f0.5": f_score, }