update scripts
Browse files- .gitignore +4 -0
- scripts/asr_utils.py +35 -0
- scripts/caculate_cer.py +12 -10
- scripts/run_funasr_quant.py +39 -4
- scripts/run_whisper.py +36 -2
.gitignore
CHANGED
|
@@ -2,3 +2,7 @@
|
|
| 2 |
.idea
|
| 3 |
__pycache__/
|
| 4 |
*.csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
.idea
|
| 3 |
__pycache__/
|
| 4 |
*.csv
|
| 5 |
+
*csv*
|
| 6 |
+
*.mp3
|
| 7 |
+
*.wav
|
| 8 |
+
*.flac
|
scripts/asr_utils.py
CHANGED
|
@@ -3,6 +3,20 @@ import csv
|
|
| 3 |
import wave
|
| 4 |
import re
|
| 5 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def add_text_index():
|
| 8 |
text_file = '../tests/test_data/text/test_asr_zh.txt'
|
|
@@ -86,6 +100,27 @@ def read_dataset(file):
|
|
| 86 |
data = json.loads(line)
|
| 87 |
|
| 88 |
yield data["audio"]["path"], data["sentence"], data["duration"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
if __name__ == '__main__':
|
|
|
|
| 3 |
import wave
|
| 4 |
import re
|
| 5 |
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import subprocess
|
| 8 |
+
from subprocess import CompletedProcess
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def cmd(command: str, check=True, capture_output=False) -> CompletedProcess:
|
| 12 |
+
print(command)
|
| 13 |
+
if capture_output:
|
| 14 |
+
ret = subprocess.run(command, shell=True, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 15 |
+
universal_newlines=True)
|
| 16 |
+
else:
|
| 17 |
+
ret = subprocess.run(command, shell=True, check=check)
|
| 18 |
+
print(ret.stdout)
|
| 19 |
+
return ret
|
| 20 |
|
| 21 |
def add_text_index():
|
| 22 |
text_file = '../tests/test_data/text/test_asr_zh.txt'
|
|
|
|
| 100 |
data = json.loads(line)
|
| 101 |
|
| 102 |
yield data["audio"]["path"], data["sentence"], data["duration"]
|
| 103 |
+
|
| 104 |
+
def read_emilia(folder: Path, count_limit=None):
|
| 105 |
+
"""读取 emilia 数据集,返回音频路径、文本、时长,
|
| 106 |
+
json 文件样例:
|
| 107 |
+
{"id": "ZH_B00000_S00110_W000000", "wav": "ZH_B00000/ZH_B00000_S00110/mp3/ZH_B00000_S00110_W000000.mp3", "text": "\u628a\u63e1\u6700\u524d\u6cbf\u7684\u91d1\u878d\u9886\u57df\u548c\u533a\u5757\u94fe\u6700\u65b0\u8d44\u8baf\u3002\u6211\u4eec\u4e00\u8d77\u6765\u4e86\u89e3\u4e00\u4e0b\u4eca\u5929\u5e02\u573a\u4e0a\u6709\u53d1\u751f\u54ea\u4e9b\u91cd\u8981\u4e8b\u4ef6\u3002", "duration": 7.963, "speaker": "ZH_B00000_S00110", "language": "zh", "dnsmos": 3.3808}"""
|
| 108 |
+
count = 0
|
| 109 |
+
for json_file in sorted(folder.glob("*.json")):
|
| 110 |
+
count += 1
|
| 111 |
+
if count_limit and count > count_limit:
|
| 112 |
+
break
|
| 113 |
+
with open(json_file, encoding="utf-8") as f:
|
| 114 |
+
data = json.load(f)
|
| 115 |
+
text = data["text"]
|
| 116 |
+
duration = data["duration"]
|
| 117 |
+
wav_path = folder /f'{json_file.stem}.wav'
|
| 118 |
+
if not wav_path.exists():
|
| 119 |
+
mp3_path = folder / f'{json_file.stem}.mp3'
|
| 120 |
+
command=f"ffmpeg -i {mp3_path} -ac 1 -ar 16000 {wav_path}"
|
| 121 |
+
cmd(command)
|
| 122 |
+
yield wav_path, text, duration
|
| 123 |
+
|
| 124 |
|
| 125 |
|
| 126 |
if __name__ == '__main__':
|
scripts/caculate_cer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_diff
|
| 3 |
-
|
| 4 |
# import Levenshtein
|
| 5 |
|
| 6 |
def calculate_distance(reference: str, hypothesis: str):
|
|
@@ -34,14 +34,16 @@ if __name__ == '__main__':
|
|
| 34 |
count += 1
|
| 35 |
reference = item["reference"]
|
| 36 |
hypothesis = item["inference_result"]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
print(f"
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
cer = distance_sum / reference_sum if reference_sum > 0 else 0
|
| 47 |
print(f"Total Distance: {distance_sum}, Total Reference Length: {reference_sum}, CER: {cer:.4f}")
|
|
|
|
| 1 |
import json
|
| 2 |
from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_diff
|
| 3 |
+
import re
|
| 4 |
# import Levenshtein
|
| 5 |
|
| 6 |
def calculate_distance(reference: str, hypothesis: str):
|
|
|
|
| 34 |
count += 1
|
| 35 |
reference = item["reference"]
|
| 36 |
hypothesis = item["inference_result"]
|
| 37 |
+
if re.search(r"\d", hypothesis):
|
| 38 |
+
# continue
|
| 39 |
+
distance, diff = calculate_distance(reference, hypothesis)
|
| 40 |
+
print(f"{count}. distance: {distance}")
|
| 41 |
+
if distance > 0:
|
| 42 |
+
print(f"Audio Path: {item['audio_path']}")
|
| 43 |
+
print(f"Reference: {reference}")
|
| 44 |
+
print(f"Hypothesis: {hypothesis}")
|
| 45 |
+
print(f"Diff: {diff}")
|
| 46 |
+
distance_sum += distance
|
| 47 |
+
reference_sum += len(reference)
|
| 48 |
cer = distance_sum / reference_sum if reference_sum > 0 else 0
|
| 49 |
print(f"Total Distance: {distance_sum}, Total Reference Length: {reference_sum}, CER: {cer:.4f}")
|
scripts/run_funasr_quant.py
CHANGED
|
@@ -29,11 +29,11 @@ def inference(vad_model, asr_model, punc_model, audio:Path):
|
|
| 29 |
print(audio.name)
|
| 30 |
t1 = time.time()
|
| 31 |
vad_res = vad_model(str(audio))
|
| 32 |
-
t2 = time.time()
|
| 33 |
# print("vad time:", t2-t1)
|
| 34 |
asr_res = asr_model(str(audio), hotwords="")
|
| 35 |
asr_text = asr_res[0]["preds"]
|
| 36 |
-
t3 = time.time()
|
| 37 |
# print("asr time:", t3-t2)
|
| 38 |
# print("asr text:", asr_text)
|
| 39 |
result = punc_model(asr_text)
|
|
@@ -69,6 +69,41 @@ def run_test_audios():
|
|
| 69 |
rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
|
| 70 |
file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
|
| 71 |
save_csv(file_name, rows)
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if __name__ == '__main__':
|
| 74 |
-
|
|
|
|
| 29 |
print(audio.name)
|
| 30 |
t1 = time.time()
|
| 31 |
vad_res = vad_model(str(audio))
|
| 32 |
+
# t2 = time.time()
|
| 33 |
# print("vad time:", t2-t1)
|
| 34 |
asr_res = asr_model(str(audio), hotwords="")
|
| 35 |
asr_text = asr_res[0]["preds"]
|
| 36 |
+
# t3 = time.time()
|
| 37 |
# print("asr time:", t3-t2)
|
| 38 |
# print("asr text:", asr_text)
|
| 39 |
result = punc_model(asr_text)
|
|
|
|
| 69 |
rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
|
| 70 |
file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
|
| 71 |
save_csv(file_name, rows)
|
| 72 |
+
|
| 73 |
+
def run_test_dataset():
|
| 74 |
+
from scripts.asr_utils import read_dataset
|
| 75 |
+
quantize = True
|
| 76 |
+
vad_model, asr_model, punc_model = load_model(quantize)
|
| 77 |
+
test_data = Path("../tests/test_data/dataset.txt")
|
| 78 |
+
audio_parent = Path("../tests/test_data/")
|
| 79 |
+
rows = [["file_name", "time", "inference_result"]]
|
| 80 |
+
result_list = []
|
| 81 |
+
count = 0
|
| 82 |
+
try:
|
| 83 |
+
for audio_path, sentence, duration in read_dataset(test_data):
|
| 84 |
+
count += 1
|
| 85 |
+
print(f"processing {count}: {audio_path}")
|
| 86 |
+
|
| 87 |
+
t1 = time.time()
|
| 88 |
+
text, t = inference(vad_model, asr_model, punc_model, audio_parent/audio_path)
|
| 89 |
+
t = time.time() - t1
|
| 90 |
+
print("inference time:", t)
|
| 91 |
+
print(text)
|
| 92 |
+
result_list.append({
|
| 93 |
+
"index": count,
|
| 94 |
+
"audio_path": audio_path,
|
| 95 |
+
"reference": sentence,
|
| 96 |
+
"duration": duration,
|
| 97 |
+
"inference_time": round(t, 3),
|
| 98 |
+
"inference_result": text
|
| 99 |
+
})
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(e)
|
| 102 |
+
except KeyboardInterrupt as e:
|
| 103 |
+
print(e)
|
| 104 |
+
import json
|
| 105 |
+
with open("csv/funasr_dataset_results.json", "w", encoding="utf-8") as f:
|
| 106 |
+
json.dump(result_list, f, ensure_ascii=False, indent=2)
|
| 107 |
+
|
| 108 |
if __name__ == '__main__':
|
| 109 |
+
run_test_dataset()
|
scripts/run_whisper.py
CHANGED
|
@@ -4,7 +4,7 @@ import time
|
|
| 4 |
import csv
|
| 5 |
|
| 6 |
from silero_vad.utils_vad import languages
|
| 7 |
-
from scripts.asr_utils import get_origin_text_dict, get_text_distance
|
| 8 |
|
| 9 |
def save_csv(file_path, rows):
|
| 10 |
with open(file_path, "w", encoding="utf-8") as f:
|
|
@@ -67,6 +67,7 @@ def run_test_audios():
|
|
| 67 |
save_csv("csv/whisper.csv", rows)
|
| 68 |
|
| 69 |
def run_test_dataset():
|
|
|
|
| 70 |
model = load_model()
|
| 71 |
test_data = Path("../tests/test_data/dataset.txt")
|
| 72 |
audio_parent = Path("../tests/test_data/")
|
|
@@ -99,5 +100,38 @@ def run_test_dataset():
|
|
| 99 |
import json
|
| 100 |
with open("csv/whisper_dataset_results.json", "w", encoding="utf-8") as f:
|
| 101 |
json.dump(result_list, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
if __name__ == '__main__':
|
| 103 |
-
|
|
|
|
| 4 |
import csv
|
| 5 |
|
| 6 |
from silero_vad.utils_vad import languages
|
| 7 |
+
from scripts.asr_utils import get_origin_text_dict, get_text_distance
|
| 8 |
|
| 9 |
def save_csv(file_path, rows):
|
| 10 |
with open(file_path, "w", encoding="utf-8") as f:
|
|
|
|
| 67 |
save_csv("csv/whisper.csv", rows)
|
| 68 |
|
| 69 |
def run_test_dataset():
|
| 70 |
+
from scripts.asr_utils import read_dataset
|
| 71 |
model = load_model()
|
| 72 |
test_data = Path("../tests/test_data/dataset.txt")
|
| 73 |
audio_parent = Path("../tests/test_data/")
|
|
|
|
| 100 |
import json
|
| 101 |
with open("csv/whisper_dataset_results.json", "w", encoding="utf-8") as f:
|
| 102 |
json.dump(result_list, f, ensure_ascii=False, indent=2)
|
| 103 |
+
|
| 104 |
+
def run_test_emilia():
|
| 105 |
+
from scripts.asr_utils import read_emilia
|
| 106 |
+
model = load_model()
|
| 107 |
+
parent = Path("../tests/test_data/ZH-B000000")
|
| 108 |
+
result_list = []
|
| 109 |
+
count = 0
|
| 110 |
+
try:
|
| 111 |
+
for audio_path, sentence, duration in read_emilia(parent, count_limit=5000):
|
| 112 |
+
count += 1
|
| 113 |
+
print(f"processing {count}: {audio_path.name}")
|
| 114 |
+
|
| 115 |
+
t1 = time.time()
|
| 116 |
+
output = model.transcribe(str(audio_path), language="zh")# , initial_prompt="以下是普通话句子,这是一段会议内容。"
|
| 117 |
+
t = time.time() - t1
|
| 118 |
+
print("inference time:", t)
|
| 119 |
+
text = " ".join([a.text for a in output])
|
| 120 |
+
print(text)
|
| 121 |
+
result_list.append({
|
| 122 |
+
"index": count,
|
| 123 |
+
"audio_path": audio_path.name,
|
| 124 |
+
"reference": sentence,
|
| 125 |
+
"duration": duration,
|
| 126 |
+
"inference_time": round(t, 3),
|
| 127 |
+
"inference_result": text
|
| 128 |
+
})
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(e)
|
| 131 |
+
except KeyboardInterrupt as e:
|
| 132 |
+
print(e)
|
| 133 |
+
import json
|
| 134 |
+
with open("csv/whisper_emilia_results.json", "w", encoding="utf-8") as f:
|
| 135 |
+
json.dump(result_list, f, ensure_ascii=False, indent=2)
|
| 136 |
if __name__ == '__main__':
|
| 137 |
+
run_test_emilia()
|