File size: 5,884 Bytes
db0d138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from pathlib import Path
import time
import csv
from funasr import AutoModel


def main():
    device = "mps"
    model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512"
    model = AutoModel(
        model=model_dir,
        trust_remote_code=True,
        remote_code="./model.py",
        device=device,
    )

    wav_path = f"/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/zhengyaowei-part1.mp3"
    res = model.generate(
        input=[wav_path],
        cache={},
        batch_size=1,
        # hotwords=["开放时间"],
        # 中文、英文、日文 for Fun-ASR-Nano-2512
        # 中文、英文、粤语、日文、韩文、越南语、印尼语、泰语、马来语、菲律宾语、阿拉伯语、
        # 印地语、保加利亚语、克罗地亚语、捷克语、丹麦语、荷兰语、爱沙尼亚语、芬兰语、希腊语、
        # 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、
        # 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512
        language="中文",
        itn=True, # or False
    )
    text = res[0]["text"]
    print(text)
    text = model.generate(input=[wav_path],
        cache={},
        batch_size=1,
        # hotwords=["开放时间"],
        # language="中文",
        itn=True, # or False
    )[0]["text"]
    print(text)
    text = model.generate(input=[wav_path],
                          cache={},
                          batch_size=1,
                          hotwords=["头数", "llama", "decode", "query"],
                          # language="中文",
                          itn=True,  # or False
                          )[0]["text"]
    print(text)

    # model = AutoModel(
    #     model=model_dir,
    #     trust_remote_code=True,
    #     vad_model="fsmn-vad",
    #     vad_kwargs={"max_single_segment_time": 30000},
    #     remote_code="./model.py",
    #     device=device,
    # )
    # res = model.generate(input=[wav_path], cache={}, batch_size=1)
    # text = res[0]["text"]
    # print(text)

def save_csv(file_path, rows):
    with open(file_path, "w", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(rows)
        print(f"write csv to {file_path}")

def load_model():
    device = "mps"
    s = time.time()
    # model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512"
    model_dir = "/Users/jeqin/work/code/Fun-ASR-MLT-Nano-2512"
    model = AutoModel(
        model=model_dir,
        trust_remote_code=True,
        remote_code="./model.py",
        device=device,
        disable_update=True,
    )
    print("load model cost:", time.time() - s)
    return model

def inference(model, wav_path):
    t1 = time.time()
    res = model.generate(input=[str(wav_path)], cache={}, batch_size=1)
    # res = model.generate(
    #     input=[str(wav_path)],
    #     cache={},
    #     # batch_size=1,
    #     hotwords=["开放时间", "llama", "decode"],
    #     # 中文、英文、日文 for Fun-ASR-Nano-2512
    #     # 中文、英文、粤语、日文、韩文、越南语、印尼语、泰语、马来语、菲律宾语、阿拉伯语、
    #     # 印地语、保加利亚语、克罗地亚语、捷克语、丹麦语、荷兰语、爱沙尼亚语、芬兰语、希腊语、
    #     # 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、
    #     # 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512
    #     language="中文",
    #     itn=True,  # or False
    # )
    text = res[0]["text"]
    return text, time.time()-t1

def run_audio_clips():
    model = load_model()
    audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/10s-mix")
    rows = [["file_name", "time", "inference_result"]]
    for audio in sorted(audios.glob("*.wav")):
        print(audio)
        text, cost = inference(model, audio)
        print("inference cost: ", cost)
        print(text)
        rows.append([audio.name, round(cost, 3), text])  # f"{audio.parent.name}/{audio.name}"
    file_name = "csv/funasr_nano.csv"
    # save_csv(file_name, rows)


def run_recordings():
    from scripts.asr_utils import get_origin_text_dict, get_text_distance
    model = load_model()
    audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/")
    rows = [["file_name", "time", "inference_result"]]
    original = get_origin_text_dict()
    for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
        print("processing: ", audio)
        text, cost = inference(model, audio)
        print("inference cost: ", cost)
        print(text)
        d, nd, diff = get_text_distance(original[audio.stem], text)
        rows.append([audio.name, round(cost, 3), text, d, diff]) # f"{audio.parent.name}/{audio.name}"
    file_name = "csv/funasr_nano.csv"
    save_csv(file_name, rows)

def run_test_wenet():
    from test_data.audios import read_wenet
    model = load_model()
    result_list = []
    count = 0
    for audio, sentence in read_wenet(count_limit=5000):
        count += 1
        print(f"processing {count}: {audio}")
        text, cost = inference(model, audio)
        print("inference time:", cost)
        result_list.append({
            "index": count,
            "audio_path": audio.name,
            "reference": sentence,
            # "duration": duration,
            "inference_time": round(cost, 3),
            "inference_result": text
        })
        print("inference cost: ", cost)
        print(text)

    import json
    with open("csv/funasr_nano_wenet.json", "w", encoding="utf-8") as f:
        json.dump(result_list, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    # main()
    run_recordings()
    # run_audio_clips()
    # run_test_wenet()