|
|
import argparse |
|
|
import opencc |
|
|
import csv |
|
|
import editdistance |
|
|
from tqdm import tqdm |
|
|
from collections import defaultdict |
|
|
from functools import partial |
|
|
from pypinyin import pinyin, lazy_pinyin, Style |
|
|
from g2p_en import G2p |
|
|
import edit_distance |
|
|
lexicon_fpath = '/root/distil-whisper/utils/lexicon.lst' |
|
|
|
|
|
def cal_complete_mer(ref_data, hyp_data): |
|
|
S, D, I, N = (0, 0, 0, 0) |
|
|
count = 0 |
|
|
for ref, hyp in zip(ref_data, hyp_data): |
|
|
_S, _D, _I, _N = cal_single_complete_mer(ref, hyp) |
|
|
S += _S |
|
|
D += _D |
|
|
I += _I |
|
|
N += _N |
|
|
count += 1 |
|
|
return S, D, I, N, count |
|
|
|
|
|
def cal_single_complete_mer(ref, hyp): |
|
|
sm = edit_distance.SequenceMatcher(a=ref, b=hyp) |
|
|
opcodes = sm.get_opcodes() |
|
|
|
|
|
|
|
|
s = sum([(max(x[2] - x[1], x[4] - x[3]) if x[0] == 'replace' else 0) for x in opcodes]) |
|
|
|
|
|
d = sum([(max(x[2] - x[1], x[4] - x[3]) if x[0] == 'delete' else 0) for x in opcodes]) |
|
|
|
|
|
i = sum([(max(x[2] - x[1], x[4] - x[3]) if x[0] == 'insert' else 0) for x in opcodes]) |
|
|
n = len(ref) |
|
|
return s, d, i, n |
|
|
|
|
|
class MixErrorRate(object): |
|
|
def __init__( |
|
|
self, |
|
|
to_simplified_chinese=True, |
|
|
to_traditional_chinese=False, |
|
|
phonemize=False, |
|
|
separate_language=False, |
|
|
test_only=False, |
|
|
count_repetitive_hallucination=False, |
|
|
calculate_complete_mer=False |
|
|
): |
|
|
self.converter = None |
|
|
if to_simplified_chinese and to_traditional_chinese: |
|
|
raise ValueError("Can't convert to both simplified and traditional chinese at the same time.") |
|
|
if to_simplified_chinese: |
|
|
print("Convert to simplified chinese") |
|
|
self.converter = opencc.OpenCC('t2s.json') |
|
|
elif to_traditional_chinese: |
|
|
print("Convert to traditional chinese") |
|
|
self.converter = opencc.OpenCC('s2t.json') |
|
|
else: |
|
|
print("No chinese conversion") |
|
|
if phonemize: |
|
|
if separate_language: |
|
|
raise NotImplementedError("Can't separate language and phonemize at the same time.") |
|
|
print("Phonemize chinese and english words") |
|
|
print("Force traditional to simplified conversion") |
|
|
self.converter = opencc.OpenCC('t2s.json') |
|
|
self.zh_phonemizer = partial(lazy_pinyin, style=Style.BOPOMOFO, errors='ignore') |
|
|
self.zh_bopomofo_stress_marks = ['ˊ', 'ˇ', 'ˋ', '˙'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.en_wrd2phn = defaultdict(lambda: []) |
|
|
with open(lexicon_fpath, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
word, phonemes = line.strip().split('\t') |
|
|
self.en_wrd2phn[word] = phonemes.split() |
|
|
self.phonemize = phonemize |
|
|
self.test_only = test_only |
|
|
self.separate_language = separate_language |
|
|
self.count_repetitive_hallucination = count_repetitive_hallucination |
|
|
self.calculate_complete_mer = calculate_complete_mer |
|
|
if self.count_repetitive_hallucination: |
|
|
print("Count repetitive hallucination (6gram-5repeat)") |
|
|
|
|
|
def _from_str_to_list(self, cs_string): |
|
|
cs_list = [] |
|
|
cur_en_word = '' |
|
|
for s in cs_string: |
|
|
|
|
|
if s in [' ', '\t', '\n', '\r', ',', '.', '!', '?', '。', ',', '!', '?', '、', ';', ':', '「', '」', '『', '』', '(', ')', '(', ')', '\[', '\]', '{', '}', '<', '>', '《', '》', '“', '”', '‘', '’', '…', '—', '~', '·', '•']: |
|
|
if cur_en_word != '': |
|
|
cs_list.append(cur_en_word) |
|
|
cur_en_word = '' |
|
|
continue |
|
|
|
|
|
if u'\u4e00' <= s <= u'\u9fff': |
|
|
if cur_en_word != '': |
|
|
cs_list.append(cur_en_word) |
|
|
cur_en_word = '' |
|
|
if self.converter is not None: |
|
|
s = self.converter.convert(s) |
|
|
cs_list.append(s) |
|
|
|
|
|
elif s.isalnum() or s in ["'", "-"]: |
|
|
cur_en_word += s |
|
|
else: |
|
|
print(f"Unknown character during conversion: {s}") |
|
|
if cur_en_word != '': |
|
|
cs_list.append(cur_en_word) |
|
|
return cs_list |
|
|
|
|
|
def _unit_is_en(self, token): |
|
|
if u'\u4e00' <= token[0] <= u'\u9fff': |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _unit_is_zh(self, token): |
|
|
if u'\u4e00' <= token[0] <= u'\u9fff': |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _phonemized_cs_list(self, cs_list): |
|
|
cur_zh_chars = [] |
|
|
phonemes = [] |
|
|
for unit in cs_list: |
|
|
if u'\u4e00' <= unit[0] <= u'\u9fff': |
|
|
cur_zh_chars.append(unit) |
|
|
else: |
|
|
if cur_zh_chars: |
|
|
zh_phns = ''.join(self.zh_phonemizer(''.join(cur_zh_chars))) |
|
|
phonemes.extend(filter(lambda p: p not in self.zh_bopomofo_stress_marks, zh_phns)) |
|
|
cur_zh_chars = [] |
|
|
phonemes.extend(self.en_wrd2phn[unit]) |
|
|
if cur_zh_chars: |
|
|
zh_phns = ''.join(self.zh_phonemizer(''.join(cur_zh_chars))) |
|
|
phonemes.extend(filter(lambda p: p not in self.zh_bopomofo_stress_marks, zh_phns)) |
|
|
cur_zh_chars = [] |
|
|
return phonemes |
|
|
|
|
|
def _count_repetitive_hallucination(self, cs_str, n=6, repeat=5, reset_len=100): |
|
|
count = 0 |
|
|
ngram_counts = defaultdict(lambda: 0) |
|
|
if len(cs_str) < n: |
|
|
return 0 |
|
|
prev_reset_idx = 0 |
|
|
for i in range(len(cs_str) - n + 1): |
|
|
ngram = cs_str[i:i+n] |
|
|
if '|>' in ngram or '<|' in ngram: continue |
|
|
ngram_counts[ngram] += 1 |
|
|
if ngram_counts[ngram] >= repeat: |
|
|
count += 1 |
|
|
|
|
|
ngram_counts = defaultdict(lambda: 0) |
|
|
if i - prev_reset_idx >= reset_len: |
|
|
ngram_counts = defaultdict(lambda: 0) |
|
|
prev_reset_idx = i |
|
|
return count |
|
|
|
|
|
def compute(self, predictions=None, references=None, show_progress=True, empty_error_rate=1.0, **kwargs): |
|
|
total_err = 0 |
|
|
total_ref_len = 0 |
|
|
total_en_err = 0 |
|
|
total_en_ref_len = 0 |
|
|
total_zh_err = 0 |
|
|
total_zh_ref_len = 0 |
|
|
repetitive_hallucination_count = 0 |
|
|
ref_repetitive_hallucination_count = 0 |
|
|
if self.test_only: |
|
|
predictions = predictions[:10] |
|
|
references = references[:10] |
|
|
if self.calculate_complete_mer: |
|
|
S, D, I, N = 0, 0, 0, 0 |
|
|
iterator = tqdm(enumerate(zip(predictions, references)), total=len(predictions), desc="Computing Mix Error Rate...") if show_progress and len(predictions) > 20 else enumerate(zip(predictions, references)) |
|
|
for i, (pred, ref) in iterator: |
|
|
|
|
|
|
|
|
if self.count_repetitive_hallucination: |
|
|
repetitive_hallucination_count += self._count_repetitive_hallucination(pred) |
|
|
ref_repetitive_hallucination_count += self._count_repetitive_hallucination(ref) |
|
|
pred_list = self._from_str_to_list(pred) |
|
|
ref_list = self._from_str_to_list(ref) |
|
|
if self.test_only: |
|
|
print(f"Prediction List First 20@{i}: {pred_list[:20]}") |
|
|
print(f"Reference List First 20@{i}: {ref_list[:20]}") |
|
|
if self.phonemize: |
|
|
pred_list = self._phonemized_cs_list(pred_list) |
|
|
ref_list = self._phonemized_cs_list(ref_list) |
|
|
if self.calculate_complete_mer: |
|
|
_S, _D, _I, _N = cal_single_complete_mer(ref_list, pred_list) |
|
|
S += _S |
|
|
D += _D |
|
|
I += _I |
|
|
N += _N |
|
|
|
|
|
if self.separate_language: |
|
|
en_pred_list = list(filter(self._unit_is_en, pred_list)) |
|
|
en_ref_list = list(filter(self._unit_is_en, ref_list)) |
|
|
zh_pred_list = list(filter(self._unit_is_zh, pred_list)) |
|
|
zh_ref_list = list(filter(self._unit_is_zh, ref_list)) |
|
|
en_err = editdistance.eval(en_pred_list, en_ref_list) |
|
|
total_en_err += en_err |
|
|
total_en_ref_len += len(en_ref_list) |
|
|
zh_err = editdistance.eval(zh_pred_list, zh_ref_list) |
|
|
total_zh_err += zh_err |
|
|
total_zh_ref_len += len(zh_ref_list) |
|
|
err = editdistance.eval(pred_list, ref_list) |
|
|
total_err += err |
|
|
total_ref_len += len(ref_list) |
|
|
if self.test_only: |
|
|
local_mer = { |
|
|
"MER": err / len(ref_list), |
|
|
"EN WER": en_err / len(en_ref_list) if len(en_ref_list) != 0 else 0, |
|
|
"ZH CER": zh_err / len(zh_ref_list) if len(zh_ref_list) != 0 else 0 |
|
|
} |
|
|
print(f"Local MER@{i}: {local_mer}") |
|
|
if total_ref_len == 0: |
|
|
print(f"No reference found, return {empty_error_rate*100}% error rate instead") |
|
|
return empty_error_rate |
|
|
mer = total_err / total_ref_len |
|
|
if self.separate_language or self.count_repetitive_hallucination: |
|
|
result = { |
|
|
"MER": mer, |
|
|
} |
|
|
if self.separate_language: |
|
|
en_wer = total_en_err / total_en_ref_len if total_en_ref_len != 0 else 0 |
|
|
zh_cer = total_zh_err / total_zh_ref_len if total_zh_ref_len != 0 else 0 |
|
|
result["EN WER"] = en_wer |
|
|
result["ZH CER"] = zh_cer |
|
|
if self.count_repetitive_hallucination: |
|
|
result["Hyp Repetitive Hallucination Count"] = repetitive_hallucination_count |
|
|
result["Ref Repetitive Hallucination Count"] = ref_repetitive_hallucination_count |
|
|
return result |
|
|
if self.calculate_complete_mer: |
|
|
print(f"SUB={S/N}, DEL={D/N}, INS={I/N}, (S, D, I, N)={(S, D, I, N)}, total_len={total_ref_len}") |
|
|
return mer |
|
|
|
|
|
|
|
|
|
|
|
def load_output_csv(fpath, skip_header=True, hyp_col=1, ref_col=2, delimiter='\t'): |
|
|
with open(fpath, 'r', encoding='utf-8') as f: |
|
|
reader = csv.reader(f, delimiter=delimiter) |
|
|
predictions = [] |
|
|
references = [] |
|
|
if skip_header: |
|
|
columns = next(reader) |
|
|
print(f"Columns: {columns}") |
|
|
for i, row in enumerate(reader): |
|
|
predictions.append(row[hyp_col]) |
|
|
references.append(row[ref_col]) |
|
|
return predictions, references |
|
|
|
|
|
def main(args): |
|
|
print(args) |
|
|
mer = MixErrorRate( |
|
|
to_simplified_chinese=args.to_simplified_chinese, |
|
|
to_traditional_chinese=args.to_traditional_chinese, |
|
|
separate_language=args.separate_language, |
|
|
test_only=args.test_only, |
|
|
count_repetitive_hallucination=args.count_repetitive_hallucination, |
|
|
calculate_complete_mer=args.calculate_complete_mer |
|
|
) |
|
|
predictions, references = load_output_csv(args.csv_fpath) |
|
|
mer_value = mer.compute(predictions=predictions, references=references) |
|
|
print(f"Mix Error Rate: {mer_value}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Compute Mix Error Rate") |
|
|
parser.add_argument("--csv_fpath", type=str, required=True, help="Path to the csv file") |
|
|
parser.add_argument("--to_simplified_chinese", action="store_true", help="Convert chinese to simplified chinese") |
|
|
parser.add_argument("--to_traditional_chinese", action="store_true", help="Convert chinese to traditional chinese") |
|
|
parser.add_argument("--separate_language", action="store_true", help="Compute MER separately for chinese and english") |
|
|
parser.add_argument("--test_only", action="store_true", help="Run test and give some cs_list examples") |
|
|
parser.add_argument("--count_repetitive_hallucination", action="store_true", help="Count repetitive hallucination") |
|
|
parser.add_argument("--calculate_complete_mer", action="store_true", help="Calculate complete MER") |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|