|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unicodedata |
|
|
|
|
|
import sacrebleu as sb |
|
|
|
|
|
from fairseq.dataclass import ChoiceEnum |
|
|
|
|
|
SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2 |
|
|
|
|
|
|
|
|
class EvaluationTokenizer(object): |
|
|
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers |
|
|
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides |
|
|
lowercasing, punctuation removal and character tokenization, which are |
|
|
applied after sacreBLEU tokenization. |
|
|
|
|
|
Args: |
|
|
tokenizer_type (str): the type of sacreBLEU tokenizer to apply. |
|
|
lowercase (bool): lowercase the text. |
|
|
punctuation_removal (bool): remove punctuation (based on unicode |
|
|
category) from text. |
|
|
character_tokenization (bool): tokenize the text to characters. |
|
|
""" |
|
|
|
|
|
SPACE = chr(32) |
|
|
SPACE_ESCAPE = chr(9601) |
|
|
_ALL_TOKENIZER_TYPES = ( |
|
|
sb.BLEU.TOKENIZERS |
|
|
if SACREBLEU_V2_ABOVE |
|
|
else ["none", "13a", "intl", "zh", "ja-mecab"] |
|
|
) |
|
|
ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer_type: str = "13a", |
|
|
lowercase: bool = False, |
|
|
punctuation_removal: bool = False, |
|
|
character_tokenization: bool = False, |
|
|
): |
|
|
|
|
|
assert ( |
|
|
tokenizer_type in self._ALL_TOKENIZER_TYPES |
|
|
), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}" |
|
|
self.lowercase = lowercase |
|
|
self.punctuation_removal = punctuation_removal |
|
|
self.character_tokenization = character_tokenization |
|
|
if SACREBLEU_V2_ABOVE: |
|
|
self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer |
|
|
else: |
|
|
self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]() |
|
|
|
|
|
@classmethod |
|
|
def remove_punctuation(cls, sent: str): |
|
|
"""Remove punctuation based on Unicode category.""" |
|
|
return cls.SPACE.join( |
|
|
t |
|
|
for t in sent.split(cls.SPACE) |
|
|
if not all(unicodedata.category(c)[0] == "P" for c in t) |
|
|
) |
|
|
|
|
|
def tokenize(self, sent: str): |
|
|
tokenized = self.tokenizer(sent) |
|
|
|
|
|
if self.punctuation_removal: |
|
|
tokenized = self.remove_punctuation(tokenized) |
|
|
|
|
|
if self.character_tokenization: |
|
|
tokenized = self.SPACE.join( |
|
|
list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) |
|
|
) |
|
|
|
|
|
if self.lowercase: |
|
|
tokenized = tokenized.lower() |
|
|
|
|
|
return tokenized |
|
|
|