|
|
import sys, os |
|
|
from tqdm import tqdm |
|
|
import multiprocessing |
|
|
from jiwer import compute_measures |
|
|
from zhon.hanzi import punctuation |
|
|
import string |
|
|
import numpy as np |
|
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
import soundfile as sf |
|
|
import scipy |
|
|
import zhconv |
|
|
from funasr import AutoModel |
|
|
|
|
|
punctuation_all = punctuation + string.punctuation |
|
|
|
|
|
wav_res_text_path = sys.argv[1] |
|
|
res_path = sys.argv[2] |
|
|
lang = sys.argv[3] |
|
|
device = "cuda:0" |
|
|
|
|
|
|
|
|
def load_zh_model(): |
|
|
model = AutoModel(model="./speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", disable_update=True) |
|
|
return model |
|
|
|
|
|
def process_one(hypo, truth): |
|
|
raw_truth = truth |
|
|
raw_hypo = hypo |
|
|
|
|
|
for x in punctuation_all: |
|
|
if x == '\'': |
|
|
continue |
|
|
truth = truth.replace(x, '') |
|
|
hypo = hypo.replace(x, '') |
|
|
|
|
|
truth = truth.replace(' ', ' ') |
|
|
hypo = hypo.replace(' ', ' ') |
|
|
|
|
|
if lang == "zh": |
|
|
truth = " ".join([x for x in truth]) |
|
|
hypo = " ".join([x for x in hypo]) |
|
|
elif lang == "en": |
|
|
truth = truth.lower() |
|
|
hypo = hypo.lower() |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
measures = compute_measures(truth, hypo) |
|
|
ref_list = truth.split(" ") |
|
|
wer = measures["wer"] |
|
|
subs = measures["substitutions"] / len(ref_list) |
|
|
dele = measures["deletions"] / len(ref_list) |
|
|
inse = measures["insertions"] / len(ref_list) |
|
|
return (raw_truth, raw_hypo, wer, subs, dele, inse) |
|
|
|
|
|
|
|
|
def run_asr(wav_res_text_path, res_path): |
|
|
model = load_zh_model() |
|
|
|
|
|
params = [] |
|
|
for line in open(wav_res_text_path).readlines(): |
|
|
line = line.strip() |
|
|
if len(line.split('|')) == 2: |
|
|
wav_res_path, text_ref = line.split('|') |
|
|
elif len(line.split('|')) == 3: |
|
|
wav_res_path, wav_ref_path, text_ref = line.split('|') |
|
|
elif len(line.split('|')) == 4: |
|
|
wav_res_path, _, text_ref, wav_ref_path = line.split('|') |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if not os.path.exists(wav_res_path): |
|
|
continue |
|
|
params.append((wav_res_path, text_ref)) |
|
|
fout = open(res_path, "w") |
|
|
|
|
|
n_higher_than_50 = 0 |
|
|
wers_below_50 = [] |
|
|
for wav_res_path, text_ref in tqdm(params): |
|
|
res = model.generate(input=wav_res_path, |
|
|
batch_size_s=300) |
|
|
transcription = res[0]["text"] |
|
|
transcription = zhconv.convert(transcription, 'zh-cn') |
|
|
|
|
|
raw_truth, raw_hypo, wer, subs, dele, inse = process_one(transcription, text_ref) |
|
|
fout.write(f"{wav_res_path}\t{wer}\t{raw_truth}\t{raw_hypo}\t{inse}\t{dele}\t{subs}\n") |
|
|
fout.flush() |
|
|
|
|
|
run_asr(wav_res_text_path, res_path) |
|
|
|