DSTK / evaluation /eval_detok_zh.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from funasr import AutoModel
import argparse
from zhon.hanzi import punctuation
import zhconv
import string
from tqdm import tqdm
from eval_detok_en import (
get_gt_ref_texts_and_wav_files,
get_ref_texts_and_gen_files,
get_hypo_texts,
calc_wer,
)
model_path = "./paraformer-zh" # "./speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
def split_text(text):
text = " ".join(text)
return text
def dummy_split_text(text):
return text
def remove_punct(text):
puncts = set(punctuation + string.punctuation)
output = ""
for char in text:
if char not in puncts:
output += char
output = output.replace(" ", " ")
return output
def process_wavs(wav_file_list, batch_size=300):
model = AutoModel(
model=model_path,
disable_update=True,
)
results = []
for wav_file_path in tqdm(wav_file_list):
res = model.generate(
input=wav_file_path,
batch_size_s=batch_size,
)
transcription = zhconv.convert(res[0]["text"], "zh-cn")
results.append({"text": transcription})
return results
def main(args):
handler = logging.FileHandler(filename=args.log_file, mode="w")
logging.root.setLevel(logging.INFO)
logging.root.addHandler(handler)
test_path = (
args.test_path
) # './40ms.AISHELL2.test_with_single_ref.base.chunk25.gen'
lst_path = args.test_lst # "40ms.AISHELL2.test_with_single_ref.base.lst"
if args.eval_gt:
logging.info(f"run ASR for GT: {lst_path}")
reference, wav_file_list = get_gt_ref_texts_and_wav_files(
args, lst_path, test_path, remove_punct, split_text
)
results = process_wavs(wav_file_list, batch_size=300)
else:
logging.info(f"run ASR for detok: {lst_path}")
reference, gen_file_list = get_ref_texts_and_gen_files(
args, lst_path, test_path, remove_punct, split_text
)
results = process_wavs(gen_file_list, batch_size=300)
hypothesis = get_hypo_texts(args, results, remove_punct, split_text)
assert len(hypothesis) == len(reference)
logging.info(f"Finish runing ASR for {lst_path}")
logging.info(f"hypothesis: {len(hypothesis)} vs reference: {len(reference)}")
calc_wer(reference, hypothesis, test_path)
logging.info(f"Finish evaluate {lst_path}, results are in {args.log_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--test-path",
required=True,
type=str,
help=f"folder of wav files",
)
parser.add_argument(
"--test-lst",
required=True,
type=str,
help=f"path to test file lst",
)
parser.add_argument(
"--log-file",
required=False,
type=str,
default=None,
help=f"path to test file lst",
)
parser.add_argument(
"--remove-punct",
default=False,
action="store_true",
help=f"remove punct from GT and hypo texts",
)
parser.add_argument(
"--norm-text",
default=False,
action="store_true",
help=f"normalized GT and hypo texts",
)
parser.add_argument(
"--eval-gt",
default=False,
action="store_true",
help=f"remove punct from GT and hypo texts",
)
args = parser.parse_args()
args.norm_text = False
main(args)