import sys, os from tqdm import tqdm import multiprocessing from jiwer import compute_measures from zhon.hanzi import punctuation import string import numpy as np from transformers import WhisperProcessor, WhisperForConditionalGeneration import soundfile as sf import scipy import zhconv from funasr import AutoModel punctuation_all = punctuation + string.punctuation wav_res_text_path = sys.argv[1] res_path = sys.argv[2] lang = sys.argv[3] # zh or en device = "cuda:0" def load_zh_model(): model = AutoModel(model="./speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", disable_update=True) return model def process_one(hypo, truth): raw_truth = truth raw_hypo = hypo for x in punctuation_all: if x == '\'': continue truth = truth.replace(x, '') hypo = hypo.replace(x, '') truth = truth.replace(' ', ' ') hypo = hypo.replace(' ', ' ') if lang == "zh": truth = " ".join([x for x in truth]) hypo = " ".join([x for x in hypo]) elif lang == "en": truth = truth.lower() hypo = hypo.lower() else: raise NotImplementedError measures = compute_measures(truth, hypo) ref_list = truth.split(" ") wer = measures["wer"] subs = measures["substitutions"] / len(ref_list) dele = measures["deletions"] / len(ref_list) inse = measures["insertions"] / len(ref_list) return (raw_truth, raw_hypo, wer, subs, dele, inse) def run_asr(wav_res_text_path, res_path): model = load_zh_model() params = [] for line in open(wav_res_text_path).readlines(): line = line.strip() if len(line.split('|')) == 2: wav_res_path, text_ref = line.split('|') elif len(line.split('|')) == 3: wav_res_path, wav_ref_path, text_ref = line.split('|') elif len(line.split('|')) == 4: # for edit wav_res_path, _, text_ref, wav_ref_path = line.split('|') else: raise NotImplementedError if not os.path.exists(wav_res_path): continue params.append((wav_res_path, text_ref)) fout = open(res_path, "w") n_higher_than_50 = 0 wers_below_50 = [] for wav_res_path, text_ref in tqdm(params): res = model.generate(input=wav_res_path, batch_size_s=300) transcription = res[0]["text"] transcription = zhconv.convert(transcription, 'zh-cn') raw_truth, raw_hypo, wer, subs, dele, inse = process_one(transcription, text_ref) fout.write(f"{wav_res_path}\t{wer}\t{raw_truth}\t{raw_hypo}\t{inse}\t{dele}\t{subs}\n") fout.flush() run_asr(wav_res_text_path, res_path)