TestTranslator / scripts /run_funasr_quant.py
yujuanqin's picture
update path to relative
8a27f7c
raw
history blame
2.94 kB
from pathlib import Path
import time
import csv
from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad
from scripts.asr_utils import get_origin_text_dict, get_text_distance
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_model(quantize=True):
model_dir = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")
asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch'
punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large'
t0 = time.time()
quantize = True
vad_model = Fsmn_vad(vad_model_path, quantize=quantize)
asr_model = SeacoParaformer(asr_model_path, quantize=quantize)
punc_model = CT_Transformer(punc_model_path, quantize=quantize)
t1 = time.time()
print("load model time:", t1 - t0)
return vad_model, asr_model, punc_model
def inference(vad_model, asr_model, punc_model, audio:Path):
print(audio.name)
t1 = time.time()
vad_res = vad_model(str(audio))
t2 = time.time()
# print("vad time:", t2-t1)
asr_res = asr_model(str(audio), hotwords="")
asr_text = asr_res[0]["preds"]
t3 = time.time()
# print("asr time:", t3-t2)
# print("asr text:", asr_text)
result = punc_model(asr_text)
text = result[0]
t4 = time.time()
# print("punc time:", t4-t3)
# print("punc text:", text)
print(text)
t = t4-t1
print("inference:", t)
return text, t
def run_recordings():
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
audios = Path("../tests/test_data/recordings/")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
text, t = inference(vad_model, asr_model, punc_model, audio)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) # f"{audio.parent.name}/{audio.name}"
file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
save_csv(file_name, rows)
def run_test_audios():
quantize = True
vad_model, asr_model, punc_model = load_model(quantize)
audios = Path("../tests/test_data/test_audios/")
rows = [["file_name", "time", "inference_result"]]
for audio in sorted(audios.glob("*s/zh*.wav")):
text, t = inference(vad_model, asr_model, punc_model, audio)
rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
save_csv(file_name, rows)
if __name__ == '__main__':
run_recordings()