TestTranslator / scripts /run_funasr_quant.py
yujuanqin's picture
update scripts
42742c6
from pathlib import Path
import time
import csv
from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad
from scripts.asr_utils import get_origin_text_dict, get_text_distance
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_model(quantize=True):
model_dir = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")
asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch'
punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large'
t0 = time.time()
quantize = True
vad_model = Fsmn_vad(vad_model_path, quantize=quantize)
asr_model = SeacoParaformer(asr_model_path, quantize=quantize)
punc_model = CT_Transformer(punc_model_path, quantize=quantize)
t1 = time.time()
print("load model time:", t1 - t0)
return vad_model, asr_model, punc_model
def inference(vad_model, asr_model, punc_model, audio:Path):
t1 = time.time()
# vad_res = vad_model(str(audio))
asr_res = asr_model(str(audio), hotwords="")
text = ""
if len(asr_res) > 0:
asr_text = asr_res[0]["preds"]
result = punc_model(asr_text)
text = result[0]
t4 = time.time()
t = t4-t1
return text, t
def run_once(audio):
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
text, t = inference(vad_model, asr_model, punc_model, audio)
print(text)
def run_recordings():
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
audios = Path("../test_data/recordings/")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
text, t = inference(vad_model, asr_model, punc_model, audio)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) # f"{audio.parent.name}/{audio.name}"
file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
save_csv(file_name, rows)
def run_test_audios():
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
audios = Path("../test_data/audio_clips/")
rows = [["file_name", "time", "inference_result"]]
for audio in sorted(audios.glob("*s/zh*.wav")):
text, t = inference(vad_model, asr_model, punc_model, audio)
rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
save_csv(file_name, rows)
def run_test_dataset():
from test_data.audios import read_dataset
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
test_data = Path("../test_data/dataset/dataset.txt")
audio_parent = Path("../test_data/")
rows = [["file_name", "time", "inference_result"]]
result_list = []
count = 0
try:
for audio_path, sentence, duration in read_dataset(test_data):
count += 1
print(f"processing {count}: {audio_path}")
t1 = time.time()
text, t = inference(vad_model, asr_model, punc_model, audio_parent/audio_path)
t = time.time() - t1
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path,
"reference": sentence,
"duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/funasr_dataset_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
def run_test_emilia():
from test_data.audios import read_emilia
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
parent = Path("../test_data/ZH-B000000")
result_list = []
count = 0
try:
for audio_path, sentence, duration in read_emilia(parent, count_limit=5000):
count += 1
print(f"processing {count}: {audio_path.name}")
text, t = inference(vad_model, asr_model, punc_model, audio_path)
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path.name,
"reference": sentence,
"duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/funasr_emilia_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
def run_test_wenet():
from test_data.audios import read_wenet
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
result_list = []
count = 0
try:
for audio_path, sentence in read_wenet(count_limit=5000):
count += 1
print(f"processing {count}: {audio_path.name}")
text, t = inference(vad_model, asr_model, punc_model, audio_path)
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path.name,
"reference": sentence,
# "duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
# except Exception as e:
# print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/funasr_wenet_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
if __name__ == '__main__':
# run_recordings()
run_test_wenet()
# run_once(Path("/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/zhengyaowei-part1.mp3"))