yujuanqin commited on
Commit
152a3e8
·
1 Parent(s): 79f9224

update script

Browse files
scripts/caculate_cer.py CHANGED
@@ -3,7 +3,7 @@ from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_
3
 
4
  # import Levenshtein
5
 
6
- def calculate_distance(reference: str, hypothesis: str) -> float:
7
  """
8
  使用 python-Levenshtein 库计算字符错误率 (CER)。
9
 
 
3
 
4
  # import Levenshtein
5
 
6
+ def calculate_distance(reference: str, hypothesis: str):
7
  """
8
  使用 python-Levenshtein 库计算字符错误率 (CER)。
9
 
scripts/run_whisper_finetuned.py CHANGED
@@ -176,6 +176,44 @@ def run_recordings():
176
  except Exception as e:
177
  print(f"{audio.name} -> 失败: {e}")
178
  save_csv("csv/fine-tune_whisper.csv", rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if __name__ == "__main__":
180
  # main()
181
- run_recordings()
 
176
  except Exception as e:
177
  print(f"{audio.name} -> 失败: {e}")
178
  save_csv("csv/fine-tune_whisper.csv", rows)
179
+
180
+
181
+ def run_test_dataset():
182
+ from scripts.asr_utils import read_dataset
183
+ model, processor = load_model()
184
+ test_data = Path("../tests/test_data/dataset.txt")
185
+ audio_parent = Path("../tests/test_data/")
186
+ rows = [["file_name", "time", "inference_result"]]
187
+ result_list = []
188
+ count = 0
189
+ try:
190
+ for audio_path, sentence, duration in read_dataset(test_data):
191
+ count += 1
192
+ print(f"processing {count}: {audio_path}")
193
+
194
+ t1 = time.time()
195
+ text = transcribe_file(
196
+ str(audio_parent/audio_path), model, processor
197
+ )
198
+ t = time.time() - t1
199
+ print("inference time:", t)
200
+ print(text)
201
+ result_list.append({
202
+ "index": count,
203
+ "audio_path": audio_path,
204
+ "reference": sentence,
205
+ "duration": duration,
206
+ "inference_time": round(t, 3),
207
+ "inference_result": text
208
+ })
209
+ except Exception as e:
210
+ print(e)
211
+ except KeyboardInterrupt as e:
212
+ print(e)
213
+ import json
214
+ with open("csv/whisper_finetuned_dataset_results.json", "w", encoding="utf-8") as f:
215
+ json.dump(result_list, f, ensure_ascii=False, indent=2)
216
+
217
  if __name__ == "__main__":
218
  # main()
219
+ run_test_dataset()