Spaces:
Sleeping
Sleeping
MFA集成
Browse files- .gitignore +6 -6
- requirements.in +4 -0
- requirements.txt +27 -2
- src/gui.py +195 -50
- src/mfa_model_downloader.py +236 -0
- src/mfa_runner.py +208 -0
- src/silero_vad_downloader.py +194 -0
- tests/test_mfa_model_downloader.py +182 -0
- tests/test_mfa_runner.py +243 -0
- tests/test_silero_vad_downloader.py +65 -0
.gitignore
CHANGED
|
@@ -7,11 +7,6 @@ __pycache__/
|
|
| 7 |
*$py.class
|
| 8 |
*.so
|
| 9 |
|
| 10 |
-
# 分发/打包
|
| 11 |
-
dist/
|
| 12 |
-
build/
|
| 13 |
-
*.egg-info/
|
| 14 |
-
|
| 15 |
# pip-tools
|
| 16 |
*.egg
|
| 17 |
|
|
@@ -29,5 +24,10 @@ build/
|
|
| 29 |
temp/
|
| 30 |
*.tmp
|
| 31 |
|
| 32 |
-
# 数据
|
|
|
|
| 33 |
bank/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
*$py.class
|
| 8 |
*.so
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# pip-tools
|
| 11 |
*.egg
|
| 12 |
|
|
|
|
| 24 |
temp/
|
| 25 |
*.tmp
|
| 26 |
|
| 27 |
+
# 数据(根据需要调整)
|
| 28 |
+
config.json
|
| 29 |
bank/
|
| 30 |
+
|
| 31 |
+
# AI 模型相关
|
| 32 |
+
tools/mfa_engine
|
| 33 |
+
models
|
requirements.in
CHANGED
|
@@ -10,3 +10,7 @@ customtkinter
|
|
| 10 |
transformers>=4.25.0
|
| 11 |
torch
|
| 12 |
accelerate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
transformers>=4.25.0
|
| 11 |
torch
|
| 12 |
accelerate
|
| 13 |
+
|
| 14 |
+
# Silero VAD 语音活动检测
|
| 15 |
+
silero-vad>=5.1
|
| 16 |
+
onnxruntime
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
# This file is autogenerated by pip-compile with Python 3.13
|
| 3 |
# by the following command:
|
| 4 |
#
|
| 5 |
-
# pip-compile requirements.in
|
| 6 |
#
|
| 7 |
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
| 8 |
|
|
@@ -28,6 +28,8 @@ colorama==0.4.6
|
|
| 28 |
# via
|
| 29 |
# click
|
| 30 |
# tqdm
|
|
|
|
|
|
|
| 31 |
customtkinter==5.2.2
|
| 32 |
# via -r requirements.in
|
| 33 |
darkdetect==0.8.0
|
|
@@ -37,6 +39,8 @@ filelock==3.20.3
|
|
| 37 |
# huggingface-hub
|
| 38 |
# torch
|
| 39 |
# transformers
|
|
|
|
|
|
|
| 40 |
fsspec==2026.1.0
|
| 41 |
# via
|
| 42 |
# huggingface-hub
|
|
@@ -54,6 +58,8 @@ huggingface-hub==1.3.5
|
|
| 54 |
# accelerate
|
| 55 |
# tokenizers
|
| 56 |
# transformers
|
|
|
|
|
|
|
| 57 |
idna==3.11
|
| 58 |
# via
|
| 59 |
# anyio
|
|
@@ -71,18 +77,29 @@ numpy==2.4.1
|
|
| 71 |
# accelerate
|
| 72 |
# audiofile
|
| 73 |
# audmath
|
|
|
|
| 74 |
# soundfile
|
| 75 |
# transformers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
packaging==26.0
|
| 77 |
# via
|
| 78 |
# accelerate
|
| 79 |
# customtkinter
|
| 80 |
# huggingface-hub
|
|
|
|
|
|
|
| 81 |
# transformers
|
|
|
|
|
|
|
| 82 |
psutil==7.2.2
|
| 83 |
# via accelerate
|
| 84 |
pycparser==3.0
|
| 85 |
# via cffi
|
|
|
|
|
|
|
| 86 |
pyyaml==6.0.3
|
| 87 |
# via
|
| 88 |
# accelerate
|
|
@@ -96,10 +113,14 @@ safetensors==0.7.0
|
|
| 96 |
# transformers
|
| 97 |
shellingham==1.5.4
|
| 98 |
# via huggingface-hub
|
|
|
|
|
|
|
| 99 |
soundfile==0.13.1
|
| 100 |
# via audiofile
|
| 101 |
sympy==1.14.0
|
| 102 |
-
# via
|
|
|
|
|
|
|
| 103 |
textgrid==1.6.1
|
| 104 |
# via -r requirements.in
|
| 105 |
tokenizers==0.22.2
|
|
@@ -108,6 +129,10 @@ torch==2.10.0
|
|
| 108 |
# via
|
| 109 |
# -r requirements.in
|
| 110 |
# accelerate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
tqdm==4.67.1
|
| 112 |
# via
|
| 113 |
# -r requirements.in
|
|
|
|
| 2 |
# This file is autogenerated by pip-compile with Python 3.13
|
| 3 |
# by the following command:
|
| 4 |
#
|
| 5 |
+
# pip-compile --output-file=requirements.txt requirements.in
|
| 6 |
#
|
| 7 |
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
| 8 |
|
|
|
|
| 28 |
# via
|
| 29 |
# click
|
| 30 |
# tqdm
|
| 31 |
+
coloredlogs==15.0.1
|
| 32 |
+
# via onnxruntime
|
| 33 |
customtkinter==5.2.2
|
| 34 |
# via -r requirements.in
|
| 35 |
darkdetect==0.8.0
|
|
|
|
| 39 |
# huggingface-hub
|
| 40 |
# torch
|
| 41 |
# transformers
|
| 42 |
+
flatbuffers==25.12.19
|
| 43 |
+
# via onnxruntime
|
| 44 |
fsspec==2026.1.0
|
| 45 |
# via
|
| 46 |
# huggingface-hub
|
|
|
|
| 58 |
# accelerate
|
| 59 |
# tokenizers
|
| 60 |
# transformers
|
| 61 |
+
humanfriendly==10.0
|
| 62 |
+
# via coloredlogs
|
| 63 |
idna==3.11
|
| 64 |
# via
|
| 65 |
# anyio
|
|
|
|
| 77 |
# accelerate
|
| 78 |
# audiofile
|
| 79 |
# audmath
|
| 80 |
+
# onnxruntime
|
| 81 |
# soundfile
|
| 82 |
# transformers
|
| 83 |
+
onnxruntime==1.23.2
|
| 84 |
+
# via
|
| 85 |
+
# -r requirements.in
|
| 86 |
+
# silero-vad
|
| 87 |
packaging==26.0
|
| 88 |
# via
|
| 89 |
# accelerate
|
| 90 |
# customtkinter
|
| 91 |
# huggingface-hub
|
| 92 |
+
# onnxruntime
|
| 93 |
+
# silero-vad
|
| 94 |
# transformers
|
| 95 |
+
protobuf==6.33.5
|
| 96 |
+
# via onnxruntime
|
| 97 |
psutil==7.2.2
|
| 98 |
# via accelerate
|
| 99 |
pycparser==3.0
|
| 100 |
# via cffi
|
| 101 |
+
pyreadline3==3.5.4
|
| 102 |
+
# via humanfriendly
|
| 103 |
pyyaml==6.0.3
|
| 104 |
# via
|
| 105 |
# accelerate
|
|
|
|
| 113 |
# transformers
|
| 114 |
shellingham==1.5.4
|
| 115 |
# via huggingface-hub
|
| 116 |
+
silero-vad==6.2.0
|
| 117 |
+
# via -r requirements.in
|
| 118 |
soundfile==0.13.1
|
| 119 |
# via audiofile
|
| 120 |
sympy==1.14.0
|
| 121 |
+
# via
|
| 122 |
+
# onnxruntime
|
| 123 |
+
# torch
|
| 124 |
textgrid==1.6.1
|
| 125 |
# via -r requirements.in
|
| 126 |
tokenizers==0.22.2
|
|
|
|
| 129 |
# via
|
| 130 |
# -r requirements.in
|
| 131 |
# accelerate
|
| 132 |
+
# silero-vad
|
| 133 |
+
# torchaudio
|
| 134 |
+
torchaudio==2.10.0
|
| 135 |
+
# via silero-vad
|
| 136 |
tqdm==4.67.1
|
| 137 |
# via
|
| 138 |
# -r requirements.in
|
src/gui.py
CHANGED
|
@@ -353,17 +353,43 @@ class ModelDownloadFrame(ctk.CTkFrame):
|
|
| 353 |
ctk.CTkEntry(self, textvariable=self.mfa_dir_var, width=320).grid(row=7, column=1, padx=5, pady=5, sticky="w")
|
| 354 |
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa_dir).grid(row=7, column=2, padx=5, pady=5)
|
| 355 |
|
| 356 |
-
# MFA
|
| 357 |
-
ctk.CTkLabel(self, text="
|
| 358 |
-
self.
|
| 359 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
# MFA 文件列表
|
| 362 |
-
ctk.CTkLabel(self, text="已有文件:").grid(row=
|
| 363 |
self.mfa_files_text = ctk.CTkTextbox(self, height=70, width=400)
|
| 364 |
-
self.mfa_files_text.grid(row=
|
| 365 |
self.mfa_files_text.insert("end", "选择目录后显示文件列表")
|
| 366 |
self.mfa_files_text.configure(state="disabled")
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
def _get_model_desc(self):
|
| 369 |
"""获取当前选中模型的描述"""
|
|
@@ -397,6 +423,51 @@ class ModelDownloadFrame(ctk.CTkFrame):
|
|
| 397 |
self._save_config()
|
| 398 |
self._scan_mfa_dir()
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
def _scan_mfa_dir(self):
|
| 401 |
"""扫描 MFA 模型目录"""
|
| 402 |
mfa_dir = self.mfa_dir_var.get()
|
|
@@ -498,82 +569,156 @@ class MakeDatasetFrame(ctk.CTkFrame):
|
|
| 498 |
def __init__(self, master, log_callback):
|
| 499 |
super().__init__(master)
|
| 500 |
self.log_callback = log_callback
|
|
|
|
| 501 |
self._setup_ui()
|
|
|
|
| 502 |
|
| 503 |
def _setup_ui(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
# 数据集原始目录
|
| 505 |
-
ctk.CTkLabel(self, text="① 切片及LAB目录:").grid(row=
|
| 506 |
self.raw_dir_var = ctk.StringVar()
|
| 507 |
-
ctk.CTkEntry(self, textvariable=self.raw_dir_var, width=400).grid(row=
|
| 508 |
-
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_raw_dir).grid(row=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
|
| 510 |
# 字典路径
|
| 511 |
-
ctk.CTkLabel(self, text="
|
| 512 |
-
self.dict_path_var = ctk.StringVar()
|
| 513 |
-
ctk.CTkEntry(self, textvariable=self.dict_path_var, width=400).grid(row=
|
| 514 |
-
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_dict).grid(row=
|
| 515 |
|
| 516 |
# MFA模型路径
|
| 517 |
-
ctk.CTkLabel(self, text="
|
| 518 |
-
self.mfa_model_var = ctk.StringVar()
|
| 519 |
-
ctk.CTkEntry(self, textvariable=self.mfa_model_var, width=400).grid(row=
|
| 520 |
-
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa).grid(row=
|
| 521 |
-
|
| 522 |
-
#
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
# 执行按钮
|
| 534 |
-
ctk.CTkButton(self, text="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
def _browse_raw_dir(self):
|
| 537 |
path = filedialog.askdirectory(title="选择切片及LAB目录")
|
| 538 |
if path:
|
| 539 |
self.raw_dir_var.set(path)
|
| 540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
def _browse_dict(self):
|
| 542 |
-
path = filedialog.askopenfilename(
|
|
|
|
|
|
|
|
|
|
| 543 |
if path:
|
| 544 |
self.dict_path_var.set(path)
|
| 545 |
|
| 546 |
def _browse_mfa(self):
|
| 547 |
-
path = filedialog.askopenfilename(
|
|
|
|
|
|
|
|
|
|
| 548 |
if path:
|
| 549 |
self.mfa_model_var.set(path)
|
| 550 |
|
| 551 |
-
def _browse_temp(self):
|
| 552 |
-
path = filedialog.askdirectory(title="选择临时目录")
|
| 553 |
-
if path:
|
| 554 |
-
self.temp_dir_var.set(path)
|
| 555 |
-
|
| 556 |
def _run(self):
|
|
|
|
|
|
|
|
|
|
| 557 |
raw_dir = self.raw_dir_var.get()
|
|
|
|
| 558 |
dict_path = self.dict_path_var.get()
|
| 559 |
mfa_model = self.mfa_model_var.get()
|
| 560 |
-
temp_dir = self.temp_dir_var.get()
|
| 561 |
-
dataset_name = self.dataset_name_var.get()
|
| 562 |
|
| 563 |
-
if not
|
| 564 |
-
messagebox.showerror("错误", "请填写
|
| 565 |
return
|
| 566 |
|
| 567 |
-
self.
|
| 568 |
-
self.
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
|
| 579 |
class App(ctk.CTk):
|
|
|
|
| 353 |
ctk.CTkEntry(self, textvariable=self.mfa_dir_var, width=320).grid(row=7, column=1, padx=5, pady=5, sticky="w")
|
| 354 |
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa_dir).grid(row=7, column=2, padx=5, pady=5)
|
| 355 |
|
| 356 |
+
# MFA 语言选择
|
| 357 |
+
ctk.CTkLabel(self, text="选择语言:").grid(row=8, column=0, padx=10, pady=5, sticky="w")
|
| 358 |
+
self.mfa_lang_var = ctk.StringVar(value="mandarin")
|
| 359 |
+
self.mfa_lang_dropdown = ctk.CTkComboBox(
|
| 360 |
+
self,
|
| 361 |
+
values=["mandarin", "japanese"],
|
| 362 |
+
variable=self.mfa_lang_var,
|
| 363 |
+
width=200,
|
| 364 |
+
command=self._on_mfa_lang_change
|
| 365 |
+
)
|
| 366 |
+
self.mfa_lang_dropdown.grid(row=8, column=1, padx=5, pady=5, sticky="w")
|
| 367 |
+
|
| 368 |
+
self.mfa_lang_desc = ctk.CTkLabel(self, text="中文 (普通话)", text_color="gray")
|
| 369 |
+
self.mfa_lang_desc.grid(row=8, column=2, padx=5, pady=5, sticky="w")
|
| 370 |
+
|
| 371 |
+
# MFA 下载按钮和状态
|
| 372 |
+
ctk.CTkLabel(self, text="状态:").grid(row=9, column=0, padx=10, pady=5, sticky="w")
|
| 373 |
+
self.mfa_status = ctk.CTkLabel(self, text="⏳ 未下载", text_color="gray")
|
| 374 |
+
self.mfa_status.grid(row=9, column=1, padx=5, pady=5, sticky="w")
|
| 375 |
+
|
| 376 |
+
self.mfa_download_btn = ctk.CTkButton(
|
| 377 |
+
self,
|
| 378 |
+
text="下载模型",
|
| 379 |
+
command=self._download_mfa_models,
|
| 380 |
+
width=140
|
| 381 |
+
)
|
| 382 |
+
self.mfa_download_btn.grid(row=9, column=2, padx=5, pady=5, sticky="w")
|
| 383 |
|
| 384 |
# MFA 文件列表
|
| 385 |
+
ctk.CTkLabel(self, text="已有文件:").grid(row=10, column=0, padx=10, pady=(10, 5), sticky="nw")
|
| 386 |
self.mfa_files_text = ctk.CTkTextbox(self, height=70, width=400)
|
| 387 |
+
self.mfa_files_text.grid(row=10, column=1, columnspan=2, padx=5, pady=(10, 5), sticky="w")
|
| 388 |
self.mfa_files_text.insert("end", "选择目录后显示文件列表")
|
| 389 |
self.mfa_files_text.configure(state="disabled")
|
| 390 |
+
|
| 391 |
+
# 初始扫描
|
| 392 |
+
self._scan_mfa_dir()
|
| 393 |
|
| 394 |
def _get_model_desc(self):
|
| 395 |
"""获取当前选中模型的描述"""
|
|
|
|
| 423 |
self._save_config()
|
| 424 |
self._scan_mfa_dir()
|
| 425 |
|
| 426 |
+
def _on_mfa_lang_change(self, choice):
|
| 427 |
+
"""MFA 语言选择变更"""
|
| 428 |
+
from src.mfa_model_downloader import get_available_languages
|
| 429 |
+
langs = get_available_languages()
|
| 430 |
+
self.mfa_lang_desc.configure(text=langs.get(choice, ""))
|
| 431 |
+
|
| 432 |
+
def _download_mfa_models(self):
|
| 433 |
+
"""下载 MFA 模型"""
|
| 434 |
+
if self._download_thread and self._download_thread.is_alive():
|
| 435 |
+
return
|
| 436 |
+
|
| 437 |
+
self.mfa_download_btn.configure(state="disabled")
|
| 438 |
+
self.mfa_status.configure(text="⏳ 下载中...", text_color="gray")
|
| 439 |
+
self._download_thread = threading.Thread(target=self._do_download_mfa, daemon=True)
|
| 440 |
+
self._download_thread.start()
|
| 441 |
+
|
| 442 |
+
def _do_download_mfa(self):
|
| 443 |
+
"""执行 MFA 模型下载(后台线程)"""
|
| 444 |
+
from src.mfa_model_downloader import download_language_models
|
| 445 |
+
|
| 446 |
+
language = self.mfa_lang_var.get()
|
| 447 |
+
output_dir = self.mfa_dir_var.get()
|
| 448 |
+
|
| 449 |
+
# 确保目录存在
|
| 450 |
+
if not os.path.exists(output_dir):
|
| 451 |
+
os.makedirs(output_dir)
|
| 452 |
+
|
| 453 |
+
self.log_callback(f"开始下载 MFA 模型: {language}")
|
| 454 |
+
|
| 455 |
+
success, acoustic_path, dict_path = download_language_models(
|
| 456 |
+
language=language,
|
| 457 |
+
output_dir=output_dir,
|
| 458 |
+
progress_callback=self.log_callback
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if success:
|
| 462 |
+
self.after(0, lambda: self.mfa_status.configure(text="✅ 已下载", text_color="green"))
|
| 463 |
+
self.log_callback(f"声学模型: {acoustic_path}")
|
| 464 |
+
self.log_callback(f"字典文件: {dict_path}")
|
| 465 |
+
else:
|
| 466 |
+
self.after(0, lambda: self.mfa_status.configure(text="❌ 下载失败", text_color="red"))
|
| 467 |
+
|
| 468 |
+
self.after(0, lambda: self.mfa_download_btn.configure(state="normal"))
|
| 469 |
+
self.after(0, self._scan_mfa_dir)
|
| 470 |
+
|
| 471 |
def _scan_mfa_dir(self):
|
| 472 |
"""扫描 MFA 模型目录"""
|
| 473 |
mfa_dir = self.mfa_dir_var.get()
|
|
|
|
| 569 |
def __init__(self, master, log_callback):
|
| 570 |
super().__init__(master)
|
| 571 |
self.log_callback = log_callback
|
| 572 |
+
self._is_running = False
|
| 573 |
self._setup_ui()
|
| 574 |
+
self._check_mfa_status()
|
| 575 |
|
| 576 |
def _setup_ui(self):
|
| 577 |
+
# MFA 状态提示
|
| 578 |
+
self.mfa_status_label = ctk.CTkLabel(
|
| 579 |
+
self,
|
| 580 |
+
text="⏳ 检查 MFA 环境...",
|
| 581 |
+
font=ctk.CTkFont(size=12)
|
| 582 |
+
)
|
| 583 |
+
self.mfa_status_label.grid(row=0, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w")
|
| 584 |
+
|
| 585 |
# 数据集原始目录
|
| 586 |
+
ctk.CTkLabel(self, text="① 切片及LAB目录:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
|
| 587 |
self.raw_dir_var = ctk.StringVar()
|
| 588 |
+
ctk.CTkEntry(self, textvariable=self.raw_dir_var, width=400).grid(row=1, column=1, padx=5, pady=5)
|
| 589 |
+
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_raw_dir).grid(row=1, column=2, padx=5, pady=5)
|
| 590 |
+
|
| 591 |
+
# 输出目录
|
| 592 |
+
ctk.CTkLabel(self, text="② TextGrid输出目录:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
|
| 593 |
+
self.output_dir_var = ctk.StringVar()
|
| 594 |
+
ctk.CTkEntry(self, textvariable=self.output_dir_var, width=400).grid(row=2, column=1, padx=5, pady=5)
|
| 595 |
+
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_output_dir).grid(row=2, column=2, padx=5, pady=5)
|
| 596 |
|
| 597 |
# 字典路径
|
| 598 |
+
ctk.CTkLabel(self, text="③ 字典文件:").grid(row=3, column=0, padx=10, pady=5, sticky="w")
|
| 599 |
+
self.dict_path_var = ctk.StringVar(value="models/mfa/mandarin_china_mfa.dict")
|
| 600 |
+
ctk.CTkEntry(self, textvariable=self.dict_path_var, width=400).grid(row=3, column=1, padx=5, pady=5)
|
| 601 |
+
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_dict).grid(row=3, column=2, padx=5, pady=5)
|
| 602 |
|
| 603 |
# MFA模型路径
|
| 604 |
+
ctk.CTkLabel(self, text="④ MFA模型文件:").grid(row=4, column=0, padx=10, pady=5, sticky="w")
|
| 605 |
+
self.mfa_model_var = ctk.StringVar(value="models/mfa/mandarin_mfa.zip")
|
| 606 |
+
ctk.CTkEntry(self, textvariable=self.mfa_model_var, width=400).grid(row=4, column=1, padx=5, pady=5)
|
| 607 |
+
ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa).grid(row=4, column=2, padx=5, pady=5)
|
| 608 |
+
|
| 609 |
+
# 选项
|
| 610 |
+
options_frame = ctk.CTkFrame(self)
|
| 611 |
+
options_frame.grid(row=5, column=0, columnspan=3, padx=10, pady=10, sticky="w")
|
| 612 |
+
|
| 613 |
+
self.single_speaker_var = ctk.BooleanVar(value=True)
|
| 614 |
+
ctk.CTkCheckBox(
|
| 615 |
+
options_frame,
|
| 616 |
+
text="单说话人模式",
|
| 617 |
+
variable=self.single_speaker_var
|
| 618 |
+
).pack(side="left", padx=10)
|
| 619 |
+
|
| 620 |
+
self.clean_var = ctk.BooleanVar(value=True)
|
| 621 |
+
ctk.CTkCheckBox(
|
| 622 |
+
options_frame,
|
| 623 |
+
text="清理旧缓存",
|
| 624 |
+
variable=self.clean_var
|
| 625 |
+
).pack(side="left", padx=10)
|
| 626 |
|
| 627 |
# 执行按钮
|
| 628 |
+
self.run_btn = ctk.CTkButton(self, text="⑤ 开始对齐", command=self._run)
|
| 629 |
+
self.run_btn.grid(row=6, column=1, pady=20)
|
| 630 |
+
|
| 631 |
+
def _check_mfa_status(self):
|
| 632 |
+
"""检查 MFA 环境状态"""
|
| 633 |
+
from src.mfa_runner import check_mfa_available
|
| 634 |
+
|
| 635 |
+
if check_mfa_available():
|
| 636 |
+
self.mfa_status_label.configure(
|
| 637 |
+
text="✅ MFA 外挂环境已就绪 (tools/mfa_engine)",
|
| 638 |
+
text_color="green"
|
| 639 |
+
)
|
| 640 |
+
else:
|
| 641 |
+
self.mfa_status_label.configure(
|
| 642 |
+
text="❌ MFA 外挂环境不可用,请检查 tools/mfa_engine 目录",
|
| 643 |
+
text_color="red"
|
| 644 |
+
)
|
| 645 |
|
| 646 |
def _browse_raw_dir(self):
|
| 647 |
path = filedialog.askdirectory(title="选择切片及LAB目录")
|
| 648 |
if path:
|
| 649 |
self.raw_dir_var.set(path)
|
| 650 |
|
| 651 |
+
def _browse_output_dir(self):
|
| 652 |
+
path = filedialog.askdirectory(title="选择TextGrid输出目录")
|
| 653 |
+
if path:
|
| 654 |
+
self.output_dir_var.set(path)
|
| 655 |
+
|
| 656 |
def _browse_dict(self):
|
| 657 |
+
path = filedialog.askopenfilename(
|
| 658 |
+
title="选择字典文件",
|
| 659 |
+
filetypes=[("字典文件", "*.dict *.txt"), ("所有文件", "*.*")]
|
| 660 |
+
)
|
| 661 |
if path:
|
| 662 |
self.dict_path_var.set(path)
|
| 663 |
|
| 664 |
def _browse_mfa(self):
|
| 665 |
+
path = filedialog.askopenfilename(
|
| 666 |
+
title="选择MFA模型",
|
| 667 |
+
filetypes=[("ZIP文件", "*.zip"), ("所有文件", "*.*")]
|
| 668 |
+
)
|
| 669 |
if path:
|
| 670 |
self.mfa_model_var.set(path)
|
| 671 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
def _run(self):
|
| 673 |
+
if self._is_running:
|
| 674 |
+
return
|
| 675 |
+
|
| 676 |
raw_dir = self.raw_dir_var.get()
|
| 677 |
+
output_dir = self.output_dir_var.get()
|
| 678 |
dict_path = self.dict_path_var.get()
|
| 679 |
mfa_model = self.mfa_model_var.get()
|
|
|
|
|
|
|
| 680 |
|
| 681 |
+
if not raw_dir or not output_dir:
|
| 682 |
+
messagebox.showerror("错误", "请填写输入目录和输出目录")
|
| 683 |
return
|
| 684 |
|
| 685 |
+
self._is_running = True
|
| 686 |
+
self.run_btn.configure(state="disabled", text="对齐中...")
|
| 687 |
+
|
| 688 |
+
threading.Thread(
|
| 689 |
+
target=self._process,
|
| 690 |
+
args=(raw_dir, output_dir, dict_path, mfa_model),
|
| 691 |
+
daemon=True
|
| 692 |
+
).start()
|
| 693 |
+
|
| 694 |
+
def _process(self, raw_dir, output_dir, dict_path, mfa_model):
|
| 695 |
+
"""执行 MFA 对齐(后台线程)"""
|
| 696 |
+
from src.mfa_runner import run_mfa_alignment
|
| 697 |
+
|
| 698 |
+
self.log_callback("=" * 50)
|
| 699 |
+
self.log_callback("开始 MFA 对齐任务")
|
| 700 |
+
|
| 701 |
+
success, message = run_mfa_alignment(
|
| 702 |
+
corpus_dir=raw_dir,
|
| 703 |
+
output_dir=output_dir,
|
| 704 |
+
dict_path=dict_path if dict_path else None,
|
| 705 |
+
model_path=mfa_model if mfa_model else None,
|
| 706 |
+
single_speaker=self.single_speaker_var.get(),
|
| 707 |
+
clean=self.clean_var.get(),
|
| 708 |
+
progress_callback=self.log_callback
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
if success:
|
| 712 |
+
self.log_callback("✅ MFA 对齐任务完成!")
|
| 713 |
+
self.log_callback(f"TextGrid 文件已输出到: {output_dir}")
|
| 714 |
+
else:
|
| 715 |
+
self.log_callback(f"❌ MFA 对齐失败: {message}")
|
| 716 |
+
|
| 717 |
+
self.log_callback("=" * 50)
|
| 718 |
+
|
| 719 |
+
# 恢复按钮状态
|
| 720 |
+
self.after(0, lambda: self.run_btn.configure(state="normal", text="⑤ 开始对齐"))
|
| 721 |
+
self._is_running = False
|
| 722 |
|
| 723 |
|
| 724 |
class App(ctk.CTk):
|
src/mfa_model_downloader.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
MFA 模型下载模块
|
| 4 |
+
支持下载中文和日文的声学模型及字典
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
import urllib.request
|
| 10 |
+
import urllib.error
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Callable
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# 模型下载基础 URL
|
| 17 |
+
GITHUB_RELEASE_BASE = "https://github.com/MontrealCorpusTools/mfa-models/releases/download"
|
| 18 |
+
GITHUB_RAW_BASE = "https://raw.githubusercontent.com/MontrealCorpusTools/mfa-models/main"
|
| 19 |
+
|
| 20 |
+
# 支持的语言配置
|
| 21 |
+
# 格式: {语言代码: {名称, 声学模型信息, 字典信息}}
|
| 22 |
+
LANGUAGE_MODELS = {
|
| 23 |
+
"mandarin": {
|
| 24 |
+
"name": "中文 (普通话)",
|
| 25 |
+
"acoustic": {
|
| 26 |
+
"tag": "acoustic-mandarin_mfa-v3.0.0",
|
| 27 |
+
"filename": "mandarin_mfa.zip",
|
| 28 |
+
"description": "Mandarin MFA acoustic model v3.0.0"
|
| 29 |
+
},
|
| 30 |
+
"dictionary": {
|
| 31 |
+
# 字典从 releases 下载,tag 格式: dictionary-{name}-v{version}
|
| 32 |
+
"tag": "dictionary-mandarin_china_mfa-v3.0.0",
|
| 33 |
+
"filename": "mandarin_china_mfa.dict",
|
| 34 |
+
"description": "Mandarin (China) MFA dictionary v3.0.0"
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"japanese": {
|
| 38 |
+
"name": "日文",
|
| 39 |
+
"acoustic": {
|
| 40 |
+
"tag": "acoustic-japanese_mfa-v3.0.0",
|
| 41 |
+
"filename": "japanese_mfa.zip",
|
| 42 |
+
"description": "Japanese MFA acoustic model v3.0.0"
|
| 43 |
+
},
|
| 44 |
+
"dictionary": {
|
| 45 |
+
"tag": "dictionary-japanese_mfa-v3.0.0",
|
| 46 |
+
"filename": "japanese_mfa.dict",
|
| 47 |
+
"description": "Japanese MFA dictionary v3.0.0"
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_available_languages() -> dict:
|
| 54 |
+
"""获取可用的语言列表"""
|
| 55 |
+
return {k: v["name"] for k, v in LANGUAGE_MODELS.items()}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _download_file(
|
| 59 |
+
url: str,
|
| 60 |
+
dest_path: str,
|
| 61 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 62 |
+
) -> bool:
|
| 63 |
+
"""
|
| 64 |
+
下载文件
|
| 65 |
+
|
| 66 |
+
参数:
|
| 67 |
+
url: 下载地址
|
| 68 |
+
dest_path: 保存路径
|
| 69 |
+
progress_callback: 进度回调
|
| 70 |
+
|
| 71 |
+
返回:
|
| 72 |
+
是否成功
|
| 73 |
+
"""
|
| 74 |
+
def log(msg: str):
|
| 75 |
+
logger.info(msg)
|
| 76 |
+
if progress_callback:
|
| 77 |
+
progress_callback(msg)
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
log(f"正在下载: {url}")
|
| 81 |
+
|
| 82 |
+
# 创建目录
|
| 83 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
| 84 |
+
|
| 85 |
+
# 下载文件
|
| 86 |
+
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
| 87 |
+
|
| 88 |
+
with urllib.request.urlopen(req, timeout=60) as response:
|
| 89 |
+
total_size = response.headers.get("Content-Length")
|
| 90 |
+
if total_size:
|
| 91 |
+
total_size = int(total_size)
|
| 92 |
+
log(f"文件大小: {total_size / 1024 / 1024:.1f} MB")
|
| 93 |
+
|
| 94 |
+
# 分块下载
|
| 95 |
+
block_size = 8192
|
| 96 |
+
downloaded = 0
|
| 97 |
+
|
| 98 |
+
with open(dest_path, "wb") as f:
|
| 99 |
+
while True:
|
| 100 |
+
chunk = response.read(block_size)
|
| 101 |
+
if not chunk:
|
| 102 |
+
break
|
| 103 |
+
f.write(chunk)
|
| 104 |
+
downloaded += len(chunk)
|
| 105 |
+
|
| 106 |
+
if total_size and downloaded % (block_size * 100) == 0:
|
| 107 |
+
percent = downloaded / total_size * 100
|
| 108 |
+
log(f"下载进度: {percent:.1f}%")
|
| 109 |
+
|
| 110 |
+
log(f"下载完成: {dest_path}")
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
except urllib.error.HTTPError as e:
|
| 114 |
+
log(f"HTTP 错误: {e.code} - {e.reason}")
|
| 115 |
+
return False
|
| 116 |
+
except urllib.error.URLError as e:
|
| 117 |
+
log(f"网络错误: {e.reason}")
|
| 118 |
+
return False
|
| 119 |
+
except Exception as e:
|
| 120 |
+
log(f"下载失败: {e}")
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def download_acoustic_model(
|
| 125 |
+
language: str,
|
| 126 |
+
output_dir: str,
|
| 127 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 128 |
+
) -> tuple[bool, str]:
|
| 129 |
+
"""
|
| 130 |
+
下载声学模型
|
| 131 |
+
|
| 132 |
+
参数:
|
| 133 |
+
language: 语言代码 (mandarin/japanese)
|
| 134 |
+
output_dir: 输出目录
|
| 135 |
+
progress_callback: 进度回调
|
| 136 |
+
|
| 137 |
+
返回:
|
| 138 |
+
(成功标志, 文件路径或错误信息)
|
| 139 |
+
"""
|
| 140 |
+
if language not in LANGUAGE_MODELS:
|
| 141 |
+
return False, f"不支持的语言: {language}"
|
| 142 |
+
|
| 143 |
+
config = LANGUAGE_MODELS[language]["acoustic"]
|
| 144 |
+
url = f"{GITHUB_RELEASE_BASE}/{config['tag']}/{config['filename']}"
|
| 145 |
+
dest_path = os.path.join(output_dir, config["filename"])
|
| 146 |
+
|
| 147 |
+
if os.path.exists(dest_path):
|
| 148 |
+
if progress_callback:
|
| 149 |
+
progress_callback(f"声学模型已存在: {dest_path}")
|
| 150 |
+
return True, dest_path
|
| 151 |
+
|
| 152 |
+
if _download_file(url, dest_path, progress_callback):
|
| 153 |
+
return True, dest_path
|
| 154 |
+
else:
|
| 155 |
+
return False, "声学模型下载失败"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def download_dictionary(
|
| 159 |
+
language: str,
|
| 160 |
+
output_dir: str,
|
| 161 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 162 |
+
) -> tuple[bool, str]:
|
| 163 |
+
"""
|
| 164 |
+
下载字典文件
|
| 165 |
+
|
| 166 |
+
参数:
|
| 167 |
+
language: 语言代码 (mandarin/japanese)
|
| 168 |
+
output_dir: 输出目录
|
| 169 |
+
progress_callback: 进度回调
|
| 170 |
+
|
| 171 |
+
返回:
|
| 172 |
+
(成功标志, 文件路径或错误信息)
|
| 173 |
+
"""
|
| 174 |
+
if language not in LANGUAGE_MODELS:
|
| 175 |
+
return False, f"不支持的语言: {language}"
|
| 176 |
+
|
| 177 |
+
config = LANGUAGE_MODELS[language]["dictionary"]
|
| 178 |
+
# 字典文件从 releases 下载
|
| 179 |
+
url = f"{GITHUB_RELEASE_BASE}/{config['tag']}/{config['filename']}"
|
| 180 |
+
dest_path = os.path.join(output_dir, config["filename"])
|
| 181 |
+
|
| 182 |
+
if os.path.exists(dest_path):
|
| 183 |
+
if progress_callback:
|
| 184 |
+
progress_callback(f"字典文件已存在: {dest_path}")
|
| 185 |
+
return True, dest_path
|
| 186 |
+
|
| 187 |
+
if _download_file(url, dest_path, progress_callback):
|
| 188 |
+
return True, dest_path
|
| 189 |
+
else:
|
| 190 |
+
return False, "字典文件下载失败"
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def download_language_models(
|
| 194 |
+
language: str,
|
| 195 |
+
output_dir: str,
|
| 196 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 197 |
+
) -> tuple[bool, str, str]:
|
| 198 |
+
"""
|
| 199 |
+
下载指定语言的声学模型和字典
|
| 200 |
+
|
| 201 |
+
参数:
|
| 202 |
+
language: 语言代码 (mandarin/japanese)
|
| 203 |
+
output_dir: 输出目录
|
| 204 |
+
progress_callback: 进度回调
|
| 205 |
+
|
| 206 |
+
返回:
|
| 207 |
+
(成功标志, 声学模型路径, 字典路径)
|
| 208 |
+
"""
|
| 209 |
+
def log(msg: str):
|
| 210 |
+
logger.info(msg)
|
| 211 |
+
if progress_callback:
|
| 212 |
+
progress_callback(msg)
|
| 213 |
+
|
| 214 |
+
if language not in LANGUAGE_MODELS:
|
| 215 |
+
return False, "", f"不支持的语言: {language}"
|
| 216 |
+
|
| 217 |
+
lang_name = LANGUAGE_MODELS[language]["name"]
|
| 218 |
+
log(f"开始下载 {lang_name} 模型...")
|
| 219 |
+
|
| 220 |
+
# 下载声学模型
|
| 221 |
+
log("=" * 40)
|
| 222 |
+
log("下载声学模型...")
|
| 223 |
+
success, acoustic_path = download_acoustic_model(language, output_dir, progress_callback)
|
| 224 |
+
if not success:
|
| 225 |
+
return False, "", acoustic_path
|
| 226 |
+
|
| 227 |
+
# 下载字典
|
| 228 |
+
log("=" * 40)
|
| 229 |
+
log("下载字典文件...")
|
| 230 |
+
success, dict_path = download_dictionary(language, output_dir, progress_callback)
|
| 231 |
+
if not success:
|
| 232 |
+
return False, acoustic_path, dict_path
|
| 233 |
+
|
| 234 |
+
log("=" * 40)
|
| 235 |
+
log(f"{lang_name} 模型下载完成!")
|
| 236 |
+
return True, acoustic_path, dict_path
|
src/mfa_runner.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
MFA 外挂调用模块
|
| 4 |
+
采用 Sidecar Pattern,通过 subprocess 调用独立的 MFA 环境
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import subprocess
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Callable
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# 定位路径
|
| 16 |
+
BASE_DIR = Path(__file__).parent.parent.absolute()
|
| 17 |
+
MFA_ENGINE_DIR = BASE_DIR / "tools" / "mfa_engine"
|
| 18 |
+
MFA_PYTHON = MFA_ENGINE_DIR / "python.exe"
|
| 19 |
+
|
| 20 |
+
# 默认模型路径
|
| 21 |
+
DEFAULT_DICT_PATH = BASE_DIR / "models" / "mandarin.dict"
|
| 22 |
+
DEFAULT_MODEL_PATH = BASE_DIR / "models" / "mandarin.zip"
|
| 23 |
+
DEFAULT_TEMP_DIR = BASE_DIR / "mfa_temp"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def check_mfa_available() -> bool:
|
| 27 |
+
"""检查 MFA 外挂环境是否可用"""
|
| 28 |
+
if not MFA_ENGINE_DIR.exists():
|
| 29 |
+
logger.warning(f"MFA 引擎目录不存在: {MFA_ENGINE_DIR}")
|
| 30 |
+
return False
|
| 31 |
+
if not MFA_PYTHON.exists():
|
| 32 |
+
logger.warning(f"MFA Python 不存在: {MFA_PYTHON}")
|
| 33 |
+
return False
|
| 34 |
+
return True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _build_mfa_env() -> dict:
|
| 38 |
+
"""构造 MFA 专用环境变量"""
|
| 39 |
+
env = os.environ.copy()
|
| 40 |
+
|
| 41 |
+
# 必须把 Library\bin 加入 PATH,否则 Kaldi DLL 找不到
|
| 42 |
+
mfa_paths = [
|
| 43 |
+
str(MFA_ENGINE_DIR),
|
| 44 |
+
str(MFA_ENGINE_DIR / "Library" / "bin"),
|
| 45 |
+
str(MFA_ENGINE_DIR / "Scripts"),
|
| 46 |
+
str(MFA_ENGINE_DIR / "bin"),
|
| 47 |
+
]
|
| 48 |
+
env["PATH"] = ";".join(mfa_paths) + ";" + env.get("PATH", "")
|
| 49 |
+
|
| 50 |
+
return env
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def run_mfa_alignment(
|
| 54 |
+
corpus_dir: str,
|
| 55 |
+
output_dir: str,
|
| 56 |
+
dict_path: Optional[str] = None,
|
| 57 |
+
model_path: Optional[str] = None,
|
| 58 |
+
temp_dir: Optional[str] = None,
|
| 59 |
+
single_speaker: bool = True,
|
| 60 |
+
clean: bool = True,
|
| 61 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 62 |
+
) -> tuple[bool, str]:
|
| 63 |
+
"""
|
| 64 |
+
执行 MFA 对齐
|
| 65 |
+
|
| 66 |
+
参数:
|
| 67 |
+
corpus_dir: 包含 wav 和 lab/txt 的输入目录
|
| 68 |
+
output_dir: TextGrid 输出目录
|
| 69 |
+
dict_path: 字典文件路径,默认使用 models/mandarin.dict
|
| 70 |
+
model_path: 声学模型路径,默认使用 models/mandarin.zip
|
| 71 |
+
temp_dir: 临时目录,默认使用 mfa_temp
|
| 72 |
+
single_speaker: 是否为单说话人模式
|
| 73 |
+
clean: 是否清理旧缓存
|
| 74 |
+
progress_callback: 进度回调函数
|
| 75 |
+
|
| 76 |
+
返回:
|
| 77 |
+
(成功标志, 输出信息或错误信息)
|
| 78 |
+
"""
|
| 79 |
+
def log(msg: str):
|
| 80 |
+
logger.info(msg)
|
| 81 |
+
if progress_callback:
|
| 82 |
+
progress_callback(msg)
|
| 83 |
+
|
| 84 |
+
# 检查环境
|
| 85 |
+
if not check_mfa_available():
|
| 86 |
+
return False, "MFA 外挂环境不可用,请检查 tools/mfa_engine 目录"
|
| 87 |
+
|
| 88 |
+
# 设置默认路径
|
| 89 |
+
dict_path = dict_path or str(DEFAULT_DICT_PATH)
|
| 90 |
+
model_path = model_path or str(DEFAULT_MODEL_PATH)
|
| 91 |
+
temp_dir = temp_dir or str(DEFAULT_TEMP_DIR)
|
| 92 |
+
|
| 93 |
+
# 验证路径
|
| 94 |
+
if not os.path.isdir(corpus_dir):
|
| 95 |
+
return False, f"输入目录不存在: {corpus_dir}"
|
| 96 |
+
if not os.path.isfile(dict_path):
|
| 97 |
+
return False, f"字典文件不存在: {dict_path}"
|
| 98 |
+
if not os.path.isfile(model_path):
|
| 99 |
+
return False, f"声学模型不存在: {model_path}"
|
| 100 |
+
|
| 101 |
+
# 创建输出和临时目录
|
| 102 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 103 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# 构造命令
|
| 106 |
+
cmd = [
|
| 107 |
+
str(MFA_PYTHON),
|
| 108 |
+
"-m", "montreal_forced_aligner",
|
| 109 |
+
"align",
|
| 110 |
+
str(corpus_dir),
|
| 111 |
+
str(dict_path),
|
| 112 |
+
str(model_path),
|
| 113 |
+
str(output_dir),
|
| 114 |
+
"--temp_directory", str(temp_dir),
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
if clean:
|
| 118 |
+
cmd.append("--clean")
|
| 119 |
+
if single_speaker:
|
| 120 |
+
cmd.append("--single_speaker")
|
| 121 |
+
|
| 122 |
+
log(f"正在启动 MFA 对齐引擎...")
|
| 123 |
+
log(f"输入目录: {corpus_dir}")
|
| 124 |
+
log(f"输出目录: {output_dir}")
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
env = _build_mfa_env()
|
| 128 |
+
|
| 129 |
+
result = subprocess.run(
|
| 130 |
+
cmd,
|
| 131 |
+
env=env,
|
| 132 |
+
capture_output=True,
|
| 133 |
+
text=True,
|
| 134 |
+
encoding='utf-8',
|
| 135 |
+
errors='replace'
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if result.returncode == 0:
|
| 139 |
+
log("MFA 对齐完成!")
|
| 140 |
+
return True, result.stdout
|
| 141 |
+
else:
|
| 142 |
+
error_msg = result.stderr or result.stdout or "未知错误"
|
| 143 |
+
log(f"MFA 运行出错: {error_msg}")
|
| 144 |
+
return False, error_msg
|
| 145 |
+
|
| 146 |
+
except FileNotFoundError as e:
|
| 147 |
+
msg = f"找不到 MFA Python: {e}"
|
| 148 |
+
log(msg)
|
| 149 |
+
return False, msg
|
| 150 |
+
except Exception as e:
|
| 151 |
+
msg = f"MFA 执行异常: {e}"
|
| 152 |
+
log(msg)
|
| 153 |
+
return False, msg
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_mfa_validate(
|
| 157 |
+
corpus_dir: str,
|
| 158 |
+
dict_path: Optional[str] = None,
|
| 159 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 160 |
+
) -> tuple[bool, str]:
|
| 161 |
+
"""
|
| 162 |
+
验证语料库格式是否正确
|
| 163 |
+
|
| 164 |
+
参数:
|
| 165 |
+
corpus_dir: 语料库目录
|
| 166 |
+
dict_path: 字典文件路径
|
| 167 |
+
progress_callback: 进度回调函数
|
| 168 |
+
|
| 169 |
+
返回:
|
| 170 |
+
(成功标志, 输出信息)
|
| 171 |
+
"""
|
| 172 |
+
def log(msg: str):
|
| 173 |
+
logger.info(msg)
|
| 174 |
+
if progress_callback:
|
| 175 |
+
progress_callback(msg)
|
| 176 |
+
|
| 177 |
+
if not check_mfa_available():
|
| 178 |
+
return False, "MFA 外挂环境不可用"
|
| 179 |
+
|
| 180 |
+
dict_path = dict_path or str(DEFAULT_DICT_PATH)
|
| 181 |
+
|
| 182 |
+
cmd = [
|
| 183 |
+
str(MFA_PYTHON),
|
| 184 |
+
"-m", "montreal_forced_aligner",
|
| 185 |
+
"validate",
|
| 186 |
+
str(corpus_dir),
|
| 187 |
+
str(dict_path),
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
log("正在验证语料库...")
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
env = _build_mfa_env()
|
| 194 |
+
result = subprocess.run(
|
| 195 |
+
cmd,
|
| 196 |
+
env=env,
|
| 197 |
+
capture_output=True,
|
| 198 |
+
text=True,
|
| 199 |
+
encoding='utf-8',
|
| 200 |
+
errors='replace'
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
output = result.stdout + "\n" + result.stderr
|
| 204 |
+
log("验证完成")
|
| 205 |
+
return result.returncode == 0, output
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
return False, str(e)
|
src/silero_vad_downloader.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Silero VAD 模型下载模块
|
| 4 |
+
支持自动下载 Silero VAD 模型到指定目录
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
import urllib.request
|
| 10 |
+
import urllib.error
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Callable
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Silero VAD 模型配置
|
| 17 |
+
SILERO_VAD_CONFIG = {
|
| 18 |
+
"repo": "snakers4/silero-vad",
|
| 19 |
+
"model_name": "silero_vad",
|
| 20 |
+
"version": "v5.1",
|
| 21 |
+
"onnx_filename": "silero_vad.onnx",
|
| 22 |
+
"jit_filename": "silero_vad.jit",
|
| 23 |
+
"download_base": "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _download_file(
|
| 28 |
+
url: str,
|
| 29 |
+
dest_path: str,
|
| 30 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 31 |
+
) -> bool:
|
| 32 |
+
"""
|
| 33 |
+
下载文件
|
| 34 |
+
|
| 35 |
+
参数:
|
| 36 |
+
url: 下载地址
|
| 37 |
+
dest_path: 保存路径
|
| 38 |
+
progress_callback: 进度回调
|
| 39 |
+
|
| 40 |
+
返回:
|
| 41 |
+
是否成功
|
| 42 |
+
"""
|
| 43 |
+
def log(msg: str):
|
| 44 |
+
logger.info(msg)
|
| 45 |
+
if progress_callback:
|
| 46 |
+
progress_callback(msg)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
log(f"正在下载: {url}")
|
| 50 |
+
|
| 51 |
+
# 创建目录
|
| 52 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# 下载文件
|
| 55 |
+
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
| 56 |
+
|
| 57 |
+
with urllib.request.urlopen(req, timeout=60) as response:
|
| 58 |
+
total_size = response.headers.get("Content-Length")
|
| 59 |
+
if total_size:
|
| 60 |
+
total_size = int(total_size)
|
| 61 |
+
log(f"文件大小: {total_size / 1024 / 1024:.2f} MB")
|
| 62 |
+
|
| 63 |
+
# 分块下载
|
| 64 |
+
block_size = 8192
|
| 65 |
+
downloaded = 0
|
| 66 |
+
|
| 67 |
+
with open(dest_path, "wb") as f:
|
| 68 |
+
while True:
|
| 69 |
+
chunk = response.read(block_size)
|
| 70 |
+
if not chunk:
|
| 71 |
+
break
|
| 72 |
+
f.write(chunk)
|
| 73 |
+
downloaded += len(chunk)
|
| 74 |
+
|
| 75 |
+
if total_size and downloaded % (block_size * 100) == 0:
|
| 76 |
+
percent = downloaded / total_size * 100
|
| 77 |
+
log(f"下载进度: {percent:.1f}%")
|
| 78 |
+
|
| 79 |
+
log(f"下载完成: {dest_path}")
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
except urllib.error.HTTPError as e:
|
| 83 |
+
log(f"HTTP 错误: {e.code} - {e.reason}")
|
| 84 |
+
return False
|
| 85 |
+
except urllib.error.URLError as e:
|
| 86 |
+
log(f"网络错误: {e.reason}")
|
| 87 |
+
return False
|
| 88 |
+
except Exception as e:
|
| 89 |
+
log(f"下载失败: {e}")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_vad_model_path(models_dir: str) -> str:
|
| 94 |
+
"""
|
| 95 |
+
获取 VAD 模型文件路径
|
| 96 |
+
|
| 97 |
+
参数:
|
| 98 |
+
models_dir: 模型根目录
|
| 99 |
+
|
| 100 |
+
返回:
|
| 101 |
+
ONNX 模型文件路径
|
| 102 |
+
"""
|
| 103 |
+
return os.path.join(models_dir, "silero_vad", SILERO_VAD_CONFIG["onnx_filename"])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_vad_model_downloaded(models_dir: str) -> bool:
|
| 107 |
+
"""
|
| 108 |
+
检查 VAD 模型是否已下载
|
| 109 |
+
|
| 110 |
+
参数:
|
| 111 |
+
models_dir: 模型根目录
|
| 112 |
+
|
| 113 |
+
返回:
|
| 114 |
+
是否已下载
|
| 115 |
+
"""
|
| 116 |
+
model_path = get_vad_model_path(models_dir)
|
| 117 |
+
return os.path.exists(model_path)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def download_silero_vad(
|
| 121 |
+
output_dir: str,
|
| 122 |
+
progress_callback: Optional[Callable[[str], None]] = None,
|
| 123 |
+
use_onnx: bool = True
|
| 124 |
+
) -> tuple[bool, str]:
|
| 125 |
+
"""
|
| 126 |
+
下载 Silero VAD 模型
|
| 127 |
+
|
| 128 |
+
参数:
|
| 129 |
+
output_dir: 输出目录 (模型根目录)
|
| 130 |
+
progress_callback: 进度回调
|
| 131 |
+
use_onnx: 是否下载 ONNX 格式 (默认 True,否则下载 JIT 格式)
|
| 132 |
+
|
| 133 |
+
返回:
|
| 134 |
+
(成功标志, 文件路径或错误信息)
|
| 135 |
+
"""
|
| 136 |
+
def log(msg: str):
|
| 137 |
+
logger.info(msg)
|
| 138 |
+
if progress_callback:
|
| 139 |
+
progress_callback(msg)
|
| 140 |
+
|
| 141 |
+
# 确定文件名和 URL
|
| 142 |
+
if use_onnx:
|
| 143 |
+
filename = SILERO_VAD_CONFIG["onnx_filename"]
|
| 144 |
+
else:
|
| 145 |
+
filename = SILERO_VAD_CONFIG["jit_filename"]
|
| 146 |
+
|
| 147 |
+
url = f"{SILERO_VAD_CONFIG['download_base']}/{filename}"
|
| 148 |
+
vad_dir = os.path.join(output_dir, "silero_vad")
|
| 149 |
+
dest_path = os.path.join(vad_dir, filename)
|
| 150 |
+
|
| 151 |
+
# 检查是否已存在
|
| 152 |
+
if os.path.exists(dest_path):
|
| 153 |
+
log(f"Silero VAD 模型已存在: {dest_path}")
|
| 154 |
+
return True, dest_path
|
| 155 |
+
|
| 156 |
+
log("开始下载 Silero VAD 模型...")
|
| 157 |
+
log(f"版本: {SILERO_VAD_CONFIG['version']}")
|
| 158 |
+
log(f"格式: {'ONNX' if use_onnx else 'JIT'}")
|
| 159 |
+
|
| 160 |
+
if _download_file(url, dest_path, progress_callback):
|
| 161 |
+
log("Silero VAD 模型下载完成!")
|
| 162 |
+
return True, dest_path
|
| 163 |
+
else:
|
| 164 |
+
return False, "Silero VAD 模型下载失败"
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def ensure_vad_model(
|
| 168 |
+
models_dir: str,
|
| 169 |
+
progress_callback: Optional[Callable[[str], None]] = None
|
| 170 |
+
) -> str:
|
| 171 |
+
"""
|
| 172 |
+
确保 VAD 模型已下载,如未下载则自动下载
|
| 173 |
+
|
| 174 |
+
参数:
|
| 175 |
+
models_dir: 模型根目录
|
| 176 |
+
progress_callback: 进度回调
|
| 177 |
+
|
| 178 |
+
返回:
|
| 179 |
+
模型文件路径
|
| 180 |
+
|
| 181 |
+
异常:
|
| 182 |
+
RuntimeError: 下载失败时抛出
|
| 183 |
+
"""
|
| 184 |
+
model_path = get_vad_model_path(models_dir)
|
| 185 |
+
|
| 186 |
+
if os.path.exists(model_path):
|
| 187 |
+
logger.info(f"Silero VAD 模型已就绪: {model_path}")
|
| 188 |
+
return model_path
|
| 189 |
+
|
| 190 |
+
success, result = download_silero_vad(models_dir, progress_callback)
|
| 191 |
+
if success:
|
| 192 |
+
return result
|
| 193 |
+
else:
|
| 194 |
+
raise RuntimeError(f"Silero VAD 模型下载失败: {result}")
|
tests/test_mfa_model_downloader.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
MFA 模型下载模块单元测试
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import unittest
|
| 9 |
+
from unittest.mock import patch, MagicMock
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# 添加项目根目录到路径
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
from src.mfa_model_downloader import (
|
| 16 |
+
get_available_languages,
|
| 17 |
+
LANGUAGE_MODELS,
|
| 18 |
+
GITHUB_RELEASE_BASE,
|
| 19 |
+
download_acoustic_model,
|
| 20 |
+
download_dictionary,
|
| 21 |
+
download_language_models,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TestGetAvailableLanguages(unittest.TestCase):
|
| 26 |
+
"""测试获取可用语言列表"""
|
| 27 |
+
|
| 28 |
+
def test_returns_dict(self):
|
| 29 |
+
"""返回值应为字典"""
|
| 30 |
+
result = get_available_languages()
|
| 31 |
+
self.assertIsInstance(result, dict)
|
| 32 |
+
|
| 33 |
+
def test_contains_mandarin(self):
|
| 34 |
+
"""应包含中文"""
|
| 35 |
+
result = get_available_languages()
|
| 36 |
+
self.assertIn("mandarin", result)
|
| 37 |
+
self.assertEqual(result["mandarin"], "中文 (普通话)")
|
| 38 |
+
|
| 39 |
+
def test_contains_japanese(self):
|
| 40 |
+
"""应包含日文"""
|
| 41 |
+
result = get_available_languages()
|
| 42 |
+
self.assertIn("japanese", result)
|
| 43 |
+
self.assertEqual(result["japanese"], "日文")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TestLanguageModelsConfig(unittest.TestCase):
|
| 47 |
+
"""测试语言模型配置"""
|
| 48 |
+
|
| 49 |
+
def test_mandarin_config_complete(self):
|
| 50 |
+
"""中文配置应完整"""
|
| 51 |
+
config = LANGUAGE_MODELS["mandarin"]
|
| 52 |
+
self.assertIn("name", config)
|
| 53 |
+
self.assertIn("acoustic", config)
|
| 54 |
+
self.assertIn("dictionary", config)
|
| 55 |
+
|
| 56 |
+
# 声学模型配置
|
| 57 |
+
acoustic = config["acoustic"]
|
| 58 |
+
self.assertIn("tag", acoustic)
|
| 59 |
+
self.assertIn("filename", acoustic)
|
| 60 |
+
self.assertTrue(acoustic["filename"].endswith(".zip"))
|
| 61 |
+
|
| 62 |
+
# 字典配置
|
| 63 |
+
dictionary = config["dictionary"]
|
| 64 |
+
self.assertIn("tag", dictionary)
|
| 65 |
+
self.assertIn("filename", dictionary)
|
| 66 |
+
self.assertTrue(dictionary["filename"].endswith(".dict"))
|
| 67 |
+
|
| 68 |
+
def test_japanese_config_complete(self):
|
| 69 |
+
"""日文配置应完整"""
|
| 70 |
+
config = LANGUAGE_MODELS["japanese"]
|
| 71 |
+
self.assertIn("name", config)
|
| 72 |
+
self.assertIn("acoustic", config)
|
| 73 |
+
self.assertIn("dictionary", config)
|
| 74 |
+
|
| 75 |
+
def test_acoustic_url_format(self):
|
| 76 |
+
"""声学模型 URL 格式应正确"""
|
| 77 |
+
for lang, config in LANGUAGE_MODELS.items():
|
| 78 |
+
acoustic = config["acoustic"]
|
| 79 |
+
url = f"{GITHUB_RELEASE_BASE}/{acoustic['tag']}/{acoustic['filename']}"
|
| 80 |
+
self.assertTrue(url.startswith("https://github.com/"))
|
| 81 |
+
self.assertIn("mfa-models", url)
|
| 82 |
+
|
| 83 |
+
def test_dictionary_url_format(self):
|
| 84 |
+
"""字典 URL 格式应正确"""
|
| 85 |
+
for lang, config in LANGUAGE_MODELS.items():
|
| 86 |
+
dictionary = config["dictionary"]
|
| 87 |
+
url = f"{GITHUB_RELEASE_BASE}/{dictionary['tag']}/{dictionary['filename']}"
|
| 88 |
+
self.assertTrue(url.startswith("https://github.com/"))
|
| 89 |
+
self.assertIn("dictionary-", url)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestDownloadAcousticModel(unittest.TestCase):
|
| 93 |
+
"""测试声学模型下载"""
|
| 94 |
+
|
| 95 |
+
def test_invalid_language(self):
|
| 96 |
+
"""不支持的语言应返回失败"""
|
| 97 |
+
success, result = download_acoustic_model("invalid_lang", "/tmp")
|
| 98 |
+
self.assertFalse(success)
|
| 99 |
+
self.assertIn("不支持的语言", result)
|
| 100 |
+
|
| 101 |
+
@patch('src.mfa_model_downloader._download_file')
|
| 102 |
+
def test_download_called_with_correct_url(self, mock_download):
|
| 103 |
+
"""应使用正确的 URL 下载"""
|
| 104 |
+
mock_download.return_value = True
|
| 105 |
+
|
| 106 |
+
with patch('os.path.exists', return_value=False):
|
| 107 |
+
download_acoustic_model("mandarin", "/tmp/models")
|
| 108 |
+
|
| 109 |
+
# 验证调用参数
|
| 110 |
+
call_args = mock_download.call_args
|
| 111 |
+
url = call_args[0][0]
|
| 112 |
+
self.assertIn("mandarin_mfa.zip", url)
|
| 113 |
+
self.assertIn("acoustic-mandarin_mfa", url)
|
| 114 |
+
|
| 115 |
+
@patch('os.path.exists')
|
| 116 |
+
def test_skip_if_exists(self, mock_exists):
|
| 117 |
+
"""文件已存在时应跳过下载"""
|
| 118 |
+
mock_exists.return_value = True
|
| 119 |
+
|
| 120 |
+
success, result = download_acoustic_model("mandarin", "/tmp/models")
|
| 121 |
+
self.assertTrue(success)
|
| 122 |
+
self.assertIn("mandarin_mfa.zip", result)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class TestDownloadDictionary(unittest.TestCase):
|
| 126 |
+
"""测试字典下载"""
|
| 127 |
+
|
| 128 |
+
def test_invalid_language(self):
|
| 129 |
+
"""不支持的语言应返回失败"""
|
| 130 |
+
success, result = download_dictionary("invalid_lang", "/tmp")
|
| 131 |
+
self.assertFalse(success)
|
| 132 |
+
self.assertIn("不支持的语言", result)
|
| 133 |
+
|
| 134 |
+
@patch('src.mfa_model_downloader._download_file')
|
| 135 |
+
def test_download_called_with_correct_url(self, mock_download):
|
| 136 |
+
"""应使用正确的 URL 下载"""
|
| 137 |
+
mock_download.return_value = True
|
| 138 |
+
|
| 139 |
+
with patch('os.path.exists', return_value=False):
|
| 140 |
+
download_dictionary("japanese", "/tmp/models")
|
| 141 |
+
|
| 142 |
+
call_args = mock_download.call_args
|
| 143 |
+
url = call_args[0][0]
|
| 144 |
+
self.assertIn("github.com", url)
|
| 145 |
+
self.assertIn("dictionary-japanese", url)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class TestDownloadLanguageModels(unittest.TestCase):
|
| 149 |
+
"""测试完整语言模型下载"""
|
| 150 |
+
|
| 151 |
+
def test_invalid_language(self):
|
| 152 |
+
"""不支持的语言应返回失败"""
|
| 153 |
+
success, acoustic, dict_path = download_language_models("invalid", "/tmp")
|
| 154 |
+
self.assertFalse(success)
|
| 155 |
+
|
| 156 |
+
@patch('src.mfa_model_downloader.download_dictionary')
|
| 157 |
+
@patch('src.mfa_model_downloader.download_acoustic_model')
|
| 158 |
+
def test_downloads_both_models(self, mock_acoustic, mock_dict):
|
| 159 |
+
"""应同时下载声学模型和字典"""
|
| 160 |
+
mock_acoustic.return_value = (True, "/tmp/acoustic.zip")
|
| 161 |
+
mock_dict.return_value = (True, "/tmp/dict.dict")
|
| 162 |
+
|
| 163 |
+
success, acoustic, dict_path = download_language_models("mandarin", "/tmp")
|
| 164 |
+
|
| 165 |
+
self.assertTrue(success)
|
| 166 |
+
mock_acoustic.assert_called_once()
|
| 167 |
+
mock_dict.assert_called_once()
|
| 168 |
+
|
| 169 |
+
@patch('src.mfa_model_downloader.download_dictionary')
|
| 170 |
+
@patch('src.mfa_model_downloader.download_acoustic_model')
|
| 171 |
+
def test_stops_on_acoustic_failure(self, mock_acoustic, mock_dict):
|
| 172 |
+
"""声学模型下载失败时应停止"""
|
| 173 |
+
mock_acoustic.return_value = (False, "下载失败")
|
| 174 |
+
|
| 175 |
+
success, _, _ = download_language_models("mandarin", "/tmp")
|
| 176 |
+
|
| 177 |
+
self.assertFalse(success)
|
| 178 |
+
mock_dict.assert_not_called()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
unittest.main()
|
tests/test_mfa_runner.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
MFA 运行模块单元测试
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import unittest
|
| 9 |
+
from unittest.mock import patch, MagicMock
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# 添加项目根目录到路径
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
from src.mfa_runner import (
|
| 16 |
+
check_mfa_available,
|
| 17 |
+
_build_mfa_env,
|
| 18 |
+
run_mfa_alignment,
|
| 19 |
+
run_mfa_validate,
|
| 20 |
+
BASE_DIR,
|
| 21 |
+
MFA_ENGINE_DIR,
|
| 22 |
+
MFA_PYTHON,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TestCheckMfaAvailable(unittest.TestCase):
|
| 27 |
+
"""测试 MFA 环境检查"""
|
| 28 |
+
|
| 29 |
+
@patch('src.mfa_runner.MFA_ENGINE_DIR')
|
| 30 |
+
def test_returns_false_when_dir_not_exists(self, mock_dir):
|
| 31 |
+
"""目录不存在时应返回 False"""
|
| 32 |
+
mock_path = MagicMock()
|
| 33 |
+
mock_path.exists.return_value = False
|
| 34 |
+
|
| 35 |
+
with patch.object(Path, 'exists', return_value=False):
|
| 36 |
+
# 由于模块级变量,需要重新导入或直接测试逻辑
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def test_path_constants_defined(self):
|
| 40 |
+
"""路径常量应正确定义"""
|
| 41 |
+
self.assertIsInstance(BASE_DIR, Path)
|
| 42 |
+
self.assertIsInstance(MFA_ENGINE_DIR, Path)
|
| 43 |
+
self.assertIsInstance(MFA_PYTHON, Path)
|
| 44 |
+
|
| 45 |
+
# 验证路径结构
|
| 46 |
+
self.assertTrue(str(MFA_ENGINE_DIR).endswith("mfa_engine"))
|
| 47 |
+
self.assertTrue(str(MFA_PYTHON).endswith("python.exe"))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestBuildMfaEnv(unittest.TestCase):
|
| 51 |
+
"""测试 MFA 环境变量构建"""
|
| 52 |
+
|
| 53 |
+
def test_returns_dict(self):
|
| 54 |
+
"""应返回字典"""
|
| 55 |
+
env = _build_mfa_env()
|
| 56 |
+
self.assertIsInstance(env, dict)
|
| 57 |
+
|
| 58 |
+
def test_path_contains_mfa_dirs(self):
|
| 59 |
+
"""PATH 应包含 MFA 相关目录"""
|
| 60 |
+
env = _build_mfa_env()
|
| 61 |
+
path = env.get("PATH", "")
|
| 62 |
+
|
| 63 |
+
self.assertIn("mfa_engine", path)
|
| 64 |
+
self.assertIn("Library", path)
|
| 65 |
+
|
| 66 |
+
def test_preserves_original_path(self):
|
| 67 |
+
"""应保留原始 PATH"""
|
| 68 |
+
original_path = os.environ.get("PATH", "")
|
| 69 |
+
env = _build_mfa_env()
|
| 70 |
+
|
| 71 |
+
# 原始 PATH 应在新 PATH 中
|
| 72 |
+
self.assertIn(original_path.split(";")[0], env["PATH"])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TestRunMfaAlignment(unittest.TestCase):
|
| 76 |
+
"""测试 MFA 对齐功能"""
|
| 77 |
+
|
| 78 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 79 |
+
def test_fails_when_mfa_unavailable(self, mock_check):
|
| 80 |
+
"""MFA 不可用时应返回失败"""
|
| 81 |
+
mock_check.return_value = False
|
| 82 |
+
|
| 83 |
+
success, msg = run_mfa_alignment("/input", "/output")
|
| 84 |
+
|
| 85 |
+
self.assertFalse(success)
|
| 86 |
+
self.assertIn("不可用", msg)
|
| 87 |
+
|
| 88 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 89 |
+
@patch('os.path.isdir')
|
| 90 |
+
def test_fails_when_corpus_not_exists(self, mock_isdir, mock_check):
|
| 91 |
+
"""输入目录不存在时应返回失败"""
|
| 92 |
+
mock_check.return_value = True
|
| 93 |
+
mock_isdir.return_value = False
|
| 94 |
+
|
| 95 |
+
success, msg = run_mfa_alignment("/nonexistent", "/output")
|
| 96 |
+
|
| 97 |
+
self.assertFalse(success)
|
| 98 |
+
self.assertIn("不存在", msg)
|
| 99 |
+
|
| 100 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 101 |
+
@patch('os.path.isdir')
|
| 102 |
+
@patch('os.path.isfile')
|
| 103 |
+
def test_fails_when_dict_not_exists(self, mock_isfile, mock_isdir, mock_check):
|
| 104 |
+
"""字典文件不存在时应返回失败"""
|
| 105 |
+
mock_check.return_value = True
|
| 106 |
+
mock_isdir.return_value = True
|
| 107 |
+
mock_isfile.return_value = False
|
| 108 |
+
|
| 109 |
+
success, msg = run_mfa_alignment(
|
| 110 |
+
"/input", "/output",
|
| 111 |
+
dict_path="/nonexistent.dict"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.assertFalse(success)
|
| 115 |
+
self.assertIn("不存在", msg)
|
| 116 |
+
|
| 117 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 118 |
+
@patch('os.path.isdir')
|
| 119 |
+
@patch('os.path.isfile')
|
| 120 |
+
@patch('os.makedirs')
|
| 121 |
+
@patch('subprocess.run')
|
| 122 |
+
def test_calls_subprocess_with_correct_args(
|
| 123 |
+
self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
|
| 124 |
+
):
|
| 125 |
+
"""应使用正确的参数调用 subprocess"""
|
| 126 |
+
mock_check.return_value = True
|
| 127 |
+
mock_isdir.return_value = True
|
| 128 |
+
mock_isfile.return_value = True
|
| 129 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
| 130 |
+
|
| 131 |
+
run_mfa_alignment(
|
| 132 |
+
"/input", "/output",
|
| 133 |
+
dict_path="/dict.dict",
|
| 134 |
+
model_path="/model.zip",
|
| 135 |
+
single_speaker=True,
|
| 136 |
+
clean=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# 验证 subprocess.run 被调用
|
| 140 |
+
mock_run.assert_called_once()
|
| 141 |
+
|
| 142 |
+
# 验证命令参数
|
| 143 |
+
call_args = mock_run.call_args
|
| 144 |
+
cmd = call_args[0][0]
|
| 145 |
+
|
| 146 |
+
self.assertIn("align", cmd)
|
| 147 |
+
self.assertIn("/input", cmd)
|
| 148 |
+
self.assertIn("/dict.dict", cmd)
|
| 149 |
+
self.assertIn("/model.zip", cmd)
|
| 150 |
+
self.assertIn("/output", cmd)
|
| 151 |
+
self.assertIn("--single_speaker", cmd)
|
| 152 |
+
self.assertIn("--clean", cmd)
|
| 153 |
+
|
| 154 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 155 |
+
@patch('os.path.isdir')
|
| 156 |
+
@patch('os.path.isfile')
|
| 157 |
+
@patch('os.makedirs')
|
| 158 |
+
@patch('subprocess.run')
|
| 159 |
+
def test_returns_success_on_zero_returncode(
|
| 160 |
+
self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
|
| 161 |
+
):
|
| 162 |
+
"""返回码为 0 时应返回成功"""
|
| 163 |
+
mock_check.return_value = True
|
| 164 |
+
mock_isdir.return_value = True
|
| 165 |
+
mock_isfile.return_value = True
|
| 166 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="完成", stderr="")
|
| 167 |
+
|
| 168 |
+
success, msg = run_mfa_alignment(
|
| 169 |
+
"/input", "/output",
|
| 170 |
+
dict_path="/dict.dict",
|
| 171 |
+
model_path="/model.zip"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.assertTrue(success)
|
| 175 |
+
|
| 176 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 177 |
+
@patch('os.path.isdir')
|
| 178 |
+
@patch('os.path.isfile')
|
| 179 |
+
@patch('os.makedirs')
|
| 180 |
+
@patch('subprocess.run')
|
| 181 |
+
def test_returns_failure_on_nonzero_returncode(
|
| 182 |
+
self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
|
| 183 |
+
):
|
| 184 |
+
"""返回码非 0 时应返回失败"""
|
| 185 |
+
mock_check.return_value = True
|
| 186 |
+
mock_isdir.return_value = True
|
| 187 |
+
mock_isfile.return_value = True
|
| 188 |
+
mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="错误")
|
| 189 |
+
|
| 190 |
+
success, msg = run_mfa_alignment(
|
| 191 |
+
"/input", "/output",
|
| 192 |
+
dict_path="/dict.dict",
|
| 193 |
+
model_path="/model.zip"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.assertFalse(success)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TestRunMfaValidate(unittest.TestCase):
|
| 200 |
+
"""测试 MFA 验证功能"""
|
| 201 |
+
|
| 202 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 203 |
+
def test_fails_when_mfa_unavailable(self, mock_check):
|
| 204 |
+
"""MFA 不可用时应返回失败"""
|
| 205 |
+
mock_check.return_value = False
|
| 206 |
+
|
| 207 |
+
success, msg = run_mfa_validate("/corpus")
|
| 208 |
+
|
| 209 |
+
self.assertFalse(success)
|
| 210 |
+
self.assertIn("不可用", msg)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class TestProgressCallback(unittest.TestCase):
|
| 214 |
+
"""测试进度回调"""
|
| 215 |
+
|
| 216 |
+
@patch('src.mfa_runner.check_mfa_available')
|
| 217 |
+
@patch('os.path.isdir')
|
| 218 |
+
@patch('os.path.isfile')
|
| 219 |
+
@patch('os.makedirs')
|
| 220 |
+
@patch('subprocess.run')
|
| 221 |
+
def test_callback_called_on_success(
|
| 222 |
+
self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
|
| 223 |
+
):
|
| 224 |
+
"""成功时应调用回调"""
|
| 225 |
+
mock_check.return_value = True
|
| 226 |
+
mock_isdir.return_value = True
|
| 227 |
+
mock_isfile.return_value = True
|
| 228 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="完成", stderr="")
|
| 229 |
+
callback = MagicMock()
|
| 230 |
+
|
| 231 |
+
run_mfa_alignment(
|
| 232 |
+
"/input", "/output",
|
| 233 |
+
dict_path="/dict.dict",
|
| 234 |
+
model_path="/model.zip",
|
| 235 |
+
progress_callback=callback
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# 回调应被调用(至少一次)
|
| 239 |
+
self.assertTrue(callback.called)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
unittest.main()
|
tests/test_silero_vad_downloader.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Silero VAD 下载模块测试
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
import unittest
|
| 9 |
+
from unittest.mock import patch, MagicMock
|
| 10 |
+
|
| 11 |
+
from src.silero_vad_downloader import (
|
| 12 |
+
get_vad_model_path,
|
| 13 |
+
is_vad_model_downloaded,
|
| 14 |
+
download_silero_vad,
|
| 15 |
+
ensure_vad_model,
|
| 16 |
+
SILERO_VAD_CONFIG
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestSileroVadDownloader(unittest.TestCase):
|
| 21 |
+
"""Silero VAD 下载器测试类"""
|
| 22 |
+
|
| 23 |
+
def test_get_vad_model_path(self):
|
| 24 |
+
"""测试获取模型路径"""
|
| 25 |
+
models_dir = "/test/models"
|
| 26 |
+
expected = os.path.join(models_dir, "silero_vad", "silero_vad.onnx")
|
| 27 |
+
self.assertEqual(get_vad_model_path(models_dir), expected)
|
| 28 |
+
|
| 29 |
+
def test_is_vad_model_downloaded_false(self):
|
| 30 |
+
"""测试模型未下载时返回 False"""
|
| 31 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 32 |
+
self.assertFalse(is_vad_model_downloaded(tmpdir))
|
| 33 |
+
|
| 34 |
+
def test_is_vad_model_downloaded_true(self):
|
| 35 |
+
"""测试模型已下载时返回 True"""
|
| 36 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 37 |
+
vad_dir = os.path.join(tmpdir, "silero_vad")
|
| 38 |
+
os.makedirs(vad_dir)
|
| 39 |
+
model_path = os.path.join(vad_dir, "silero_vad.onnx")
|
| 40 |
+
with open(model_path, "w") as f:
|
| 41 |
+
f.write("dummy")
|
| 42 |
+
self.assertTrue(is_vad_model_downloaded(tmpdir))
|
| 43 |
+
|
| 44 |
+
def test_download_silero_vad_already_exists(self):
|
| 45 |
+
"""测试模型已存在时跳过下载"""
|
| 46 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 47 |
+
vad_dir = os.path.join(tmpdir, "silero_vad")
|
| 48 |
+
os.makedirs(vad_dir)
|
| 49 |
+
model_path = os.path.join(vad_dir, "silero_vad.onnx")
|
| 50 |
+
with open(model_path, "w") as f:
|
| 51 |
+
f.write("dummy")
|
| 52 |
+
|
| 53 |
+
success, result = download_silero_vad(tmpdir)
|
| 54 |
+
self.assertTrue(success)
|
| 55 |
+
self.assertEqual(result, model_path)
|
| 56 |
+
|
| 57 |
+
def test_config_values(self):
|
| 58 |
+
"""测试配置值正确性"""
|
| 59 |
+
self.assertEqual(SILERO_VAD_CONFIG["onnx_filename"], "silero_vad.onnx")
|
| 60 |
+
self.assertEqual(SILERO_VAD_CONFIG["jit_filename"], "silero_vad.jit")
|
| 61 |
+
self.assertIn("snakers4/silero-vad", SILERO_VAD_CONFIG["repo"])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
unittest.main()
|