TNOT commited on
Commit
b0dfe06
·
1 Parent(s): 43d08a3

MFA集成

Browse files
.gitignore CHANGED
@@ -7,11 +7,6 @@ __pycache__/
7
  *$py.class
8
  *.so
9
 
10
- # 分发/打包
11
- dist/
12
- build/
13
- *.egg-info/
14
-
15
  # pip-tools
16
  *.egg
17
 
@@ -29,5 +24,10 @@ build/
29
  temp/
30
  *.tmp
31
 
32
- # 数据目录(根据需要调整)
 
33
  bank/
 
 
 
 
 
7
  *$py.class
8
  *.so
9
 
 
 
 
 
 
10
  # pip-tools
11
  *.egg
12
 
 
24
  temp/
25
  *.tmp
26
 
27
+ # 数据(根据需要调整)
28
+ config.json
29
  bank/
30
+
31
+ # AI 模型相关
32
+ tools/mfa_engine
33
+ models
requirements.in CHANGED
@@ -10,3 +10,7 @@ customtkinter
10
  transformers>=4.25.0
11
  torch
12
  accelerate
 
 
 
 
 
10
  transformers>=4.25.0
11
  torch
12
  accelerate
13
+
14
+ # Silero VAD 语音活动检测
15
+ silero-vad>=5.1
16
+ onnxruntime
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  # This file is autogenerated by pip-compile with Python 3.13
3
  # by the following command:
4
  #
5
- # pip-compile requirements.in
6
  #
7
  --index-url https://pypi.tuna.tsinghua.edu.cn/simple
8
 
@@ -28,6 +28,8 @@ colorama==0.4.6
28
  # via
29
  # click
30
  # tqdm
 
 
31
  customtkinter==5.2.2
32
  # via -r requirements.in
33
  darkdetect==0.8.0
@@ -37,6 +39,8 @@ filelock==3.20.3
37
  # huggingface-hub
38
  # torch
39
  # transformers
 
 
40
  fsspec==2026.1.0
41
  # via
42
  # huggingface-hub
@@ -54,6 +58,8 @@ huggingface-hub==1.3.5
54
  # accelerate
55
  # tokenizers
56
  # transformers
 
 
57
  idna==3.11
58
  # via
59
  # anyio
@@ -71,18 +77,29 @@ numpy==2.4.1
71
  # accelerate
72
  # audiofile
73
  # audmath
 
74
  # soundfile
75
  # transformers
 
 
 
 
76
  packaging==26.0
77
  # via
78
  # accelerate
79
  # customtkinter
80
  # huggingface-hub
 
 
81
  # transformers
 
 
82
  psutil==7.2.2
83
  # via accelerate
84
  pycparser==3.0
85
  # via cffi
 
 
86
  pyyaml==6.0.3
87
  # via
88
  # accelerate
@@ -96,10 +113,14 @@ safetensors==0.7.0
96
  # transformers
97
  shellingham==1.5.4
98
  # via huggingface-hub
 
 
99
  soundfile==0.13.1
100
  # via audiofile
101
  sympy==1.14.0
102
- # via torch
 
 
103
  textgrid==1.6.1
104
  # via -r requirements.in
105
  tokenizers==0.22.2
@@ -108,6 +129,10 @@ torch==2.10.0
108
  # via
109
  # -r requirements.in
110
  # accelerate
 
 
 
 
111
  tqdm==4.67.1
112
  # via
113
  # -r requirements.in
 
2
  # This file is autogenerated by pip-compile with Python 3.13
3
  # by the following command:
4
  #
5
+ # pip-compile --output-file=requirements.txt requirements.in
6
  #
7
  --index-url https://pypi.tuna.tsinghua.edu.cn/simple
8
 
 
28
  # via
29
  # click
30
  # tqdm
31
+ coloredlogs==15.0.1
32
+ # via onnxruntime
33
  customtkinter==5.2.2
34
  # via -r requirements.in
35
  darkdetect==0.8.0
 
39
  # huggingface-hub
40
  # torch
41
  # transformers
42
+ flatbuffers==25.12.19
43
+ # via onnxruntime
44
  fsspec==2026.1.0
45
  # via
46
  # huggingface-hub
 
58
  # accelerate
59
  # tokenizers
60
  # transformers
61
+ humanfriendly==10.0
62
+ # via coloredlogs
63
  idna==3.11
64
  # via
65
  # anyio
 
77
  # accelerate
78
  # audiofile
79
  # audmath
80
+ # onnxruntime
81
  # soundfile
82
  # transformers
83
+ onnxruntime==1.23.2
84
+ # via
85
+ # -r requirements.in
86
+ # silero-vad
87
  packaging==26.0
88
  # via
89
  # accelerate
90
  # customtkinter
91
  # huggingface-hub
92
+ # onnxruntime
93
+ # silero-vad
94
  # transformers
95
+ protobuf==6.33.5
96
+ # via onnxruntime
97
  psutil==7.2.2
98
  # via accelerate
99
  pycparser==3.0
100
  # via cffi
101
+ pyreadline3==3.5.4
102
+ # via humanfriendly
103
  pyyaml==6.0.3
104
  # via
105
  # accelerate
 
113
  # transformers
114
  shellingham==1.5.4
115
  # via huggingface-hub
116
+ silero-vad==6.2.0
117
+ # via -r requirements.in
118
  soundfile==0.13.1
119
  # via audiofile
120
  sympy==1.14.0
121
+ # via
122
+ # onnxruntime
123
+ # torch
124
  textgrid==1.6.1
125
  # via -r requirements.in
126
  tokenizers==0.22.2
 
129
  # via
130
  # -r requirements.in
131
  # accelerate
132
+ # silero-vad
133
+ # torchaudio
134
+ torchaudio==2.10.0
135
+ # via silero-vad
136
  tqdm==4.67.1
137
  # via
138
  # -r requirements.in
src/gui.py CHANGED
@@ -353,17 +353,43 @@ class ModelDownloadFrame(ctk.CTkFrame):
353
  ctk.CTkEntry(self, textvariable=self.mfa_dir_var, width=320).grid(row=7, column=1, padx=5, pady=5, sticky="w")
354
  ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa_dir).grid(row=7, column=2, padx=5, pady=5)
355
 
356
- # MFA 状态
357
- ctk.CTkLabel(self, text="状态:").grid(row=8, column=0, padx=10, pady=5, sticky="w")
358
- self.mfa_status = ctk.CTkLabel(self, text="🚧 TODO: 自动下载功能开发中", text_color="orange")
359
- self.mfa_status.grid(row=8, column=1, columnspan=2, padx=5, pady=5, sticky="w")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  # MFA 文件列表
362
- ctk.CTkLabel(self, text="已有文件:").grid(row=9, column=0, padx=10, pady=(10, 5), sticky="nw")
363
  self.mfa_files_text = ctk.CTkTextbox(self, height=70, width=400)
364
- self.mfa_files_text.grid(row=9, column=1, columnspan=2, padx=5, pady=(10, 5), sticky="w")
365
  self.mfa_files_text.insert("end", "选择目录后显示文件列表")
366
  self.mfa_files_text.configure(state="disabled")
 
 
 
367
 
368
  def _get_model_desc(self):
369
  """获取当前选中模型的描述"""
@@ -397,6 +423,51 @@ class ModelDownloadFrame(ctk.CTkFrame):
397
  self._save_config()
398
  self._scan_mfa_dir()
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  def _scan_mfa_dir(self):
401
  """扫描 MFA 模型目录"""
402
  mfa_dir = self.mfa_dir_var.get()
@@ -498,82 +569,156 @@ class MakeDatasetFrame(ctk.CTkFrame):
498
  def __init__(self, master, log_callback):
499
  super().__init__(master)
500
  self.log_callback = log_callback
 
501
  self._setup_ui()
 
502
 
503
  def _setup_ui(self):
 
 
 
 
 
 
 
 
504
  # 数据集原始目录
505
- ctk.CTkLabel(self, text="① 切片及LAB目录:").grid(row=0, column=0, padx=10, pady=5, sticky="w")
506
  self.raw_dir_var = ctk.StringVar()
507
- ctk.CTkEntry(self, textvariable=self.raw_dir_var, width=400).grid(row=0, column=1, padx=5, pady=5)
508
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_raw_dir).grid(row=0, column=2, padx=5, pady=5)
 
 
 
 
 
 
509
 
510
  # 字典路径
511
- ctk.CTkLabel(self, text=" 字典文件:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
512
- self.dict_path_var = ctk.StringVar()
513
- ctk.CTkEntry(self, textvariable=self.dict_path_var, width=400).grid(row=1, column=1, padx=5, pady=5)
514
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_dict).grid(row=1, column=2, padx=5, pady=5)
515
 
516
  # MFA模型路径
517
- ctk.CTkLabel(self, text=" MFA模型文件:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
518
- self.mfa_model_var = ctk.StringVar()
519
- ctk.CTkEntry(self, textvariable=self.mfa_model_var, width=400).grid(row=2, column=1, padx=5, pady=5)
520
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa).grid(row=2, column=2, padx=5, pady=5)
521
-
522
- # 临时目录
523
- ctk.CTkLabel(self, text="④ 临时目录:").grid(row=3, column=0, padx=10, pady=5, sticky="w")
524
- self.temp_dir_var = ctk.StringVar(value="temp")
525
- ctk.CTkEntry(self, textvariable=self.temp_dir_var, width=400).grid(row=3, column=1, padx=5, pady=5)
526
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_temp).grid(row=3, column=2, padx=5, pady=5)
527
-
528
- # 数据集名称
529
- ctk.CTkLabel(self, text="⑤ 数据集名称:").grid(row=4, column=0, padx=10, pady=5, sticky="w")
530
- self.dataset_name_var = ctk.StringVar()
531
- ctk.CTkEntry(self, textvariable=self.dataset_name_var, width=400).grid(row=4, column=1, padx=5, pady=5)
 
 
 
 
 
 
 
532
 
533
  # 执行按钮
534
- ctk.CTkButton(self, text=" 开始制作", command=self._run).grid(row=5, column=1, pady=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
  def _browse_raw_dir(self):
537
  path = filedialog.askdirectory(title="选择切片及LAB目录")
538
  if path:
539
  self.raw_dir_var.set(path)
540
 
 
 
 
 
 
541
  def _browse_dict(self):
542
- path = filedialog.askopenfilename(title="选择字典文件", filetypes=[("文本文件", "*.txt")])
 
 
 
543
  if path:
544
  self.dict_path_var.set(path)
545
 
546
  def _browse_mfa(self):
547
- path = filedialog.askopenfilename(title="选择MFA模型", filetypes=[("ZIP文件", "*.zip")])
 
 
 
548
  if path:
549
  self.mfa_model_var.set(path)
550
 
551
- def _browse_temp(self):
552
- path = filedialog.askdirectory(title="选择临时目录")
553
- if path:
554
- self.temp_dir_var.set(path)
555
-
556
  def _run(self):
 
 
 
557
  raw_dir = self.raw_dir_var.get()
 
558
  dict_path = self.dict_path_var.get()
559
  mfa_model = self.mfa_model_var.get()
560
- temp_dir = self.temp_dir_var.get()
561
- dataset_name = self.dataset_name_var.get()
562
 
563
- if not all([raw_dir, dict_path, mfa_model, temp_dir, dataset_name]):
564
- messagebox.showerror("错误", "请填写所有必要字段")
565
  return
566
 
567
- self.log_callback("批量制作数据集功能需要MFA环境支持")
568
- self.log_callback("请确保已安装Montreal Forced Aligner")
569
- self.log_callback(f"配置信息:")
570
- self.log_callback(f" - 原始目录: {raw_dir}")
571
- self.log_callback(f" - 字典: {dict_path}")
572
- self.log_callback(f" - MFA模型: {mfa_model}")
573
- self.log_callback(f" - 临时目录: {temp_dir}")
574
- self.log_callback(f" - 数据集名称: {dataset_name}")
575
- self.log_callback("此功能涉及多个外部脚本调用,建议在命令行中执行")
576
- logger.info("批量制作数据集配置已记录")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
 
579
  class App(ctk.CTk):
 
353
  ctk.CTkEntry(self, textvariable=self.mfa_dir_var, width=320).grid(row=7, column=1, padx=5, pady=5, sticky="w")
354
  ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa_dir).grid(row=7, column=2, padx=5, pady=5)
355
 
356
+ # MFA 语言选择
357
+ ctk.CTkLabel(self, text="选择语言:").grid(row=8, column=0, padx=10, pady=5, sticky="w")
358
+ self.mfa_lang_var = ctk.StringVar(value="mandarin")
359
+ self.mfa_lang_dropdown = ctk.CTkComboBox(
360
+ self,
361
+ values=["mandarin", "japanese"],
362
+ variable=self.mfa_lang_var,
363
+ width=200,
364
+ command=self._on_mfa_lang_change
365
+ )
366
+ self.mfa_lang_dropdown.grid(row=8, column=1, padx=5, pady=5, sticky="w")
367
+
368
+ self.mfa_lang_desc = ctk.CTkLabel(self, text="中文 (普通话)", text_color="gray")
369
+ self.mfa_lang_desc.grid(row=8, column=2, padx=5, pady=5, sticky="w")
370
+
371
+ # MFA 下载按钮和状态
372
+ ctk.CTkLabel(self, text="状态:").grid(row=9, column=0, padx=10, pady=5, sticky="w")
373
+ self.mfa_status = ctk.CTkLabel(self, text="⏳ 未下载", text_color="gray")
374
+ self.mfa_status.grid(row=9, column=1, padx=5, pady=5, sticky="w")
375
+
376
+ self.mfa_download_btn = ctk.CTkButton(
377
+ self,
378
+ text="下载模型",
379
+ command=self._download_mfa_models,
380
+ width=140
381
+ )
382
+ self.mfa_download_btn.grid(row=9, column=2, padx=5, pady=5, sticky="w")
383
 
384
  # MFA 文件列表
385
+ ctk.CTkLabel(self, text="已有文件:").grid(row=10, column=0, padx=10, pady=(10, 5), sticky="nw")
386
  self.mfa_files_text = ctk.CTkTextbox(self, height=70, width=400)
387
+ self.mfa_files_text.grid(row=10, column=1, columnspan=2, padx=5, pady=(10, 5), sticky="w")
388
  self.mfa_files_text.insert("end", "选择目录后显示文件列表")
389
  self.mfa_files_text.configure(state="disabled")
390
+
391
+ # 初始扫描
392
+ self._scan_mfa_dir()
393
 
394
  def _get_model_desc(self):
395
  """获取当前选中模型的描述"""
 
423
  self._save_config()
424
  self._scan_mfa_dir()
425
 
426
+ def _on_mfa_lang_change(self, choice):
427
+ """MFA 语言选择变更"""
428
+ from src.mfa_model_downloader import get_available_languages
429
+ langs = get_available_languages()
430
+ self.mfa_lang_desc.configure(text=langs.get(choice, ""))
431
+
432
+ def _download_mfa_models(self):
433
+ """下载 MFA 模型"""
434
+ if self._download_thread and self._download_thread.is_alive():
435
+ return
436
+
437
+ self.mfa_download_btn.configure(state="disabled")
438
+ self.mfa_status.configure(text="⏳ 下载中...", text_color="gray")
439
+ self._download_thread = threading.Thread(target=self._do_download_mfa, daemon=True)
440
+ self._download_thread.start()
441
+
442
+ def _do_download_mfa(self):
443
+ """执行 MFA 模型下载(后台线程)"""
444
+ from src.mfa_model_downloader import download_language_models
445
+
446
+ language = self.mfa_lang_var.get()
447
+ output_dir = self.mfa_dir_var.get()
448
+
449
+ # 确保目录存在
450
+ if not os.path.exists(output_dir):
451
+ os.makedirs(output_dir)
452
+
453
+ self.log_callback(f"开始下载 MFA 模型: {language}")
454
+
455
+ success, acoustic_path, dict_path = download_language_models(
456
+ language=language,
457
+ output_dir=output_dir,
458
+ progress_callback=self.log_callback
459
+ )
460
+
461
+ if success:
462
+ self.after(0, lambda: self.mfa_status.configure(text="✅ 已下载", text_color="green"))
463
+ self.log_callback(f"声学模型: {acoustic_path}")
464
+ self.log_callback(f"字典文件: {dict_path}")
465
+ else:
466
+ self.after(0, lambda: self.mfa_status.configure(text="❌ 下载失败", text_color="red"))
467
+
468
+ self.after(0, lambda: self.mfa_download_btn.configure(state="normal"))
469
+ self.after(0, self._scan_mfa_dir)
470
+
471
  def _scan_mfa_dir(self):
472
  """扫描 MFA 模型目录"""
473
  mfa_dir = self.mfa_dir_var.get()
 
569
  def __init__(self, master, log_callback):
570
  super().__init__(master)
571
  self.log_callback = log_callback
572
+ self._is_running = False
573
  self._setup_ui()
574
+ self._check_mfa_status()
575
 
576
  def _setup_ui(self):
577
+ # MFA 状态提示
578
+ self.mfa_status_label = ctk.CTkLabel(
579
+ self,
580
+ text="⏳ 检查 MFA 环境...",
581
+ font=ctk.CTkFont(size=12)
582
+ )
583
+ self.mfa_status_label.grid(row=0, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w")
584
+
585
  # 数据集原始目录
586
+ ctk.CTkLabel(self, text="① 切片及LAB目录:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
587
  self.raw_dir_var = ctk.StringVar()
588
+ ctk.CTkEntry(self, textvariable=self.raw_dir_var, width=400).grid(row=1, column=1, padx=5, pady=5)
589
+ ctk.CTkButton(self, text="浏览", width=60, command=self._browse_raw_dir).grid(row=1, column=2, padx=5, pady=5)
590
+
591
+ # 输出目录
592
+ ctk.CTkLabel(self, text="② TextGrid输出目录:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
593
+ self.output_dir_var = ctk.StringVar()
594
+ ctk.CTkEntry(self, textvariable=self.output_dir_var, width=400).grid(row=2, column=1, padx=5, pady=5)
595
+ ctk.CTkButton(self, text="浏览", width=60, command=self._browse_output_dir).grid(row=2, column=2, padx=5, pady=5)
596
 
597
  # 字典路径
598
+ ctk.CTkLabel(self, text=" 字典文件:").grid(row=3, column=0, padx=10, pady=5, sticky="w")
599
+ self.dict_path_var = ctk.StringVar(value="models/mfa/mandarin_china_mfa.dict")
600
+ ctk.CTkEntry(self, textvariable=self.dict_path_var, width=400).grid(row=3, column=1, padx=5, pady=5)
601
+ ctk.CTkButton(self, text="浏览", width=60, command=self._browse_dict).grid(row=3, column=2, padx=5, pady=5)
602
 
603
  # MFA模型路径
604
+ ctk.CTkLabel(self, text=" MFA模型文件:").grid(row=4, column=0, padx=10, pady=5, sticky="w")
605
+ self.mfa_model_var = ctk.StringVar(value="models/mfa/mandarin_mfa.zip")
606
+ ctk.CTkEntry(self, textvariable=self.mfa_model_var, width=400).grid(row=4, column=1, padx=5, pady=5)
607
+ ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa).grid(row=4, column=2, padx=5, pady=5)
608
+
609
+ # 选项
610
+ options_frame = ctk.CTkFrame(self)
611
+ options_frame.grid(row=5, column=0, columnspan=3, padx=10, pady=10, sticky="w")
612
+
613
+ self.single_speaker_var = ctk.BooleanVar(value=True)
614
+ ctk.CTkCheckBox(
615
+ options_frame,
616
+ text="单说话人模式",
617
+ variable=self.single_speaker_var
618
+ ).pack(side="left", padx=10)
619
+
620
+ self.clean_var = ctk.BooleanVar(value=True)
621
+ ctk.CTkCheckBox(
622
+ options_frame,
623
+ text="清理旧缓存",
624
+ variable=self.clean_var
625
+ ).pack(side="left", padx=10)
626
 
627
  # 执行按钮
628
+ self.run_btn = ctk.CTkButton(self, text=" 开始对齐", command=self._run)
629
+ self.run_btn.grid(row=6, column=1, pady=20)
630
+
631
+ def _check_mfa_status(self):
632
+ """检查 MFA 环境状态"""
633
+ from src.mfa_runner import check_mfa_available
634
+
635
+ if check_mfa_available():
636
+ self.mfa_status_label.configure(
637
+ text="✅ MFA 外挂环境已就绪 (tools/mfa_engine)",
638
+ text_color="green"
639
+ )
640
+ else:
641
+ self.mfa_status_label.configure(
642
+ text="❌ MFA 外挂环境不可用,请检查 tools/mfa_engine 目录",
643
+ text_color="red"
644
+ )
645
 
646
  def _browse_raw_dir(self):
647
  path = filedialog.askdirectory(title="选择切片及LAB目录")
648
  if path:
649
  self.raw_dir_var.set(path)
650
 
651
+ def _browse_output_dir(self):
652
+ path = filedialog.askdirectory(title="选择TextGrid输出目录")
653
+ if path:
654
+ self.output_dir_var.set(path)
655
+
656
  def _browse_dict(self):
657
+ path = filedialog.askopenfilename(
658
+ title="选择字典文件",
659
+ filetypes=[("字典文件", "*.dict *.txt"), ("所有文件", "*.*")]
660
+ )
661
  if path:
662
  self.dict_path_var.set(path)
663
 
664
  def _browse_mfa(self):
665
+ path = filedialog.askopenfilename(
666
+ title="选择MFA模型",
667
+ filetypes=[("ZIP文件", "*.zip"), ("所有文件", "*.*")]
668
+ )
669
  if path:
670
  self.mfa_model_var.set(path)
671
 
 
 
 
 
 
672
  def _run(self):
673
+ if self._is_running:
674
+ return
675
+
676
  raw_dir = self.raw_dir_var.get()
677
+ output_dir = self.output_dir_var.get()
678
  dict_path = self.dict_path_var.get()
679
  mfa_model = self.mfa_model_var.get()
 
 
680
 
681
+ if not raw_dir or not output_dir:
682
+ messagebox.showerror("错误", "请填写输入目录和输出目录")
683
  return
684
 
685
+ self._is_running = True
686
+ self.run_btn.configure(state="disabled", text="对齐中...")
687
+
688
+ threading.Thread(
689
+ target=self._process,
690
+ args=(raw_dir, output_dir, dict_path, mfa_model),
691
+ daemon=True
692
+ ).start()
693
+
694
+ def _process(self, raw_dir, output_dir, dict_path, mfa_model):
695
+ """执行 MFA 对齐(后台线程)"""
696
+ from src.mfa_runner import run_mfa_alignment
697
+
698
+ self.log_callback("=" * 50)
699
+ self.log_callback("开始 MFA 对齐任务")
700
+
701
+ success, message = run_mfa_alignment(
702
+ corpus_dir=raw_dir,
703
+ output_dir=output_dir,
704
+ dict_path=dict_path if dict_path else None,
705
+ model_path=mfa_model if mfa_model else None,
706
+ single_speaker=self.single_speaker_var.get(),
707
+ clean=self.clean_var.get(),
708
+ progress_callback=self.log_callback
709
+ )
710
+
711
+ if success:
712
+ self.log_callback("✅ MFA 对齐任务完成!")
713
+ self.log_callback(f"TextGrid 文件已输出到: {output_dir}")
714
+ else:
715
+ self.log_callback(f"❌ MFA 对齐失败: {message}")
716
+
717
+ self.log_callback("=" * 50)
718
+
719
+ # 恢复按钮状态
720
+ self.after(0, lambda: self.run_btn.configure(state="normal", text="⑤ 开始对齐"))
721
+ self._is_running = False
722
 
723
 
724
  class App(ctk.CTk):
src/mfa_model_downloader.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ MFA 模型下载模块
4
+ 支持下载中文和日文的声学模型及字典
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import urllib.request
10
+ import urllib.error
11
+ from pathlib import Path
12
+ from typing import Optional, Callable
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # 模型下载基础 URL
17
+ GITHUB_RELEASE_BASE = "https://github.com/MontrealCorpusTools/mfa-models/releases/download"
18
+ GITHUB_RAW_BASE = "https://raw.githubusercontent.com/MontrealCorpusTools/mfa-models/main"
19
+
20
+ # 支持的语言配置
21
+ # 格式: {语言代码: {名称, 声学模型信息, 字典信息}}
22
+ LANGUAGE_MODELS = {
23
+ "mandarin": {
24
+ "name": "中文 (普通话)",
25
+ "acoustic": {
26
+ "tag": "acoustic-mandarin_mfa-v3.0.0",
27
+ "filename": "mandarin_mfa.zip",
28
+ "description": "Mandarin MFA acoustic model v3.0.0"
29
+ },
30
+ "dictionary": {
31
+ # 字典从 releases 下载,tag 格式: dictionary-{name}-v{version}
32
+ "tag": "dictionary-mandarin_china_mfa-v3.0.0",
33
+ "filename": "mandarin_china_mfa.dict",
34
+ "description": "Mandarin (China) MFA dictionary v3.0.0"
35
+ }
36
+ },
37
+ "japanese": {
38
+ "name": "日文",
39
+ "acoustic": {
40
+ "tag": "acoustic-japanese_mfa-v3.0.0",
41
+ "filename": "japanese_mfa.zip",
42
+ "description": "Japanese MFA acoustic model v3.0.0"
43
+ },
44
+ "dictionary": {
45
+ "tag": "dictionary-japanese_mfa-v3.0.0",
46
+ "filename": "japanese_mfa.dict",
47
+ "description": "Japanese MFA dictionary v3.0.0"
48
+ }
49
+ }
50
+ }
51
+
52
+
53
+ def get_available_languages() -> dict:
54
+ """获取可用的语言列表"""
55
+ return {k: v["name"] for k, v in LANGUAGE_MODELS.items()}
56
+
57
+
58
+ def _download_file(
59
+ url: str,
60
+ dest_path: str,
61
+ progress_callback: Optional[Callable[[str], None]] = None
62
+ ) -> bool:
63
+ """
64
+ 下载文件
65
+
66
+ 参数:
67
+ url: 下载地址
68
+ dest_path: 保存路径
69
+ progress_callback: 进度回调
70
+
71
+ 返回:
72
+ 是否成功
73
+ """
74
+ def log(msg: str):
75
+ logger.info(msg)
76
+ if progress_callback:
77
+ progress_callback(msg)
78
+
79
+ try:
80
+ log(f"正在下载: {url}")
81
+
82
+ # 创建目录
83
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
84
+
85
+ # 下载文件
86
+ req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
87
+
88
+ with urllib.request.urlopen(req, timeout=60) as response:
89
+ total_size = response.headers.get("Content-Length")
90
+ if total_size:
91
+ total_size = int(total_size)
92
+ log(f"文件大小: {total_size / 1024 / 1024:.1f} MB")
93
+
94
+ # 分块下载
95
+ block_size = 8192
96
+ downloaded = 0
97
+
98
+ with open(dest_path, "wb") as f:
99
+ while True:
100
+ chunk = response.read(block_size)
101
+ if not chunk:
102
+ break
103
+ f.write(chunk)
104
+ downloaded += len(chunk)
105
+
106
+ if total_size and downloaded % (block_size * 100) == 0:
107
+ percent = downloaded / total_size * 100
108
+ log(f"下载进度: {percent:.1f}%")
109
+
110
+ log(f"下载完成: {dest_path}")
111
+ return True
112
+
113
+ except urllib.error.HTTPError as e:
114
+ log(f"HTTP 错误: {e.code} - {e.reason}")
115
+ return False
116
+ except urllib.error.URLError as e:
117
+ log(f"网络错误: {e.reason}")
118
+ return False
119
+ except Exception as e:
120
+ log(f"下载失败: {e}")
121
+ return False
122
+
123
+
124
+ def download_acoustic_model(
125
+ language: str,
126
+ output_dir: str,
127
+ progress_callback: Optional[Callable[[str], None]] = None
128
+ ) -> tuple[bool, str]:
129
+ """
130
+ 下载声学模型
131
+
132
+ 参数:
133
+ language: 语言代码 (mandarin/japanese)
134
+ output_dir: 输出目录
135
+ progress_callback: 进度回调
136
+
137
+ 返回:
138
+ (成功标志, 文件路径或错误信息)
139
+ """
140
+ if language not in LANGUAGE_MODELS:
141
+ return False, f"不支持的语言: {language}"
142
+
143
+ config = LANGUAGE_MODELS[language]["acoustic"]
144
+ url = f"{GITHUB_RELEASE_BASE}/{config['tag']}/{config['filename']}"
145
+ dest_path = os.path.join(output_dir, config["filename"])
146
+
147
+ if os.path.exists(dest_path):
148
+ if progress_callback:
149
+ progress_callback(f"声学模型已存在: {dest_path}")
150
+ return True, dest_path
151
+
152
+ if _download_file(url, dest_path, progress_callback):
153
+ return True, dest_path
154
+ else:
155
+ return False, "声学模型下载失败"
156
+
157
+
158
+ def download_dictionary(
159
+ language: str,
160
+ output_dir: str,
161
+ progress_callback: Optional[Callable[[str], None]] = None
162
+ ) -> tuple[bool, str]:
163
+ """
164
+ 下载字典文件
165
+
166
+ 参数:
167
+ language: 语言代码 (mandarin/japanese)
168
+ output_dir: 输出目录
169
+ progress_callback: 进度回调
170
+
171
+ 返回:
172
+ (成功标志, 文件路径或错误信息)
173
+ """
174
+ if language not in LANGUAGE_MODELS:
175
+ return False, f"不支持的语言: {language}"
176
+
177
+ config = LANGUAGE_MODELS[language]["dictionary"]
178
+ # 字典文件从 releases 下载
179
+ url = f"{GITHUB_RELEASE_BASE}/{config['tag']}/{config['filename']}"
180
+ dest_path = os.path.join(output_dir, config["filename"])
181
+
182
+ if os.path.exists(dest_path):
183
+ if progress_callback:
184
+ progress_callback(f"字典文件已存在: {dest_path}")
185
+ return True, dest_path
186
+
187
+ if _download_file(url, dest_path, progress_callback):
188
+ return True, dest_path
189
+ else:
190
+ return False, "字典文件下载失败"
191
+
192
+
193
+ def download_language_models(
194
+ language: str,
195
+ output_dir: str,
196
+ progress_callback: Optional[Callable[[str], None]] = None
197
+ ) -> tuple[bool, str, str]:
198
+ """
199
+ 下载指定语言的声学模型和字典
200
+
201
+ 参数:
202
+ language: 语言代码 (mandarin/japanese)
203
+ output_dir: 输出目录
204
+ progress_callback: 进度回调
205
+
206
+ 返回:
207
+ (成功标志, 声学模型路径, 字典路径)
208
+ """
209
+ def log(msg: str):
210
+ logger.info(msg)
211
+ if progress_callback:
212
+ progress_callback(msg)
213
+
214
+ if language not in LANGUAGE_MODELS:
215
+ return False, "", f"不支持的语言: {language}"
216
+
217
+ lang_name = LANGUAGE_MODELS[language]["name"]
218
+ log(f"开始下载 {lang_name} 模型...")
219
+
220
+ # 下载声学模型
221
+ log("=" * 40)
222
+ log("下载声学模型...")
223
+ success, acoustic_path = download_acoustic_model(language, output_dir, progress_callback)
224
+ if not success:
225
+ return False, "", acoustic_path
226
+
227
+ # 下载字典
228
+ log("=" * 40)
229
+ log("下载字典文件...")
230
+ success, dict_path = download_dictionary(language, output_dir, progress_callback)
231
+ if not success:
232
+ return False, acoustic_path, dict_path
233
+
234
+ log("=" * 40)
235
+ log(f"{lang_name} 模型下载完成!")
236
+ return True, acoustic_path, dict_path
src/mfa_runner.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ MFA 外挂调用模块
4
+ 采用 Sidecar Pattern,通过 subprocess 调用独立的 MFA 环境
5
+ """
6
+
7
+ import os
8
+ import subprocess
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional, Callable
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # 定位路径
16
+ BASE_DIR = Path(__file__).parent.parent.absolute()
17
+ MFA_ENGINE_DIR = BASE_DIR / "tools" / "mfa_engine"
18
+ MFA_PYTHON = MFA_ENGINE_DIR / "python.exe"
19
+
20
+ # 默认模型路径
21
+ DEFAULT_DICT_PATH = BASE_DIR / "models" / "mandarin.dict"
22
+ DEFAULT_MODEL_PATH = BASE_DIR / "models" / "mandarin.zip"
23
+ DEFAULT_TEMP_DIR = BASE_DIR / "mfa_temp"
24
+
25
+
26
+ def check_mfa_available() -> bool:
27
+ """检查 MFA 外挂环境是否可用"""
28
+ if not MFA_ENGINE_DIR.exists():
29
+ logger.warning(f"MFA 引擎目录不存在: {MFA_ENGINE_DIR}")
30
+ return False
31
+ if not MFA_PYTHON.exists():
32
+ logger.warning(f"MFA Python 不存在: {MFA_PYTHON}")
33
+ return False
34
+ return True
35
+
36
+
37
+ def _build_mfa_env() -> dict:
38
+ """构造 MFA 专用环境变量"""
39
+ env = os.environ.copy()
40
+
41
+ # 必须把 Library\bin 加入 PATH,否则 Kaldi DLL 找不到
42
+ mfa_paths = [
43
+ str(MFA_ENGINE_DIR),
44
+ str(MFA_ENGINE_DIR / "Library" / "bin"),
45
+ str(MFA_ENGINE_DIR / "Scripts"),
46
+ str(MFA_ENGINE_DIR / "bin"),
47
+ ]
48
+ env["PATH"] = ";".join(mfa_paths) + ";" + env.get("PATH", "")
49
+
50
+ return env
51
+
52
+
53
+ def run_mfa_alignment(
54
+ corpus_dir: str,
55
+ output_dir: str,
56
+ dict_path: Optional[str] = None,
57
+ model_path: Optional[str] = None,
58
+ temp_dir: Optional[str] = None,
59
+ single_speaker: bool = True,
60
+ clean: bool = True,
61
+ progress_callback: Optional[Callable[[str], None]] = None
62
+ ) -> tuple[bool, str]:
63
+ """
64
+ 执行 MFA 对齐
65
+
66
+ 参数:
67
+ corpus_dir: 包含 wav 和 lab/txt 的输入目录
68
+ output_dir: TextGrid 输出目录
69
+ dict_path: 字典文件路径,默认使用 models/mandarin.dict
70
+ model_path: 声学模型路径,默认使用 models/mandarin.zip
71
+ temp_dir: 临时目录,默认使用 mfa_temp
72
+ single_speaker: 是否为单说话人模式
73
+ clean: 是否清理旧缓存
74
+ progress_callback: 进度回调函数
75
+
76
+ 返回:
77
+ (成功标志, 输出信息或错误信息)
78
+ """
79
+ def log(msg: str):
80
+ logger.info(msg)
81
+ if progress_callback:
82
+ progress_callback(msg)
83
+
84
+ # 检查环境
85
+ if not check_mfa_available():
86
+ return False, "MFA 外挂环境不可用,请检查 tools/mfa_engine 目录"
87
+
88
+ # 设置默认路径
89
+ dict_path = dict_path or str(DEFAULT_DICT_PATH)
90
+ model_path = model_path or str(DEFAULT_MODEL_PATH)
91
+ temp_dir = temp_dir or str(DEFAULT_TEMP_DIR)
92
+
93
+ # 验证路径
94
+ if not os.path.isdir(corpus_dir):
95
+ return False, f"输入目录不存在: {corpus_dir}"
96
+ if not os.path.isfile(dict_path):
97
+ return False, f"字典文件不存在: {dict_path}"
98
+ if not os.path.isfile(model_path):
99
+ return False, f"声学模型不存在: {model_path}"
100
+
101
+ # 创建输出和临时目录
102
+ os.makedirs(output_dir, exist_ok=True)
103
+ os.makedirs(temp_dir, exist_ok=True)
104
+
105
+ # 构造命令
106
+ cmd = [
107
+ str(MFA_PYTHON),
108
+ "-m", "montreal_forced_aligner",
109
+ "align",
110
+ str(corpus_dir),
111
+ str(dict_path),
112
+ str(model_path),
113
+ str(output_dir),
114
+ "--temp_directory", str(temp_dir),
115
+ ]
116
+
117
+ if clean:
118
+ cmd.append("--clean")
119
+ if single_speaker:
120
+ cmd.append("--single_speaker")
121
+
122
+ log(f"正在启动 MFA 对齐引擎...")
123
+ log(f"输入目录: {corpus_dir}")
124
+ log(f"输出目录: {output_dir}")
125
+
126
+ try:
127
+ env = _build_mfa_env()
128
+
129
+ result = subprocess.run(
130
+ cmd,
131
+ env=env,
132
+ capture_output=True,
133
+ text=True,
134
+ encoding='utf-8',
135
+ errors='replace'
136
+ )
137
+
138
+ if result.returncode == 0:
139
+ log("MFA 对齐完成!")
140
+ return True, result.stdout
141
+ else:
142
+ error_msg = result.stderr or result.stdout or "未知错误"
143
+ log(f"MFA 运行出错: {error_msg}")
144
+ return False, error_msg
145
+
146
+ except FileNotFoundError as e:
147
+ msg = f"找不到 MFA Python: {e}"
148
+ log(msg)
149
+ return False, msg
150
+ except Exception as e:
151
+ msg = f"MFA 执行异常: {e}"
152
+ log(msg)
153
+ return False, msg
154
+
155
+
156
+ def run_mfa_validate(
157
+ corpus_dir: str,
158
+ dict_path: Optional[str] = None,
159
+ progress_callback: Optional[Callable[[str], None]] = None
160
+ ) -> tuple[bool, str]:
161
+ """
162
+ 验证语料库格式是否正确
163
+
164
+ 参数:
165
+ corpus_dir: 语料库目录
166
+ dict_path: 字典文件路径
167
+ progress_callback: 进度回调函数
168
+
169
+ 返回:
170
+ (成功标志, 输出信息)
171
+ """
172
+ def log(msg: str):
173
+ logger.info(msg)
174
+ if progress_callback:
175
+ progress_callback(msg)
176
+
177
+ if not check_mfa_available():
178
+ return False, "MFA 外挂环境不可用"
179
+
180
+ dict_path = dict_path or str(DEFAULT_DICT_PATH)
181
+
182
+ cmd = [
183
+ str(MFA_PYTHON),
184
+ "-m", "montreal_forced_aligner",
185
+ "validate",
186
+ str(corpus_dir),
187
+ str(dict_path),
188
+ ]
189
+
190
+ log("正在验证语料库...")
191
+
192
+ try:
193
+ env = _build_mfa_env()
194
+ result = subprocess.run(
195
+ cmd,
196
+ env=env,
197
+ capture_output=True,
198
+ text=True,
199
+ encoding='utf-8',
200
+ errors='replace'
201
+ )
202
+
203
+ output = result.stdout + "\n" + result.stderr
204
+ log("验证完成")
205
+ return result.returncode == 0, output
206
+
207
+ except Exception as e:
208
+ return False, str(e)
src/silero_vad_downloader.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Silero VAD 模型下载模块
4
+ 支持自动下载 Silero VAD 模型到指定目录
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import urllib.request
10
+ import urllib.error
11
+ from pathlib import Path
12
+ from typing import Optional, Callable
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Silero VAD 模型配置
17
+ SILERO_VAD_CONFIG = {
18
+ "repo": "snakers4/silero-vad",
19
+ "model_name": "silero_vad",
20
+ "version": "v5.1",
21
+ "onnx_filename": "silero_vad.onnx",
22
+ "jit_filename": "silero_vad.jit",
23
+ "download_base": "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data"
24
+ }
25
+
26
+
27
+ def _download_file(
28
+ url: str,
29
+ dest_path: str,
30
+ progress_callback: Optional[Callable[[str], None]] = None
31
+ ) -> bool:
32
+ """
33
+ 下载文件
34
+
35
+ 参数:
36
+ url: 下载地址
37
+ dest_path: 保存路径
38
+ progress_callback: 进度回调
39
+
40
+ 返回:
41
+ 是否成功
42
+ """
43
+ def log(msg: str):
44
+ logger.info(msg)
45
+ if progress_callback:
46
+ progress_callback(msg)
47
+
48
+ try:
49
+ log(f"正在下载: {url}")
50
+
51
+ # 创建目录
52
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
53
+
54
+ # 下载文件
55
+ req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
56
+
57
+ with urllib.request.urlopen(req, timeout=60) as response:
58
+ total_size = response.headers.get("Content-Length")
59
+ if total_size:
60
+ total_size = int(total_size)
61
+ log(f"文件大小: {total_size / 1024 / 1024:.2f} MB")
62
+
63
+ # 分块下载
64
+ block_size = 8192
65
+ downloaded = 0
66
+
67
+ with open(dest_path, "wb") as f:
68
+ while True:
69
+ chunk = response.read(block_size)
70
+ if not chunk:
71
+ break
72
+ f.write(chunk)
73
+ downloaded += len(chunk)
74
+
75
+ if total_size and downloaded % (block_size * 100) == 0:
76
+ percent = downloaded / total_size * 100
77
+ log(f"下载进度: {percent:.1f}%")
78
+
79
+ log(f"下载完成: {dest_path}")
80
+ return True
81
+
82
+ except urllib.error.HTTPError as e:
83
+ log(f"HTTP 错误: {e.code} - {e.reason}")
84
+ return False
85
+ except urllib.error.URLError as e:
86
+ log(f"网络错误: {e.reason}")
87
+ return False
88
+ except Exception as e:
89
+ log(f"下载失败: {e}")
90
+ return False
91
+
92
+
93
+ def get_vad_model_path(models_dir: str) -> str:
94
+ """
95
+ 获取 VAD 模型文件路径
96
+
97
+ 参数:
98
+ models_dir: 模型根目录
99
+
100
+ 返回:
101
+ ONNX 模型文件路径
102
+ """
103
+ return os.path.join(models_dir, "silero_vad", SILERO_VAD_CONFIG["onnx_filename"])
104
+
105
+
106
+ def is_vad_model_downloaded(models_dir: str) -> bool:
107
+ """
108
+ 检查 VAD 模型是否已下载
109
+
110
+ 参数:
111
+ models_dir: 模型根目录
112
+
113
+ 返回:
114
+ 是否已下载
115
+ """
116
+ model_path = get_vad_model_path(models_dir)
117
+ return os.path.exists(model_path)
118
+
119
+
120
+ def download_silero_vad(
121
+ output_dir: str,
122
+ progress_callback: Optional[Callable[[str], None]] = None,
123
+ use_onnx: bool = True
124
+ ) -> tuple[bool, str]:
125
+ """
126
+ 下载 Silero VAD 模型
127
+
128
+ 参数:
129
+ output_dir: 输出目录 (模型根目录)
130
+ progress_callback: 进度回调
131
+ use_onnx: 是否下载 ONNX 格式 (默认 True,否则下载 JIT 格式)
132
+
133
+ 返回:
134
+ (成功标志, 文件路径或错误信息)
135
+ """
136
+ def log(msg: str):
137
+ logger.info(msg)
138
+ if progress_callback:
139
+ progress_callback(msg)
140
+
141
+ # 确定文件名和 URL
142
+ if use_onnx:
143
+ filename = SILERO_VAD_CONFIG["onnx_filename"]
144
+ else:
145
+ filename = SILERO_VAD_CONFIG["jit_filename"]
146
+
147
+ url = f"{SILERO_VAD_CONFIG['download_base']}/{filename}"
148
+ vad_dir = os.path.join(output_dir, "silero_vad")
149
+ dest_path = os.path.join(vad_dir, filename)
150
+
151
+ # 检查是否已存在
152
+ if os.path.exists(dest_path):
153
+ log(f"Silero VAD 模型已存在: {dest_path}")
154
+ return True, dest_path
155
+
156
+ log("开始下载 Silero VAD 模型...")
157
+ log(f"版本: {SILERO_VAD_CONFIG['version']}")
158
+ log(f"格式: {'ONNX' if use_onnx else 'JIT'}")
159
+
160
+ if _download_file(url, dest_path, progress_callback):
161
+ log("Silero VAD 模型下载完成!")
162
+ return True, dest_path
163
+ else:
164
+ return False, "Silero VAD 模型下载失败"
165
+
166
+
167
+ def ensure_vad_model(
168
+ models_dir: str,
169
+ progress_callback: Optional[Callable[[str], None]] = None
170
+ ) -> str:
171
+ """
172
+ 确保 VAD 模型已下载,如未下载则自动下载
173
+
174
+ 参数:
175
+ models_dir: 模型根目录
176
+ progress_callback: 进度回调
177
+
178
+ 返回:
179
+ 模型文件路径
180
+
181
+ 异常:
182
+ RuntimeError: 下载失败时抛出
183
+ """
184
+ model_path = get_vad_model_path(models_dir)
185
+
186
+ if os.path.exists(model_path):
187
+ logger.info(f"Silero VAD 模型已就绪: {model_path}")
188
+ return model_path
189
+
190
+ success, result = download_silero_vad(models_dir, progress_callback)
191
+ if success:
192
+ return result
193
+ else:
194
+ raise RuntimeError(f"Silero VAD 模型下载失败: {result}")
tests/test_mfa_model_downloader.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ MFA 模型下载模块单元测试
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import unittest
9
+ from unittest.mock import patch, MagicMock
10
+ from pathlib import Path
11
+
12
+ # 添加项目根目录到路径
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ from src.mfa_model_downloader import (
16
+ get_available_languages,
17
+ LANGUAGE_MODELS,
18
+ GITHUB_RELEASE_BASE,
19
+ download_acoustic_model,
20
+ download_dictionary,
21
+ download_language_models,
22
+ )
23
+
24
+
25
+ class TestGetAvailableLanguages(unittest.TestCase):
26
+ """测试获取可用语言列表"""
27
+
28
+ def test_returns_dict(self):
29
+ """返回值应为字典"""
30
+ result = get_available_languages()
31
+ self.assertIsInstance(result, dict)
32
+
33
+ def test_contains_mandarin(self):
34
+ """应包含中文"""
35
+ result = get_available_languages()
36
+ self.assertIn("mandarin", result)
37
+ self.assertEqual(result["mandarin"], "中文 (普通话)")
38
+
39
+ def test_contains_japanese(self):
40
+ """应包含日文"""
41
+ result = get_available_languages()
42
+ self.assertIn("japanese", result)
43
+ self.assertEqual(result["japanese"], "日文")
44
+
45
+
46
+ class TestLanguageModelsConfig(unittest.TestCase):
47
+ """测试语言模型配置"""
48
+
49
+ def test_mandarin_config_complete(self):
50
+ """中文配置应完整"""
51
+ config = LANGUAGE_MODELS["mandarin"]
52
+ self.assertIn("name", config)
53
+ self.assertIn("acoustic", config)
54
+ self.assertIn("dictionary", config)
55
+
56
+ # 声学模型配置
57
+ acoustic = config["acoustic"]
58
+ self.assertIn("tag", acoustic)
59
+ self.assertIn("filename", acoustic)
60
+ self.assertTrue(acoustic["filename"].endswith(".zip"))
61
+
62
+ # 字典配置
63
+ dictionary = config["dictionary"]
64
+ self.assertIn("tag", dictionary)
65
+ self.assertIn("filename", dictionary)
66
+ self.assertTrue(dictionary["filename"].endswith(".dict"))
67
+
68
+ def test_japanese_config_complete(self):
69
+ """日文配置应完整"""
70
+ config = LANGUAGE_MODELS["japanese"]
71
+ self.assertIn("name", config)
72
+ self.assertIn("acoustic", config)
73
+ self.assertIn("dictionary", config)
74
+
75
+ def test_acoustic_url_format(self):
76
+ """声学模型 URL 格式应正确"""
77
+ for lang, config in LANGUAGE_MODELS.items():
78
+ acoustic = config["acoustic"]
79
+ url = f"{GITHUB_RELEASE_BASE}/{acoustic['tag']}/{acoustic['filename']}"
80
+ self.assertTrue(url.startswith("https://github.com/"))
81
+ self.assertIn("mfa-models", url)
82
+
83
+ def test_dictionary_url_format(self):
84
+ """字典 URL 格式应正确"""
85
+ for lang, config in LANGUAGE_MODELS.items():
86
+ dictionary = config["dictionary"]
87
+ url = f"{GITHUB_RELEASE_BASE}/{dictionary['tag']}/{dictionary['filename']}"
88
+ self.assertTrue(url.startswith("https://github.com/"))
89
+ self.assertIn("dictionary-", url)
90
+
91
+
92
+ class TestDownloadAcousticModel(unittest.TestCase):
93
+ """测试声学模型下载"""
94
+
95
+ def test_invalid_language(self):
96
+ """不支持的语言应返回失败"""
97
+ success, result = download_acoustic_model("invalid_lang", "/tmp")
98
+ self.assertFalse(success)
99
+ self.assertIn("不支持的语言", result)
100
+
101
+ @patch('src.mfa_model_downloader._download_file')
102
+ def test_download_called_with_correct_url(self, mock_download):
103
+ """应使用正确的 URL 下载"""
104
+ mock_download.return_value = True
105
+
106
+ with patch('os.path.exists', return_value=False):
107
+ download_acoustic_model("mandarin", "/tmp/models")
108
+
109
+ # 验证调用参数
110
+ call_args = mock_download.call_args
111
+ url = call_args[0][0]
112
+ self.assertIn("mandarin_mfa.zip", url)
113
+ self.assertIn("acoustic-mandarin_mfa", url)
114
+
115
+ @patch('os.path.exists')
116
+ def test_skip_if_exists(self, mock_exists):
117
+ """文件已存在时应跳过下载"""
118
+ mock_exists.return_value = True
119
+
120
+ success, result = download_acoustic_model("mandarin", "/tmp/models")
121
+ self.assertTrue(success)
122
+ self.assertIn("mandarin_mfa.zip", result)
123
+
124
+
125
+ class TestDownloadDictionary(unittest.TestCase):
126
+ """测试字典下载"""
127
+
128
+ def test_invalid_language(self):
129
+ """不支持的语言应返回失败"""
130
+ success, result = download_dictionary("invalid_lang", "/tmp")
131
+ self.assertFalse(success)
132
+ self.assertIn("不支持的语言", result)
133
+
134
+ @patch('src.mfa_model_downloader._download_file')
135
+ def test_download_called_with_correct_url(self, mock_download):
136
+ """应使用正确的 URL 下载"""
137
+ mock_download.return_value = True
138
+
139
+ with patch('os.path.exists', return_value=False):
140
+ download_dictionary("japanese", "/tmp/models")
141
+
142
+ call_args = mock_download.call_args
143
+ url = call_args[0][0]
144
+ self.assertIn("github.com", url)
145
+ self.assertIn("dictionary-japanese", url)
146
+
147
+
148
+ class TestDownloadLanguageModels(unittest.TestCase):
149
+ """测试完整语言模型下载"""
150
+
151
+ def test_invalid_language(self):
152
+ """不支持的语言应返回失败"""
153
+ success, acoustic, dict_path = download_language_models("invalid", "/tmp")
154
+ self.assertFalse(success)
155
+
156
+ @patch('src.mfa_model_downloader.download_dictionary')
157
+ @patch('src.mfa_model_downloader.download_acoustic_model')
158
+ def test_downloads_both_models(self, mock_acoustic, mock_dict):
159
+ """应同时下载声学模型和字典"""
160
+ mock_acoustic.return_value = (True, "/tmp/acoustic.zip")
161
+ mock_dict.return_value = (True, "/tmp/dict.dict")
162
+
163
+ success, acoustic, dict_path = download_language_models("mandarin", "/tmp")
164
+
165
+ self.assertTrue(success)
166
+ mock_acoustic.assert_called_once()
167
+ mock_dict.assert_called_once()
168
+
169
+ @patch('src.mfa_model_downloader.download_dictionary')
170
+ @patch('src.mfa_model_downloader.download_acoustic_model')
171
+ def test_stops_on_acoustic_failure(self, mock_acoustic, mock_dict):
172
+ """声学模型下载失败时应停止"""
173
+ mock_acoustic.return_value = (False, "下载失败")
174
+
175
+ success, _, _ = download_language_models("mandarin", "/tmp")
176
+
177
+ self.assertFalse(success)
178
+ mock_dict.assert_not_called()
179
+
180
+
181
+ if __name__ == "__main__":
182
+ unittest.main()
tests/test_mfa_runner.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ MFA 运行模块单元测试
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import unittest
9
+ from unittest.mock import patch, MagicMock
10
+ from pathlib import Path
11
+
12
+ # 添加项目根目录到路径
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ from src.mfa_runner import (
16
+ check_mfa_available,
17
+ _build_mfa_env,
18
+ run_mfa_alignment,
19
+ run_mfa_validate,
20
+ BASE_DIR,
21
+ MFA_ENGINE_DIR,
22
+ MFA_PYTHON,
23
+ )
24
+
25
+
26
+ class TestCheckMfaAvailable(unittest.TestCase):
27
+ """测试 MFA 环境检查"""
28
+
29
+ @patch('src.mfa_runner.MFA_ENGINE_DIR')
30
+ def test_returns_false_when_dir_not_exists(self, mock_dir):
31
+ """目录不存在时应返回 False"""
32
+ mock_path = MagicMock()
33
+ mock_path.exists.return_value = False
34
+
35
+ with patch.object(Path, 'exists', return_value=False):
36
+ # 由于模块级变量,需要重新导入或直接测试逻辑
37
+ pass
38
+
39
+ def test_path_constants_defined(self):
40
+ """路径常量应正确定义"""
41
+ self.assertIsInstance(BASE_DIR, Path)
42
+ self.assertIsInstance(MFA_ENGINE_DIR, Path)
43
+ self.assertIsInstance(MFA_PYTHON, Path)
44
+
45
+ # 验证路径结构
46
+ self.assertTrue(str(MFA_ENGINE_DIR).endswith("mfa_engine"))
47
+ self.assertTrue(str(MFA_PYTHON).endswith("python.exe"))
48
+
49
+
50
+ class TestBuildMfaEnv(unittest.TestCase):
51
+ """测试 MFA 环境变量构建"""
52
+
53
+ def test_returns_dict(self):
54
+ """应返回字典"""
55
+ env = _build_mfa_env()
56
+ self.assertIsInstance(env, dict)
57
+
58
+ def test_path_contains_mfa_dirs(self):
59
+ """PATH 应包含 MFA 相关目录"""
60
+ env = _build_mfa_env()
61
+ path = env.get("PATH", "")
62
+
63
+ self.assertIn("mfa_engine", path)
64
+ self.assertIn("Library", path)
65
+
66
+ def test_preserves_original_path(self):
67
+ """应保留原始 PATH"""
68
+ original_path = os.environ.get("PATH", "")
69
+ env = _build_mfa_env()
70
+
71
+ # 原始 PATH 应在新 PATH 中
72
+ self.assertIn(original_path.split(";")[0], env["PATH"])
73
+
74
+
75
+ class TestRunMfaAlignment(unittest.TestCase):
76
+ """测试 MFA 对齐功能"""
77
+
78
+ @patch('src.mfa_runner.check_mfa_available')
79
+ def test_fails_when_mfa_unavailable(self, mock_check):
80
+ """MFA 不可用时应返回失败"""
81
+ mock_check.return_value = False
82
+
83
+ success, msg = run_mfa_alignment("/input", "/output")
84
+
85
+ self.assertFalse(success)
86
+ self.assertIn("不可用", msg)
87
+
88
+ @patch('src.mfa_runner.check_mfa_available')
89
+ @patch('os.path.isdir')
90
+ def test_fails_when_corpus_not_exists(self, mock_isdir, mock_check):
91
+ """输入目录不存在时应返回失败"""
92
+ mock_check.return_value = True
93
+ mock_isdir.return_value = False
94
+
95
+ success, msg = run_mfa_alignment("/nonexistent", "/output")
96
+
97
+ self.assertFalse(success)
98
+ self.assertIn("不存在", msg)
99
+
100
+ @patch('src.mfa_runner.check_mfa_available')
101
+ @patch('os.path.isdir')
102
+ @patch('os.path.isfile')
103
+ def test_fails_when_dict_not_exists(self, mock_isfile, mock_isdir, mock_check):
104
+ """字典文件不存在时应返回失败"""
105
+ mock_check.return_value = True
106
+ mock_isdir.return_value = True
107
+ mock_isfile.return_value = False
108
+
109
+ success, msg = run_mfa_alignment(
110
+ "/input", "/output",
111
+ dict_path="/nonexistent.dict"
112
+ )
113
+
114
+ self.assertFalse(success)
115
+ self.assertIn("不存在", msg)
116
+
117
+ @patch('src.mfa_runner.check_mfa_available')
118
+ @patch('os.path.isdir')
119
+ @patch('os.path.isfile')
120
+ @patch('os.makedirs')
121
+ @patch('subprocess.run')
122
+ def test_calls_subprocess_with_correct_args(
123
+ self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
124
+ ):
125
+ """应使用正确的参数调用 subprocess"""
126
+ mock_check.return_value = True
127
+ mock_isdir.return_value = True
128
+ mock_isfile.return_value = True
129
+ mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
130
+
131
+ run_mfa_alignment(
132
+ "/input", "/output",
133
+ dict_path="/dict.dict",
134
+ model_path="/model.zip",
135
+ single_speaker=True,
136
+ clean=True
137
+ )
138
+
139
+ # 验证 subprocess.run 被调用
140
+ mock_run.assert_called_once()
141
+
142
+ # 验证命令参数
143
+ call_args = mock_run.call_args
144
+ cmd = call_args[0][0]
145
+
146
+ self.assertIn("align", cmd)
147
+ self.assertIn("/input", cmd)
148
+ self.assertIn("/dict.dict", cmd)
149
+ self.assertIn("/model.zip", cmd)
150
+ self.assertIn("/output", cmd)
151
+ self.assertIn("--single_speaker", cmd)
152
+ self.assertIn("--clean", cmd)
153
+
154
+ @patch('src.mfa_runner.check_mfa_available')
155
+ @patch('os.path.isdir')
156
+ @patch('os.path.isfile')
157
+ @patch('os.makedirs')
158
+ @patch('subprocess.run')
159
+ def test_returns_success_on_zero_returncode(
160
+ self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
161
+ ):
162
+ """返回码为 0 时应返回成功"""
163
+ mock_check.return_value = True
164
+ mock_isdir.return_value = True
165
+ mock_isfile.return_value = True
166
+ mock_run.return_value = MagicMock(returncode=0, stdout="完成", stderr="")
167
+
168
+ success, msg = run_mfa_alignment(
169
+ "/input", "/output",
170
+ dict_path="/dict.dict",
171
+ model_path="/model.zip"
172
+ )
173
+
174
+ self.assertTrue(success)
175
+
176
+ @patch('src.mfa_runner.check_mfa_available')
177
+ @patch('os.path.isdir')
178
+ @patch('os.path.isfile')
179
+ @patch('os.makedirs')
180
+ @patch('subprocess.run')
181
+ def test_returns_failure_on_nonzero_returncode(
182
+ self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
183
+ ):
184
+ """返回码非 0 时应返回失败"""
185
+ mock_check.return_value = True
186
+ mock_isdir.return_value = True
187
+ mock_isfile.return_value = True
188
+ mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="错误")
189
+
190
+ success, msg = run_mfa_alignment(
191
+ "/input", "/output",
192
+ dict_path="/dict.dict",
193
+ model_path="/model.zip"
194
+ )
195
+
196
+ self.assertFalse(success)
197
+
198
+
199
+ class TestRunMfaValidate(unittest.TestCase):
200
+ """测试 MFA 验证功能"""
201
+
202
+ @patch('src.mfa_runner.check_mfa_available')
203
+ def test_fails_when_mfa_unavailable(self, mock_check):
204
+ """MFA 不可用时应返回失败"""
205
+ mock_check.return_value = False
206
+
207
+ success, msg = run_mfa_validate("/corpus")
208
+
209
+ self.assertFalse(success)
210
+ self.assertIn("不可用", msg)
211
+
212
+
213
+ class TestProgressCallback(unittest.TestCase):
214
+ """测试进度回调"""
215
+
216
+ @patch('src.mfa_runner.check_mfa_available')
217
+ @patch('os.path.isdir')
218
+ @patch('os.path.isfile')
219
+ @patch('os.makedirs')
220
+ @patch('subprocess.run')
221
+ def test_callback_called_on_success(
222
+ self, mock_run, mock_makedirs, mock_isfile, mock_isdir, mock_check
223
+ ):
224
+ """成功时应调用回调"""
225
+ mock_check.return_value = True
226
+ mock_isdir.return_value = True
227
+ mock_isfile.return_value = True
228
+ mock_run.return_value = MagicMock(returncode=0, stdout="完成", stderr="")
229
+ callback = MagicMock()
230
+
231
+ run_mfa_alignment(
232
+ "/input", "/output",
233
+ dict_path="/dict.dict",
234
+ model_path="/model.zip",
235
+ progress_callback=callback
236
+ )
237
+
238
+ # 回调应被调用(至少一次)
239
+ self.assertTrue(callback.called)
240
+
241
+
242
+ if __name__ == "__main__":
243
+ unittest.main()
tests/test_silero_vad_downloader.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Silero VAD 下载模块测试
4
+ """
5
+
6
+ import os
7
+ import tempfile
8
+ import unittest
9
+ from unittest.mock import patch, MagicMock
10
+
11
+ from src.silero_vad_downloader import (
12
+ get_vad_model_path,
13
+ is_vad_model_downloaded,
14
+ download_silero_vad,
15
+ ensure_vad_model,
16
+ SILERO_VAD_CONFIG
17
+ )
18
+
19
+
20
+ class TestSileroVadDownloader(unittest.TestCase):
21
+ """Silero VAD 下载器测试类"""
22
+
23
+ def test_get_vad_model_path(self):
24
+ """测试获取模型路径"""
25
+ models_dir = "/test/models"
26
+ expected = os.path.join(models_dir, "silero_vad", "silero_vad.onnx")
27
+ self.assertEqual(get_vad_model_path(models_dir), expected)
28
+
29
+ def test_is_vad_model_downloaded_false(self):
30
+ """测试模型未下载时返回 False"""
31
+ with tempfile.TemporaryDirectory() as tmpdir:
32
+ self.assertFalse(is_vad_model_downloaded(tmpdir))
33
+
34
+ def test_is_vad_model_downloaded_true(self):
35
+ """测试模型已下载时返回 True"""
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ vad_dir = os.path.join(tmpdir, "silero_vad")
38
+ os.makedirs(vad_dir)
39
+ model_path = os.path.join(vad_dir, "silero_vad.onnx")
40
+ with open(model_path, "w") as f:
41
+ f.write("dummy")
42
+ self.assertTrue(is_vad_model_downloaded(tmpdir))
43
+
44
+ def test_download_silero_vad_already_exists(self):
45
+ """测试模型已存在时跳过下载"""
46
+ with tempfile.TemporaryDirectory() as tmpdir:
47
+ vad_dir = os.path.join(tmpdir, "silero_vad")
48
+ os.makedirs(vad_dir)
49
+ model_path = os.path.join(vad_dir, "silero_vad.onnx")
50
+ with open(model_path, "w") as f:
51
+ f.write("dummy")
52
+
53
+ success, result = download_silero_vad(tmpdir)
54
+ self.assertTrue(success)
55
+ self.assertEqual(result, model_path)
56
+
57
+ def test_config_values(self):
58
+ """测试配置值正确性"""
59
+ self.assertEqual(SILERO_VAD_CONFIG["onnx_filename"], "silero_vad.onnx")
60
+ self.assertEqual(SILERO_VAD_CONFIG["jit_filename"], "silero_vad.jit")
61
+ self.assertIn("snakers4/silero-vad", SILERO_VAD_CONFIG["repo"])
62
+
63
+
64
+ if __name__ == "__main__":
65
+ unittest.main()