| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
|
|
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.scoring import BaseScorer, register_scorer |
| from fairseq.scoring.tokenizer import EvaluationTokenizer |
|
|
|
|
| @dataclass |
| class WerScorerConfig(FairseqDataclass): |
| wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( |
| default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} |
| ) |
| wer_remove_punct: bool = field( |
| default=False, metadata={"help": "remove punctuation"} |
| ) |
| wer_char_level: bool = field( |
| default=False, metadata={"help": "evaluate at character level"} |
| ) |
| wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) |
|
|
|
|
| @register_scorer("wer", dataclass=WerScorerConfig) |
| class WerScorer(BaseScorer): |
| def __init__(self, cfg): |
| super().__init__(cfg) |
| self.reset() |
| try: |
| import editdistance as ed |
| except ImportError: |
| raise ImportError("Please install editdistance to use WER scorer") |
| self.ed = ed |
| self.tokenizer = EvaluationTokenizer( |
| tokenizer_type=self.cfg.wer_tokenizer, |
| lowercase=self.cfg.wer_lowercase, |
| punctuation_removal=self.cfg.wer_remove_punct, |
| character_tokenization=self.cfg.wer_char_level, |
| ) |
|
|
| def reset(self): |
| self.distance = 0 |
| self.ref_length = 0 |
|
|
| def add_string(self, ref, pred): |
| ref_items = self.tokenizer.tokenize(ref).split() |
| pred_items = self.tokenizer.tokenize(pred).split() |
| self.distance += self.ed.eval(ref_items, pred_items) |
| self.ref_length += len(ref_items) |
|
|
| def result_string(self): |
| return f"WER: {self.score():.2f}" |
|
|
| def score(self): |
| return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 |
|
|