update script
Browse files- scripts/caculate_cer.py +1 -1
- scripts/run_whisper_finetuned.py +39 -1
scripts/caculate_cer.py
CHANGED
|
@@ -3,7 +3,7 @@ from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_
|
|
| 3 |
|
| 4 |
# import Levenshtein
|
| 5 |
|
| 6 |
-
def calculate_distance(reference: str, hypothesis: str)
|
| 7 |
"""
|
| 8 |
使用 python-Levenshtein 库计算字符错误率 (CER)。
|
| 9 |
|
|
|
|
| 3 |
|
| 4 |
# import Levenshtein
|
| 5 |
|
| 6 |
+
def calculate_distance(reference: str, hypothesis: str):
|
| 7 |
"""
|
| 8 |
使用 python-Levenshtein 库计算字符错误率 (CER)。
|
| 9 |
|
scripts/run_whisper_finetuned.py
CHANGED
|
@@ -176,6 +176,44 @@ def run_recordings():
|
|
| 176 |
except Exception as e:
|
| 177 |
print(f"{audio.name} -> 失败: {e}")
|
| 178 |
save_csv("csv/fine-tune_whisper.csv", rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if __name__ == "__main__":
|
| 180 |
# main()
|
| 181 |
-
|
|
|
|
| 176 |
except Exception as e:
|
| 177 |
print(f"{audio.name} -> 失败: {e}")
|
| 178 |
save_csv("csv/fine-tune_whisper.csv", rows)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def run_test_dataset():
|
| 182 |
+
from scripts.asr_utils import read_dataset
|
| 183 |
+
model, processor = load_model()
|
| 184 |
+
test_data = Path("../tests/test_data/dataset.txt")
|
| 185 |
+
audio_parent = Path("../tests/test_data/")
|
| 186 |
+
rows = [["file_name", "time", "inference_result"]]
|
| 187 |
+
result_list = []
|
| 188 |
+
count = 0
|
| 189 |
+
try:
|
| 190 |
+
for audio_path, sentence, duration in read_dataset(test_data):
|
| 191 |
+
count += 1
|
| 192 |
+
print(f"processing {count}: {audio_path}")
|
| 193 |
+
|
| 194 |
+
t1 = time.time()
|
| 195 |
+
text = transcribe_file(
|
| 196 |
+
str(audio_parent/audio_path), model, processor
|
| 197 |
+
)
|
| 198 |
+
t = time.time() - t1
|
| 199 |
+
print("inference time:", t)
|
| 200 |
+
print(text)
|
| 201 |
+
result_list.append({
|
| 202 |
+
"index": count,
|
| 203 |
+
"audio_path": audio_path,
|
| 204 |
+
"reference": sentence,
|
| 205 |
+
"duration": duration,
|
| 206 |
+
"inference_time": round(t, 3),
|
| 207 |
+
"inference_result": text
|
| 208 |
+
})
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(e)
|
| 211 |
+
except KeyboardInterrupt as e:
|
| 212 |
+
print(e)
|
| 213 |
+
import json
|
| 214 |
+
with open("csv/whisper_finetuned_dataset_results.json", "w", encoding="utf-8") as f:
|
| 215 |
+
json.dump(result_list, f, ensure_ascii=False, indent=2)
|
| 216 |
+
|
| 217 |
if __name__ == "__main__":
|
| 218 |
# main()
|
| 219 |
+
run_test_dataset()
|