import os import re import unicodedata from collections import OrderedDict import editdistance as ed import zhconv from lmms_eval.tasks.librispeech.cn_tn import TextNorm from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer _FLEURS_LANG_TO_ID = OrderedDict( [ ("Afrikaans", "af"), ("Amharic", "am"), ("Arabic", "ar"), ("Armenian", "hy"), ("Assamese", "as"), ("Asturian", "ast"), ("Azerbaijani", "az"), ("Belarusian", "be"), ("Bengali", "bn"), ("Bosnian", "bs"), ("Bulgarian", "bg"), ("Burmese", "my"), ("Catalan", "ca"), ("Cebuano", "ceb"), ("Mandarin Chinese", "cmn_hans"), ("Cantonese Chinese", "yue_hant"), ("Croatian", "hr"), ("Czech", "cs"), ("Danish", "da"), ("Dutch", "nl"), ("English", "en"), ("Estonian", "et"), ("Filipino", "fil"), ("Finnish", "fi"), ("French", "fr"), ("Fula", "ff"), ("Galician", "gl"), ("Ganda", "lg"), ("Georgian", "ka"), ("German", "de"), ("Greek", "el"), ("Gujarati", "gu"), ("Hausa", "ha"), ("Hebrew", "he"), ("Hindi", "hi"), ("Hungarian", "hu"), ("Icelandic", "is"), ("Igbo", "ig"), ("Indonesian", "id"), ("Irish", "ga"), ("Italian", "it"), ("Japanese", "ja"), ("Javanese", "jv"), ("Kabuverdianu", "kea"), ("Kamba", "kam"), ("Kannada", "kn"), ("Kazakh", "kk"), ("Khmer", "km"), ("Korean", "ko"), ("Kyrgyz", "ky"), ("Lao", "lo"), ("Latvian", "lv"), ("Lingala", "ln"), ("Lithuanian", "lt"), ("Luo", "luo"), ("Luxembourgish", "lb"), ("Macedonian", "mk"), ("Malay", "ms"), ("Malayalam", "ml"), ("Maltese", "mt"), ("Maori", "mi"), ("Marathi", "mr"), ("Mongolian", "mn"), ("Nepali", "ne"), ("Northern-Sotho", "nso"), ("Norwegian", "nb"), ("Nyanja", "ny"), ("Occitan", "oc"), ("Oriya", "or"), ("Oromo", "om"), ("Pashto", "ps"), ("Persian", "fa"), ("Polish", "pl"), ("Portuguese", "pt"), ("Punjabi", "pa"), ("Romanian", "ro"), ("Russian", "ru"), ("Serbian", "sr"), ("Shona", "sn"), ("Sindhi", "sd"), ("Slovak", "sk"), ("Slovenian", "sl"), ("Somali", "so"), ("Sorani-Kurdish", "ckb"), ("Spanish", "es"), ("Swahili", "sw"), ("Swedish", "sv"), ("Tajik", "tg"), ("Tamil", "ta"), ("Telugu", "te"), ("Thai", "th"), ("Turkish", "tr"), ("Ukrainian", "uk"), ("Umbundu", "umb"), ("Urdu", "ur"), ("Uzbek", "uz"), ("Vietnamese", "vi"), ("Welsh", "cy"), ("Wolof", "wo"), ("Xhosa", "xh"), ("Yoruba", "yo"), ("Zulu", "zu"), ] ) _FLEURS_LANG_SHORT_TO_LONG = {v: k for k, v in _FLEURS_LANG_TO_ID.items()} # ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'. english_normalizer = EnglishTextNormalizer() chinese_normalizer = TextNorm( to_banjiao=False, to_upper=False, to_lower=False, remove_fillers=False, remove_erhua=False, check_chars=False, remove_space=False, cc_mode="", ) basic_normalizer = BasicTextNormalizer() dir_name = os.path.dirname(os.path.abspath(__file__)) def fleurs_doc_to_audio(doc): return [doc["audio"]] def fleurs_doc_to_text(doc, lmms_eval_specific_kwargs): pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] post_prompt = lmms_eval_specific_kwargs["post_prompt"] return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}" def fleurs_process_result(doc, result): pred = result[0] if len(result) > 0 else "" gt = doc["transcription"] source = doc["path"] language = doc["language"] data_dict = {"gt": gt, "pred": pred, "source": source, "language": language} return {"wer": data_dict} PUNCS = "!,.?;:" def remove_sp(text, language): gt = re.sub(r"<\|.*?\|>", " ", text) gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space. gt = re.sub(f" ?([{PUNCS}])", r"\1", gt) gt = gt.lstrip(" ") if language == "cmn_hans": gt = re.sub(rf"\s+", r"", gt) return gt class EvaluationTokenizer(object): """A generic evaluation-time tokenizer, which leverages built-in tokenizers in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides lowercasing, punctuation removal and character tokenization, which are applied after sacreBLEU tokenization. Args: tokenizer_type (str): the type of sacreBLEU tokenizer to apply. lowercase (bool): lowercase the text. punctuation_removal (bool): remove punctuation (based on unicode category) from text. character_tokenization (bool): tokenize the text to characters. """ SPACE = chr(32) SPACE_ESCAPE = chr(9601) # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) def __init__( self, tokenizer_type: str = "13a", lowercase: bool = False, punctuation_removal: bool = False, character_tokenization: bool = False, ): from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a from sacrebleu.tokenizers.tokenizer_char import TokenizerChar from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh TOKENIZERS = { "none": NoneTokenizer, "13a": Tokenizer13a, "intl": TokenizerV14International, "zh": TokenizerZh, "ja-mecab": TokenizerJaMecab, "char": TokenizerChar, } assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" self.lowercase = lowercase self.punctuation_removal = punctuation_removal self.character_tokenization = character_tokenization self.tokenizer = TOKENIZERS[tokenizer_type] # self.tokenizer = tokenizer_none @classmethod def remove_punctuation(cls, sent: str): """Remove punctuation based on Unicode category.""" return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t)) def tokenize(self, sent: str): tokenized = self.tokenizer()(sent) if self.punctuation_removal: tokenized = self.remove_punctuation(tokenized) if self.character_tokenization: tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE))) if self.lowercase: tokenized = tokenized.lower() return tokenized def compute_wer(refs, hyps, language): distance = 0 ref_length = 0 tokenizer = EvaluationTokenizer( tokenizer_type="none", lowercase=True, punctuation_removal=True, character_tokenization=False, ) for i in range(len(refs)): ref = refs[i] pred = hyps[i] if language in ["yue_hant"]: ref = zhconv.convert(ref, "zh-cn") pred = zhconv.convert(pred, "zh-cn") if language in ["en"]: ref = english_normalizer(ref) pred = english_normalizer(pred) if language in ["cmn_hans"]: ref = chinese_normalizer(ref) pred = chinese_normalizer(pred) else: ref = basic_normalizer(ref) pred = basic_normalizer(pred) ref_items = tokenizer.tokenize(ref).split() pred_items = tokenizer.tokenize(pred).split() if language in ["cmn_hans", "yue_hant"]: ref_items = [x for x in "".join(ref_items)] pred_items = [x for x in "".join(pred_items)] if i == 0: print(f"ref: {ref}") print(f"pred: {pred}") print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}") print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}") distance += ed.eval(ref_items, pred_items) ref_length += len(ref_items) return distance / ref_length def fleurs_wer(results, args): refs, hyps = [], [] for result in results: lan = _FLEURS_LANG_TO_ID[result["language"]] gt = result["gt"] response = result["pred"] gt = remove_sp(gt, lan) response = remove_sp(response, lan) refs.append(gt) hyps.append(response) wer = compute_wer(refs, hyps, lan) return wer * 100