# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from transformers import WhisperProcessor, WhisperForConditionalGeneration import soundfile as sf import scipy import argparse from whisper_normalizer.english import EnglishTextNormalizer import os import string import lingvo.tasks.asr.tools.simple_wer_v2 as WER from tqdm import tqdm import logging import torch keyphrases = None english_normalizer = EnglishTextNormalizer() device = torch.device("cuda") en_asr_model_path = "./whisper-large-v3" wer_obj = WER.SimpleWER( key_phrases=keyphrases, html_handler=WER.HighlightAlignedHtmlHandler(WER.HighlightAlignedHtml), preprocess_handler=WER.RemoveCommentTxtPreprocess, ) def dummy_split_text(text): return text def remove_punct(text): puncts = set(string.punctuation) output = "" for char in text: if char not in puncts: output += char output = output.replace(" ", " ") return output def get_gt_ref_texts_and_wav_files( args, gt_test_lst, gt_folder, punct_remover, text_spliter ): wav_file_list = [] reference = [] with open(gt_test_lst, "r") as fp: for line in fp: fields = line.strip().split("|") wav_file = f"{gt_folder}/{fields[0]}.wav" if not os.path.isfile(wav_file): continue wav_file_list.append(wav_file) text = fields[-1].lower() if args.norm_text: truth_text = english_normalizer(text) # " ".join(fields[-1]) elif args.remove_punct: truth_text = punct_remover(text) else: truth_text = text truth_text = text_spliter(truth_text) reference.append([truth_text, fields[-1]]) assert len(reference) == len(wav_file_list) return reference, wav_file_list def get_ref_texts_and_gen_files( args, test_lst, test_folder, punct_remover, text_spliter ): reference = [] gen_file_list = [] with open(test_lst, "r") as fp: for line in fp: fields = line.strip().split("|") filename = fields[2].split("/")[-1] filename = filename.split(".")[0] gen_file = f"{filename}_gen.wav" gen_file_list.append(f"{test_folder}/{gen_file}") text = fields[-1].lower() if args.norm_text: truth_text = english_normalizer(text) # " ".join(fields[-1]) elif args.remove_punct: truth_text = punct_remover(text) else: truth_text = text truth_text = text_spliter(truth_text) reference.append([truth_text, fields[-1]]) assert len(reference) == len(gen_file_list) return reference, gen_file_list def get_hypo_texts(args, results_list, punct_remover, text_spliter): hypothesis = [] for res in results_list: text = res["text"].lower() if args.norm_text: hypo_text = english_normalizer(text) elif args.remove_punct: hypo_text = punct_remover(text) else: hypo_text = text hypo_text = text_spliter(hypo_text) hypothesis.append([hypo_text, res["text"]]) return hypothesis def calc_wer(reference, hypothesis, test_lst): logging.info(f"calc WER:") for idx in tqdm(range(len(hypothesis))): hypo = hypothesis[idx][0].strip() ref = reference[idx][0].strip() wer_obj.AddHypRef(hypo, ref) str_summary, str_details, str_keyphrases_info = wer_obj.GetSummaries() logging.info(f"WER summary:") logging.info(str_summary) logging.info(str_details) logging.info(str_keyphrases_info) try: fn_output = test_lst + "_diagnosis.html" aligned_html = "
".join(wer_obj.aligned_htmls) with open(fn_output, "wt") as fp: fp.write("") fp.write("
%s
" % aligned_html) fp.write("") fp.close() text_output = test_lst + "_rawtext.lst" with open(text_output, "w") as fp: for ref, hypo in zip(reference, hypothesis): fp.write(f"{ref[1]}|{hypo[1]}\n") fp.close() logging.info(f"Save {fn_output} and {text_output} for diagnosis") except IOError: logging.info("failed to write diagnosis html") def load_en_model(): processor = WhisperProcessor.from_pretrained(en_asr_model_path) model = WhisperForConditionalGeneration.from_pretrained(en_asr_model_path).to( device ) return processor, model def process_wavs(wav_file_list, batch_size=300): results = [] processor, model = load_en_model() for wav_file_path in tqdm(wav_file_list): wav, sr = sf.read(wav_file_path) if sr != 16000: wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr)) input_features = processor( wav, sampling_rate=16000, return_tensors="pt" ).input_features input_features = input_features.to(device) forced_decoder_ids = processor.get_decoder_prompt_ids( language="english", task="transcribe" ) predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[ 0 ] results.append({"text": transcription.strip()}) return results def main(args): handler = logging.FileHandler(filename=args.log_file, mode="w") logging.root.setLevel(logging.INFO) logging.root.addHandler(handler) test_path = ( args.test_path ) # './40ms.AISHELL2.test_with_single_ref.base.chunk25.gen' lst_path = args.test_lst # "40ms.AISHELL2.test_with_single_ref.base.lst" logging.info( f"Evaluate {args.test_path} with Text Normalization: {args.norm_text} and Remove Punct: {args.remove_punct}" ) if args.eval_gt: logging.info(f"run ASR for GT: {lst_path}") reference, wav_file_list = get_gt_ref_texts_and_wav_files( args, lst_path, test_path, remove_punct, dummy_split_text ) results = process_wavs(wav_file_list, batch_size=12) else: logging.info(f"run ASR for detok: {lst_path}") reference, gen_file_list = get_ref_texts_and_gen_files( args, lst_path, test_path, remove_punct, dummy_split_text ) results = process_wavs(gen_file_list, batch_size=12) hypothesis = get_hypo_texts(args, results, remove_punct, dummy_split_text) assert len(hypothesis) == len(reference) logging.info(f"Finish runing ASR for {lst_path}") logging.info(f"hypothesis: {len(hypothesis)} vs reference: {len(reference)}") calc_wer(reference, hypothesis, test_path) logging.info(f"Finish evaluate {lst_path}, results are in {args.log_file}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--test-path", required=True, type=str, help=f"folder of wav files", ) parser.add_argument( "--test-lst", required=True, type=str, help=f"path to test file lst", ) parser.add_argument( "--log-file", required=False, type=str, default=None, help=f"path to test file lst", ) parser.add_argument( "--norm-text", default=False, action="store_true", help=f"normalized GT and hypo texts", ) parser.add_argument( "--remove-punct", default=False, action="store_true", help=f"remove punct from GT and hypo texts", ) parser.add_argument( "--eval-gt", default=False, action="store_true", help=f"remove punct from GT and hypo texts", ) args = parser.parse_args() main(args)