TestTranslator / scripts /run_funasr_nano.py
yujuanqin's picture
add asr test
db0d138
from pathlib import Path
import time
import csv
from funasr import AutoModel
def main():
device = "mps"
model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512"
model = AutoModel(
model=model_dir,
trust_remote_code=True,
remote_code="./model.py",
device=device,
)
wav_path = f"/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/zhengyaowei-part1.mp3"
res = model.generate(
input=[wav_path],
cache={},
batch_size=1,
# hotwords=["开放时间"],
# 中文、英文、日文 for Fun-ASR-Nano-2512
# 中文、英文、粤语、日文、韩文、越南语、印尼语、泰语、马来语、菲律宾语、阿拉伯语、
# 印地语、保加利亚语、克罗地亚语、捷克语、丹麦语、荷兰语、爱沙尼亚语、芬兰语、希腊语、
# 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、
# 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512
language="中文",
itn=True, # or False
)
text = res[0]["text"]
print(text)
text = model.generate(input=[wav_path],
cache={},
batch_size=1,
# hotwords=["开放时间"],
# language="中文",
itn=True, # or False
)[0]["text"]
print(text)
text = model.generate(input=[wav_path],
cache={},
batch_size=1,
hotwords=["头数", "llama", "decode", "query"],
# language="中文",
itn=True, # or False
)[0]["text"]
print(text)
# model = AutoModel(
# model=model_dir,
# trust_remote_code=True,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
# remote_code="./model.py",
# device=device,
# )
# res = model.generate(input=[wav_path], cache={}, batch_size=1)
# text = res[0]["text"]
# print(text)
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_model():
device = "mps"
s = time.time()
# model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512"
model_dir = "/Users/jeqin/work/code/Fun-ASR-MLT-Nano-2512"
model = AutoModel(
model=model_dir,
trust_remote_code=True,
remote_code="./model.py",
device=device,
disable_update=True,
)
print("load model cost:", time.time() - s)
return model
def inference(model, wav_path):
t1 = time.time()
res = model.generate(input=[str(wav_path)], cache={}, batch_size=1)
# res = model.generate(
# input=[str(wav_path)],
# cache={},
# # batch_size=1,
# hotwords=["开放时间", "llama", "decode"],
# # 中文、英文、日文 for Fun-ASR-Nano-2512
# # 中文、英文、粤语、日文、韩文、越南语、印尼语、泰语、马来语、菲律宾语、阿拉伯语、
# # 印地语、保加利亚语、克罗地亚语、捷克语、丹麦语、荷兰语、爱沙尼亚语、芬兰语、希腊语、
# # 匈牙利语、爱尔兰语、拉脱维亚语、立陶宛语、马耳他语、波兰语、葡萄牙语、罗马尼亚语、
# # 斯洛伐克语、斯洛文尼亚语、瑞典语 for Fun-ASR-MLT-Nano-2512
# language="中文",
# itn=True, # or False
# )
text = res[0]["text"]
return text, time.time()-t1
def run_audio_clips():
model = load_model()
audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/10s-mix")
rows = [["file_name", "time", "inference_result"]]
for audio in sorted(audios.glob("*.wav")):
print(audio)
text, cost = inference(model, audio)
print("inference cost: ", cost)
print(text)
rows.append([audio.name, round(cost, 3), text]) # f"{audio.parent.name}/{audio.name}"
file_name = "csv/funasr_nano.csv"
# save_csv(file_name, rows)
def run_recordings():
from scripts.asr_utils import get_origin_text_dict, get_text_distance
model = load_model()
audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
print("processing: ", audio)
text, cost = inference(model, audio)
print("inference cost: ", cost)
print(text)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(cost, 3), text, d, diff]) # f"{audio.parent.name}/{audio.name}"
file_name = "csv/funasr_nano.csv"
save_csv(file_name, rows)
def run_test_wenet():
from test_data.audios import read_wenet
model = load_model()
result_list = []
count = 0
for audio, sentence in read_wenet(count_limit=5000):
count += 1
print(f"processing {count}: {audio}")
text, cost = inference(model, audio)
print("inference time:", cost)
result_list.append({
"index": count,
"audio_path": audio.name,
"reference": sentence,
# "duration": duration,
"inference_time": round(cost, 3),
"inference_result": text
})
print("inference cost: ", cost)
print(text)
import json
with open("csv/funasr_nano_wenet.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# main()
run_recordings()
# run_audio_clips()
# run_test_wenet()