File size: 5,884 Bytes
db0d138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from pathlib import Path
import time
import csv
from funasr import AutoModel
def main():
device = "mps"
model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512"
model = AutoModel(
model=model_dir,
trust_remote_code=True,
remote_code="./model.py",
device=device,
)
wav_path = f"/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/zhengyaowei-part1.mp3"
res = model.generate(
input=[wav_path],
cache={},
batch_size=1,
# hotwords=["开放时间"],
# 中文、英文、日文 for Fun-ASR-Nano-2512
# 中文、英文、粤语、日文、韩文、越南语、印尼语、泰语、马来语、菲律宾语、阿拉伯语、
# 印地语、保加利亚语、克罗地亚语、捷克语、丹麦语、荷兰语、爱沙尼亚语、芬兰语、希腊语、
# 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、
# 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512
language="中文",
itn=True, # or False
)
text = res[0]["text"]
print(text)
text = model.generate(input=[wav_path],
cache={},
batch_size=1,
# hotwords=["开放时间"],
# language="中文",
itn=True, # or False
)[0]["text"]
print(text)
text = model.generate(input=[wav_path],
cache={},
batch_size=1,
hotwords=["头数", "llama", "decode", "query"],
# language="中文",
itn=True, # or False
)[0]["text"]
print(text)
# model = AutoModel(
# model=model_dir,
# trust_remote_code=True,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
# remote_code="./model.py",
# device=device,
# )
# res = model.generate(input=[wav_path], cache={}, batch_size=1)
# text = res[0]["text"]
# print(text)
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_model():
device = "mps"
s = time.time()
# model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512"
model_dir = "/Users/jeqin/work/code/Fun-ASR-MLT-Nano-2512"
model = AutoModel(
model=model_dir,
trust_remote_code=True,
remote_code="./model.py",
device=device,
disable_update=True,
)
print("load model cost:", time.time() - s)
return model
def inference(model, wav_path):
t1 = time.time()
res = model.generate(input=[str(wav_path)], cache={}, batch_size=1)
# res = model.generate(
# input=[str(wav_path)],
# cache={},
# # batch_size=1,
# hotwords=["开放时间", "llama", "decode"],
# # 中文、英文、日文 for Fun-ASR-Nano-2512
# # 中文、英文、粤语、日文、韩文、越南语、印尼语、泰语、马来语、菲律宾语、阿拉伯语、
# # 印地语、保加利亚语、克罗地亚语、捷克语、丹麦语、荷兰语、爱沙尼亚语、芬兰语、希腊语、
# # 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、
# # 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512
# language="中文",
# itn=True, # or False
# )
text = res[0]["text"]
return text, time.time()-t1
def run_audio_clips():
model = load_model()
audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/10s-mix")
rows = [["file_name", "time", "inference_result"]]
for audio in sorted(audios.glob("*.wav")):
print(audio)
text, cost = inference(model, audio)
print("inference cost: ", cost)
print(text)
rows.append([audio.name, round(cost, 3), text]) # f"{audio.parent.name}/{audio.name}"
file_name = "csv/funasr_nano.csv"
# save_csv(file_name, rows)
def run_recordings():
from scripts.asr_utils import get_origin_text_dict, get_text_distance
model = load_model()
audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
print("processing: ", audio)
text, cost = inference(model, audio)
print("inference cost: ", cost)
print(text)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(cost, 3), text, d, diff]) # f"{audio.parent.name}/{audio.name}"
file_name = "csv/funasr_nano.csv"
save_csv(file_name, rows)
def run_test_wenet():
from test_data.audios import read_wenet
model = load_model()
result_list = []
count = 0
for audio, sentence in read_wenet(count_limit=5000):
count += 1
print(f"processing {count}: {audio}")
text, cost = inference(model, audio)
print("inference time:", cost)
result_list.append({
"index": count,
"audio_path": audio.name,
"reference": sentence,
# "duration": duration,
"inference_time": round(cost, 3),
"inference_result": text
})
print("inference cost: ", cost)
print(text)
import json
with open("csv/funasr_nano_wenet.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# main()
run_recordings()
# run_audio_clips()
# run_test_wenet() |