| |
| |
| |
| |
|
|
| import ctypes |
| import math |
| import sys |
| from dataclasses import dataclass, field |
|
|
| import torch |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.scoring import BaseScorer, register_scorer |
| from fairseq.scoring.tokenizer import EvaluationTokenizer |
|
|
|
|
| class BleuStat(ctypes.Structure): |
| _fields_ = [ |
| ("reflen", ctypes.c_size_t), |
| ("predlen", ctypes.c_size_t), |
| ("match1", ctypes.c_size_t), |
| ("count1", ctypes.c_size_t), |
| ("match2", ctypes.c_size_t), |
| ("count2", ctypes.c_size_t), |
| ("match3", ctypes.c_size_t), |
| ("count3", ctypes.c_size_t), |
| ("match4", ctypes.c_size_t), |
| ("count4", ctypes.c_size_t), |
| ] |
|
|
|
|
| @dataclass |
| class SacrebleuConfig(FairseqDataclass): |
| sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( |
| default="13a", metadata={"help": "tokenizer"} |
| ) |
| sacrebleu_lowercase: bool = field( |
| default=False, metadata={"help": "apply lowercasing"} |
| ) |
| sacrebleu_char_level: bool = field( |
| default=False, metadata={"help": "evaluate at character level"} |
| ) |
|
|
|
|
| @register_scorer("sacrebleu", dataclass=SacrebleuConfig) |
| class SacrebleuScorer(BaseScorer): |
| def __init__(self, cfg): |
| super(SacrebleuScorer, self).__init__(cfg) |
| import sacrebleu |
|
|
| self.sacrebleu = sacrebleu |
| self.tokenizer = EvaluationTokenizer( |
| tokenizer_type=cfg.sacrebleu_tokenizer, |
| lowercase=cfg.sacrebleu_lowercase, |
| character_tokenization=cfg.sacrebleu_char_level, |
| ) |
|
|
| def add_string(self, ref, pred): |
| self.ref.append(self.tokenizer.tokenize(ref)) |
| self.pred.append(self.tokenizer.tokenize(pred)) |
|
|
| def score(self, order=4): |
| return self.result_string(order).score |
|
|
| def result_string(self, order=4): |
| if order != 4: |
| raise NotImplementedError |
| |
| return self.sacrebleu.corpus_bleu( |
| self.pred, [self.ref], tokenize="none" |
| ).format() |
|
|
|
|
| @dataclass |
| class BleuConfig(FairseqDataclass): |
| pad: int = field(default=1, metadata={"help": "padding index"}) |
| eos: int = field(default=2, metadata={"help": "eos index"}) |
| unk: int = field(default=3, metadata={"help": "unk index"}) |
|
|
|
|
| @register_scorer("bleu", dataclass=BleuConfig) |
| class Scorer(object): |
| def __init__(self, cfg): |
| self.stat = BleuStat() |
| self.pad = cfg.pad |
| self.eos = cfg.eos |
| self.unk = cfg.unk |
|
|
| try: |
| from fairseq import libbleu |
| except ImportError as e: |
| sys.stderr.write( |
| "ERROR: missing libbleu.so. run `pip install --editable .`\n" |
| ) |
| raise e |
|
|
| self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) |
|
|
| self.reset() |
|
|
| def reset(self, one_init=False): |
| if one_init: |
| self.C.bleu_one_init(ctypes.byref(self.stat)) |
| else: |
| self.C.bleu_zero_init(ctypes.byref(self.stat)) |
|
|
| def add(self, ref, pred): |
| if not isinstance(ref, torch.IntTensor): |
| raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref))) |
| if not isinstance(pred, torch.IntTensor): |
| raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred))) |
|
|
| |
| rref = ref.clone() |
| assert not rref.lt(0).any() |
| rref[rref.eq(self.unk)] = -999 |
|
|
| rref = rref.contiguous().view(-1) |
| pred = pred.contiguous().view(-1) |
|
|
| self.C.bleu_add( |
| ctypes.byref(self.stat), |
| ctypes.c_size_t(rref.size(0)), |
| ctypes.c_void_p(rref.data_ptr()), |
| ctypes.c_size_t(pred.size(0)), |
| ctypes.c_void_p(pred.data_ptr()), |
| ctypes.c_int(self.pad), |
| ctypes.c_int(self.eos), |
| ) |
|
|
| def score(self, order=4): |
| psum = sum( |
| math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order] |
| ) |
| return self.brevity() * math.exp(psum / order) * 100 |
|
|
| def precision(self): |
| def ratio(a, b): |
| return a / b if b > 0 else 0 |
|
|
| return [ |
| ratio(self.stat.match1, self.stat.count1), |
| ratio(self.stat.match2, self.stat.count2), |
| ratio(self.stat.match3, self.stat.count3), |
| ratio(self.stat.match4, self.stat.count4), |
| ] |
|
|
| def brevity(self): |
| r = self.stat.reflen / self.stat.predlen |
| return min(1, math.exp(1 - r)) |
|
|
| def result_string(self, order=4): |
| assert order <= 4, "BLEU scores for order > 4 aren't supported" |
| fmt = "BLEU{} = {:2.2f}, {:2.1f}" |
| for _ in range(1, order): |
| fmt += "/{:2.1f}" |
| fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})" |
| bleup = [p * 100 for p in self.precision()[:order]] |
| return fmt.format( |
| order, |
| self.score(order=order), |
| *bleup, |
| self.brevity(), |
| self.stat.predlen / self.stat.reflen, |
| self.stat.predlen, |
| self.stat.reflen |
| ) |
|
|