dpss-exp3-TTS / eval /run_wer.py
lglg666's picture
Upload folder using huggingface_hub
6766eda verified
import sys, os
from tqdm import tqdm
import multiprocessing
from jiwer import compute_measures
from zhon.hanzi import punctuation
import string
import numpy as np
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import soundfile as sf
import scipy
import zhconv
from funasr import AutoModel
punctuation_all = punctuation + string.punctuation
wav_res_text_path = sys.argv[1]
res_path = sys.argv[2]
lang = sys.argv[3] # zh or en
device = "cuda:0"
def load_zh_model():
model = AutoModel(model="./speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", disable_update=True)
return model
def process_one(hypo, truth):
raw_truth = truth
raw_hypo = hypo
for x in punctuation_all:
if x == '\'':
continue
truth = truth.replace(x, '')
hypo = hypo.replace(x, '')
truth = truth.replace(' ', ' ')
hypo = hypo.replace(' ', ' ')
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
else:
raise NotImplementedError
measures = compute_measures(truth, hypo)
ref_list = truth.split(" ")
wer = measures["wer"]
subs = measures["substitutions"] / len(ref_list)
dele = measures["deletions"] / len(ref_list)
inse = measures["insertions"] / len(ref_list)
return (raw_truth, raw_hypo, wer, subs, dele, inse)
def run_asr(wav_res_text_path, res_path):
model = load_zh_model()
params = []
for line in open(wav_res_text_path).readlines():
line = line.strip()
if len(line.split('|')) == 2:
wav_res_path, text_ref = line.split('|')
elif len(line.split('|')) == 3:
wav_res_path, wav_ref_path, text_ref = line.split('|')
elif len(line.split('|')) == 4: # for edit
wav_res_path, _, text_ref, wav_ref_path = line.split('|')
else:
raise NotImplementedError
if not os.path.exists(wav_res_path):
continue
params.append((wav_res_path, text_ref))
fout = open(res_path, "w")
n_higher_than_50 = 0
wers_below_50 = []
for wav_res_path, text_ref in tqdm(params):
res = model.generate(input=wav_res_path,
batch_size_s=300)
transcription = res[0]["text"]
transcription = zhconv.convert(transcription, 'zh-cn')
raw_truth, raw_hypo, wer, subs, dele, inse = process_one(transcription, text_ref)
fout.write(f"{wav_res_path}\t{wer}\t{raw_truth}\t{raw_hypo}\t{inse}\t{dele}\t{subs}\n")
fout.flush()
run_asr(wav_res_text_path, res_path)