|
|
from pathlib import Path |
|
|
import time |
|
|
import csv |
|
|
from funasr import AutoModel |
|
|
|
|
|
|
|
|
def main(): |
|
|
device = "mps" |
|
|
model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512" |
|
|
model = AutoModel( |
|
|
model=model_dir, |
|
|
trust_remote_code=True, |
|
|
remote_code="./model.py", |
|
|
device=device, |
|
|
) |
|
|
|
|
|
wav_path = f"/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/zhengyaowei-part1.mp3" |
|
|
res = model.generate( |
|
|
input=[wav_path], |
|
|
cache={}, |
|
|
batch_size=1, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
language="中文", |
|
|
itn=True, |
|
|
) |
|
|
text = res[0]["text"] |
|
|
print(text) |
|
|
text = model.generate(input=[wav_path], |
|
|
cache={}, |
|
|
batch_size=1, |
|
|
|
|
|
|
|
|
itn=True, |
|
|
)[0]["text"] |
|
|
print(text) |
|
|
text = model.generate(input=[wav_path], |
|
|
cache={}, |
|
|
batch_size=1, |
|
|
hotwords=["头数", "llama", "decode", "query"], |
|
|
|
|
|
itn=True, |
|
|
)[0]["text"] |
|
|
print(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_csv(file_path, rows): |
|
|
with open(file_path, "w", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerows(rows) |
|
|
print(f"write csv to {file_path}") |
|
|
|
|
|
def load_model(): |
|
|
device = "mps" |
|
|
s = time.time() |
|
|
|
|
|
model_dir = "/Users/jeqin/work/code/Fun-ASR-MLT-Nano-2512" |
|
|
model = AutoModel( |
|
|
model=model_dir, |
|
|
trust_remote_code=True, |
|
|
remote_code="./model.py", |
|
|
device=device, |
|
|
disable_update=True, |
|
|
) |
|
|
print("load model cost:", time.time() - s) |
|
|
return model |
|
|
|
|
|
def inference(model, wav_path): |
|
|
t1 = time.time() |
|
|
res = model.generate(input=[str(wav_path)], cache={}, batch_size=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text = res[0]["text"] |
|
|
return text, time.time()-t1 |
|
|
|
|
|
def run_audio_clips(): |
|
|
model = load_model() |
|
|
audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/10s-mix") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
for audio in sorted(audios.glob("*.wav")): |
|
|
print(audio) |
|
|
text, cost = inference(model, audio) |
|
|
print("inference cost: ", cost) |
|
|
print(text) |
|
|
rows.append([audio.name, round(cost, 3), text]) |
|
|
file_name = "csv/funasr_nano.csv" |
|
|
|
|
|
|
|
|
|
|
|
def run_recordings(): |
|
|
from scripts.asr_utils import get_origin_text_dict, get_text_distance |
|
|
model = load_model() |
|
|
audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
original = get_origin_text_dict() |
|
|
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)): |
|
|
print("processing: ", audio) |
|
|
text, cost = inference(model, audio) |
|
|
print("inference cost: ", cost) |
|
|
print(text) |
|
|
d, nd, diff = get_text_distance(original[audio.stem], text) |
|
|
rows.append([audio.name, round(cost, 3), text, d, diff]) |
|
|
file_name = "csv/funasr_nano.csv" |
|
|
save_csv(file_name, rows) |
|
|
|
|
|
def run_test_wenet(): |
|
|
from test_data.audios import read_wenet |
|
|
model = load_model() |
|
|
result_list = [] |
|
|
count = 0 |
|
|
for audio, sentence in read_wenet(count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio}") |
|
|
text, cost = inference(model, audio) |
|
|
print("inference time:", cost) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio.name, |
|
|
"reference": sentence, |
|
|
|
|
|
"inference_time": round(cost, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
print("inference cost: ", cost) |
|
|
print(text) |
|
|
|
|
|
import json |
|
|
with open("csv/funasr_nano_wenet.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
run_recordings() |
|
|
|
|
|
|