yujuanqin commited on
Commit
8a34d65
·
1 Parent(s): 152a3e8

update scripts

Browse files
.gitignore CHANGED
@@ -2,3 +2,7 @@
2
  .idea
3
  __pycache__/
4
  *.csv
 
 
 
 
 
2
  .idea
3
  __pycache__/
4
  *.csv
5
+ *csv*
6
+ *.mp3
7
+ *.wav
8
+ *.flac
scripts/asr_utils.py CHANGED
@@ -3,6 +3,20 @@ import csv
3
  import wave
4
  import re
5
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def add_text_index():
8
  text_file = '../tests/test_data/text/test_asr_zh.txt'
@@ -86,6 +100,27 @@ def read_dataset(file):
86
  data = json.loads(line)
87
 
88
  yield data["audio"]["path"], data["sentence"], data["duration"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  if __name__ == '__main__':
 
3
  import wave
4
  import re
5
  import json
6
+ from pathlib import Path
7
+ import subprocess
8
+ from subprocess import CompletedProcess
9
+
10
+
11
+ def cmd(command: str, check=True, capture_output=False) -> CompletedProcess:
12
+ print(command)
13
+ if capture_output:
14
+ ret = subprocess.run(command, shell=True, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
15
+ universal_newlines=True)
16
+ else:
17
+ ret = subprocess.run(command, shell=True, check=check)
18
+ print(ret.stdout)
19
+ return ret
20
 
21
  def add_text_index():
22
  text_file = '../tests/test_data/text/test_asr_zh.txt'
 
100
  data = json.loads(line)
101
 
102
  yield data["audio"]["path"], data["sentence"], data["duration"]
103
+
104
+ def read_emilia(folder: Path, count_limit=None):
105
+ """读取 emilia 数据集,返回音频路径、文本、时长,
106
+ json 文件样例:
107
+ {"id": "ZH_B00000_S00110_W000000", "wav": "ZH_B00000/ZH_B00000_S00110/mp3/ZH_B00000_S00110_W000000.mp3", "text": "\u628a\u63e1\u6700\u524d\u6cbf\u7684\u91d1\u878d\u9886\u57df\u548c\u533a\u5757\u94fe\u6700\u65b0\u8d44\u8baf\u3002\u6211\u4eec\u4e00\u8d77\u6765\u4e86\u89e3\u4e00\u4e0b\u4eca\u5929\u5e02\u573a\u4e0a\u6709\u53d1\u751f\u54ea\u4e9b\u91cd\u8981\u4e8b\u4ef6\u3002", "duration": 7.963, "speaker": "ZH_B00000_S00110", "language": "zh", "dnsmos": 3.3808}"""
108
+ count = 0
109
+ for json_file in sorted(folder.glob("*.json")):
110
+ count += 1
111
+ if count_limit and count > count_limit:
112
+ break
113
+ with open(json_file, encoding="utf-8") as f:
114
+ data = json.load(f)
115
+ text = data["text"]
116
+ duration = data["duration"]
117
+ wav_path = folder /f'{json_file.stem}.wav'
118
+ if not wav_path.exists():
119
+ mp3_path = folder / f'{json_file.stem}.mp3'
120
+ command=f"ffmpeg -i {mp3_path} -ac 1 -ar 16000 {wav_path}"
121
+ cmd(command)
122
+ yield wav_path, text, duration
123
+
124
 
125
 
126
  if __name__ == '__main__':
scripts/caculate_cer.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_diff
3
-
4
  # import Levenshtein
5
 
6
  def calculate_distance(reference: str, hypothesis: str):
@@ -34,14 +34,16 @@ if __name__ == '__main__':
34
  count += 1
35
  reference = item["reference"]
36
  hypothesis = item["inference_result"]
37
- distance, diff = calculate_distance(reference, hypothesis)
38
- print(f"{count}. distance: {distance}")
39
- if distance > 0:
40
- print(f"Audio Path: {item['audio_path']}")
41
- print(f"Reference: {reference}")
42
- print(f"Hypothesis: {hypothesis}")
43
- print(f"Diff: {diff}")
44
- distance_sum += distance
45
- reference_sum += len(reference)
 
 
46
  cer = distance_sum / reference_sum if reference_sum > 0 else 0
47
  print(f"Total Distance: {distance_sum}, Total Reference Length: {reference_sum}, CER: {cer:.4f}")
 
1
  import json
2
  from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_diff
3
+ import re
4
  # import Levenshtein
5
 
6
  def calculate_distance(reference: str, hypothesis: str):
 
34
  count += 1
35
  reference = item["reference"]
36
  hypothesis = item["inference_result"]
37
+ if re.search(r"\d", hypothesis):
38
+ # continue
39
+ distance, diff = calculate_distance(reference, hypothesis)
40
+ print(f"{count}. distance: {distance}")
41
+ if distance > 0:
42
+ print(f"Audio Path: {item['audio_path']}")
43
+ print(f"Reference: {reference}")
44
+ print(f"Hypothesis: {hypothesis}")
45
+ print(f"Diff: {diff}")
46
+ distance_sum += distance
47
+ reference_sum += len(reference)
48
  cer = distance_sum / reference_sum if reference_sum > 0 else 0
49
  print(f"Total Distance: {distance_sum}, Total Reference Length: {reference_sum}, CER: {cer:.4f}")
scripts/run_funasr_quant.py CHANGED
@@ -29,11 +29,11 @@ def inference(vad_model, asr_model, punc_model, audio:Path):
29
  print(audio.name)
30
  t1 = time.time()
31
  vad_res = vad_model(str(audio))
32
- t2 = time.time()
33
  # print("vad time:", t2-t1)
34
  asr_res = asr_model(str(audio), hotwords="")
35
  asr_text = asr_res[0]["preds"]
36
- t3 = time.time()
37
  # print("asr time:", t3-t2)
38
  # print("asr text:", asr_text)
39
  result = punc_model(asr_text)
@@ -69,6 +69,41 @@ def run_test_audios():
69
  rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
70
  file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
71
  save_csv(file_name, rows)
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if __name__ == '__main__':
74
- run_recordings()
 
29
  print(audio.name)
30
  t1 = time.time()
31
  vad_res = vad_model(str(audio))
32
+ # t2 = time.time()
33
  # print("vad time:", t2-t1)
34
  asr_res = asr_model(str(audio), hotwords="")
35
  asr_text = asr_res[0]["preds"]
36
+ # t3 = time.time()
37
  # print("asr time:", t3-t2)
38
  # print("asr text:", asr_text)
39
  result = punc_model(asr_text)
 
69
  rows.append([f"{audio.parent.name}/{audio.name}", round(t, 3), text])
70
  file_name = "csv/funasr_quant.csv" if quantize else "funasr_onnx.csv"
71
  save_csv(file_name, rows)
72
+
73
+ def run_test_dataset():
74
+ from scripts.asr_utils import read_dataset
75
+ quantize = True
76
+ vad_model, asr_model, punc_model = load_model(quantize)
77
+ test_data = Path("../tests/test_data/dataset.txt")
78
+ audio_parent = Path("../tests/test_data/")
79
+ rows = [["file_name", "time", "inference_result"]]
80
+ result_list = []
81
+ count = 0
82
+ try:
83
+ for audio_path, sentence, duration in read_dataset(test_data):
84
+ count += 1
85
+ print(f"processing {count}: {audio_path}")
86
+
87
+ t1 = time.time()
88
+ text, t = inference(vad_model, asr_model, punc_model, audio_parent/audio_path)
89
+ t = time.time() - t1
90
+ print("inference time:", t)
91
+ print(text)
92
+ result_list.append({
93
+ "index": count,
94
+ "audio_path": audio_path,
95
+ "reference": sentence,
96
+ "duration": duration,
97
+ "inference_time": round(t, 3),
98
+ "inference_result": text
99
+ })
100
+ except Exception as e:
101
+ print(e)
102
+ except KeyboardInterrupt as e:
103
+ print(e)
104
+ import json
105
+ with open("csv/funasr_dataset_results.json", "w", encoding="utf-8") as f:
106
+ json.dump(result_list, f, ensure_ascii=False, indent=2)
107
+
108
  if __name__ == '__main__':
109
+ run_test_dataset()
scripts/run_whisper.py CHANGED
@@ -4,7 +4,7 @@ import time
4
  import csv
5
 
6
  from silero_vad.utils_vad import languages
7
- from scripts.asr_utils import get_origin_text_dict, get_text_distance, read_dataset
8
 
9
  def save_csv(file_path, rows):
10
  with open(file_path, "w", encoding="utf-8") as f:
@@ -67,6 +67,7 @@ def run_test_audios():
67
  save_csv("csv/whisper.csv", rows)
68
 
69
  def run_test_dataset():
 
70
  model = load_model()
71
  test_data = Path("../tests/test_data/dataset.txt")
72
  audio_parent = Path("../tests/test_data/")
@@ -99,5 +100,38 @@ def run_test_dataset():
99
  import json
100
  with open("csv/whisper_dataset_results.json", "w", encoding="utf-8") as f:
101
  json.dump(result_list, f, ensure_ascii=False, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  if __name__ == '__main__':
103
- run_test_dataset()
 
4
  import csv
5
 
6
  from silero_vad.utils_vad import languages
7
+ from scripts.asr_utils import get_origin_text_dict, get_text_distance
8
 
9
  def save_csv(file_path, rows):
10
  with open(file_path, "w", encoding="utf-8") as f:
 
67
  save_csv("csv/whisper.csv", rows)
68
 
69
  def run_test_dataset():
70
+ from scripts.asr_utils import read_dataset
71
  model = load_model()
72
  test_data = Path("../tests/test_data/dataset.txt")
73
  audio_parent = Path("../tests/test_data/")
 
100
  import json
101
  with open("csv/whisper_dataset_results.json", "w", encoding="utf-8") as f:
102
  json.dump(result_list, f, ensure_ascii=False, indent=2)
103
+
104
+ def run_test_emilia():
105
+ from scripts.asr_utils import read_emilia
106
+ model = load_model()
107
+ parent = Path("../tests/test_data/ZH-B000000")
108
+ result_list = []
109
+ count = 0
110
+ try:
111
+ for audio_path, sentence, duration in read_emilia(parent, count_limit=5000):
112
+ count += 1
113
+ print(f"processing {count}: {audio_path.name}")
114
+
115
+ t1 = time.time()
116
+ output = model.transcribe(str(audio_path), language="zh")# , initial_prompt="以下是普通话句子,这是一段会议内容。"
117
+ t = time.time() - t1
118
+ print("inference time:", t)
119
+ text = " ".join([a.text for a in output])
120
+ print(text)
121
+ result_list.append({
122
+ "index": count,
123
+ "audio_path": audio_path.name,
124
+ "reference": sentence,
125
+ "duration": duration,
126
+ "inference_time": round(t, 3),
127
+ "inference_result": text
128
+ })
129
+ except Exception as e:
130
+ print(e)
131
+ except KeyboardInterrupt as e:
132
+ print(e)
133
+ import json
134
+ with open("csv/whisper_emilia_results.json", "w", encoding="utf-8") as f:
135
+ json.dump(result_list, f, ensure_ascii=False, indent=2)
136
  if __name__ == '__main__':
137
+ run_test_emilia()