|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from funasr import AutoModel |
|
|
import argparse |
|
|
from zhon.hanzi import punctuation |
|
|
import zhconv |
|
|
import string |
|
|
from tqdm import tqdm |
|
|
from eval_detok_en import ( |
|
|
get_gt_ref_texts_and_wav_files, |
|
|
get_ref_texts_and_gen_files, |
|
|
get_hypo_texts, |
|
|
calc_wer, |
|
|
) |
|
|
|
|
|
model_path = "./paraformer-zh" |
|
|
|
|
|
|
|
|
def split_text(text): |
|
|
text = " ".join(text) |
|
|
return text |
|
|
|
|
|
|
|
|
def dummy_split_text(text): |
|
|
return text |
|
|
|
|
|
|
|
|
def remove_punct(text): |
|
|
puncts = set(punctuation + string.punctuation) |
|
|
output = "" |
|
|
for char in text: |
|
|
if char not in puncts: |
|
|
output += char |
|
|
output = output.replace(" ", " ") |
|
|
return output |
|
|
|
|
|
|
|
|
def process_wavs(wav_file_list, batch_size=300): |
|
|
model = AutoModel( |
|
|
model=model_path, |
|
|
disable_update=True, |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for wav_file_path in tqdm(wav_file_list): |
|
|
res = model.generate( |
|
|
input=wav_file_path, |
|
|
batch_size_s=batch_size, |
|
|
) |
|
|
transcription = zhconv.convert(res[0]["text"], "zh-cn") |
|
|
results.append({"text": transcription}) |
|
|
return results |
|
|
|
|
|
|
|
|
def main(args): |
|
|
handler = logging.FileHandler(filename=args.log_file, mode="w") |
|
|
logging.root.setLevel(logging.INFO) |
|
|
logging.root.addHandler(handler) |
|
|
|
|
|
test_path = ( |
|
|
args.test_path |
|
|
) |
|
|
lst_path = args.test_lst |
|
|
|
|
|
if args.eval_gt: |
|
|
logging.info(f"run ASR for GT: {lst_path}") |
|
|
reference, wav_file_list = get_gt_ref_texts_and_wav_files( |
|
|
args, lst_path, test_path, remove_punct, split_text |
|
|
) |
|
|
results = process_wavs(wav_file_list, batch_size=300) |
|
|
else: |
|
|
logging.info(f"run ASR for detok: {lst_path}") |
|
|
reference, gen_file_list = get_ref_texts_and_gen_files( |
|
|
args, lst_path, test_path, remove_punct, split_text |
|
|
) |
|
|
results = process_wavs(gen_file_list, batch_size=300) |
|
|
|
|
|
hypothesis = get_hypo_texts(args, results, remove_punct, split_text) |
|
|
|
|
|
assert len(hypothesis) == len(reference) |
|
|
logging.info(f"Finish runing ASR for {lst_path}") |
|
|
logging.info(f"hypothesis: {len(hypothesis)} vs reference: {len(reference)}") |
|
|
|
|
|
calc_wer(reference, hypothesis, test_path) |
|
|
logging.info(f"Finish evaluate {lst_path}, results are in {args.log_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--test-path", |
|
|
required=True, |
|
|
type=str, |
|
|
help=f"folder of wav files", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test-lst", |
|
|
required=True, |
|
|
type=str, |
|
|
help=f"path to test file lst", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--log-file", |
|
|
required=False, |
|
|
type=str, |
|
|
default=None, |
|
|
help=f"path to test file lst", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--remove-punct", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help=f"remove punct from GT and hypo texts", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--norm-text", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help=f"normalized GT and hypo texts", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--eval-gt", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help=f"remove punct from GT and hypo texts", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
args.norm_text = False |
|
|
|
|
|
main(args) |
|
|
|