File size: 2,936 Bytes
8a3bc32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a27f7c
8a3bc32
 
 
 
 
 
 
 
 
 
 
 
8a27f7c
8a3bc32
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pathlib import Path
import time
import csv
from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad
from scripts.asr_utils import get_origin_text_dict, get_text_distance

def save_csv(file_path, rows):
    with open(file_path, "w", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(rows)
        print(f"write csv to {file_path}")

def load_model(quantize=True):
    model_dir = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")

    asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
    vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch'
    punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large'
    t0 = time.time()
    quantize = True
    vad_model = Fsmn_vad(vad_model_path, quantize=quantize)
    asr_model = SeacoParaformer(asr_model_path, quantize=quantize)
    punc_model = CT_Transformer(punc_model_path, quantize=quantize)
    t1 = time.time()
    print("load model time:", t1 - t0)
    return vad_model, asr_model, punc_model

def inference(vad_model, asr_model, punc_model, audio:Path):
    print(audio.name)
    t1 = time.time()
    vad_res = vad_model(str(audio))
    t2 = time.time()
    # print("vad time:", t2-t1)
    asr_res = asr_model(str(audio), hotwords="")
    asr_text = asr_res[0]["preds"]
    t3 = time.time()
    # print("asr time:", t3-t2)
    # print("asr text:", asr_text)
    result = punc_model(asr_text)
    text = result[0]
    t4 = time.time()
    # print("punc time:", t4-t3)
    # print("punc text:", text)
    print(text)
    t = t4-t1
    print("inference:", t)
    return text, t

def run_recordings():
    quantize = True
    vad_model, asr_model, punc_model = load_model(quantize)
    audios = Path("../tests/test_data/recordings/")
    rows = [["file_name", "time", "inference_result"]]
    original = get_origin_text_dict()
    for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
        text, t = inference(vad_model, asr_model, punc_model, audio)
        d, nd, diff = get_text_distance(original[audio.stem], text)
        rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) # f"{audio.parent.name}/{audio.name}"
    file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
    save_csv(file_name, rows)

def run_test_audios():
    quantize = True
    vad_model, asr_model, punc_model = load_model(quantize)
    audios = Path("../tests/test_data/test_audios/")
    rows = [["file_name", "time", "inference_result"]]
    for audio in sorted(audios.glob("*s/zh*.wav")):
        text, t = inference(vad_model, asr_model, punc_model, audio)
        rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
    file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
    save_csv(file_name, rows)
    
if __name__ == '__main__':
    run_recordings()