|
|
from pywhispercpp.model import Model |
|
|
from pathlib import Path |
|
|
import time |
|
|
import csv |
|
|
|
|
|
from silero_vad.utils_vad import languages |
|
|
from scripts.asr_utils import get_origin_text_dict, get_text_distance |
|
|
|
|
|
def save_csv(file_path, rows): |
|
|
with open(file_path, "w", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerows(rows) |
|
|
print(f"write csv to {file_path}") |
|
|
|
|
|
def load_model(): |
|
|
models_dir = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models") |
|
|
whisper_model = 'large-v3-turbo-q5_0' |
|
|
t0 = time.time() |
|
|
model = Model( |
|
|
model=whisper_model, |
|
|
models_dir=models_dir, |
|
|
print_realtime=False, |
|
|
print_progress=False, |
|
|
print_timestamps=False, |
|
|
translate=False, |
|
|
|
|
|
temperature=0., |
|
|
no_context=True |
|
|
) |
|
|
print("load model time: ", time.time()-t0) |
|
|
return model |
|
|
|
|
|
def run_recordings(): |
|
|
model = load_model() |
|
|
audios = Path("../test_data/recordings/") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
original = get_origin_text_dict() |
|
|
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)): |
|
|
print(audio) |
|
|
t1 = time.time() |
|
|
output = model.transcribe(str(audio), language="zh", initial_prompt="以下是普通话句子,这是一段会议内容。") |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
text = " ".join([a.text for a in output]) |
|
|
print(text) |
|
|
d, nd, diff = get_text_distance(original[audio.stem], text) |
|
|
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) |
|
|
save_csv("csv/pywhisper.csv", rows) |
|
|
|
|
|
|
|
|
def run_test_audios(): |
|
|
model = load_model() |
|
|
lang = "zh" |
|
|
audios = Path("../test_data/audio_clips/") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
for audio in sorted(audios.glob(f"*{lang}*/*.wav")): |
|
|
print(audio) |
|
|
t1 = time.time() |
|
|
output = model.transcribe(str(audio), language=lang, initial_prompt="以下是普通话句子,这是一段会议内容。") |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
text = " ".join([a.text for a in output]) |
|
|
print(text) |
|
|
rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text]) |
|
|
save_csv("csv/whisper.csv", rows) |
|
|
|
|
|
def run_test_dataset(): |
|
|
from test_data.audios import read_dataset |
|
|
model = load_model() |
|
|
test_data = Path("../test_data/dataset/dataset.txt") |
|
|
audio_parent = Path("../test_data/") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence, duration in read_dataset(test_data): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
output = model.transcribe(str(audio_parent/audio_path), language="zh") |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
text = " ".join([a.text for a in output]) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path, |
|
|
"reference": sentence, |
|
|
"duration": duration, |
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_dataset_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def run_test_emilia(): |
|
|
from test_data.audios import read_emilia |
|
|
model = load_model() |
|
|
parent = Path("../test_data/ZH-B000000") |
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence, duration in read_emilia(parent, count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path.name}") |
|
|
|
|
|
t1 = time.time() |
|
|
output = model.transcribe(str(audio_path), language="zh") |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
text = " ".join([a.text for a in output]) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path.name, |
|
|
"reference": sentence, |
|
|
"duration": duration, |
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_emilia_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def run_test_st(): |
|
|
from test_data.audios import read_st |
|
|
model = load_model() |
|
|
|
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence in read_st(count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
output = model.transcribe( |
|
|
str(audio_path), language="zh" |
|
|
) |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
text = " ".join([a.text for a in output]) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path.name, |
|
|
"reference": sentence, |
|
|
|
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_st_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
def run_test_wenet(): |
|
|
from test_data.audios import read_wenet |
|
|
model = load_model() |
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence in read_wenet(count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
output = model.transcribe( |
|
|
str(audio_path), language="zh" |
|
|
) |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
text = " ".join([a.text for a in output]) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path.name, |
|
|
"reference": sentence, |
|
|
|
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_wenet_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
run_test_wenet() |