# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from funasr import AutoModel import argparse from zhon.hanzi import punctuation import zhconv import string from tqdm import tqdm from eval_detok_en import ( get_gt_ref_texts_and_wav_files, get_ref_texts_and_gen_files, get_hypo_texts, calc_wer, ) model_path = "./paraformer-zh" # "./speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" def split_text(text): text = " ".join(text) return text def dummy_split_text(text): return text def remove_punct(text): puncts = set(punctuation + string.punctuation) output = "" for char in text: if char not in puncts: output += char output = output.replace(" ", " ") return output def process_wavs(wav_file_list, batch_size=300): model = AutoModel( model=model_path, disable_update=True, ) results = [] for wav_file_path in tqdm(wav_file_list): res = model.generate( input=wav_file_path, batch_size_s=batch_size, ) transcription = zhconv.convert(res[0]["text"], "zh-cn") results.append({"text": transcription}) return results def main(args): handler = logging.FileHandler(filename=args.log_file, mode="w") logging.root.setLevel(logging.INFO) logging.root.addHandler(handler) test_path = ( args.test_path ) # './40ms.AISHELL2.test_with_single_ref.base.chunk25.gen' lst_path = args.test_lst # "40ms.AISHELL2.test_with_single_ref.base.lst" if args.eval_gt: logging.info(f"run ASR for GT: {lst_path}") reference, wav_file_list = get_gt_ref_texts_and_wav_files( args, lst_path, test_path, remove_punct, split_text ) results = process_wavs(wav_file_list, batch_size=300) else: logging.info(f"run ASR for detok: {lst_path}") reference, gen_file_list = get_ref_texts_and_gen_files( args, lst_path, test_path, remove_punct, split_text ) results = process_wavs(gen_file_list, batch_size=300) hypothesis = get_hypo_texts(args, results, remove_punct, split_text) assert len(hypothesis) == len(reference) logging.info(f"Finish runing ASR for {lst_path}") logging.info(f"hypothesis: {len(hypothesis)} vs reference: {len(reference)}") calc_wer(reference, hypothesis, test_path) logging.info(f"Finish evaluate {lst_path}, results are in {args.log_file}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--test-path", required=True, type=str, help=f"folder of wav files", ) parser.add_argument( "--test-lst", required=True, type=str, help=f"path to test file lst", ) parser.add_argument( "--log-file", required=False, type=str, default=None, help=f"path to test file lst", ) parser.add_argument( "--remove-punct", default=False, action="store_true", help=f"remove punct from GT and hypo texts", ) parser.add_argument( "--norm-text", default=False, action="store_true", help=f"normalized GT and hypo texts", ) parser.add_argument( "--eval-gt", default=False, action="store_true", help=f"remove punct from GT and hypo texts", ) args = parser.parse_args() args.norm_text = False main(args)