TNOT commited on
Commit
9361148
·
1 Parent(s): b0dfe06

完成了单音频或选择文件夹 → Silero VAD 切片 → Whisper 转录 → 生成 .lab → MFA 强制对齐 → 简单导出,已可用

Browse files
.gitignore CHANGED
@@ -24,9 +24,10 @@ __pycache__/
24
  temp/
25
  *.tmp
26
 
27
- # 数据(根据需要调整)
28
  config.json
29
  bank/
 
30
 
31
  # AI 模型相关
32
  tools/mfa_engine
 
24
  temp/
25
  *.tmp
26
 
27
+ # 数据
28
  config.json
29
  bank/
30
+ export/
31
 
32
  # AI 模型相关
33
  tools/mfa_engine
requirements.in CHANGED
@@ -9,8 +9,12 @@ customtkinter
9
  # Whisper 语音识别
10
  transformers>=4.25.0
11
  torch
 
12
  accelerate
13
 
14
  # Silero VAD 语音活动检测
15
  silero-vad>=5.1
16
  onnxruntime
 
 
 
 
9
  # Whisper 语音识别
10
  transformers>=4.25.0
11
  torch
12
+ torchaudio
13
  accelerate
14
 
15
  # Silero VAD 语音活动检测
16
  silero-vad>=5.1
17
  onnxruntime
18
+
19
+ # 文本处理
20
+ pypinyin
requirements.txt CHANGED
@@ -132,7 +132,9 @@ torch==2.10.0
132
  # silero-vad
133
  # torchaudio
134
  torchaudio==2.10.0
135
- # via silero-vad
 
 
136
  tqdm==4.67.1
137
  # via
138
  # -r requirements.in
 
132
  # silero-vad
133
  # torchaudio
134
  torchaudio==2.10.0
135
+ # via
136
+ # -r requirements.in
137
+ # silero-vad
138
  tqdm==4.67.1
139
  # via
140
  # -r requirements.in
src/audio_processor.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 音频处理模块
4
+ 包含 Silero VAD 切片和 Whisper 转录功能
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional, Callable, List, Tuple
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AudioProcessor:
16
+ """音频处理器,整合VAD切片和Whisper转录"""
17
+
18
+ def __init__(
19
+ self,
20
+ models_dir: str,
21
+ progress_callback: Optional[Callable[[str], None]] = None
22
+ ):
23
+ """
24
+ 初始化音频处理器
25
+
26
+ 参数:
27
+ models_dir: 模型目录
28
+ progress_callback: 进度回调函数
29
+ """
30
+ self.models_dir = models_dir
31
+ self.progress_callback = progress_callback
32
+ self.vad_model = None
33
+ self.whisper_pipe = None
34
+
35
+ def _log(self, msg: str):
36
+ """记录日志并回调"""
37
+ logger.info(msg)
38
+ if self.progress_callback:
39
+ self.progress_callback(msg)
40
+
41
+ def load_vad_model(self):
42
+ """加载 Silero VAD 模型"""
43
+ if self.vad_model is not None:
44
+ return
45
+
46
+ self._log("正在加载 Silero VAD 模型...")
47
+
48
+ from src.silero_vad_downloader import ensure_vad_model
49
+ import torch
50
+
51
+ # 确保模型已下载
52
+ model_path = ensure_vad_model(self.models_dir, self.progress_callback)
53
+
54
+ # 加载模型
55
+ self.vad_model, utils = torch.hub.load(
56
+ repo_or_dir='snakers4/silero-vad',
57
+ model='silero_vad',
58
+ force_reload=False,
59
+ onnx=True
60
+ )
61
+ self.vad_utils = utils
62
+ self._log("Silero VAD 模型加载完成")
63
+
64
+ def load_whisper_model(self, model_name: str = "openai/whisper-small"):
65
+ """
66
+ 加载 Whisper 模型
67
+
68
+ 参数:
69
+ model_name: 模型名称
70
+ """
71
+ if self.whisper_pipe is not None:
72
+ return
73
+
74
+ self._log(f"正在加载 Whisper 模型: {model_name}...")
75
+
76
+ from transformers import pipeline
77
+ import torch
78
+
79
+ cache_dir = os.path.join(self.models_dir, "whisper")
80
+ os.makedirs(cache_dir, exist_ok=True)
81
+
82
+ os.environ["HF_HOME"] = cache_dir
83
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
84
+
85
+ self.whisper_pipe = pipeline(
86
+ "automatic-speech-recognition",
87
+ model=model_name,
88
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
89
+ device_map="auto",
90
+ model_kwargs={"cache_dir": cache_dir}
91
+ )
92
+ self._log("Whisper 模型加载完成")
93
+
94
+ def vad_split(
95
+ self,
96
+ audio_path: str,
97
+ output_dir: str,
98
+ min_speech_duration_ms: int = 250,
99
+ min_silence_duration_ms: int = 100,
100
+ threshold: float = 0.5
101
+ ) -> List[str]:
102
+ """
103
+ 使用 VAD 对音频进行切片
104
+
105
+ 参数:
106
+ audio_path: 输入音频路径
107
+ output_dir: 输出目录
108
+ min_speech_duration_ms: 最小语音时长(毫秒)
109
+ min_silence_duration_ms: 最小静音时长(毫秒)
110
+ threshold: VAD阈值
111
+
112
+ 返回:
113
+ 切片文件路径列表
114
+ """
115
+ import torch
116
+ import torchaudio
117
+
118
+ self.load_vad_model()
119
+
120
+ basename = Path(audio_path).stem
121
+ os.makedirs(output_dir, exist_ok=True)
122
+
123
+ self._log(f"正在处理: {audio_path}")
124
+
125
+ # 读取音频
126
+ wav, sr = torchaudio.load(audio_path)
127
+
128
+ # 转换为单声道
129
+ if wav.shape[0] > 1:
130
+ wav = wav.mean(dim=0, keepdim=True)
131
+ wav = wav.squeeze(0)
132
+
133
+ # 重采样到16kHz (VAD要求)
134
+ if sr != 16000:
135
+ resampler = torchaudio.transforms.Resample(sr, 16000)
136
+ wav_16k = resampler(wav)
137
+ sr_vad = 16000
138
+ else:
139
+ wav_16k = wav
140
+ sr_vad = sr
141
+
142
+ # 获取语音时间戳
143
+ get_speech_timestamps = self.vad_utils[0]
144
+ speech_timestamps = get_speech_timestamps(
145
+ wav_16k,
146
+ self.vad_model,
147
+ threshold=threshold,
148
+ min_speech_duration_ms=min_speech_duration_ms,
149
+ min_silence_duration_ms=min_silence_duration_ms,
150
+ sampling_rate=sr_vad
151
+ )
152
+
153
+ self._log(f"检测到 {len(speech_timestamps)} 个语音片段")
154
+
155
+ # 切片并保存
156
+ output_files = []
157
+ for i, ts in enumerate(speech_timestamps):
158
+ # 转换回原始采样率的索引
159
+ start = int(ts['start'] * sr / sr_vad)
160
+ end = int(ts['end'] * sr / sr_vad)
161
+
162
+ segment = wav[start:end]
163
+
164
+ output_path = os.path.join(output_dir, f"{basename}_{i:04d}.wav")
165
+ torchaudio.save(output_path, segment.unsqueeze(0), sr)
166
+ output_files.append(output_path)
167
+
168
+ self._log(f"切片完成,共 {len(output_files)} 个文件")
169
+ return output_files
170
+
171
+ def transcribe(self, audio_path: str, language: str = "chinese") -> str:
172
+ """
173
+ 使用 Whisper 转录音频
174
+
175
+ 参数:
176
+ audio_path: 音频文件路径
177
+ language: 语言
178
+
179
+ 返回:
180
+ 转录文本
181
+ """
182
+ if self.whisper_pipe is None:
183
+ raise RuntimeError("Whisper 模型未加载")
184
+
185
+ result = self.whisper_pipe(
186
+ audio_path,
187
+ generate_kwargs={"language": language}
188
+ )
189
+ return result["text"].strip()
190
+
191
+ def generate_lab(self, audio_path: str, text: str) -> str:
192
+ """
193
+ 生成 .lab 文件
194
+
195
+ 参数:
196
+ audio_path: 音频文件路径
197
+ text: 转录文本
198
+
199
+ 返回:
200
+ lab文件路径
201
+ """
202
+ lab_path = os.path.splitext(audio_path)[0] + ".lab"
203
+ with open(lab_path, "w", encoding="utf-8") as f:
204
+ f.write(text)
205
+ return lab_path
206
+
207
+ def process_full_pipeline(
208
+ self,
209
+ input_path: str,
210
+ output_dir: str,
211
+ language: str = "chinese",
212
+ whisper_model: str = "openai/whisper-small"
213
+ ) -> Tuple[bool, str, List[str]]:
214
+ """
215
+ 完整处理流程: VAD切片 → Whisper转录 → 生成.lab
216
+
217
+ 参数:
218
+ input_path: 输入音频文件或目录
219
+ output_dir: 输出目录
220
+ language: 转录语言
221
+ whisper_model: Whisper模型名称
222
+
223
+ 返回:
224
+ (成功标志, 消息, 输出文件列表)
225
+ """
226
+ try:
227
+ # 加载模型
228
+ self.load_vad_model()
229
+ self.load_whisper_model(whisper_model)
230
+
231
+ # 收集输入文件
232
+ input_files = []
233
+ if os.path.isfile(input_path):
234
+ input_files = [input_path]
235
+ elif os.path.isdir(input_path):
236
+ for f in os.listdir(input_path):
237
+ if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a')):
238
+ input_files.append(os.path.join(input_path, f))
239
+
240
+ if not input_files:
241
+ return False, "未找到音频文件", []
242
+
243
+ self._log(f"找到 {len(input_files)} 个音频文件")
244
+
245
+ # 创建输出目录
246
+ slices_dir = os.path.join(output_dir, "slices")
247
+ os.makedirs(slices_dir, exist_ok=True)
248
+
249
+ all_output_files = []
250
+
251
+ for idx, audio_file in enumerate(input_files):
252
+ self._log(f"处理 [{idx+1}/{len(input_files)}]: {os.path.basename(audio_file)}")
253
+
254
+ # VAD切片
255
+ slice_files = self.vad_split(audio_file, slices_dir)
256
+
257
+ # 转录每个切片
258
+ for slice_file in slice_files:
259
+ self._log(f"转录: {os.path.basename(slice_file)}")
260
+ text = self.transcribe(slice_file, language)
261
+
262
+ if text:
263
+ lab_path = self.generate_lab(slice_file, text)
264
+ self._log(f"生成: {os.path.basename(lab_path)} -> {text[:30]}...")
265
+ all_output_files.append(slice_file)
266
+ else:
267
+ self._log(f"跳过空转录: {os.path.basename(slice_file)}")
268
+
269
+ return True, f"处理完成,共 {len(all_output_files)} 个切片", all_output_files
270
+
271
+ except Exception as e:
272
+ logger.error(f"处理失败: {e}", exc_info=True)
273
+ return False, str(e), []
274
+
275
+
276
+ def process_audio_pipeline(
277
+ input_path: str,
278
+ output_dir: str,
279
+ models_dir: str,
280
+ language: str = "chinese",
281
+ whisper_model: str = "openai/whisper-small",
282
+ progress_callback: Optional[Callable[[str], None]] = None
283
+ ) -> Tuple[bool, str, List[str]]:
284
+ """
285
+ 便捷函数:执行完整音频处理流程
286
+
287
+ 参数:
288
+ input_path: 输入音频文件或目录
289
+ output_dir: 输出目录
290
+ models_dir: 模型目录
291
+ language: 转录语言
292
+ whisper_model: Whisper模型名称
293
+ progress_callback: 进度回调
294
+
295
+ 返回:
296
+ (成功标志, 消息, 输出文件列表)
297
+ """
298
+ processor = AudioProcessor(models_dir, progress_callback)
299
+ return processor.process_full_pipeline(input_path, output_dir, language, whisper_model)
src/bank_sort.py CHANGED
@@ -1,18 +1,102 @@
 
 
 
 
 
 
 
1
  import glob
2
- import audiofile
3
- import tqdm
4
  import shutil
 
 
 
 
 
5
 
6
- stats = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- for path in tqdm.tqdm(glob.glob('bank/**/*.wav')):
9
- word, filename = path.split('\\')[1:]
10
- if word not in stats:
11
- stats[word] = []
12
- stats[word].append((filename, audiofile.duration(path)))
13
 
14
- for word in tqdm.tqdm(stats):
15
- for index, v in enumerate(sorted(stats[word], key=lambda x:-x[1])):
16
- if index >= 100:
17
- continue
18
- shutil.copyfile('bank\\' + word + '\\' + v[0], 'bank\\' + word + '_' + str(index) + '.wav')
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 音频库排序模块
4
+ 按时长排序并导出音频片段
5
+ """
6
+
7
+ import os
8
  import glob
 
 
9
  import shutil
10
+ import logging
11
+ from typing import Optional, Callable, Tuple, Dict, List
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
 
16
+ def sort_and_export_bank(
17
+ bank_dir: str,
18
+ output_dir: str,
19
+ max_per_word: int = 100,
20
+ progress_callback: Optional[Callable[[str], None]] = None
21
+ ) -> Tuple[bool, str, Dict[str, int]]:
22
+ """
23
+ 对音频库进行排序并导出
24
+
25
+ 参数:
26
+ bank_dir: 音频库目录 (包含 [词]/[编号].wav 结构)
27
+ output_dir: 导出目录
28
+ max_per_word: 每个词最多保留的样本数
29
+ progress_callback: 进度回调函数
30
+
31
+ 返回:
32
+ (成功标志, 消息, 导出统计)
33
+ """
34
+ import audiofile
35
+
36
+ def log(msg: str):
37
+ logger.info(msg)
38
+ if progress_callback:
39
+ progress_callback(msg)
40
+
41
+ try:
42
+ os.makedirs(output_dir, exist_ok=True)
43
+
44
+ # 统计所有词条
45
+ stats: Dict[str, List[Tuple[str, float]]] = {}
46
+ wav_files = glob.glob(os.path.join(bank_dir, '**', '*.wav'), recursive=True)
47
+
48
+ if not wav_files:
49
+ return False, "未找到WAV文件", {}
50
+
51
+ log(f"扫描到 {len(wav_files)} 个WAV文件")
52
+
53
+ for path in wav_files:
54
+ rel_path = os.path.relpath(path, bank_dir)
55
+ parts = rel_path.split(os.sep)
56
+ if len(parts) >= 2:
57
+ word = parts[0]
58
+ if word not in stats:
59
+ stats[word] = []
60
+ try:
61
+ duration = audiofile.duration(path)
62
+ stats[word].append((path, duration))
63
+ except Exception as e:
64
+ log(f"警告: 无法读取 {path}: {e}")
65
+
66
+ log(f"统计到 {len(stats)} 个词条")
67
+
68
+ # 按时长排序并导出
69
+ export_counts = {}
70
+ for word, files in stats.items():
71
+ # 按时长降序排序
72
+ sorted_files = sorted(files, key=lambda x: -x[1])
73
+ count = 0
74
+ for idx, (src_path, _) in enumerate(sorted_files[:max_per_word]):
75
+ dst_path = os.path.join(output_dir, f'{word}_{idx}.wav')
76
+ shutil.copyfile(src_path, dst_path)
77
+ count += 1
78
+ export_counts[word] = count
79
+ log(f"处理词条: {word} ({count} 个文件)")
80
+
81
+ total = sum(export_counts.values())
82
+ log(f"导出完成: {len(export_counts)} 个词条,{total} 个文件")
83
+
84
+ return True, f"导出完成: {len(export_counts)} 个词条,{total} 个文件", export_counts
85
+
86
+ except Exception as e:
87
+ logger.error(f"排序导出失败: {e}", exc_info=True)
88
+ return False, str(e), {}
89
 
 
 
 
 
 
90
 
91
+ # 保留原有脚本入口以兼容
92
+ if __name__ == "__main__":
93
+ bank_dir = 'bank'
94
+ output_dir = 'bank_export'
95
+
96
+ success, msg, stats = sort_and_export_bank(
97
+ bank_dir=bank_dir,
98
+ output_dir=output_dir,
99
+ max_per_word=100,
100
+ progress_callback=print
101
+ )
102
+ print(f"结果: {msg}")
src/export_plugins/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 导出插件系统
4
+
5
+ 插件式导出架构,支持动态加载和配置
6
+ """
7
+
8
+ from .base import ExportPlugin, PluginOption, OptionType
9
+ from .loader import load_plugins, get_builtin_plugins
10
+
11
+ __all__ = [
12
+ 'ExportPlugin',
13
+ 'PluginOption',
14
+ 'OptionType',
15
+ 'load_plugins',
16
+ 'get_builtin_plugins'
17
+ ]
src/export_plugins/base.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 导出插件基类
4
+
5
+ 定义插件接口和配置选项类型
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from abc import ABC, abstractmethod
11
+ from enum import Enum
12
+ from dataclasses import dataclass, field
13
+ from typing import Any, Callable, Dict, List, Optional, Tuple
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OptionType(Enum):
19
+ """配置选项类型"""
20
+ TEXT = "text" # 文本输入框
21
+ NUMBER = "number" # 数字输入框
22
+ SWITCH = "switch" # 开关
23
+ LABEL = "label" # 纯文本标签(不可编辑)
24
+ FILE = "file" # 文件选择
25
+ FOLDER = "folder" # 文件夹选择
26
+ COMBO = "combo" # 下拉选择框
27
+
28
+
29
+ @dataclass
30
+ class PluginOption:
31
+ """插件配置选项"""
32
+ key: str # 选项键名
33
+ label: str # 显示标签
34
+ option_type: OptionType # 选项类型
35
+ default: Any = None # 默认值
36
+ description: str = "" # 描述说明
37
+ choices: List[str] = field(default_factory=list) # 下拉选项(仅COMBO类型)
38
+ min_value: Optional[float] = None # 最小值(仅NUMBER类型)
39
+ max_value: Optional[float] = None # 最大值(仅NUMBER类型)
40
+ file_types: List[Tuple[str, str]] = field(default_factory=list) # 文件类型过滤
41
+
42
+
43
+ class ExportPlugin(ABC):
44
+ """导出插件基类"""
45
+
46
+ # 插件元信息(子类必须覆盖)
47
+ name: str = "未命名插件"
48
+ description: str = ""
49
+ version: str = "1.0.0"
50
+ author: str = ""
51
+
52
+ def __init__(self):
53
+ self._options: Dict[str, Any] = {}
54
+ self._progress_callback: Optional[Callable[[str], None]] = None
55
+ # 初始化默认值
56
+ for opt in self.get_options():
57
+ self._options[opt.key] = opt.default
58
+
59
+ @abstractmethod
60
+ def get_options(self) -> List[PluginOption]:
61
+ """
62
+ 获取插件配置选项列表
63
+
64
+ 返回:
65
+ 配置选项列表
66
+ """
67
+ pass
68
+
69
+ @abstractmethod
70
+ def export(
71
+ self,
72
+ source_name: str,
73
+ bank_dir: str,
74
+ options: Dict[str, Any]
75
+ ) -> Tuple[bool, str]:
76
+ """
77
+ 执行导出
78
+
79
+ 参数:
80
+ source_name: 音源名称
81
+ bank_dir: bank目录路径
82
+ options: 用户配置的选项值
83
+
84
+ 返回:
85
+ (成功标志, 消息)
86
+ """
87
+ pass
88
+
89
+ def set_progress_callback(self, callback: Callable[[str], None]):
90
+ """设置进度回调"""
91
+ self._progress_callback = callback
92
+
93
+ def _log(self, msg: str):
94
+ """记录日志"""
95
+ logger.info(msg)
96
+ if self._progress_callback:
97
+ self._progress_callback(msg)
98
+
99
+ def get_option_value(self, key: str) -> Any:
100
+ """获取选项值"""
101
+ return self._options.get(key)
102
+
103
+ def set_option_value(self, key: str, value: Any):
104
+ """设置选项值"""
105
+ self._options[key] = value
106
+
107
+ def reset_to_defaults(self):
108
+ """重置为默认值"""
109
+ for opt in self.get_options():
110
+ self._options[opt.key] = opt.default
111
+
112
+ def get_export_dir(self, bank_dir: str, source_name: str, subdir: str) -> str:
113
+ """
114
+ 获取导出目录路径
115
+
116
+ 参数:
117
+ bank_dir: bank目录
118
+ source_name: 音源名称
119
+ subdir: 子目录名
120
+
121
+ 返回:
122
+ export/[音源名称]/[subdir]/ 路径
123
+ """
124
+ from pathlib import Path
125
+ base = Path(bank_dir).parent
126
+ return os.path.join(base, "export", source_name, subdir)
127
+
128
+ def get_source_paths(self, bank_dir: str, source_name: str) -> Dict[str, str]:
129
+ """
130
+ 获取音源相关路径
131
+
132
+ 返回:
133
+ {
134
+ "source_dir": 音源目录,
135
+ "slices_dir": 切片目录,
136
+ "textgrid_dir": TextGrid目录
137
+ }
138
+ """
139
+ source_dir = os.path.join(bank_dir, source_name)
140
+ return {
141
+ "source_dir": source_dir,
142
+ "slices_dir": os.path.join(source_dir, "slices"),
143
+ "textgrid_dir": os.path.join(source_dir, "textgrid")
144
+ }
src/export_plugins/loader.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 导出插件加载器
4
+
5
+ 负责扫描和加载内置及外部插件
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import logging
11
+ import importlib.util
12
+ from typing import Dict, List, Type
13
+
14
+ from .base import ExportPlugin
15
+ from .simple_export import SimpleExportPlugin
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_builtin_plugins() -> List[Type[ExportPlugin]]:
21
+ """获取内置插件列表"""
22
+ return [SimpleExportPlugin]
23
+
24
+
25
+ def load_plugins(plugins_dir: str = None) -> Dict[str, ExportPlugin]:
26
+ """
27
+ 加载所有插件
28
+
29
+ 参数:
30
+ plugins_dir: 外部插件目录路径,默认为 export_plugins 同级目录
31
+
32
+ 返回:
33
+ {插件名称: 插件实例} 字典
34
+ """
35
+ plugins: Dict[str, ExportPlugin] = {}
36
+
37
+ # 加载内置插件
38
+ for plugin_cls in get_builtin_plugins():
39
+ try:
40
+ instance = plugin_cls()
41
+ plugins[instance.name] = instance
42
+ logger.info(f"加载内置插件: {instance.name}")
43
+ except Exception as e:
44
+ logger.error(f"加载内置插件失败: {plugin_cls.__name__}, {e}")
45
+
46
+ # 加载外部插件
47
+ if plugins_dir and os.path.exists(plugins_dir):
48
+ for filename in os.listdir(plugins_dir):
49
+ if filename.endswith('.py') and not filename.startswith('_'):
50
+ plugin_path = os.path.join(plugins_dir, filename)
51
+ try:
52
+ plugin = _load_plugin_from_file(plugin_path)
53
+ if plugin:
54
+ plugins[plugin.name] = plugin
55
+ logger.info(f"加载外部插件: {plugin.name} ({filename})")
56
+ except Exception as e:
57
+ logger.error(f"加载外部插件失败: {filename}, {e}")
58
+
59
+ return plugins
60
+
61
+
62
+ def _load_plugin_from_file(filepath: str) -> ExportPlugin:
63
+ """
64
+ 从文件加载插件
65
+
66
+ 参数:
67
+ filepath: 插件文件路径
68
+
69
+ 返回:
70
+ 插件实例,加载失败返回None
71
+ """
72
+ try:
73
+ module_name = os.path.splitext(os.path.basename(filepath))[0]
74
+ spec = importlib.util.spec_from_file_location(module_name, filepath)
75
+ if spec is None or spec.loader is None:
76
+ return None
77
+
78
+ module = importlib.util.module_from_spec(spec)
79
+ sys.modules[module_name] = module
80
+ spec.loader.exec_module(module)
81
+
82
+ # 查找 ExportPlugin 子类
83
+ for attr_name in dir(module):
84
+ attr = getattr(module, attr_name)
85
+ if (isinstance(attr, type) and
86
+ issubclass(attr, ExportPlugin) and
87
+ attr is not ExportPlugin):
88
+ return attr()
89
+
90
+ return None
91
+
92
+ except Exception as e:
93
+ logger.error(f"加载插件文件失败: {filepath}, {e}", exc_info=True)
94
+ return None
src/export_plugins/simple_export.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 简单单字导出插件
4
+
5
+ 从TextGrid提取分词片段,按拼音排序导出
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import glob
11
+ import shutil
12
+ import logging
13
+ from typing import Any, Dict, List, Tuple
14
+
15
+ from .base import ExportPlugin, PluginOption, OptionType
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SimpleExportPlugin(ExportPlugin):
21
+ """简单单字导出插件"""
22
+
23
+ name = "简单单字导出"
24
+ description = "从TextGrid提取分词片段,按拼音排序导出"
25
+ version = "1.1.0"
26
+ author = "内置"
27
+
28
+ def get_options(self) -> List[PluginOption]:
29
+ return [
30
+ PluginOption(
31
+ key="info",
32
+ label="将每个汉字按拼音分类,选取最佳样本导出",
33
+ option_type=OptionType.LABEL
34
+ ),
35
+ PluginOption(
36
+ key="max_samples",
37
+ label="每个拼音最大样本数",
38
+ option_type=OptionType.NUMBER,
39
+ default=10,
40
+ min_value=1,
41
+ max_value=1000,
42
+ description="按时长排序,保留最长的N个"
43
+ ),
44
+ PluginOption(
45
+ key="naming_rule",
46
+ label="命名规则",
47
+ option_type=OptionType.TEXT,
48
+ default="%p%%n%",
49
+ description="变量: %p%=拼音, %n%=序号。示例: %p%_%n% → ba_1.wav"
50
+ ),
51
+ PluginOption(
52
+ key="first_naming_rule",
53
+ label="首个样本命名规则",
54
+ option_type=OptionType.TEXT,
55
+ default="%p%",
56
+ description="第0个样本的特殊规则,留空则使用通用规则。示例: %p% → ba.wav"
57
+ ),
58
+ PluginOption(
59
+ key="clean_temp",
60
+ label="导出后清理临时文件",
61
+ option_type=OptionType.SWITCH,
62
+ default=True,
63
+ description="删除临时的segments目录"
64
+ )
65
+ ]
66
+
67
+ def _load_language_from_meta(self, bank_dir: str, source_name: str) -> str:
68
+ """从meta.json加载语言设置"""
69
+ meta_path = os.path.join(bank_dir, source_name, "meta.json")
70
+ try:
71
+ if os.path.exists(meta_path):
72
+ with open(meta_path, 'r', encoding='utf-8') as f:
73
+ meta = json.load(f)
74
+ language = meta.get("language", "chinese")
75
+ self._log(f"从meta.json读取语言设置: {language}")
76
+ return language
77
+ except Exception as e:
78
+ logger.warning(f"读取meta.json失败: {e}")
79
+ return "chinese"
80
+
81
+ def _apply_naming_rule(self, rule: str, pinyin: str, index: int) -> str:
82
+ """应用命名规则生成文件名"""
83
+ name = rule.replace("%p%", pinyin).replace("%n%", str(index))
84
+ return name
85
+
86
+ def export(
87
+ self,
88
+ source_name: str,
89
+ bank_dir: str,
90
+ options: Dict[str, Any]
91
+ ) -> Tuple[bool, str]:
92
+ """执行简单单字导出"""
93
+ try:
94
+ # 自动从meta.json获取语言设置
95
+ language = self._load_language_from_meta(bank_dir, source_name)
96
+ max_samples = int(options.get("max_samples", 10))
97
+ naming_rule = options.get("naming_rule", "%p%_%n%")
98
+ first_naming_rule = options.get("first_naming_rule", "")
99
+ clean_temp = options.get("clean_temp", True)
100
+
101
+ paths = self.get_source_paths(bank_dir, source_name)
102
+ export_dir = self.get_export_dir(bank_dir, source_name, "simple_export")
103
+
104
+ # 临时segments目录
105
+ temp_base = os.path.join(bank_dir, ".temp_segments")
106
+ segments_dir = os.path.join(temp_base, source_name)
107
+
108
+ # 步骤1: 提取分词片段
109
+ self._log("【提取分词片段】")
110
+ success, msg, pinyin_counts = self._extract_segments(
111
+ paths["slices_dir"],
112
+ paths["textgrid_dir"],
113
+ segments_dir,
114
+ language
115
+ )
116
+ if not success:
117
+ return False, msg
118
+
119
+ # 步骤2: 排序导出
120
+ self._log("\n【排序导出】")
121
+ success, msg = self._sort_and_export(
122
+ segments_dir,
123
+ export_dir,
124
+ max_samples,
125
+ naming_rule,
126
+ first_naming_rule
127
+ )
128
+ if not success:
129
+ return False, msg
130
+
131
+ # 清理临时目录
132
+ if clean_temp and os.path.exists(segments_dir):
133
+ self._log(f"\n清理临时目录: {segments_dir}")
134
+ shutil.rmtree(segments_dir)
135
+ if os.path.exists(temp_base) and not os.listdir(temp_base):
136
+ shutil.rmtree(temp_base)
137
+
138
+ return True, f"导出完成: {export_dir}"
139
+
140
+ except Exception as e:
141
+ logger.error(f"简单单字导出失败: {e}", exc_info=True)
142
+ return False, str(e)
143
+
144
+ def _extract_segments(
145
+ self,
146
+ slices_dir: str,
147
+ textgrid_dir: str,
148
+ segments_dir: str,
149
+ language: str
150
+ ) -> Tuple[bool, str, Dict[str, int]]:
151
+ """提取分词片段"""
152
+ try:
153
+ import textgrid
154
+ import soundfile as sf
155
+ from src.text_processor import char_to_pinyin, is_valid_char
156
+
157
+ os.makedirs(segments_dir, exist_ok=True)
158
+
159
+ tg_files = glob.glob(os.path.join(textgrid_dir, '*.TextGrid'))
160
+ if not tg_files:
161
+ return False, "未找到TextGrid文件", {}
162
+
163
+ self._log(f"处理 {len(tg_files)} 个TextGrid文件")
164
+
165
+ # 使用全局计数器避免重复
166
+ pinyin_counts: Dict[str, int] = {}
167
+
168
+ for tg_path in tg_files:
169
+ basename = os.path.basename(tg_path).replace('.TextGrid', '.wav')
170
+ wav_path = os.path.join(slices_dir, basename)
171
+
172
+ if not os.path.exists(wav_path):
173
+ self._log(f"警告: 找不到 {basename}")
174
+ continue
175
+
176
+ tg = textgrid.TextGrid.fromFile(tg_path)
177
+ audio, sr = sf.read(wav_path, dtype='float32')
178
+
179
+ for interval in tg[0]:
180
+ word_text = interval.mark.strip()
181
+
182
+ if not word_text or word_text in ['', 'SP', 'AP', '<unk>', 'spn', 'sil']:
183
+ continue
184
+
185
+ start_time = interval.minTime
186
+ end_time = interval.maxTime
187
+ duration = end_time - start_time
188
+
189
+ chars = list(word_text)
190
+ valid_chars = [c for c in chars if is_valid_char(c, language)]
191
+
192
+ if not valid_chars:
193
+ continue
194
+
195
+ char_duration = duration / len(valid_chars)
196
+
197
+ for i, char in enumerate(valid_chars):
198
+ pinyin = char_to_pinyin(char, language)
199
+ if not pinyin:
200
+ continue
201
+
202
+ char_start = start_time + i * char_duration
203
+ char_end = char_start + char_duration
204
+
205
+ pinyin_dir = os.path.join(segments_dir, pinyin)
206
+ os.makedirs(pinyin_dir, exist_ok=True)
207
+
208
+ # 使用全局计数器
209
+ current_count = pinyin_counts.get(pinyin, 0)
210
+ index = current_count + 1
211
+ pinyin_counts[pinyin] = index
212
+
213
+ start_sample = int(round(char_start * sr))
214
+ end_sample = int(round(char_end * sr))
215
+ segment = audio[start_sample:end_sample]
216
+
217
+ if len(segment) == 0:
218
+ pinyin_counts[pinyin] = current_count # 回退计数
219
+ continue
220
+
221
+ output_path = os.path.join(pinyin_dir, f'{index}.wav')
222
+ sf.write(output_path, segment, sr, subtype='PCM_16')
223
+
224
+ total = sum(pinyin_counts.values())
225
+ self._log(f"提取完成: {len(pinyin_counts)} 个拼音,共 {total} 个片段")
226
+
227
+ return True, f"提取完成: {len(pinyin_counts)} 个拼音", pinyin_counts
228
+
229
+ except Exception as e:
230
+ logger.error(f"提取分词失败: {e}", exc_info=True)
231
+ return False, str(e), {}
232
+
233
+ def _sort_and_export(
234
+ self,
235
+ segments_dir: str,
236
+ export_dir: str,
237
+ max_samples: int,
238
+ naming_rule: str,
239
+ first_naming_rule: str
240
+ ) -> Tuple[bool, str]:
241
+ """排序并导出"""
242
+ try:
243
+ import soundfile as sf
244
+
245
+ os.makedirs(export_dir, exist_ok=True)
246
+
247
+ # 清空已有导出
248
+ for f in os.listdir(export_dir):
249
+ fp = os.path.join(export_dir, f)
250
+ if os.path.isfile(fp):
251
+ os.remove(fp)
252
+
253
+ wav_files = glob.glob(
254
+ os.path.join(segments_dir, '**', '*.wav'),
255
+ recursive=True
256
+ )
257
+
258
+ if not wav_files:
259
+ return False, "未找到分字片段"
260
+
261
+ self._log(f"扫描到 {len(wav_files)} 个片段")
262
+
263
+ # 按拼音分组
264
+ stats: Dict[str, List[Tuple[str, float]]] = {}
265
+ for path in wav_files:
266
+ rel_path = os.path.relpath(path, segments_dir)
267
+ parts = rel_path.split(os.sep)
268
+ if len(parts) >= 2:
269
+ pinyin = parts[0]
270
+ if pinyin not in stats:
271
+ stats[pinyin] = []
272
+ info = sf.info(path)
273
+ stats[pinyin].append((path, info.duration))
274
+
275
+ self._log(f"统计到 {len(stats)} 个拼音")
276
+ self._log(f"命名规则: {naming_rule}")
277
+ if first_naming_rule:
278
+ self._log(f"首个样本规则: {first_naming_rule}")
279
+
280
+ # 按时长排序并导出
281
+ exported = 0
282
+ for pinyin, files in stats.items():
283
+ sorted_files = sorted(files, key=lambda x: -x[1])
284
+ for idx, (src_path, _) in enumerate(sorted_files[:max_samples]):
285
+ # 第0个样本使用特殊规则(如果设置了)
286
+ if idx == 0 and first_naming_rule:
287
+ filename = self._apply_naming_rule(first_naming_rule, pinyin, idx)
288
+ else:
289
+ filename = self._apply_naming_rule(naming_rule, pinyin, idx)
290
+
291
+ dst_path = os.path.join(export_dir, f'{filename}.wav')
292
+ shutil.copyfile(src_path, dst_path)
293
+ exported += 1
294
+
295
+ self._log(f"导出完成: {exported} 个文件")
296
+ return True, f"导出完成: {len(stats)} 个拼音,{exported} 个文件"
297
+
298
+ except Exception as e:
299
+ logger.error(f"排序导出失败: {e}", exc_info=True)
300
+ return False, str(e)
src/gui.py CHANGED
@@ -10,6 +10,7 @@ import threading
10
  import logging
11
  import os
12
  import sys
 
13
 
14
  # 配置日志
15
  logging.basicConfig(
@@ -24,510 +25,251 @@ ctk.set_appearance_mode("System")
24
  ctk.set_default_color_theme("blue")
25
 
26
 
27
- class TextGridToBankFrame(ctk.CTkFrame):
28
- """TextGrid转音频库功能面板"""
29
 
30
- def __init__(self, master, log_callback):
31
- super().__init__(master)
32
- self.log_callback = log_callback
33
- self._setup_ui()
34
 
35
- def _setup_ui(self):
36
- # WAV目录
37
- ctk.CTkLabel(self, text="① WAV文件目录:").grid(row=0, column=0, padx=10, pady=5, sticky="w")
38
- self.wav_dir_var = ctk.StringVar()
39
- ctk.CTkEntry(self, textvariable=self.wav_dir_var, width=400).grid(row=0, column=1, padx=5, pady=5)
40
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_wav_dir).grid(row=0, column=2, padx=5, pady=5)
41
-
42
- # TextGrid目录
43
- ctk.CTkLabel(self, text="② TextGrid目录:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
44
- self.tg_dir_var = ctk.StringVar()
45
- ctk.CTkEntry(self, textvariable=self.tg_dir_var, width=400).grid(row=1, column=1, padx=5, pady=5)
46
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_tg_dir).grid(row=1, column=2, padx=5, pady=5)
47
-
48
- # 输出目录
49
- ctk.CTkLabel(self, text="③ 输出目录:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
50
- self.save_dir_var = ctk.StringVar(value="bank")
51
- ctk.CTkEntry(self, textvariable=self.save_dir_var, width=400).grid(row=2, column=1, padx=5, pady=5)
52
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_save_dir).grid(row=2, column=2, padx=5, pady=5)
53
-
54
- # 执行按钮
55
- ctk.CTkButton(self, text="④ 开始转换", command=self._run).grid(row=3, column=1, pady=20)
56
-
57
- def _browse_wav_dir(self):
58
- path = filedialog.askdirectory(title="选择WAV文件目录")
59
- if path:
60
- self.wav_dir_var.set(path)
61
-
62
- def _browse_tg_dir(self):
63
- path = filedialog.askdirectory(title="选择TextGrid目录")
64
- if path:
65
- self.tg_dir_var.set(path)
66
-
67
- def _browse_save_dir(self):
68
- path = filedialog.askdirectory(title="选择输出目录")
69
- if path:
70
- self.save_dir_var.set(path)
71
-
72
- def _run(self):
73
- wav_dir = self.wav_dir_var.get()
74
- tg_dir = self.tg_dir_var.get()
75
- save_dir = self.save_dir_var.get()
76
-
77
- if not wav_dir or not tg_dir or not save_dir:
78
- messagebox.showerror("错误", "请填写所有目录路径")
79
- return
80
-
81
- threading.Thread(target=self._process, args=(wav_dir, tg_dir, save_dir), daemon=True).start()
82
-
83
- def _process(self, wav_dir, tg_dir, save_dir):
84
- import textgrid
85
- import glob
86
- import audiofile
87
-
88
- self.log_callback("开始TextGrid转音频库...")
89
- logger.info(f"WAV目录: {wav_dir}, TextGrid目录: {tg_dir}, 输出目录: {save_dir}")
90
-
91
- try:
92
- if not os.path.exists(save_dir):
93
- os.makedirs(save_dir)
94
-
95
- tg_files = glob.glob(os.path.join(tg_dir, '*.TextGrid'))
96
- total = len(tg_files)
97
-
98
- for i, path in enumerate(tg_files):
99
- basename = os.path.basename(path).replace('.TextGrid', '.wav')
100
- wav_path = os.path.join(wav_dir, basename)
101
-
102
- if not os.path.exists(wav_path):
103
- self.log_callback(f"警告: 找不到对应WAV文件 {wav_path}")
104
- continue
105
-
106
- tg = textgrid.TextGrid.fromFile(path)
107
- audio, sr = audiofile.read(wav_path)
108
-
109
- for word in tg[0]:
110
- if word.mark in ['SP', 'AP', '']:
111
- continue
112
-
113
- word_text = word.mark.split(':')[0]
114
- word_dir = os.path.join(save_dir, word_text)
115
-
116
- if not os.path.exists(word_dir):
117
- os.makedirs(word_dir)
118
-
119
- index = 1
120
- while True:
121
- filename = os.path.join(word_dir, f'{index}.wav')
122
- if not os.path.exists(filename):
123
- break
124
- index += 1
125
-
126
- start_sample = int(word.minTime * sr)
127
- end_sample = int(word.maxTime * sr)
128
- audiofile.write(filename, audio[start_sample:end_sample], sr)
129
-
130
- self.log_callback(f"进度: {i+1}/{total} - {basename}")
131
-
132
- self.log_callback("TextGrid转音频库完成!")
133
- logger.info("TextGrid转音频库处理完成")
134
- except Exception as e:
135
- self.log_callback(f"错误: {str(e)}")
136
- logger.error(f"处理失败: {e}", exc_info=True)
137
-
138
-
139
- class BankSortFrame(ctk.CTkFrame):
140
- """音频库排序功能面板"""
141
-
142
- def __init__(self, master, log_callback):
143
- super().__init__(master)
144
- self.log_callback = log_callback
145
- self._setup_ui()
146
-
147
- def _setup_ui(self):
148
- # 音频库目录
149
- ctk.CTkLabel(self, text="① 音频库目录:").grid(row=0, column=0, padx=10, pady=5, sticky="w")
150
- self.bank_dir_var = ctk.StringVar(value="bank")
151
- ctk.CTkEntry(self, textvariable=self.bank_dir_var, width=400).grid(row=0, column=1, padx=5, pady=5)
152
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_bank_dir).grid(row=0, column=2, padx=5, pady=5)
153
-
154
- # 最大数量
155
- ctk.CTkLabel(self, text="② 每词最大数量:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
156
- self.max_count_var = ctk.StringVar(value="100")
157
- ctk.CTkEntry(self, textvariable=self.max_count_var, width=100).grid(row=1, column=1, padx=5, pady=5, sticky="w")
158
-
159
- # 执行按钮
160
- ctk.CTkButton(self, text="③ 开始排序", command=self._run).grid(row=2, column=1, pady=20)
161
-
162
- def _browse_bank_dir(self):
163
- path = filedialog.askdirectory(title="选择音频库目录")
164
- if path:
165
- self.bank_dir_var.set(path)
166
-
167
- def _run(self):
168
- bank_dir = self.bank_dir_var.get()
169
- try:
170
- max_count = int(self.max_count_var.get())
171
- except ValueError:
172
- messagebox.showerror("错误", "最大数量必须是整数")
173
- return
174
-
175
- if not bank_dir:
176
- messagebox.showerror("错误", "请选择音频库目录")
177
- return
178
-
179
- threading.Thread(target=self._process, args=(bank_dir, max_count), daemon=True).start()
180
-
181
- def _process(self, bank_dir, max_count):
182
- import glob
183
- import audiofile
184
- import shutil
185
-
186
- self.log_callback("开始音频库排序...")
187
- logger.info(f"音频库目录: {bank_dir}, 最大数量: {max_count}")
188
-
189
- try:
190
- stats = {}
191
- wav_files = glob.glob(os.path.join(bank_dir, '**', '*.wav'), recursive=True)
192
-
193
- self.log_callback(f"扫描到 {len(wav_files)} 个WAV文件")
194
-
195
- for path in wav_files:
196
- rel_path = os.path.relpath(path, bank_dir)
197
- parts = rel_path.split(os.sep)
198
- if len(parts) >= 2:
199
- word = parts[0]
200
- filename = parts[-1]
201
- if word not in stats:
202
- stats[word] = []
203
- stats[word].append((path, audiofile.duration(path)))
204
-
205
- self.log_callback(f"统计到 {len(stats)} 个词条")
206
-
207
- for word in stats:
208
- sorted_files = sorted(stats[word], key=lambda x: -x[1])
209
- for index, (src_path, duration) in enumerate(sorted_files):
210
- if index >= max_count:
211
- break
212
- dst_path = os.path.join(bank_dir, f'{word}_{index}.wav')
213
- shutil.copyfile(src_path, dst_path)
214
- self.log_callback(f"处理词条: {word} ({min(len(sorted_files), max_count)} 个文件)")
215
-
216
- self.log_callback("音频库排序完成!")
217
- logger.info("音频库排序处理完成")
218
- except Exception as e:
219
- self.log_callback(f"错误: {str(e)}")
220
- logger.error(f"处理失败: {e}", exc_info=True)
221
-
222
-
223
- class ModelDownloadFrame(ctk.CTkFrame):
224
- """模型配置功能面板"""
225
-
226
- # Whisper 模型选项
227
  WHISPER_MODELS = {
228
- "whisper-small": {
229
- "name": "openai/whisper-small",
230
- "desc": "小型模型,约500MB,速度快",
231
- "size": "~500MB"
232
- },
233
- "whisper-medium": {
234
- "name": "openai/whisper-medium",
235
- "desc": "中型模型,约1.5GB,精度更高",
236
- "size": "~1.5GB"
237
- }
238
  }
239
 
240
- # 配置文件路径
241
- CONFIG_FILE = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.json")
242
-
243
- def __init__(self, master, log_callback):
244
- super().__init__(master)
245
- self.log_callback = log_callback
246
- self.whisper_pipe = None
247
- self._download_thread = None
248
- self._load_config()
249
- self._setup_ui()
250
-
251
- def _get_default_models_dir(self):
252
- """获取默认模型目录"""
253
- return os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models")
254
 
255
- def _load_config(self):
256
  """加载配置"""
257
- self.config = {
258
  "whisper_model": "whisper-small",
259
- "models_dir": self._get_default_models_dir(),
260
- "mfa_dir": os.path.join(self._get_default_models_dir(), "mfa")
 
 
 
 
 
261
  }
262
-
263
  if os.path.exists(self.CONFIG_FILE):
264
  try:
265
- import json
266
  with open(self.CONFIG_FILE, 'r', encoding='utf-8') as f:
267
- saved = json.load(f)
268
- self.config.update(saved)
269
- logger.info(f"已加载配置: {self.CONFIG_FILE}")
270
  except Exception as e:
271
  logger.warning(f"加载配置失败: {e}")
 
272
 
273
- def _save_config(self):
274
  """保存配置"""
275
  try:
276
- import json
277
  with open(self.CONFIG_FILE, 'w', encoding='utf-8') as f:
278
  json.dump(self.config, f, ensure_ascii=False, indent=2)
279
- logger.info(f"配置已保存: {self.CONFIG_FILE}")
280
  except Exception as e:
281
  logger.error(f"保存配置失败: {e}")
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  def _setup_ui(self):
284
- # ========== Whisper 模型区域 ==========
285
- whisper_label = ctk.CTkLabel(
286
- self,
287
- text="Whisper 语音识别模型",
288
- font=ctk.CTkFont(size=14, weight="bold")
289
  )
290
- whisper_label.grid(row=0, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w")
291
 
292
- # 模型选择
293
  ctk.CTkLabel(self, text="模型版本:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
294
- self.whisper_model_var = ctk.StringVar(value=self.config["whisper_model"])
295
- self.model_dropdown = ctk.CTkComboBox(
296
- self,
297
- values=list(self.WHISPER_MODELS.keys()),
298
- variable=self.whisper_model_var,
299
- width=200,
300
  command=self._on_model_change
301
- )
302
- self.model_dropdown.grid(row=1, column=1, padx=5, pady=5, sticky="w")
303
 
304
- # 模型说明
305
- self.model_desc_label = ctk.CTkLabel(
306
- self,
307
- text=self._get_model_desc(),
308
- text_color="gray"
309
- )
310
  self.model_desc_label.grid(row=1, column=2, padx=10, pady=5, sticky="w")
311
 
312
- # 下载目录
313
- ctk.CTkLabel(self, text="下载目录:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
314
- self.models_dir_var = ctk.StringVar(value=self.config["models_dir"])
315
  ctk.CTkEntry(self, textvariable=self.models_dir_var, width=320).grid(row=2, column=1, padx=5, pady=5, sticky="w")
316
  ctk.CTkButton(self, text="浏览", width=60, command=self._browse_models_dir).grid(row=2, column=2, padx=5, pady=5, sticky="w")
317
 
318
- # Whisper 状态和按钮
319
  ctk.CTkLabel(self, text="状态:").grid(row=3, column=0, padx=10, pady=5, sticky="w")
320
  self.whisper_status = ctk.CTkLabel(self, text="⏳ 未加载", text_color="gray")
321
  self.whisper_status.grid(row=3, column=1, padx=5, pady=5, sticky="w")
322
-
323
- self.whisper_btn = ctk.CTkButton(
324
- self,
325
- text="下载 / 加载模型",
326
- command=self._download_whisper,
327
- width=140
328
- )
329
  self.whisper_btn.grid(row=3, column=2, padx=5, pady=5, sticky="w")
330
 
331
- # 下载进度
332
  self.progress_label = ctk.CTkLabel(self, text="", text_color="gray")
333
  self.progress_label.grid(row=4, column=0, columnspan=3, padx=10, pady=5, sticky="w")
334
 
335
- # ========== MFA 模型区域 ==========
336
- mfa_label = ctk.CTkLabel(
337
- self,
338
- text="MFA 声学模型",
339
- font=ctk.CTkFont(size=14, weight="bold")
 
340
  )
341
- mfa_label.grid(row=5, column=0, columnspan=3, padx=10, pady=(20, 5), sticky="w")
342
 
343
- mfa_desc = ctk.CTkLabel(
344
- self,
345
- text="Montreal Forced Aligner 模型,用于语音对齐",
346
- text_color="gray"
 
 
 
 
 
 
 
 
347
  )
348
- mfa_desc.grid(row=6, column=0, columnspan=3, padx=10, pady=(0, 10), sticky="w")
349
 
350
- # MFA 模型目录
351
- ctk.CTkLabel(self, text="模型目录:").grid(row=7, column=0, padx=10, pady=5, sticky="w")
352
- self.mfa_dir_var = ctk.StringVar(value=self.config["mfa_dir"])
353
- ctk.CTkEntry(self, textvariable=self.mfa_dir_var, width=320).grid(row=7, column=1, padx=5, pady=5, sticky="w")
354
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa_dir).grid(row=7, column=2, padx=5, pady=5)
355
 
356
- # MFA 语言选择
357
- ctk.CTkLabel(self, text="选择语言:").grid(row=8, column=0, padx=10, pady=5, sticky="w")
358
  self.mfa_lang_var = ctk.StringVar(value="mandarin")
359
- self.mfa_lang_dropdown = ctk.CTkComboBox(
360
- self,
361
- values=["mandarin", "japanese"],
362
- variable=self.mfa_lang_var,
363
- width=200,
364
  command=self._on_mfa_lang_change
365
- )
366
- self.mfa_lang_dropdown.grid(row=8, column=1, padx=5, pady=5, sticky="w")
367
-
368
  self.mfa_lang_desc = ctk.CTkLabel(self, text="中文 (普通话)", text_color="gray")
369
- self.mfa_lang_desc.grid(row=8, column=2, padx=5, pady=5, sticky="w")
370
 
371
- # MFA 下载按钮和状态
372
- ctk.CTkLabel(self, text="状态:").grid(row=9, column=0, padx=10, pady=5, sticky="w")
373
  self.mfa_status = ctk.CTkLabel(self, text="⏳ 未下载", text_color="gray")
374
- self.mfa_status.grid(row=9, column=1, padx=5, pady=5, sticky="w")
375
-
376
- self.mfa_download_btn = ctk.CTkButton(
377
- self,
378
- text="下载模型",
379
- command=self._download_mfa_models,
380
- width=140
381
- )
382
- self.mfa_download_btn.grid(row=9, column=2, padx=5, pady=5, sticky="w")
383
 
384
- # MFA 文件列表
385
- ctk.CTkLabel(self, text="已有文件:").grid(row=10, column=0, padx=10, pady=(10, 5), sticky="nw")
386
- self.mfa_files_text = ctk.CTkTextbox(self, height=70, width=400)
387
- self.mfa_files_text.grid(row=10, column=1, columnspan=2, padx=5, pady=(10, 5), sticky="w")
388
- self.mfa_files_text.insert("end", "选择目录后显示文件列表")
389
- self.mfa_files_text.configure(state="disabled")
390
-
391
- # 初始扫描
392
- self._scan_mfa_dir()
393
-
394
  def _get_model_desc(self):
395
- """获取当前选中模型的描述"""
396
- model_key = self.whisper_model_var.get()
397
- info = self.WHISPER_MODELS.get(model_key, {})
398
- return f"{info.get('desc', '')} ({info.get('size', '')})"
399
 
400
  def _on_model_change(self, choice):
401
- """模型选择变更"""
402
  self.model_desc_label.configure(text=self._get_model_desc())
403
- self.config["whisper_model"] = choice
404
- self._save_config()
405
- # 重置状态
406
  self.whisper_status.configure(text="⏳ 未加载", text_color="gray")
407
  self.whisper_pipe = None
408
 
409
  def _browse_models_dir(self):
410
- """浏览选择模型下载目录"""
411
  path = filedialog.askdirectory(title="选择模型下载目录")
412
  if path:
413
  self.models_dir_var.set(path)
414
- self.config["models_dir"] = path
415
- self._save_config()
416
 
417
  def _browse_mfa_dir(self):
418
- """浏览选择 MFA 模型目录"""
419
  path = filedialog.askdirectory(title="选择 MFA 模型目录")
420
  if path:
421
  self.mfa_dir_var.set(path)
422
- self.config["mfa_dir"] = path
423
- self._save_config()
424
- self._scan_mfa_dir()
425
 
426
  def _on_mfa_lang_change(self, choice):
427
- """MFA 语言选择变更"""
428
  from src.mfa_model_downloader import get_available_languages
429
- langs = get_available_languages()
430
- self.mfa_lang_desc.configure(text=langs.get(choice, ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  def _download_mfa_models(self):
433
- """下载 MFA 模型"""
434
  if self._download_thread and self._download_thread.is_alive():
435
  return
436
-
437
  self.mfa_download_btn.configure(state="disabled")
438
  self.mfa_status.configure(text="⏳ 下载中...", text_color="gray")
439
  self._download_thread = threading.Thread(target=self._do_download_mfa, daemon=True)
440
  self._download_thread.start()
441
 
442
  def _do_download_mfa(self):
443
- """执行 MFA 模型下载(后台线程)"""
444
  from src.mfa_model_downloader import download_language_models
445
-
446
  language = self.mfa_lang_var.get()
447
  output_dir = self.mfa_dir_var.get()
448
-
449
- # 确保目录存在
450
- if not os.path.exists(output_dir):
451
- os.makedirs(output_dir)
452
-
453
  self.log_callback(f"开始下载 MFA 模型: {language}")
454
-
455
  success, acoustic_path, dict_path = download_language_models(
456
- language=language,
457
- output_dir=output_dir,
458
- progress_callback=self.log_callback
459
  )
460
-
461
  if success:
462
  self.after(0, lambda: self.mfa_status.configure(text="✅ 已下载", text_color="green"))
463
  self.log_callback(f"声学模型: {acoustic_path}")
464
  self.log_callback(f"字典文件: {dict_path}")
465
  else:
466
  self.after(0, lambda: self.mfa_status.configure(text="❌ 下载失败", text_color="red"))
467
-
468
  self.after(0, lambda: self.mfa_download_btn.configure(state="normal"))
469
- self.after(0, self._scan_mfa_dir)
470
-
471
- def _scan_mfa_dir(self):
472
- """扫描 MFA 模型目录"""
473
- mfa_dir = self.mfa_dir_var.get()
474
-
475
- self.mfa_files_text.configure(state="normal")
476
- self.mfa_files_text.delete("1.0", "end")
477
-
478
- if not os.path.exists(mfa_dir):
479
- self.mfa_files_text.insert("end", "目录不存在")
480
- else:
481
- files = []
482
- for f in os.listdir(mfa_dir):
483
- if f.endswith(('.zip', '.dict', '.txt')):
484
- fpath = os.path.join(mfa_dir, f)
485
- size = os.path.getsize(fpath)
486
- size_str = f"{size / 1024 / 1024:.1f}MB" if size > 1024 * 1024 else f"{size / 1024:.0f}KB"
487
- files.append(f"• {f} ({size_str})")
488
-
489
- if files:
490
- self.mfa_files_text.insert("end", "\n".join(files))
491
- else:
492
- self.mfa_files_text.insert("end", "目录为空,请手动放入 MFA 模型文件")
493
-
494
- self.mfa_files_text.configure(state="disabled")
495
 
496
  def _download_whisper(self):
497
- """下载/加载 Whisper 模型"""
498
  if self._download_thread and self._download_thread.is_alive():
499
  return
500
-
501
  self.whisper_btn.configure(state="disabled")
502
  self.whisper_status.configure(text="⏳ 加载中...", text_color="gray")
503
  self._download_thread = threading.Thread(target=self._do_download_whisper, daemon=True)
504
  self._download_thread.start()
505
 
506
  def _do_download_whisper(self):
507
- """执行 Whisper 模型下载(后台线程)"""
508
  try:
509
  self._update_progress("正在加载 transformers 库...")
510
  from transformers import pipeline
511
  import torch
512
 
513
  model_key = self.whisper_model_var.get()
514
- model_name = self.WHISPER_MODELS[model_key]["name"]
515
  cache_dir = os.path.join(self.models_dir_var.get(), "whisper")
516
-
517
- # 确保目录存在
518
- if not os.path.exists(cache_dir):
519
- os.makedirs(cache_dir)
520
 
521
  self._update_progress(f"正在下载/加载 {model_key}...")
522
  self.log_callback(f"开始加载 Whisper 模型: {model_name}")
523
- self.log_callback(f"缓存目录: {cache_dir}")
524
- logger.info(f"加载 Whisper 模型: {model_name}, 缓存目录: {cache_dir}")
525
 
526
- # 设置环境变量指定缓存目录
527
  os.environ["HF_HOME"] = cache_dir
528
  os.environ["TRANSFORMERS_CACHE"] = cache_dir
529
 
530
- # 加载模型
531
  self.whisper_pipe = pipeline(
532
  "automatic-speech-recognition",
533
  model=model_name,
@@ -540,185 +282,872 @@ class ModelDownloadFrame(ctk.CTkFrame):
540
  self.after(0, lambda: self.whisper_status.configure(text="✅ 已就绪", text_color="green"))
541
  self.after(0, lambda: self.whisper_btn.configure(state="normal", text="重新加载"))
542
  self.log_callback("Whisper 模型加载完成")
543
- logger.info("Whisper 模型加载成功")
544
-
545
  except Exception as e:
546
- error_msg = str(e)
547
  self._update_progress("")
548
  self.after(0, lambda: self.whisper_status.configure(text="❌ 加载失败", text_color="red"))
549
  self.after(0, lambda: self.whisper_btn.configure(state="normal"))
550
- self.log_callback(f"Whisper 模型加载失败: {error_msg}")
551
  logger.error(f"Whisper 模型加载失败: {e}", exc_info=True)
552
 
553
  def _update_progress(self, text):
554
- """更新进度文本(线程安全)"""
555
  self.after(0, lambda: self.progress_label.configure(text=text))
556
 
557
  def get_whisper_pipeline(self):
558
- """获取 Whisper pipeline(供其他模块调用)"""
559
  return self.whisper_pipe
560
 
 
 
 
561
  def get_mfa_dir(self):
562
- """获取 MFA 模型目录路径(供其他模块调用)"""
563
  return self.mfa_dir_var.get()
 
 
 
564
 
565
 
566
- class MakeDatasetFrame(ctk.CTkFrame):
567
- """批量制作数据集功能"""
568
 
569
- def __init__(self, master, log_callback):
570
  super().__init__(master)
571
  self.log_callback = log_callback
 
 
572
  self._is_running = False
573
  self._setup_ui()
574
  self._check_mfa_status()
575
 
576
  def _setup_ui(self):
577
- # MFA 状态提示
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  self.mfa_status_label = ctk.CTkLabel(
579
- self,
580
- text="⏳ 检查 MFA 环境...",
581
  font=ctk.CTkFont(size=12)
582
  )
583
- self.mfa_status_label.grid(row=0, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- # 数据集原始目录
586
- ctk.CTkLabel(self, text="① 切片及LAB目录:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
587
- self.raw_dir_var = ctk.StringVar()
588
- ctk.CTkEntry(self, textvariable=self.raw_dir_var, width=400).grid(row=1, column=1, padx=5, pady=5)
589
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_raw_dir).grid(row=1, column=2, padx=5, pady=5)
590
 
591
- # 输出目录
592
- ctk.CTkLabel(self, text="② TextGrid输出目录:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
593
- self.output_dir_var = ctk.StringVar()
594
- ctk.CTkEntry(self, textvariable=self.output_dir_var, width=400).grid(row=2, column=1, padx=5, pady=5)
595
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_output_dir).grid(row=2, column=2, padx=5, pady=5)
596
-
597
- # 字典路径
598
- ctk.CTkLabel(self, text="③ 字典文件:").grid(row=3, column=0, padx=10, pady=5, sticky="w")
599
- self.dict_path_var = ctk.StringVar(value="models/mfa/mandarin_china_mfa.dict")
600
- ctk.CTkEntry(self, textvariable=self.dict_path_var, width=400).grid(row=3, column=1, padx=5, pady=5)
601
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_dict).grid(row=3, column=2, padx=5, pady=5)
602
-
603
- # MFA模型路径
604
- ctk.CTkLabel(self, text="④ MFA模型文件:").grid(row=4, column=0, padx=10, pady=5, sticky="w")
605
- self.mfa_model_var = ctk.StringVar(value="models/mfa/mandarin_mfa.zip")
606
- ctk.CTkEntry(self, textvariable=self.mfa_model_var, width=400).grid(row=4, column=1, padx=5, pady=5)
607
- ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa).grid(row=4, column=2, padx=5, pady=5)
608
-
609
- # 选项
610
- options_frame = ctk.CTkFrame(self)
611
- options_frame.grid(row=5, column=0, columnspan=3, padx=10, pady=10, sticky="w")
612
-
613
- self.single_speaker_var = ctk.BooleanVar(value=True)
614
- ctk.CTkCheckBox(
615
- options_frame,
616
- text="单说话人模式",
617
- variable=self.single_speaker_var
618
- ).pack(side="left", padx=10)
619
-
620
- self.clean_var = ctk.BooleanVar(value=True)
621
- ctk.CTkCheckBox(
622
- options_frame,
623
- text="清理旧缓存",
624
- variable=self.clean_var
625
- ).pack(side="left", padx=10)
626
-
627
- # 执行按钮
628
- self.run_btn = ctk.CTkButton(self, text="⑤ 开始对齐", command=self._run)
629
- self.run_btn.grid(row=6, column=1, pady=20)
630
 
631
  def _check_mfa_status(self):
632
- """检查 MFA 环境状态"""
633
  from src.mfa_runner import check_mfa_available
634
-
635
  if check_mfa_available():
636
- self.mfa_status_label.configure(
637
- text="✅ MFA 外挂环境已就绪 (tools/mfa_engine)",
638
- text_color="green"
639
- )
640
  else:
641
- self.mfa_status_label.configure(
642
- text="❌ MFA 外挂环境不可用,请检查 tools/mfa_engine 目录",
643
- text_color="red"
644
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
- def _browse_raw_dir(self):
647
- path = filedialog.askdirectory(title="选择切片及LAB目录")
 
 
 
 
 
 
 
 
648
  if path:
649
- self.raw_dir_var.set(path)
650
 
651
  def _browse_output_dir(self):
652
- path = filedialog.askdirectory(title="选择TextGrid输出目录")
653
  if path:
654
  self.output_dir_var.set(path)
 
655
 
656
- def _browse_dict(self):
657
- path = filedialog.askopenfilename(
658
- title="选择字典文件",
659
- filetypes=[("字典文件", "*.dict *.txt"), ("所有文件", "*.*")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  )
661
- if path:
662
- self.dict_path_var.set(path)
663
 
664
- def _browse_mfa(self):
665
- path = filedialog.askopenfilename(
666
- title="选择MFA模型",
667
- filetypes=[("ZIP文件", "*.zip"), ("所有文件", "*.*")]
668
- )
669
- if path:
670
- self.mfa_model_var.set(path)
671
 
672
- def _run(self):
673
  if self._is_running:
674
  return
 
 
 
 
 
 
 
 
 
 
675
 
676
- raw_dir = self.raw_dir_var.get()
677
- output_dir = self.output_dir_var.get()
678
- dict_path = self.dict_path_var.get()
679
- mfa_model = self.mfa_model_var.get()
 
 
 
 
 
680
 
681
- if not raw_dir or not output_dir:
682
- messagebox.showerror("错误", "请填写输入目录和输出目录")
 
 
 
683
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  self._is_running = True
686
- self.run_btn.configure(state="disabled", text="对齐中...")
 
 
 
 
 
 
687
 
688
- threading.Thread(
689
- target=self._process,
690
- args=(raw_dir, output_dir, dict_path, mfa_model),
691
- daemon=True
692
- ).start()
 
 
693
 
694
- def _process(self, raw_dir, output_dir, dict_path, mfa_model):
695
- """执行 MFA 对齐(后台线程)"""
696
- from src.mfa_runner import run_mfa_alignment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
- self.log_callback("=" * 50)
699
- self.log_callback("开始 MFA 对齐任务")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
- success, message = run_mfa_alignment(
702
- corpus_dir=raw_dir,
703
- output_dir=output_dir,
704
- dict_path=dict_path if dict_path else None,
705
- model_path=mfa_model if mfa_model else None,
706
- single_speaker=self.single_speaker_var.get(),
707
- clean=self.clean_var.get(),
708
- progress_callback=self.log_callback
709
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
  if success:
712
- self.log_callback("✅ MFA 对齐任务完成!")
713
- self.log_callback(f"TextGrid 文件已输出到: {output_dir}")
714
  else:
715
- self.log_callback(f"❌ MFA 对齐失败: {message}")
716
-
717
  self.log_callback("=" * 50)
718
 
719
- # 恢复按钮状态
720
- self.after(0, lambda: self.run_btn.configure(state="normal", text="⑤ 开始对齐"))
 
721
  self._is_running = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
 
724
  class App(ctk.CTk):
@@ -726,55 +1155,58 @@ class App(ctk.CTk):
726
 
727
  def __init__(self):
728
  super().__init__()
729
-
730
  self.title("语音数据集处理工具")
731
- self.geometry("700x600")
732
- self.minsize(600, 500)
733
 
 
734
  self._setup_ui()
735
  logger.info("应用启动")
736
 
737
  def _setup_ui(self):
738
- # 标签页
739
  self.tabview = ctk.CTkTabview(self)
740
  self.tabview.pack(fill="both", expand=True, padx=10, pady=10)
741
 
742
- # 添加标签页(按工作流程顺序排列)
743
- tab1 = self.tabview.add("1. 模型下载")
744
- tab2 = self.tabview.add("2. 批量制作数据集")
745
- tab3 = self.tabview.add("3. TextGrid转音频库")
746
- tab4 = self.tabview.add("4. 音频库排序")
747
 
748
- # 各功能面板
749
- self.download_frame = ModelDownloadFrame(tab1, self._log)
750
  self.download_frame.pack(fill="both", expand=True, padx=5, pady=5)
751
 
752
- self.dataset_frame = MakeDatasetFrame(tab2, self._log)
753
- self.dataset_frame.pack(fill="both", expand=True, padx=5, pady=5)
754
 
755
- self.tg_frame = TextGridToBankFrame(tab3, self._log)
756
- self.tg_frame.pack(fill="both", expand=True, padx=5, pady=5)
757
 
758
- self.sort_frame = BankSortFrame(tab4, self._log)
759
- self.sort_frame.pack(fill="both", expand=True, padx=5, pady=5)
760
 
761
- # 日志区域
762
- log_frame = ctk.CTkFrame(self)
763
- log_frame.pack(fill="x", padx=10, pady=(0, 10))
764
-
765
- ctk.CTkLabel(log_frame, text="日志输出:").pack(anchor="w", padx=5, pady=2)
766
-
767
- self.log_text = ctk.CTkTextbox(log_frame, height=150)
768
  self.log_text.pack(fill="x", padx=5, pady=5)
 
 
 
 
 
 
 
 
 
 
 
769
 
770
  def _log(self, message):
771
- """添加日志消息"""
772
  self.log_text.insert("end", f"{message}\n")
773
  self.log_text.see("end")
774
 
775
 
776
  def main():
777
- """程序入口"""
778
  app = App()
779
  app.mainloop()
780
 
 
10
  import logging
11
  import os
12
  import sys
13
+ import json
14
 
15
  # 配置日志
16
  logging.basicConfig(
 
25
  ctk.set_default_color_theme("blue")
26
 
27
 
28
+ class ConfigManager:
29
+ """配置管理器"""
30
 
31
+ CONFIG_FILE = os.path.join(
32
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
33
+ "config.json"
34
+ )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  WHISPER_MODELS = {
37
+ "whisper-small": {"name": "openai/whisper-small", "desc": "小型模型 (~500MB)", "size": "~500MB"},
38
+ "whisper-medium": {"name": "openai/whisper-medium", "desc": "中型模型 (~1.5GB)", "size": "~1.5GB"}
 
 
 
 
 
 
 
 
39
  }
40
 
41
+ def __init__(self):
42
+ self._default_models_dir = os.path.join(
43
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
44
+ "models"
45
+ )
46
+ self.config = self._load()
 
 
 
 
 
 
 
 
47
 
48
+ def _load(self) -> dict:
49
  """加载配置"""
50
+ default = {
51
  "whisper_model": "whisper-small",
52
+ "models_dir": self._default_models_dir,
53
+ "mfa_dir": os.path.join(self._default_models_dir, "mfa"),
54
+ "bank_dir": os.path.join(
55
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
56
+ "bank"
57
+ ),
58
+ "show_log": False # 默认关闭日志
59
  }
 
60
  if os.path.exists(self.CONFIG_FILE):
61
  try:
 
62
  with open(self.CONFIG_FILE, 'r', encoding='utf-8') as f:
63
+ default.update(json.load(f))
 
 
64
  except Exception as e:
65
  logger.warning(f"加载配置失败: {e}")
66
+ return default
67
 
68
+ def save(self):
69
  """保存配置"""
70
  try:
 
71
  with open(self.CONFIG_FILE, 'w', encoding='utf-8') as f:
72
  json.dump(self.config, f, ensure_ascii=False, indent=2)
 
73
  except Exception as e:
74
  logger.error(f"保存配置失败: {e}")
75
 
76
+ def get(self, key: str, default=None):
77
+ return self.config.get(key, default)
78
+
79
+ def set(self, key: str, value):
80
+ self.config[key] = value
81
+ self.save()
82
+
83
+
84
+ class ModelDownloadFrame(ctk.CTkFrame):
85
+ """模型配置功能面板"""
86
+
87
+ def __init__(self, master, log_callback, config: ConfigManager):
88
+ super().__init__(master)
89
+ self.log_callback = log_callback
90
+ self.config = config
91
+ self.whisper_pipe = None
92
+ self._download_thread = None
93
+ self._setup_ui()
94
+
95
  def _setup_ui(self):
96
+ # Whisper 模型区域
97
+ ctk.CTkLabel(self, text="Whisper 语音识别模型", font=ctk.CTkFont(size=14, weight="bold")).grid(
98
+ row=0, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w"
 
 
99
  )
 
100
 
 
101
  ctk.CTkLabel(self, text="模型版本:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
102
+ self.whisper_model_var = ctk.StringVar(value=self.config.get("whisper_model"))
103
+ ctk.CTkComboBox(
104
+ self, values=list(ConfigManager.WHISPER_MODELS.keys()),
105
+ variable=self.whisper_model_var, width=200,
 
 
106
  command=self._on_model_change
107
+ ).grid(row=1, column=1, padx=5, pady=5, sticky="w")
 
108
 
109
+ self.model_desc_label = ctk.CTkLabel(self, text=self._get_model_desc(), text_color="gray")
 
 
 
 
 
110
  self.model_desc_label.grid(row=1, column=2, padx=10, pady=5, sticky="w")
111
 
112
+ ctk.CTkLabel(self, text="模型目录:").grid(row=2, column=0, padx=10, pady=5, sticky="w")
113
+ self.models_dir_var = ctk.StringVar(value=self.config.get("models_dir"))
 
114
  ctk.CTkEntry(self, textvariable=self.models_dir_var, width=320).grid(row=2, column=1, padx=5, pady=5, sticky="w")
115
  ctk.CTkButton(self, text="浏览", width=60, command=self._browse_models_dir).grid(row=2, column=2, padx=5, pady=5, sticky="w")
116
 
 
117
  ctk.CTkLabel(self, text="状态:").grid(row=3, column=0, padx=10, pady=5, sticky="w")
118
  self.whisper_status = ctk.CTkLabel(self, text="⏳ 未加载", text_color="gray")
119
  self.whisper_status.grid(row=3, column=1, padx=5, pady=5, sticky="w")
120
+ self.whisper_btn = ctk.CTkButton(self, text="下载 / 加载模型", command=self._download_whisper, width=140)
 
 
 
 
 
 
121
  self.whisper_btn.grid(row=3, column=2, padx=5, pady=5, sticky="w")
122
 
 
123
  self.progress_label = ctk.CTkLabel(self, text="", text_color="gray")
124
  self.progress_label.grid(row=4, column=0, columnspan=3, padx=10, pady=5, sticky="w")
125
 
126
+ # Silero VAD 区域
127
+ ctk.CTkLabel(self, text="Silero VAD 模型", font=ctk.CTkFont(size=14, weight="bold")).grid(
128
+ row=5, column=0, columnspan=3, padx=10, pady=(20, 5), sticky="w"
129
+ )
130
+ ctk.CTkLabel(self, text="用于语音活动检测和音频切片", text_color="gray").grid(
131
+ row=6, column=0, columnspan=3, padx=10, pady=(0, 10), sticky="w"
132
  )
 
133
 
134
+ ctk.CTkLabel(self, text="状态:").grid(row=7, column=0, padx=10, pady=5, sticky="w")
135
+ self.vad_status = ctk.CTkLabel(self, text="⏳ 未下载", text_color="gray")
136
+ self.vad_status.grid(row=7, column=1, padx=5, pady=5, sticky="w")
137
+ self.vad_btn = ctk.CTkButton(self, text="下载模型", command=self._download_vad, width=140)
138
+ self.vad_btn.grid(row=7, column=2, padx=5, pady=5, sticky="w")
139
+
140
+ # MFA 模型区域
141
+ ctk.CTkLabel(self, text="MFA 声学模型", font=ctk.CTkFont(size=14, weight="bold")).grid(
142
+ row=8, column=0, columnspan=3, padx=10, pady=(20, 5), sticky="w"
143
+ )
144
+ ctk.CTkLabel(self, text="Montreal Forced Aligner 模型,用于语音对齐", text_color="gray").grid(
145
+ row=9, column=0, columnspan=3, padx=10, pady=(0, 10), sticky="w"
146
  )
 
147
 
148
+ ctk.CTkLabel(self, text="模型目录:").grid(row=10, column=0, padx=10, pady=5, sticky="w")
149
+ self.mfa_dir_var = ctk.StringVar(value=self.config.get("mfa_dir"))
150
+ ctk.CTkEntry(self, textvariable=self.mfa_dir_var, width=320).grid(row=10, column=1, padx=5, pady=5, sticky="w")
151
+ ctk.CTkButton(self, text="浏览", width=60, command=self._browse_mfa_dir).grid(row=10, column=2, padx=5, pady=5)
 
152
 
153
+ ctk.CTkLabel(self, text="选择语言:").grid(row=11, column=0, padx=10, pady=5, sticky="w")
 
154
  self.mfa_lang_var = ctk.StringVar(value="mandarin")
155
+ ctk.CTkComboBox(
156
+ self, values=["mandarin", "japanese"],
157
+ variable=self.mfa_lang_var, width=200,
 
 
158
  command=self._on_mfa_lang_change
159
+ ).grid(row=11, column=1, padx=5, pady=5, sticky="w")
 
 
160
  self.mfa_lang_desc = ctk.CTkLabel(self, text="中文 (普通话)", text_color="gray")
161
+ self.mfa_lang_desc.grid(row=11, column=2, padx=5, pady=5, sticky="w")
162
 
163
+ ctk.CTkLabel(self, text="状态:").grid(row=12, column=0, padx=10, pady=5, sticky="w")
 
164
  self.mfa_status = ctk.CTkLabel(self, text="⏳ 未下载", text_color="gray")
165
+ self.mfa_status.grid(row=12, column=1, padx=5, pady=5, sticky="w")
166
+ self.mfa_download_btn = ctk.CTkButton(self, text="下载模型", command=self._download_mfa_models, width=140)
167
+ self.mfa_download_btn.grid(row=12, column=2, padx=5, pady=5, sticky="w")
 
 
 
 
 
 
168
 
169
+ self._check_vad_status()
170
+
 
 
 
 
 
 
 
 
171
  def _get_model_desc(self):
172
+ info = ConfigManager.WHISPER_MODELS.get(self.whisper_model_var.get(), {})
173
+ return info.get('desc', '')
 
 
174
 
175
  def _on_model_change(self, choice):
 
176
  self.model_desc_label.configure(text=self._get_model_desc())
177
+ self.config.set("whisper_model", choice)
 
 
178
  self.whisper_status.configure(text="⏳ 未加载", text_color="gray")
179
  self.whisper_pipe = None
180
 
181
  def _browse_models_dir(self):
 
182
  path = filedialog.askdirectory(title="选择模型下载目录")
183
  if path:
184
  self.models_dir_var.set(path)
185
+ self.config.set("models_dir", path)
 
186
 
187
  def _browse_mfa_dir(self):
 
188
  path = filedialog.askdirectory(title="选择 MFA 模型目录")
189
  if path:
190
  self.mfa_dir_var.set(path)
191
+ self.config.set("mfa_dir", path)
 
 
192
 
193
  def _on_mfa_lang_change(self, choice):
 
194
  from src.mfa_model_downloader import get_available_languages
195
+ self.mfa_lang_desc.configure(text=get_available_languages().get(choice, ""))
196
+
197
+ def _check_vad_status(self):
198
+ from src.silero_vad_downloader import is_vad_model_downloaded
199
+ if is_vad_model_downloaded(self.config.get("models_dir")):
200
+ self.vad_status.configure(text="✅ 已下载", text_color="green")
201
+ else:
202
+ self.vad_status.configure(text="⏳ 未下载", text_color="gray")
203
+
204
+ def _download_vad(self):
205
+ if self._download_thread and self._download_thread.is_alive():
206
+ return
207
+ self.vad_btn.configure(state="disabled")
208
+ self.vad_status.configure(text="⏳ 下载中...", text_color="gray")
209
+ self._download_thread = threading.Thread(target=self._do_download_vad, daemon=True)
210
+ self._download_thread.start()
211
+
212
+ def _do_download_vad(self):
213
+ from src.silero_vad_downloader import download_silero_vad
214
+ self.log_callback("开始下载 Silero VAD 模型...")
215
+ success, result = download_silero_vad(self.config.get("models_dir"), self.log_callback)
216
+ if success:
217
+ self.after(0, lambda: self.vad_status.configure(text="✅ 已下载", text_color="green"))
218
+ self.log_callback(f"VAD 模型已保存: {result}")
219
+ else:
220
+ self.after(0, lambda: self.vad_status.configure(text="❌ 下载失败", text_color="red"))
221
+ self.after(0, lambda: self.vad_btn.configure(state="normal"))
222
 
223
  def _download_mfa_models(self):
 
224
  if self._download_thread and self._download_thread.is_alive():
225
  return
 
226
  self.mfa_download_btn.configure(state="disabled")
227
  self.mfa_status.configure(text="⏳ 下载中...", text_color="gray")
228
  self._download_thread = threading.Thread(target=self._do_download_mfa, daemon=True)
229
  self._download_thread.start()
230
 
231
  def _do_download_mfa(self):
 
232
  from src.mfa_model_downloader import download_language_models
 
233
  language = self.mfa_lang_var.get()
234
  output_dir = self.mfa_dir_var.get()
235
+ os.makedirs(output_dir, exist_ok=True)
 
 
 
 
236
  self.log_callback(f"开始下载 MFA 模型: {language}")
 
237
  success, acoustic_path, dict_path = download_language_models(
238
+ language=language, output_dir=output_dir, progress_callback=self.log_callback
 
 
239
  )
 
240
  if success:
241
  self.after(0, lambda: self.mfa_status.configure(text="✅ 已下载", text_color="green"))
242
  self.log_callback(f"声学模型: {acoustic_path}")
243
  self.log_callback(f"字典文件: {dict_path}")
244
  else:
245
  self.after(0, lambda: self.mfa_status.configure(text="❌ 下载失败", text_color="red"))
 
246
  self.after(0, lambda: self.mfa_download_btn.configure(state="normal"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  def _download_whisper(self):
 
249
  if self._download_thread and self._download_thread.is_alive():
250
  return
 
251
  self.whisper_btn.configure(state="disabled")
252
  self.whisper_status.configure(text="⏳ 加载中...", text_color="gray")
253
  self._download_thread = threading.Thread(target=self._do_download_whisper, daemon=True)
254
  self._download_thread.start()
255
 
256
  def _do_download_whisper(self):
 
257
  try:
258
  self._update_progress("正在加载 transformers 库...")
259
  from transformers import pipeline
260
  import torch
261
 
262
  model_key = self.whisper_model_var.get()
263
+ model_name = ConfigManager.WHISPER_MODELS[model_key]["name"]
264
  cache_dir = os.path.join(self.models_dir_var.get(), "whisper")
265
+ os.makedirs(cache_dir, exist_ok=True)
 
 
 
266
 
267
  self._update_progress(f"正在下载/加载 {model_key}...")
268
  self.log_callback(f"开始加载 Whisper 模型: {model_name}")
 
 
269
 
 
270
  os.environ["HF_HOME"] = cache_dir
271
  os.environ["TRANSFORMERS_CACHE"] = cache_dir
272
 
 
273
  self.whisper_pipe = pipeline(
274
  "automatic-speech-recognition",
275
  model=model_name,
 
282
  self.after(0, lambda: self.whisper_status.configure(text="✅ 已就绪", text_color="green"))
283
  self.after(0, lambda: self.whisper_btn.configure(state="normal", text="重新加载"))
284
  self.log_callback("Whisper 模型加载完成")
 
 
285
  except Exception as e:
 
286
  self._update_progress("")
287
  self.after(0, lambda: self.whisper_status.configure(text="❌ 加载失败", text_color="red"))
288
  self.after(0, lambda: self.whisper_btn.configure(state="normal"))
289
+ self.log_callback(f"Whisper 模型加载失败: {e}")
290
  logger.error(f"Whisper 模型加载失败: {e}", exc_info=True)
291
 
292
  def _update_progress(self, text):
 
293
  self.after(0, lambda: self.progress_label.configure(text=text))
294
 
295
  def get_whisper_pipeline(self):
 
296
  return self.whisper_pipe
297
 
298
+ def get_models_dir(self):
299
+ return self.models_dir_var.get()
300
+
301
  def get_mfa_dir(self):
 
302
  return self.mfa_dir_var.get()
303
+
304
+ def get_whisper_model_name(self):
305
+ return ConfigManager.WHISPER_MODELS[self.whisper_model_var.get()]["name"]
306
 
307
 
308
+ class MakeVoiceBankFrame(ctk.CTkFrame):
309
+ """制作音源页 - 简化工作流"""
310
 
311
+ def __init__(self, master, log_callback, config: ConfigManager, model_frame: ModelDownloadFrame):
312
  super().__init__(master)
313
  self.log_callback = log_callback
314
+ self.config = config
315
+ self.model_frame = model_frame
316
  self._is_running = False
317
  self._setup_ui()
318
  self._check_mfa_status()
319
 
320
  def _setup_ui(self):
321
+ self.scroll_frame = ctk.CTkScrollableFrame(self)
322
+ self.scroll_frame.pack(fill="both", expand=True, padx=5, pady=5)
323
+ row = 0
324
+
325
+ # ========== 基本设置 ==========
326
+ ctk.CTkLabel(
327
+ self.scroll_frame, text="基本设置",
328
+ font=ctk.CTkFont(size=16, weight="bold")
329
+ ).grid(row=row, column=0, columnspan=3, padx=10, pady=(10, 15), sticky="w")
330
+ row += 1
331
+
332
+ # 音源名称
333
+ ctk.CTkLabel(self.scroll_frame, text="音源名称:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
334
+ self.source_name_var = ctk.StringVar(value="my_voice")
335
+ ctk.CTkEntry(self.scroll_frame, textvariable=self.source_name_var, width=200).grid(
336
+ row=row, column=1, padx=5, pady=5, sticky="w"
337
+ )
338
+ ctk.CTkLabel(self.scroll_frame, text="输出到 bank/[音源名称]/", text_color="gray").grid(
339
+ row=row, column=2, padx=5, pady=5, sticky="w"
340
+ )
341
+ row += 1
342
+
343
+ # 输入音频
344
+ ctk.CTkLabel(self.scroll_frame, text="输入音频:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
345
+ self.input_audio_var = ctk.StringVar()
346
+ ctk.CTkEntry(self.scroll_frame, textvariable=self.input_audio_var, width=300).grid(
347
+ row=row, column=1, padx=5, pady=5
348
+ )
349
+ btn_frame = ctk.CTkFrame(self.scroll_frame, fg_color="transparent")
350
+ btn_frame.grid(row=row, column=2, padx=5, pady=5)
351
+ ctk.CTkButton(btn_frame, text="文件", width=50, command=self._browse_input_file,
352
+ fg_color="#5a6a7a", hover_color="#4a5a6a").pack(side="left", padx=2)
353
+ ctk.CTkButton(btn_frame, text="文件夹", width=60, command=self._browse_input_dir,
354
+ fg_color="#5a6a7a", hover_color="#4a5a6a").pack(side="left", padx=2)
355
+ row += 1
356
+
357
+ # 输出目录
358
+ ctk.CTkLabel(self.scroll_frame, text="输出目录:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
359
+ self.output_dir_var = ctk.StringVar(value=self.config.get("bank_dir", "bank"))
360
+ ctk.CTkEntry(self.scroll_frame, textvariable=self.output_dir_var, width=300).grid(
361
+ row=row, column=1, padx=5, pady=5
362
+ )
363
+ ctk.CTkButton(self.scroll_frame, text="浏览", width=60, command=self._browse_output_dir,
364
+ fg_color="#5a6a7a", hover_color="#4a5a6a").grid(
365
+ row=row, column=2, padx=5, pady=5, sticky="w"
366
+ )
367
+ row += 1
368
+
369
+ # 分隔线
370
+ ctk.CTkFrame(self.scroll_frame, height=2, fg_color="gray50").grid(
371
+ row=row, column=0, columnspan=3, padx=10, pady=15, sticky="ew"
372
+ )
373
+ row += 1
374
+
375
+ # ========== 模型选择 ==========
376
+ ctk.CTkLabel(
377
+ self.scroll_frame, text="模型选择",
378
+ font=ctk.CTkFont(size=16, weight="bold")
379
+ ).grid(row=row, column=0, columnspan=3, padx=10, pady=(10, 15), sticky="w")
380
+ row += 1
381
+
382
+ # Whisper模型
383
+ ctk.CTkLabel(self.scroll_frame, text="Whisper模型:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
384
+ self.whisper_combo = ctk.CTkComboBox(
385
+ self.scroll_frame, values=["(扫描中...)"], width=250
386
+ )
387
+ self.whisper_combo.grid(row=row, column=1, padx=5, pady=5, sticky="w")
388
+ ctk.CTkButton(self.scroll_frame, text="刷新", width=60, command=self._refresh_whisper_models,
389
+ fg_color="#5a6a7a", hover_color="#4a5a6a").grid(
390
+ row=row, column=2, padx=5, pady=5, sticky="w"
391
+ )
392
+ row += 1
393
+
394
+ # MFA字典
395
+ ctk.CTkLabel(self.scroll_frame, text="MFA字典:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
396
+ self.dict_combo = ctk.CTkComboBox(self.scroll_frame, values=["(扫描中...)"], width=250)
397
+ self.dict_combo.grid(row=row, column=1, padx=5, pady=5, sticky="w")
398
+ row += 1
399
+
400
+ # MFA声学模型
401
+ ctk.CTkLabel(self.scroll_frame, text="MFA声学模型:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
402
+ self.acoustic_combo = ctk.CTkComboBox(self.scroll_frame, values=["(扫描中...)"], width=250)
403
+ self.acoustic_combo.grid(row=row, column=1, padx=5, pady=5, sticky="w")
404
+ ctk.CTkButton(self.scroll_frame, text="刷新", width=60, command=self._refresh_mfa_models,
405
+ fg_color="#5a6a7a", hover_color="#4a5a6a").grid(
406
+ row=row, column=2, padx=5, pady=5, sticky="w"
407
+ )
408
+ row += 1
409
+
410
+ # 语言
411
+ ctk.CTkLabel(self.scroll_frame, text="转录语言:").grid(row=row, column=0, padx=10, pady=5, sticky="w")
412
+ self.language_var = ctk.StringVar(value="chinese")
413
+ ctk.CTkComboBox(
414
+ self.scroll_frame, values=["chinese", "japanese", "english"],
415
+ variable=self.language_var, width=150
416
+ ).grid(row=row, column=1, padx=5, pady=5, sticky="w")
417
+ row += 1
418
+
419
+ # 分隔线
420
+ ctk.CTkFrame(self.scroll_frame, height=2, fg_color="gray50").grid(
421
+ row=row, column=0, columnspan=3, padx=10, pady=15, sticky="ew"
422
+ )
423
+ row += 1
424
+
425
+ # ========== MFA状态 ==========
426
  self.mfa_status_label = ctk.CTkLabel(
427
+ self.scroll_frame, text="⏳ 检查 MFA 环境...",
 
428
  font=ctk.CTkFont(size=12)
429
  )
430
+ self.mfa_status_label.grid(row=row, column=0, columnspan=3, padx=10, pady=5, sticky="w")
431
+ row += 1
432
+
433
+ # 分隔线
434
+ ctk.CTkFrame(self.scroll_frame, height=2, fg_color="gray50").grid(
435
+ row=row, column=0, columnspan=3, padx=10, pady=15, sticky="ew"
436
+ )
437
+ row += 1
438
+
439
+ # ========== 执行按钮 ==========
440
+ ctk.CTkLabel(
441
+ self.scroll_frame, text="执行流程",
442
+ font=ctk.CTkFont(size=16, weight="bold")
443
+ ).grid(row=row, column=0, columnspan=3, padx=10, pady=(10, 15), sticky="w")
444
+ row += 1
445
+
446
+ # 按钮容器 - 优化排版
447
+ btn_container = ctk.CTkFrame(self.scroll_frame, fg_color="transparent")
448
+ btn_container.grid(row=row, column=0, columnspan=3, padx=10, pady=10, sticky="ew")
449
+
450
+ # 分步执行按钮 - 降饱和颜色
451
+ self.step0_btn = ctk.CTkButton(
452
+ btn_container, text="步骤0: 切片+转录",
453
+ command=self._run_step0, width=150, height=36,
454
+ fg_color="#5c7a9a", hover_color="#4a6888"
455
+ )
456
+ self.step0_btn.pack(side="left", padx=8)
457
+
458
+ self.step1_btn = ctk.CTkButton(
459
+ btn_container, text="步骤1: MFA对齐",
460
+ command=self._run_step1, width=150, height=36,
461
+ fg_color="#6a9a7a", hover_color="#588868"
462
+ )
463
+ self.step1_btn.pack(side="left", padx=8)
464
+ row += 1
465
+
466
+ # 一键执行 - 降饱和
467
+ self.full_btn = ctk.CTkButton(
468
+ self.scroll_frame, text="▶ 一键执行全部流程",
469
+ command=self._run_full, width=320, height=40,
470
+ fg_color="#8a6a8a", hover_color="#785878",
471
+ font=ctk.CTkFont(size=14, weight="bold")
472
+ )
473
+ self.full_btn.grid(row=row, column=0, columnspan=3, pady=15)
474
+ row += 1
475
 
476
+ # 进度提示
477
+ self.progress_label = ctk.CTkLabel(self.scroll_frame, text="", text_color="gray")
478
+ self.progress_label.grid(row=row, column=0, columnspan=3, padx=10, pady=5, sticky="w")
 
 
479
 
480
+ # 初始化模型列表
481
+ self.after(500, self._refresh_all_models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
  def _check_mfa_status(self):
 
484
  from src.mfa_runner import check_mfa_available
 
485
  if check_mfa_available():
486
+ self.mfa_status_label.configure(text="✅ MFA 环境已就绪", text_color="green")
 
 
 
487
  else:
488
+ self.mfa_status_label.configure(text="❌ MFA 环境不可用,请检查 tools/mfa_engine", text_color="red")
489
+
490
+ def _refresh_all_models(self):
491
+ self._refresh_whisper_models()
492
+ self._refresh_mfa_models()
493
+
494
+ def _refresh_whisper_models(self):
495
+ from src.pipeline import scan_whisper_models
496
+ models_dir = self.model_frame.get_models_dir()
497
+ models = scan_whisper_models(models_dir)
498
+
499
+ all_models = list(ConfigManager.WHISPER_MODELS.values())
500
+ preset_names = [m["name"] for m in all_models]
501
+
502
+ for m in models:
503
+ if m not in preset_names:
504
+ preset_names.append(m)
505
+
506
+ if preset_names:
507
+ self.whisper_combo.configure(values=preset_names)
508
+ self.whisper_combo.set(preset_names[0])
509
+ else:
510
+ self.whisper_combo.configure(values=["openai/whisper-small"])
511
+ self.whisper_combo.set("openai/whisper-small")
512
+
513
+ def _refresh_mfa_models(self):
514
+ from src.pipeline import scan_mfa_models
515
+ mfa_dir = self.model_frame.get_mfa_dir()
516
+ models = scan_mfa_models(os.path.dirname(mfa_dir))
517
+
518
+ if models["dictionary"]:
519
+ self.dict_combo.configure(values=models["dictionary"])
520
+ self.dict_combo.set(models["dictionary"][0])
521
+ else:
522
+ self.dict_combo.configure(values=["(未找到字典文件)"])
523
+ self.dict_combo.set("(未找到字典文件)")
524
+
525
+ if models["acoustic"]:
526
+ self.acoustic_combo.configure(values=models["acoustic"])
527
+ self.acoustic_combo.set(models["acoustic"][0])
528
+ else:
529
+ self.acoustic_combo.configure(values=["(未找到声学模型)"])
530
+ self.acoustic_combo.set("(未找到声学模型)")
531
 
532
+ def _browse_input_file(self):
533
+ path = filedialog.askopenfilename(
534
+ title="选择音频文件",
535
+ filetypes=[("音频文件", "*.wav *.mp3 *.flac *.ogg *.m4a"), ("所有文件", "*.*")]
536
+ )
537
+ if path:
538
+ self.input_audio_var.set(path)
539
+
540
+ def _browse_input_dir(self):
541
+ path = filedialog.askdirectory(title="选择音频文件夹")
542
  if path:
543
+ self.input_audio_var.set(path)
544
 
545
  def _browse_output_dir(self):
546
+ path = filedialog.askdirectory(title="选择输出目录")
547
  if path:
548
  self.output_dir_var.set(path)
549
+ self.config.set("bank_dir", path)
550
 
551
+ def _get_pipeline_config(self):
552
+ """获取流水线配置"""
553
+ from src.pipeline import PipelineConfig
554
+
555
+ mfa_dir = self.model_frame.get_mfa_dir()
556
+ dict_file = self.dict_combo.get()
557
+ acoustic_file = self.acoustic_combo.get()
558
+
559
+ dict_path = None
560
+ if dict_file and not dict_file.startswith("("):
561
+ dict_path = os.path.join(mfa_dir, dict_file)
562
+
563
+ acoustic_path = None
564
+ if acoustic_file and not acoustic_file.startswith("("):
565
+ acoustic_path = os.path.join(mfa_dir, acoustic_file)
566
+
567
+ return PipelineConfig(
568
+ source_name=self.source_name_var.get(),
569
+ input_path=self.input_audio_var.get(),
570
+ output_base_dir=self.output_dir_var.get(),
571
+ models_dir=self.model_frame.get_models_dir(),
572
+ whisper_model=self.whisper_combo.get(),
573
+ mfa_dict_path=dict_path,
574
+ mfa_model_path=acoustic_path,
575
+ language=self.language_var.get()
576
  )
 
 
577
 
578
+ def _set_buttons_state(self, state: str):
579
+ """设置所有按钮状态"""
580
+ for btn in [self.step0_btn, self.step1_btn, self.full_btn]:
581
+ btn.configure(state=state)
 
 
 
582
 
583
+ def _run_step0(self):
584
  if self._is_running:
585
  return
586
+ if not self._validate_input():
587
+ return
588
+ self._is_running = True
589
+ self._set_buttons_state("disabled")
590
+ threading.Thread(target=self._do_step0, daemon=True).start()
591
+
592
+ def _do_step0(self):
593
+ from src.pipeline import VoiceBankPipeline
594
+ config = self._get_pipeline_config()
595
+ pipeline = VoiceBankPipeline(config, self.log_callback)
596
 
597
+ self.log_callback("=" * 50)
598
+ self.log_callback("【步骤0】音频预处理 (VAD切片 + Whisper转录)")
599
+ success, msg, _ = pipeline.step0_preprocess()
600
+
601
+ if success:
602
+ self.log_callback(f"✅ {msg}")
603
+ else:
604
+ self.log_callback(f"❌ {msg}")
605
+ self.log_callback("=" * 50)
606
 
607
+ self.after(0, lambda: self._set_buttons_state("normal"))
608
+ self._is_running = False
609
+
610
+ def _run_step1(self):
611
+ if self._is_running:
612
  return
613
+ if not self._validate_source_name():
614
+ return
615
+ self._is_running = True
616
+ self._set_buttons_state("disabled")
617
+ threading.Thread(target=self._do_step1, daemon=True).start()
618
+
619
+ def _do_step1(self):
620
+ from src.pipeline import VoiceBankPipeline
621
+ config = self._get_pipeline_config()
622
+ pipeline = VoiceBankPipeline(config, self.log_callback)
623
+
624
+ self.log_callback("=" * 50)
625
+ self.log_callback("【步骤1】MFA语音对齐")
626
+ success, msg = pipeline.step1_mfa_align()
627
 
628
+ if success:
629
+ self.log_callback(f"✅ {msg}")
630
+ else:
631
+ self.log_callback(f"❌ {msg}")
632
+ self.log_callback("=" * 50)
633
+
634
+ self.after(0, lambda: self._set_buttons_state("normal"))
635
+ self._is_running = False
636
+
637
+ def _run_full(self):
638
+ if self._is_running:
639
+ return
640
+ if not self._validate_input():
641
+ return
642
  self._is_running = True
643
+ self._set_buttons_state("disabled")
644
+ threading.Thread(target=self._do_full, daemon=True).start()
645
+
646
+ def _do_full(self):
647
+ from src.pipeline import VoiceBankPipeline
648
+ config = self._get_pipeline_config()
649
+ pipeline = VoiceBankPipeline(config, self.log_callback)
650
 
651
+ success, msg = pipeline.run_make_pipeline()
652
+
653
+ if not success:
654
+ self.log_callback(f"❌ 流程中断: {msg}")
655
+
656
+ self.after(0, lambda: self._set_buttons_state("normal"))
657
+ self._is_running = False
658
 
659
+ def _validate_input(self) -> bool:
660
+ """验证输入"""
661
+ if not self.source_name_var.get().strip():
662
+ messagebox.showerror("错误", "请输入音源名称")
663
+ return False
664
+ if not self.input_audio_var.get().strip():
665
+ messagebox.showerror("错误", "请选择输入音频")
666
+ return False
667
+ if not self.output_dir_var.get().strip():
668
+ messagebox.showerror("错误", "请选择输出目录")
669
+ return False
670
+ return True
671
+
672
+ def _validate_source_name(self) -> bool:
673
+ """验证音源名称"""
674
+ if not self.source_name_var.get().strip():
675
+ messagebox.showerror("错误", "请输入音源名称")
676
+ return False
677
+ return True
678
+
679
+
680
+ class ExportSettingsDialog(ctk.CTkToplevel):
681
+ """导出设置弹窗"""
682
+
683
+ def __init__(self, master, plugin, voice_bank: str, bank_dir: str, log_callback):
684
+ super().__init__(master)
685
+ self.plugin = plugin
686
+ self.voice_bank = voice_bank
687
+ self.bank_dir = bank_dir
688
+ self.log_callback = log_callback
689
+ self._option_widgets = {}
690
+ self._is_running = False
691
 
692
+ self.title(f"导出设置 - {plugin.name}")
693
+ self.geometry("500x400")
694
+ self.resizable(True, True)
695
+ self.transient(master)
696
+ self.grab_set()
697
+
698
+ self._setup_ui()
699
+ self._center_window()
700
+
701
+ def _center_window(self):
702
+ """居中显示"""
703
+ self.update_idletasks()
704
+ w = self.winfo_width()
705
+ h = self.winfo_height()
706
+ x = (self.winfo_screenwidth() - w) // 2
707
+ y = (self.winfo_screenheight() - h) // 2
708
+ self.geometry(f"{w}x{h}+{x}+{y}")
709
+
710
+ def _setup_ui(self):
711
+ from src.export_plugins import OptionType
712
+
713
+ # 标题
714
+ header = ctk.CTkFrame(self)
715
+ header.pack(fill="x", padx=10, pady=10)
716
+ ctk.CTkLabel(
717
+ header, text=self.plugin.name,
718
+ font=ctk.CTkFont(size=16, weight="bold")
719
+ ).pack(anchor="w")
720
+ ctk.CTkLabel(
721
+ header, text=self.plugin.description,
722
+ text_color="gray"
723
+ ).pack(anchor="w")
724
+ ctk.CTkLabel(
725
+ header, text=f"音源: {self.voice_bank}",
726
+ text_color="gray"
727
+ ).pack(anchor="w")
728
+
729
+ # 选项区域(可滚动)
730
+ self.options_frame = ctk.CTkScrollableFrame(self)
731
+ self.options_frame.pack(fill="both", expand=True, padx=10, pady=5)
732
+
733
+ # 动态生成选项控件
734
+ for opt in self.plugin.get_options():
735
+ self._create_option_widget(opt)
736
+
737
+ # 底部按钮
738
+ btn_frame = ctk.CTkFrame(self, fg_color="transparent")
739
+ btn_frame.pack(fill="x", padx=10, pady=10)
740
+
741
+ self.cancel_btn = ctk.CTkButton(
742
+ btn_frame, text="取消", width=80,
743
+ fg_color="gray", command=self.destroy
744
+ )
745
+ self.cancel_btn.pack(side="left", padx=5)
746
+
747
+ self.reset_btn = ctk.CTkButton(
748
+ btn_frame, text="恢复默认", width=100,
749
+ fg_color="#607D8B", command=self._reset_defaults
750
+ )
751
+ self.reset_btn.pack(side="left", padx=5)
752
 
753
+ self.export_btn = ctk.CTkButton(
754
+ btn_frame, text="导出", width=100,
755
+ fg_color="#6a9a7a", hover_color="#588868", command=self._do_export
 
 
 
 
 
756
  )
757
+ self.export_btn.pack(side="right", padx=5)
758
+
759
+ def _create_option_widget(self, opt):
760
+ """创建选项控件"""
761
+ from src.export_plugins import OptionType
762
+
763
+ frame = ctk.CTkFrame(self.options_frame, fg_color="transparent")
764
+ frame.pack(fill="x", pady=5)
765
+
766
+ if opt.option_type == OptionType.LABEL:
767
+ ctk.CTkLabel(frame, text=opt.label, text_color="gray").pack(anchor="w")
768
+ return
769
+
770
+ ctk.CTkLabel(frame, text=opt.label).pack(anchor="w")
771
+
772
+ if opt.option_type == OptionType.TEXT:
773
+ var = ctk.StringVar(value=str(opt.default or ""))
774
+ widget = ctk.CTkEntry(frame, textvariable=var, width=300)
775
+ widget.pack(anchor="w", pady=2)
776
+ self._option_widgets[opt.key] = ("text", var)
777
+
778
+ elif opt.option_type == OptionType.NUMBER:
779
+ var = ctk.StringVar(value=str(opt.default or 0))
780
+ widget = ctk.CTkEntry(frame, textvariable=var, width=150)
781
+ widget.pack(anchor="w", pady=2)
782
+ self._option_widgets[opt.key] = ("number", var, opt.min_value, opt.max_value)
783
+
784
+ elif opt.option_type == OptionType.SWITCH:
785
+ var = ctk.BooleanVar(value=bool(opt.default))
786
+ widget = ctk.CTkSwitch(frame, text="", variable=var)
787
+ widget.pack(anchor="w", pady=2)
788
+ self._option_widgets[opt.key] = ("switch", var)
789
+
790
+ elif opt.option_type == OptionType.COMBO:
791
+ var = ctk.StringVar(value=str(opt.default or ""))
792
+ widget = ctk.CTkComboBox(frame, values=opt.choices, variable=var, width=200)
793
+ widget.pack(anchor="w", pady=2)
794
+ self._option_widgets[opt.key] = ("combo", var)
795
+
796
+ elif opt.option_type == OptionType.FILE:
797
+ var = ctk.StringVar(value=str(opt.default or ""))
798
+ entry_frame = ctk.CTkFrame(frame, fg_color="transparent")
799
+ entry_frame.pack(anchor="w", pady=2)
800
+ entry = ctk.CTkEntry(entry_frame, textvariable=var, width=250)
801
+ entry.pack(side="left")
802
+ btn = ctk.CTkButton(
803
+ entry_frame, text="浏览", width=60,
804
+ command=lambda v=var, ft=opt.file_types: self._browse_file(v, ft)
805
+ )
806
+ btn.pack(side="left", padx=5)
807
+ self._option_widgets[opt.key] = ("file", var)
808
+
809
+ elif opt.option_type == OptionType.FOLDER:
810
+ var = ctk.StringVar(value=str(opt.default or ""))
811
+ entry_frame = ctk.CTkFrame(frame, fg_color="transparent")
812
+ entry_frame.pack(anchor="w", pady=2)
813
+ entry = ctk.CTkEntry(entry_frame, textvariable=var, width=250)
814
+ entry.pack(side="left")
815
+ btn = ctk.CTkButton(
816
+ entry_frame, text="浏览", width=60,
817
+ command=lambda v=var: self._browse_folder(v)
818
+ )
819
+ btn.pack(side="left", padx=5)
820
+ self._option_widgets[opt.key] = ("folder", var)
821
+
822
+ if opt.description:
823
+ ctk.CTkLabel(
824
+ frame, text=opt.description,
825
+ text_color="gray", font=ctk.CTkFont(size=11)
826
+ ).pack(anchor="w")
827
+
828
+ def _browse_file(self, var, file_types):
829
+ ft = file_types if file_types else [("所有文件", "*.*")]
830
+ path = filedialog.askopenfilename(filetypes=ft)
831
+ if path:
832
+ var.set(path)
833
+
834
+ def _browse_folder(self, var):
835
+ path = filedialog.askdirectory()
836
+ if path:
837
+ var.set(path)
838
+
839
+ def _get_options_values(self) -> dict:
840
+ values = {}
841
+ for key, widget_info in self._option_widgets.items():
842
+ widget_type = widget_info[0]
843
+ var = widget_info[1]
844
+
845
+ if widget_type == "number":
846
+ try:
847
+ val = float(var.get())
848
+ min_val = widget_info[2]
849
+ max_val = widget_info[3]
850
+ if min_val is not None:
851
+ val = max(min_val, val)
852
+ if max_val is not None:
853
+ val = min(max_val, val)
854
+ values[key] = int(val) if val == int(val) else val
855
+ except ValueError:
856
+ values[key] = 0
857
+ elif widget_type == "switch":
858
+ values[key] = var.get()
859
+ else:
860
+ values[key] = var.get()
861
+
862
+ return values
863
+
864
+ def _reset_defaults(self):
865
+ for opt in self.plugin.get_options():
866
+ if opt.key in self._option_widgets:
867
+ widget_info = self._option_widgets[opt.key]
868
+ var = widget_info[1]
869
+ if widget_info[0] == "switch":
870
+ var.set(bool(opt.default))
871
+ else:
872
+ var.set(str(opt.default or ""))
873
+
874
+ def _do_export(self):
875
+ if self._is_running:
876
+ return
877
+
878
+ self._is_running = True
879
+ self._set_buttons_state("disabled")
880
+
881
+ options = self._get_options_values()
882
+ threading.Thread(target=self._run_export, args=(options,), daemon=True).start()
883
+
884
+ def _run_export(self, options: dict):
885
+ self.log_callback("=" * 50)
886
+ self.log_callback(f"【{self.plugin.name}】音源: {self.voice_bank}")
887
+
888
+ self.plugin.set_progress_callback(self.log_callback)
889
+ success, msg = self.plugin.export(self.voice_bank, self.bank_dir, options)
890
 
891
  if success:
892
+ self.log_callback(f"✅ {msg}")
 
893
  else:
894
+ self.log_callback(f"❌ {msg}")
 
895
  self.log_callback("=" * 50)
896
 
897
+ self.after(0, self._on_export_complete)
898
+
899
+ def _on_export_complete(self):
900
  self._is_running = False
901
+ self._set_buttons_state("normal")
902
+ messagebox.showinfo("完成", "导出完成")
903
+
904
+ def _set_buttons_state(self, state: str):
905
+ self.cancel_btn.configure(state=state)
906
+ self.reset_btn.configure(state=state)
907
+ self.export_btn.configure(state=state)
908
+
909
+
910
+ class ExportVoiceBankFrame(ctk.CTkFrame):
911
+ """导出音源页面"""
912
+
913
+ def __init__(self, master, log_callback, config: ConfigManager):
914
+ super().__init__(master)
915
+ self.log_callback = log_callback
916
+ self.config = config
917
+ self._plugins = {}
918
+ self._load_plugins()
919
+ self._setup_ui()
920
+ self.after(500, self._refresh_voice_banks)
921
+
922
+ def _load_plugins(self):
923
+ from src.export_plugins import load_plugins
924
+ plugins_dir = os.path.join(
925
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
926
+ "export_plugins"
927
+ )
928
+ self._plugins = load_plugins(plugins_dir)
929
+
930
+ def _setup_ui(self):
931
+ # 音源选择区域
932
+ ctk.CTkLabel(
933
+ self, text="选择音源",
934
+ font=ctk.CTkFont(size=16, weight="bold")
935
+ ).grid(row=0, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w")
936
+
937
+ ctk.CTkLabel(self, text="音源:").grid(row=1, column=0, padx=10, pady=5, sticky="w")
938
+ self.voice_bank_var = ctk.StringVar()
939
+ self.voice_bank_combo = ctk.CTkComboBox(
940
+ self, values=["(扫描中...)"],
941
+ variable=self.voice_bank_var, width=250,
942
+ command=self._on_voice_bank_change
943
+ )
944
+ self.voice_bank_combo.grid(row=1, column=1, padx=5, pady=5, sticky="w")
945
+ ctk.CTkButton(self, text="刷新", width=60, command=self._refresh_voice_banks,
946
+ fg_color="#5a6a7a", hover_color="#4a5a6a").grid(
947
+ row=1, column=2, padx=5, pady=5, sticky="w"
948
+ )
949
+
950
+ # 音源信息
951
+ self.info_label = ctk.CTkLabel(self, text="", text_color="gray")
952
+ self.info_label.grid(row=2, column=0, columnspan=3, padx=10, pady=5, sticky="w")
953
+
954
+ # 分隔线
955
+ ctk.CTkFrame(self, height=2, fg_color="gray50").grid(
956
+ row=3, column=0, columnspan=3, padx=10, pady=15, sticky="ew"
957
+ )
958
+
959
+ # 导出方式区域
960
+ ctk.CTkLabel(
961
+ self, text="导出方式",
962
+ font=ctk.CTkFont(size=16, weight="bold")
963
+ ).grid(row=4, column=0, columnspan=3, padx=10, pady=(10, 5), sticky="w")
964
+
965
+ # 插件列表(可滚动)
966
+ self.plugins_frame = ctk.CTkScrollableFrame(self, height=250)
967
+ self.plugins_frame.grid(row=5, column=0, columnspan=3, padx=10, pady=10, sticky="nsew")
968
+
969
+ # 动态生成插件卡片
970
+ self._create_plugin_cards()
971
+
972
+ # 配置行列权重
973
+ self.grid_columnconfigure(1, weight=1)
974
+ self.grid_rowconfigure(5, weight=1)
975
+
976
+ def _create_plugin_cards(self):
977
+ """创建插件卡片 - 整个卡片可点击"""
978
+ for idx, (name, plugin) in enumerate(self._plugins.items()):
979
+ # 卡片容器 - 作为按钮
980
+ card = ctk.CTkFrame(
981
+ self.plugins_frame,
982
+ fg_color=("#e8e8e8", "#2a2a2a"),
983
+ corner_radius=8
984
+ )
985
+ card.pack(fill="x", pady=6, padx=4)
986
+ card.bind("<Enter>", lambda e, c=card: c.configure(fg_color=("#d8d8d8", "#3a3a3a")))
987
+ card.bind("<Leave>", lambda e, c=card: c.configure(fg_color=("#e8e8e8", "#2a2a2a")))
988
+ card.bind("<Button-1>", lambda e, p=plugin: self._open_plugin_settings(p))
989
+
990
+ # 内容容器
991
+ content = ctk.CTkFrame(card, fg_color="transparent")
992
+ content.pack(fill="x", padx=12, pady=10)
993
+ content.bind("<Button-1>", lambda e, p=plugin: self._open_plugin_settings(p))
994
+
995
+ # 插件名称 - 白色,较大,左中部
996
+ name_label = ctk.CTkLabel(
997
+ content, text=name,
998
+ font=ctk.CTkFont(size=15, weight="bold"),
999
+ text_color=("#1a1a1a", "#ffffff")
1000
+ )
1001
+ name_label.pack(anchor="w")
1002
+ name_label.bind("<Button-1>", lambda e, p=plugin: self._open_plugin_settings(p))
1003
+
1004
+ # 描述
1005
+ desc_label = ctk.CTkLabel(
1006
+ content, text=plugin.description,
1007
+ text_color="gray",
1008
+ font=ctk.CTkFont(size=12)
1009
+ )
1010
+ desc_label.pack(anchor="w", pady=(2, 0))
1011
+ desc_label.bind("<Button-1>", lambda e, p=plugin: self._open_plugin_settings(p))
1012
+
1013
+ # 作者和版本
1014
+ if plugin.author:
1015
+ meta_label = ctk.CTkLabel(
1016
+ content, text=f"作者: {plugin.author} | 版本: {plugin.version}",
1017
+ text_color="gray",
1018
+ font=ctk.CTkFont(size=10)
1019
+ )
1020
+ meta_label.pack(anchor="w", pady=(2, 0))
1021
+ meta_label.bind("<Button-1>", lambda e, p=plugin: self._open_plugin_settings(p))
1022
+
1023
+ def _open_plugin_settings(self, plugin):
1024
+ """打开插件设置弹窗"""
1025
+ voice_bank = self.voice_bank_var.get()
1026
+ if not voice_bank or voice_bank.startswith("("):
1027
+ messagebox.showerror("错误", "请先选择有效的音源")
1028
+ return
1029
+
1030
+ bank_dir = self.config.get("bank_dir", "bank")
1031
+ ExportSettingsDialog(self, plugin, voice_bank, bank_dir, self.log_callback)
1032
+
1033
+ def _refresh_voice_banks(self):
1034
+ """刷新音源列表"""
1035
+ bank_dir = self.config.get("bank_dir", "bank")
1036
+ voice_banks = []
1037
+
1038
+ if os.path.exists(bank_dir):
1039
+ for name in os.listdir(bank_dir):
1040
+ source_dir = os.path.join(bank_dir, name)
1041
+ if os.path.isdir(source_dir) and not name.startswith('.'):
1042
+ slices_dir = os.path.join(source_dir, "slices")
1043
+ textgrid_dir = os.path.join(source_dir, "textgrid")
1044
+ if os.path.exists(slices_dir) or os.path.exists(textgrid_dir):
1045
+ voice_banks.append(name)
1046
+
1047
+ if voice_banks:
1048
+ self.voice_bank_combo.configure(values=voice_banks)
1049
+ self.voice_bank_combo.set(voice_banks[0])
1050
+ self._on_voice_bank_change(voice_banks[0])
1051
+ else:
1052
+ self.voice_bank_combo.configure(values=["(未找到音源)"])
1053
+ self.voice_bank_combo.set("(未找到音源)")
1054
+ self.info_label.configure(text="")
1055
+
1056
+ def _on_voice_bank_change(self, choice):
1057
+ """音源选择变化"""
1058
+ if choice.startswith("("):
1059
+ self.info_label.configure(text="")
1060
+ return
1061
+
1062
+ bank_dir = self.config.get("bank_dir", "bank")
1063
+ source_dir = os.path.join(bank_dir, choice)
1064
+ slices_dir = os.path.join(source_dir, "slices")
1065
+ textgrid_dir = os.path.join(source_dir, "textgrid")
1066
+
1067
+ slices_count = 0
1068
+ textgrid_count = 0
1069
+
1070
+ if os.path.exists(slices_dir):
1071
+ slices_count = len([f for f in os.listdir(slices_dir) if f.endswith('.wav')])
1072
+ if os.path.exists(textgrid_dir):
1073
+ textgrid_count = len([f for f in os.listdir(textgrid_dir) if f.endswith('.TextGrid')])
1074
+
1075
+ self.info_label.configure(
1076
+ text=f"切片: {slices_count} 个 | TextGrid: {textgrid_count} 个"
1077
+ )
1078
+
1079
+
1080
+ class SettingsFrame(ctk.CTkFrame):
1081
+ """设置页面"""
1082
+
1083
+ def __init__(self, master, config: ConfigManager, on_log_toggle):
1084
+ super().__init__(master)
1085
+ self.config = config
1086
+ self.on_log_toggle = on_log_toggle
1087
+ self._setup_ui()
1088
+
1089
+ def _setup_ui(self):
1090
+ # 标题
1091
+ ctk.CTkLabel(
1092
+ self, text="应用设置",
1093
+ font=ctk.CTkFont(size=18, weight="bold")
1094
+ ).pack(anchor="w", padx=15, pady=(15, 20))
1095
+
1096
+ # 日志设置区域
1097
+ log_frame = ctk.CTkFrame(self, fg_color="transparent")
1098
+ log_frame.pack(fill="x", padx=15, pady=10)
1099
+
1100
+ ctk.CTkLabel(
1101
+ log_frame, text="界面设置",
1102
+ font=ctk.CTkFont(size=14, weight="bold")
1103
+ ).pack(anchor="w", pady=(0, 10))
1104
+
1105
+ # 显示日志开关
1106
+ log_switch_frame = ctk.CTkFrame(log_frame, fg_color="transparent")
1107
+ log_switch_frame.pack(fill="x", pady=5)
1108
+
1109
+ ctk.CTkLabel(log_switch_frame, text="显示日志输出面板").pack(side="left")
1110
+
1111
+ self.show_log_var = ctk.BooleanVar(value=self.config.get("show_log", False))
1112
+ self.log_switch = ctk.CTkSwitch(
1113
+ log_switch_frame, text="",
1114
+ variable=self.show_log_var,
1115
+ command=self._on_log_switch_change
1116
+ )
1117
+ self.log_switch.pack(side="right")
1118
+
1119
+ ctk.CTkLabel(
1120
+ log_frame, text="开启后将在主界面底部显示日志输出区域",
1121
+ text_color="gray", font=ctk.CTkFont(size=11)
1122
+ ).pack(anchor="w", pady=(2, 0))
1123
+
1124
+ # 分隔线
1125
+ ctk.CTkFrame(self, height=1, fg_color="gray50").pack(fill="x", padx=15, pady=20)
1126
+
1127
+ # 关于区域
1128
+ about_frame = ctk.CTkFrame(self, fg_color="transparent")
1129
+ about_frame.pack(fill="x", padx=15, pady=10)
1130
+
1131
+ ctk.CTkLabel(
1132
+ about_frame, text="关于",
1133
+ font=ctk.CTkFont(size=14, weight="bold")
1134
+ ).pack(anchor="w", pady=(0, 10))
1135
+
1136
+ ctk.CTkLabel(
1137
+ about_frame, text="语音数据集处理工具",
1138
+ font=ctk.CTkFont(size=12)
1139
+ ).pack(anchor="w")
1140
+
1141
+ ctk.CTkLabel(
1142
+ about_frame, text="基于 CustomTkinter 构建",
1143
+ text_color="gray", font=ctk.CTkFont(size=11)
1144
+ ).pack(anchor="w", pady=(2, 0))
1145
+
1146
+ def _on_log_switch_change(self):
1147
+ """日志开关变化"""
1148
+ show_log = self.show_log_var.get()
1149
+ self.config.set("show_log", show_log)
1150
+ self.on_log_toggle(show_log)
1151
 
1152
 
1153
  class App(ctk.CTk):
 
1155
 
1156
  def __init__(self):
1157
  super().__init__()
 
1158
  self.title("语音数据集处理工具")
1159
+ self.geometry("750x720")
1160
+ self.minsize(700, 620)
1161
 
1162
+ self.config = ConfigManager()
1163
  self._setup_ui()
1164
  logger.info("应用启动")
1165
 
1166
  def _setup_ui(self):
 
1167
  self.tabview = ctk.CTkTabview(self)
1168
  self.tabview.pack(fill="both", expand=True, padx=10, pady=10)
1169
 
1170
+ tab1 = self.tabview.add("模型下载")
1171
+ tab2 = self.tabview.add("制作音源")
1172
+ tab3 = self.tabview.add("导出音源")
1173
+ tab4 = self.tabview.add("设置")
 
1174
 
1175
+ self.download_frame = ModelDownloadFrame(tab1, self._log, self.config)
 
1176
  self.download_frame.pack(fill="both", expand=True, padx=5, pady=5)
1177
 
1178
+ self.make_frame = MakeVoiceBankFrame(tab2, self._log, self.config, self.download_frame)
1179
+ self.make_frame.pack(fill="both", expand=True, padx=5, pady=5)
1180
 
1181
+ self.export_frame = ExportVoiceBankFrame(tab3, self._log, self.config)
1182
+ self.export_frame.pack(fill="both", expand=True, padx=5, pady=5)
1183
 
1184
+ self.settings_frame = SettingsFrame(tab4, self.config, self._toggle_log_panel)
1185
+ self.settings_frame.pack(fill="both", expand=True, padx=5, pady=5)
1186
 
1187
+ # 日志区域 - 默认隐藏
1188
+ self.log_frame = ctk.CTkFrame(self)
1189
+ ctk.CTkLabel(self.log_frame, text="日志输出:").pack(anchor="w", padx=5, pady=2)
1190
+ self.log_text = ctk.CTkTextbox(self.log_frame, height=100)
 
 
 
1191
  self.log_text.pack(fill="x", padx=5, pady=5)
1192
+
1193
+ # 根据配置决定是否显示日志
1194
+ if self.config.get("show_log", False):
1195
+ self.log_frame.pack(fill="x", padx=10, pady=(0, 10))
1196
+
1197
+ def _toggle_log_panel(self, show: bool):
1198
+ """切换日志面板显示"""
1199
+ if show:
1200
+ self.log_frame.pack(fill="x", padx=10, pady=(0, 10))
1201
+ else:
1202
+ self.log_frame.pack_forget()
1203
 
1204
  def _log(self, message):
 
1205
  self.log_text.insert("end", f"{message}\n")
1206
  self.log_text.see("end")
1207
 
1208
 
1209
  def main():
 
1210
  app = App()
1211
  app.mainloop()
1212
 
src/mfa_runner.py CHANGED
@@ -112,6 +112,7 @@ def run_mfa_alignment(
112
  str(model_path),
113
  str(output_dir),
114
  "--temp_directory", str(temp_dir),
 
115
  ]
116
 
117
  if clean:
 
112
  str(model_path),
113
  str(output_dir),
114
  "--temp_directory", str(temp_dir),
115
+ "--use_mp", "false", # 禁用多进程,避免Windows问题
116
  ]
117
 
118
  if clean:
src/pipeline.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 音源制作流水线
4
+ 将所有非GUI的业务逻辑集中管理
5
+ """
6
+
7
+ import os
8
+ import glob
9
+ import json
10
+ import logging
11
+ import shutil
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Optional, Callable, List, Tuple, Dict
15
+ from dataclasses import dataclass, asdict
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class PipelineConfig:
22
+ """流水线配置"""
23
+ source_name: str # 音源名称
24
+ input_path: str # 输入音频路径(文件或目录)
25
+ output_base_dir: str # 输出基础目录 (bank)
26
+ models_dir: str # 模型目录
27
+
28
+ # 模型配置
29
+ whisper_model: str = "openai/whisper-small"
30
+ mfa_dict_path: Optional[str] = None
31
+ mfa_model_path: Optional[str] = None
32
+
33
+ # 处理参数
34
+ language: str = "chinese"
35
+ single_speaker: bool = True
36
+ clean_mfa_cache: bool = True
37
+ max_samples_per_word: int = 100
38
+
39
+ @property
40
+ def source_dir(self) -> str:
41
+ """音源目录: bank/[音源名称]"""
42
+ return os.path.join(self.output_base_dir, self.source_name)
43
+
44
+ @property
45
+ def slices_dir(self) -> str:
46
+ """切片目录: bank/[音源名称]/slices"""
47
+ return os.path.join(self.source_dir, "slices")
48
+
49
+ @property
50
+ def textgrid_dir(self) -> str:
51
+ """TextGrid目录: bank/[音源名称]/textgrid"""
52
+ return os.path.join(self.source_dir, "textgrid")
53
+
54
+ @property
55
+ def segments_dir(self) -> str:
56
+ """分字片段临时目录(处理完成后可删除)"""
57
+ return os.path.join(self.output_base_dir, ".temp_segments", self.source_name)
58
+
59
+ @property
60
+ def export_dir(self) -> str:
61
+ """导出目录: export/[音源名称]/simple_export"""
62
+ # 导出到项目根目录的 export 文件夹
63
+ base = Path(self.output_base_dir).parent
64
+ return os.path.join(base, "export", self.source_name, "simple_export")
65
+
66
+ @property
67
+ def meta_file(self) -> str:
68
+ """元文件路径: bank/[音源名称]/meta.json"""
69
+ return os.path.join(self.source_dir, "meta.json")
70
+
71
+
72
+ @dataclass
73
+ class VoiceBankMeta:
74
+ """
75
+ 音源元信息
76
+
77
+ 存储制作音源时的设置和模型信息
78
+ """
79
+ # 基本信息
80
+ source_name: str
81
+ created_at: str # ISO格式时间戳
82
+ updated_at: str # ISO格式时间戳
83
+
84
+ # 模型信息
85
+ whisper_model: str # Whisper模型名称
86
+ mfa_dict: str # MFA字典文件名
87
+ mfa_acoustic: str # MFA声学模型文件名
88
+
89
+ # 处理参数
90
+ language: str # 转录语言
91
+ single_speaker: bool # 单说话人模式
92
+
93
+ # 统计信息
94
+ slice_count: int = 0 # 切片数量
95
+ textgrid_count: int = 0 # TextGrid文件数量
96
+
97
+ @classmethod
98
+ def from_config(cls, config: PipelineConfig) -> "VoiceBankMeta":
99
+ """从流水线配置创建元信息"""
100
+ now = datetime.now().isoformat()
101
+
102
+ # 提取模型文件名(不含路径)
103
+ mfa_dict = ""
104
+ if config.mfa_dict_path:
105
+ mfa_dict = os.path.basename(config.mfa_dict_path)
106
+
107
+ mfa_acoustic = ""
108
+ if config.mfa_model_path:
109
+ mfa_acoustic = os.path.basename(config.mfa_model_path)
110
+
111
+ return cls(
112
+ source_name=config.source_name,
113
+ created_at=now,
114
+ updated_at=now,
115
+ whisper_model=config.whisper_model,
116
+ mfa_dict=mfa_dict,
117
+ mfa_acoustic=mfa_acoustic,
118
+ language=config.language,
119
+ single_speaker=config.single_speaker
120
+ )
121
+
122
+ def to_dict(self) -> Dict:
123
+ """转换为字典"""
124
+ return asdict(self)
125
+
126
+ @classmethod
127
+ def from_dict(cls, data: Dict) -> "VoiceBankMeta":
128
+ """从字典创建"""
129
+ return cls(**data)
130
+
131
+ def save(self, path: str):
132
+ """保存到文件"""
133
+ with open(path, 'w', encoding='utf-8') as f:
134
+ json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
135
+
136
+ @classmethod
137
+ def load(cls, path: str) -> Optional["VoiceBankMeta"]:
138
+ """从文件加载"""
139
+ if not os.path.exists(path):
140
+ return None
141
+ try:
142
+ with open(path, 'r', encoding='utf-8') as f:
143
+ data = json.load(f)
144
+ return cls.from_dict(data)
145
+ except Exception as e:
146
+ logger.warning(f"加载元文件失败: {e}")
147
+ return None
148
+
149
+
150
+ class VoiceBankPipeline:
151
+ """音源制作流水线"""
152
+
153
+ def __init__(
154
+ self,
155
+ config: PipelineConfig,
156
+ progress_callback: Optional[Callable[[str], None]] = None
157
+ ):
158
+ self.config = config
159
+ self.progress_callback = progress_callback
160
+ self._vad_model = None
161
+ self._get_speech_timestamps = None
162
+ self._whisper_model = None
163
+ self._whisper_processor = None
164
+ self._device = None
165
+
166
+ def _log(self, msg: str):
167
+ """记录日志"""
168
+ logger.info(msg)
169
+ if self.progress_callback:
170
+ self.progress_callback(msg)
171
+
172
+ def _ensure_dirs(self):
173
+ """确保目录存在"""
174
+ os.makedirs(self.config.source_dir, exist_ok=True)
175
+ os.makedirs(self.config.slices_dir, exist_ok=True)
176
+
177
+ def _save_meta(self, slice_count: int = 0):
178
+ """
179
+ 保存元文件
180
+
181
+ 参数:
182
+ slice_count: 切片数量
183
+ """
184
+ # 统计TextGrid文件数量
185
+ textgrid_count = 0
186
+ if os.path.exists(self.config.textgrid_dir):
187
+ textgrid_count = len([
188
+ f for f in os.listdir(self.config.textgrid_dir)
189
+ if f.endswith('.TextGrid')
190
+ ])
191
+
192
+ # 检查是否已有元文件(更新而非覆盖)
193
+ existing_meta = VoiceBankMeta.load(self.config.meta_file)
194
+
195
+ if existing_meta:
196
+ # 更新现有元文件
197
+ existing_meta.updated_at = datetime.now().isoformat()
198
+ existing_meta.whisper_model = self.config.whisper_model
199
+ existing_meta.mfa_dict = os.path.basename(self.config.mfa_dict_path) if self.config.mfa_dict_path else ""
200
+ existing_meta.mfa_acoustic = os.path.basename(self.config.mfa_model_path) if self.config.mfa_model_path else ""
201
+ existing_meta.language = self.config.language
202
+ existing_meta.single_speaker = self.config.single_speaker
203
+ if slice_count > 0:
204
+ existing_meta.slice_count = slice_count
205
+ existing_meta.textgrid_count = textgrid_count
206
+ meta = existing_meta
207
+ else:
208
+ # 创建新元文件
209
+ meta = VoiceBankMeta.from_config(self.config)
210
+ meta.slice_count = slice_count
211
+ meta.textgrid_count = textgrid_count
212
+
213
+ meta.save(self.config.meta_file)
214
+ self._log(f"元文件已保存: {self.config.meta_file}")
215
+
216
+ # ==================== 模型加载 ====================
217
+
218
+ def _load_vad_model(self):
219
+ """加载VAD模型"""
220
+ if self._vad_model is not None:
221
+ return
222
+
223
+ self._log("正在加载 Silero VAD 模型...")
224
+ from src.silero_vad_downloader import ensure_vad_model
225
+ from silero_vad import load_silero_vad, get_speech_timestamps
226
+
227
+ # 确保模型已下载
228
+ model_path = ensure_vad_model(self.config.models_dir, self.progress_callback)
229
+
230
+ # 使用 silero_vad 包加载本地 ONNX 模型
231
+ self._vad_model = load_silero_vad(onnx=True)
232
+ self._get_speech_timestamps = get_speech_timestamps
233
+ self._log("VAD 模型加载完成")
234
+
235
+ def _load_whisper_model(self):
236
+ """加载Whisper模型"""
237
+ if self._whisper_model is not None:
238
+ return
239
+
240
+ self._log(f"正在加载 Whisper 模型: {self.config.whisper_model}...")
241
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
242
+ import torch
243
+
244
+ cache_dir = os.path.join(self.config.models_dir, "whisper")
245
+ os.makedirs(cache_dir, exist_ok=True)
246
+ os.environ["HF_HOME"] = cache_dir
247
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
248
+
249
+ self._whisper_processor = WhisperProcessor.from_pretrained(
250
+ self.config.whisper_model,
251
+ cache_dir=cache_dir
252
+ )
253
+ self._whisper_model = WhisperForConditionalGeneration.from_pretrained(
254
+ self.config.whisper_model,
255
+ cache_dir=cache_dir,
256
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
257
+ )
258
+
259
+ # 移动到GPU(如果可用)
260
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
261
+ self._whisper_model.to(self._device)
262
+
263
+ self._log("Whisper 模型加载完成")
264
+
265
+ # ==================== 步骤0: VAD切片 + 转录 ====================
266
+
267
+ def step0_preprocess(self) -> Tuple[bool, str, List[str]]:
268
+ """
269
+ 步骤0: VAD切片 + Whisper转录
270
+
271
+ 输入: 原始音频文件
272
+ 输出: bank/[音源名称]/slices/ 下的 .wav 和 .lab 文件
273
+ """
274
+ try:
275
+ self._ensure_dirs()
276
+ self._load_vad_model()
277
+ self._load_whisper_model()
278
+
279
+ # 收集输入文件
280
+ input_files = self._collect_audio_files(self.config.input_path)
281
+ if not input_files:
282
+ return False, "未找到音频文件", []
283
+
284
+ self._log(f"找到 {len(input_files)} 个音频文件")
285
+
286
+ all_slices = []
287
+ for idx, audio_file in enumerate(input_files):
288
+ basename = Path(audio_file).stem
289
+ self._log(f"处理 [{idx+1}/{len(input_files)}]: {basename}")
290
+
291
+ # VAD切片
292
+ slices = self._vad_split(audio_file, self.config.slices_dir, basename)
293
+
294
+ # 转录每个切片
295
+ for slice_path in slices:
296
+ text = self._transcribe(slice_path)
297
+ if text:
298
+ self._write_lab(slice_path, text)
299
+ all_slices.append(slice_path)
300
+ self._log(f" {Path(slice_path).name} -> {text[:30]}...")
301
+ else:
302
+ self._log(f" 跳过空转录: {Path(slice_path).name}")
303
+
304
+ # 保存元文件
305
+ self._save_meta(slice_count=len(all_slices))
306
+
307
+ return True, f"预处理完成,共 {len(all_slices)} 个切片", all_slices
308
+
309
+ except Exception as e:
310
+ logger.error(f"预处理失败: {e}", exc_info=True)
311
+ return False, str(e), []
312
+
313
+ def _collect_audio_files(self, path: str) -> List[str]:
314
+ """收集音频文件"""
315
+ extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a')
316
+ if os.path.isfile(path):
317
+ return [path] if path.lower().endswith(extensions) else []
318
+ elif os.path.isdir(path):
319
+ return [
320
+ os.path.join(path, f) for f in os.listdir(path)
321
+ if f.lower().endswith(extensions)
322
+ ]
323
+ return []
324
+
325
+ def _vad_split(self, audio_path: str, output_dir: str, prefix: str) -> List[str]:
326
+ """
327
+ VAD切片
328
+
329
+ 输出格式统一为: 16bit 44.1kHz 单声道 WAV
330
+ """
331
+ import torch
332
+ import soundfile as sf
333
+ import numpy as np
334
+
335
+ # 标准输出格式
336
+ TARGET_SR = 44100
337
+
338
+ # 读取并转换为标准格式
339
+ audio, sr = sf.read(audio_path, dtype='float32')
340
+
341
+ # 转换为单声道
342
+ if len(audio.shape) > 1:
343
+ audio = np.mean(audio, axis=1)
344
+
345
+ # 重采样到 44.1kHz(标准格式)
346
+ if sr != TARGET_SR:
347
+ import torchaudio
348
+ audio_tensor = torch.from_numpy(audio).float()
349
+ resampler = torchaudio.transforms.Resample(sr, TARGET_SR)
350
+ audio = resampler(audio_tensor).numpy()
351
+
352
+ # VAD 需要 16kHz,单独重采样用于检测
353
+ import torchaudio
354
+ audio_tensor = torch.from_numpy(audio).float()
355
+ resampler_16k = torchaudio.transforms.Resample(TARGET_SR, 16000)
356
+ wav_16k = resampler_16k(audio_tensor)
357
+
358
+ # 获取语音时间戳(基于16kHz)
359
+ timestamps = self._get_speech_timestamps(
360
+ wav_16k, self._vad_model,
361
+ threshold=0.5,
362
+ min_speech_duration_ms=250,
363
+ min_silence_duration_ms=100,
364
+ sampling_rate=16000
365
+ )
366
+
367
+ self._log(f" 检测到 {len(timestamps)} 个语音片段")
368
+
369
+ output_files = []
370
+ for i, ts in enumerate(timestamps):
371
+ # 将16kHz的时间戳转换为44.1kHz的采样点索引
372
+ start = int(ts['start'] * TARGET_SR / 16000)
373
+ end = int(ts['end'] * TARGET_SR / 16000)
374
+ segment = audio[start:end]
375
+
376
+ output_path = os.path.join(output_dir, f"{prefix}_{i:04d}.wav")
377
+ # 写入 16bit 44.1kHz 单声道 WAV
378
+ sf.write(output_path, segment, TARGET_SR, subtype='PCM_16')
379
+ output_files.append(output_path)
380
+
381
+ return output_files
382
+
383
+ def _transcribe(self, audio_path: str) -> str:
384
+ """Whisper转录(输入已是44.1kHz,需转为16kHz)"""
385
+ import soundfile as sf
386
+ import numpy as np
387
+ import torch
388
+ import torchaudio
389
+
390
+ # 读取音频(已是44.1kHz单声道)
391
+ audio, sr = sf.read(audio_path, dtype='float32')
392
+
393
+ # Whisper 需要 16kHz
394
+ audio_tensor = torch.from_numpy(audio).float()
395
+ resampler = torchaudio.transforms.Resample(sr, 16000)
396
+ audio_16k = resampler(audio_tensor).numpy()
397
+
398
+ # 处理输入
399
+ input_features = self._whisper_processor(
400
+ audio_16k,
401
+ sampling_rate=16000,
402
+ return_tensors="pt"
403
+ ).input_features.to(self._device)
404
+
405
+ # 设置语言
406
+ forced_decoder_ids = self._whisper_processor.get_decoder_prompt_ids(
407
+ language=self.config.language,
408
+ task="transcribe"
409
+ )
410
+
411
+ # 生成
412
+ with torch.no_grad():
413
+ predicted_ids = self._whisper_model.generate(
414
+ input_features,
415
+ forced_decoder_ids=forced_decoder_ids
416
+ )
417
+
418
+ # 解码
419
+ transcription = self._whisper_processor.batch_decode(
420
+ predicted_ids,
421
+ skip_special_tokens=True
422
+ )[0]
423
+
424
+ return transcription.strip()
425
+
426
+ def _write_lab(self, audio_path: str, text: str):
427
+ """写入.lab文件"""
428
+ lab_path = os.path.splitext(audio_path)[0] + ".lab"
429
+ with open(lab_path, "w", encoding="utf-8") as f:
430
+ f.write(text)
431
+
432
+ # ==================== ���骤1: MFA对齐 ====================
433
+
434
+ def step1_mfa_align(self) -> Tuple[bool, str]:
435
+ """
436
+ 步骤1: MFA语音对齐
437
+
438
+ 输入: bank/[音源名称]/slices/ 下的 .wav 和 .lab 文件
439
+ 输出: bank/[音源名称]/textgrid/ 下的 .TextGrid 文件
440
+
441
+ 注意: 直接使用中文文本,MFA字典为汉字到音素映射
442
+ """
443
+ try:
444
+ os.makedirs(self.config.textgrid_dir, exist_ok=True)
445
+
446
+ # 调用 MFA 对齐(直接使用中文文本,不转拼音)
447
+ from src.mfa_runner import run_mfa_alignment
448
+
449
+ success, message = run_mfa_alignment(
450
+ corpus_dir=self.config.slices_dir,
451
+ output_dir=self.config.textgrid_dir,
452
+ dict_path=self.config.mfa_dict_path,
453
+ model_path=self.config.mfa_model_path,
454
+ single_speaker=self.config.single_speaker,
455
+ clean=self.config.clean_mfa_cache,
456
+ progress_callback=self.progress_callback
457
+ )
458
+
459
+ # 更新元文件(更新TextGrid数量)
460
+ if success:
461
+ self._save_meta()
462
+
463
+ return success, message
464
+
465
+ except Exception as e:
466
+ logger.error(f"MFA对齐失败: {e}", exc_info=True)
467
+ return False, str(e)
468
+
469
+ # ==================== 制作流程(步骤0+1) ====================
470
+
471
+ def run_make_pipeline(self) -> Tuple[bool, str]:
472
+ """运行制作流水线(仅步骤0和步骤1)"""
473
+ self._log("=" * 50)
474
+ self._log(f"开始制作音源: {self.config.source_name}")
475
+ self._log("=" * 50)
476
+
477
+ # 步骤0
478
+ self._log("\n【步骤0】音频预处理 (VAD切片 + Whisper转录)")
479
+ success, msg, _ = self.step0_preprocess()
480
+ if not success:
481
+ return False, f"步骤0失败: {msg}"
482
+
483
+ # 步骤1
484
+ self._log("\n【步骤1】MFA语音对齐")
485
+ success, msg = self.step1_mfa_align()
486
+ if not success:
487
+ return False, f"步骤1失败: {msg}"
488
+
489
+ self._log("\n" + "=" * 50)
490
+ self._log("✅ 音源制作完成!")
491
+ self._log(f"输出目录: {self.config.source_dir}")
492
+ self._log("提示: 请到「导出音源」页面进行导出")
493
+ self._log("=" * 50)
494
+
495
+ return True, "音源制作完成"
496
+
497
+
498
+ # ==================== 模型扫描工具 ====================
499
+
500
+ def scan_mfa_models(models_dir: str) -> Dict[str, List[str]]:
501
+ """
502
+ 扫描MFA模型目录
503
+
504
+ 返回:
505
+ {
506
+ "acoustic": ["mandarin_mfa.zip", ...],
507
+ "dictionary": ["mandarin_china_mfa.dict", ...]
508
+ }
509
+ """
510
+ mfa_dir = os.path.join(models_dir, "mfa")
511
+ result = {"acoustic": [], "dictionary": []}
512
+
513
+ if not os.path.exists(mfa_dir):
514
+ return result
515
+
516
+ for f in os.listdir(mfa_dir):
517
+ if f.endswith('.zip'):
518
+ result["acoustic"].append(f)
519
+ elif f.endswith('.dict') or f.endswith('.txt'):
520
+ result["dictionary"].append(f)
521
+
522
+ return result
523
+
524
+
525
+ def scan_whisper_models(models_dir: str) -> List[str]:
526
+ """
527
+ 扫描已下载的Whisper模型
528
+
529
+ 返回模型名称列表
530
+ """
531
+ whisper_dir = os.path.join(models_dir, "whisper")
532
+ models = []
533
+
534
+ if not os.path.exists(whisper_dir):
535
+ return models
536
+
537
+ # 检查 HuggingFace 缓存目录结构
538
+ for item in os.listdir(whisper_dir):
539
+ if item.startswith("models--"):
540
+ # 格式: models--openai--whisper-small
541
+ parts = item.replace("models--", "").split("--")
542
+ if len(parts) >= 2:
543
+ models.append("/".join(parts))
544
+
545
+ return models
546
+
547
+
548
+ def load_voice_bank_meta(bank_dir: str, source_name: str) -> Optional[VoiceBankMeta]:
549
+ """
550
+ 加载音源元信息
551
+
552
+ 参数:
553
+ bank_dir: bank目录路径
554
+ source_name: 音源名称
555
+
556
+ 返回:
557
+ VoiceBankMeta对象,如果不存在则返回None
558
+ """
559
+ meta_path = os.path.join(bank_dir, source_name, "meta.json")
560
+ return VoiceBankMeta.load(meta_path)
561
+
562
+
563
+ def list_voice_banks_with_meta(bank_dir: str) -> List[Dict]:
564
+ """
565
+ 列出所有音源及其元信息
566
+
567
+ 参数:
568
+ bank_dir: bank目录路径
569
+
570
+ 返回:
571
+ 包含音源信息的字典列表
572
+ """
573
+ result = []
574
+
575
+ if not os.path.exists(bank_dir):
576
+ return result
577
+
578
+ for name in os.listdir(bank_dir):
579
+ source_dir = os.path.join(bank_dir, name)
580
+ if not os.path.isdir(source_dir):
581
+ continue
582
+
583
+ # 检查是否为有效音源目录(包含slices子目录)
584
+ slices_dir = os.path.join(source_dir, "slices")
585
+ if not os.path.exists(slices_dir):
586
+ continue
587
+
588
+ info = {"name": name, "meta": None}
589
+
590
+ # 尝试加载元信息
591
+ meta = load_voice_bank_meta(bank_dir, name)
592
+ if meta:
593
+ info["meta"] = meta.to_dict()
594
+
595
+ result.append(info)
596
+
597
+ return result
src/text_processor.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 文本处理模块
4
+ 将中文文本转换为拼音,供 MFA 对齐使用
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional, Callable, List, Tuple
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ # ==================== 单字转拼音/罗马音 ====================
17
+
18
+ def char_to_pinyin(char: str, language: str = "chinese") -> Optional[str]:
19
+ """
20
+ 将单个字符转换为拼音/罗马音
21
+
22
+ 参数:
23
+ char: 单个字符
24
+ language: 语言 (chinese/japanese)
25
+
26
+ 返回:
27
+ 拼音/罗马音字符串,无法转换返回 None
28
+ """
29
+ if not char or len(char) != 1:
30
+ return None
31
+
32
+ if language in ("chinese", "zh", "mandarin"):
33
+ return _chinese_char_to_pinyin(char)
34
+ elif language in ("japanese", "ja", "jp"):
35
+ return _japanese_char_to_romaji(char)
36
+ else:
37
+ # 英文字母直接返回小写
38
+ if char.isalpha():
39
+ return char.lower()
40
+ return None
41
+
42
+
43
+ def _chinese_char_to_pinyin(char: str) -> Optional[str]:
44
+ """中文单字转拼音"""
45
+ try:
46
+ from pypinyin import pinyin, Style
47
+
48
+ # 数字转中文读法
49
+ digit_map = {
50
+ '0': 'ling', '1': 'yi', '2': 'er', '3': 'san', '4': 'si',
51
+ '5': 'wu', '6': 'liu', '7': 'qi', '8': 'ba', '9': 'jiu',
52
+ '0': 'ling', '1': 'yi', '2': 'er', '3': 'san', '4': 'si',
53
+ '5': 'wu', '6': 'liu', '7': 'qi', '8': 'ba', '9': 'jiu',
54
+ }
55
+ if char in digit_map:
56
+ return digit_map[char]
57
+
58
+ # 英文字母按中文读法
59
+ letter_map = {
60
+ 'a': 'ei', 'b': 'bi', 'c': 'xi', 'd': 'di', 'e': 'yi',
61
+ 'f': 'ai fu', 'g': 'ji', 'h': 'ai qi', 'i': 'ai', 'j': 'jie',
62
+ 'k': 'kai', 'l': 'ai lu', 'm': 'ai mu', 'n': 'en', 'o': 'ou',
63
+ 'p': 'pi', 'q': 'kiu', 'r': 'a', 's': 'ai si', 't': 'ti',
64
+ 'u': 'you', 'v': 'wei', 'w': 'da bu liu', 'x': 'ai ke si',
65
+ 'y': 'wai', 'z': 'zei',
66
+ }
67
+ lower_char = char.lower()
68
+ if lower_char in letter_map:
69
+ # 返回第一个音节
70
+ return letter_map[lower_char].split()[0]
71
+
72
+ # 汉字转拼音
73
+ result = pinyin(char, style=Style.NORMAL, heteronym=False)
74
+ if result and result[0] and result[0][0]:
75
+ return result[0][0].strip()
76
+
77
+ return None
78
+ except ImportError:
79
+ logger.error("pypinyin 未安装")
80
+ return None
81
+
82
+
83
+ def _japanese_char_to_romaji(char: str) -> Optional[str]:
84
+ """日文单字转罗马音"""
85
+ try:
86
+ import pykakasi
87
+
88
+ # 数字转日文读法
89
+ digit_map = {
90
+ '0': 'zero', '1': 'ichi', '2': 'ni', '3': 'san', '4': 'yon',
91
+ '5': 'go', '6': 'roku', '7': 'nana', '8': 'hachi', '9': 'kyuu',
92
+ }
93
+ if char in digit_map:
94
+ return digit_map[char]
95
+
96
+ kks = pykakasi.kakasi()
97
+ result = kks.convert(char)
98
+ if result and result[0]:
99
+ romaji = result[0].get('hepburn', result[0].get('orig', ''))
100
+ return romaji if romaji else None
101
+ return None
102
+ except ImportError:
103
+ logger.error("pykakasi 未安装")
104
+ return None
105
+
106
+
107
+ def is_valid_char(char: str, language: str = "chinese") -> bool:
108
+ """
109
+ 判断字符是否为有效的可转换字符
110
+
111
+ 参数:
112
+ char: 单个字符
113
+ language: 语言
114
+
115
+ 返回:
116
+ 是否有效
117
+ """
118
+ if not char or len(char) != 1:
119
+ return False
120
+
121
+ # 数字有效
122
+ if char.isdigit():
123
+ return True
124
+
125
+ # 英文字母有效
126
+ if char.isalpha() and char.isascii():
127
+ return True
128
+
129
+ if language in ("chinese", "zh", "mandarin"):
130
+ # 中文字符范围
131
+ return '\u4e00' <= char <= '\u9fff' or '\u3400' <= char <= '\u4dbf'
132
+ elif language in ("japanese", "ja", "jp"):
133
+ # 日文假名和汉字
134
+ return (
135
+ '\u3040' <= char <= '\u309f' or # 平假名
136
+ '\u30a0' <= char <= '\u30ff' or # 片假名
137
+ '\u4e00' <= char <= '\u9fff' # 汉字
138
+ )
139
+
140
+ return False
141
+
142
+
143
+ def chinese_to_pinyin(text: str) -> str:
144
+ """
145
+ 将中文文本转换为拼音(空格分隔)
146
+
147
+ 参数:
148
+ text: 中文文本
149
+
150
+ 返回:
151
+ 拼音字符串,空格分隔
152
+ """
153
+ try:
154
+ from pypinyin import pinyin, Style
155
+
156
+ # 获取拼音,不带声调
157
+ result = pinyin(text, style=Style.NORMAL, heteronym=False)
158
+
159
+ # 展平并过滤空值
160
+ pinyins = []
161
+ for item in result:
162
+ if item and item[0]:
163
+ py = item[0].strip()
164
+ if py:
165
+ pinyins.append(py)
166
+
167
+ return ' '.join(pinyins)
168
+ except ImportError:
169
+ logger.error("pypinyin 未安装,请运行: pip install pypinyin")
170
+ raise
171
+
172
+
173
+ def japanese_to_romaji(text: str) -> str:
174
+ """
175
+ 将日文文本转换为罗马字
176
+
177
+ 参数:
178
+ text: 日文文本
179
+
180
+ 返回:
181
+ 罗马字字符串,空格分隔
182
+ """
183
+ try:
184
+ import pykakasi
185
+
186
+ kks = pykakasi.kakasi()
187
+ result = kks.convert(text)
188
+
189
+ romajis = []
190
+ for item in result:
191
+ romaji = item.get('hepburn', item.get('orig', ''))
192
+ if romaji:
193
+ romajis.append(romaji)
194
+
195
+ return ' '.join(romajis)
196
+ except ImportError:
197
+ logger.error("pykakasi 未安装,请运行: pip install pykakasi")
198
+ raise
199
+
200
+
201
+ def process_lab_file(
202
+ lab_path: str,
203
+ language: str = "chinese",
204
+ output_path: Optional[str] = None
205
+ ) -> Tuple[bool, str]:
206
+ """
207
+ 处理单个 .lab 文件,将文本转换为拼音/罗马字
208
+
209
+ 参数:
210
+ lab_path: .lab 文件路径
211
+ language: 语言 (chinese/japanese)
212
+ output_path: 输出路径,默认覆盖原文件
213
+
214
+ 返回:
215
+ (成功标志, 转换后的文本或错误信息)
216
+ """
217
+ try:
218
+ with open(lab_path, 'r', encoding='utf-8') as f:
219
+ text = f.read().strip()
220
+
221
+ if not text:
222
+ return False, "空文件"
223
+
224
+ # 根据语言选择转换函数
225
+ if language in ("chinese", "zh", "mandarin"):
226
+ converted = chinese_to_pinyin(text)
227
+ elif language in ("japanese", "ja", "jp"):
228
+ converted = japanese_to_romaji(text)
229
+ else:
230
+ # 英文或其他语言,保持原样但分词
231
+ converted = ' '.join(text.split())
232
+
233
+ # 写入文件
234
+ output = output_path or lab_path
235
+ with open(output, 'w', encoding='utf-8') as f:
236
+ f.write(converted)
237
+
238
+ return True, converted
239
+
240
+ except Exception as e:
241
+ logger.error(f"处理 {lab_path} 失败: {e}")
242
+ return False, str(e)
243
+
244
+
245
+ def process_lab_directory(
246
+ input_dir: str,
247
+ language: str = "chinese",
248
+ output_dir: Optional[str] = None,
249
+ progress_callback: Optional[Callable[[str], None]] = None
250
+ ) -> Tuple[bool, str, int]:
251
+ """
252
+ 批量处理目录下的所有 .lab 文件
253
+
254
+ 参数:
255
+ input_dir: 输入目录
256
+ language: 语言
257
+ output_dir: 输出目录,默认覆盖原文件
258
+ progress_callback: 进度回调
259
+
260
+ 返回:
261
+ (成功标志, 消息, 处理文件数)
262
+ """
263
+ def log(msg: str):
264
+ logger.info(msg)
265
+ if progress_callback:
266
+ progress_callback(msg)
267
+
268
+ try:
269
+ lab_files = list(Path(input_dir).glob('*.lab'))
270
+
271
+ if not lab_files:
272
+ return False, "未找到 .lab 文件", 0
273
+
274
+ log(f"找到 {len(lab_files)} 个 .lab 文件")
275
+
276
+ if output_dir:
277
+ os.makedirs(output_dir, exist_ok=True)
278
+
279
+ success_count = 0
280
+ for i, lab_path in enumerate(lab_files):
281
+ output_path = None
282
+ if output_dir:
283
+ output_path = os.path.join(output_dir, lab_path.name)
284
+
285
+ success, result = process_lab_file(
286
+ str(lab_path),
287
+ language,
288
+ output_path
289
+ )
290
+
291
+ if success:
292
+ success_count += 1
293
+ log(f"[{i+1}/{len(lab_files)}] {lab_path.name} -> {result[:30]}...")
294
+ else:
295
+ log(f"[{i+1}/{len(lab_files)}] {lab_path.name} 失败: {result}")
296
+
297
+ return True, f"处理完成: {success_count}/{len(lab_files)}", success_count
298
+
299
+ except Exception as e:
300
+ logger.error(f"批量处理失败: {e}", exc_info=True)
301
+ return False, str(e), 0
src/textgrid2bank.py CHANGED
@@ -1,33 +1,112 @@
1
- import textgrid
2
- import glob
3
- import audiofile
4
- import os
5
- import tqdm
6
-
7
- wavDir = r'E:\Workspace\umamusume-voice-text-extractor\extracted'
8
- tgDir = r'E:\SVS\DiffSinger\MakeDiffSinger\temp\revised'
9
 
10
- saveDir = 'bank'
 
 
 
11
 
 
12
 
13
- for path in tqdm.tqdm(glob.glob(tgDir + '/*.TextGrid')):
14
- wavPath = path.replace(tgDir, wavDir).replace('.TextGrid', '.wav')
15
- tg = textgrid.TextGrid.fromFile(path)
16
- audio, sr = audiofile.read(wavPath)
17
- for word in tg[0]:
18
- if word.mark in ['SP', 'AP']:
19
- continue
20
 
21
- wordText = word.mark.split(':')[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- wordDir = saveDir + '/' + wordText
24
- if not os.path.exists(wordDir):
25
- os.mkdir(wordDir)
26
- index = 1
27
- while True:
28
- filename = '{}/{}.wav'.format(wordDir, index)
29
- if not os.path.exists(filename):
30
- break
31
- index += 1
32
 
33
- audiofile.write(filename, audio[int(word.minTime * sr) : int(word.maxTime * sr)], sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ TextGrid 转音频库模块
4
+ TextGrid 对齐结果中提取分词片段
5
+ """
 
 
 
6
 
7
+ import os
8
+ import glob
9
+ import logging
10
+ from typing import Optional, Callable, Dict, Tuple
11
 
12
+ logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
14
 
15
+ def textgrid_to_bank(
16
+ wav_dir: str,
17
+ textgrid_dir: str,
18
+ output_dir: str,
19
+ progress_callback: Optional[Callable[[str], None]] = None
20
+ ) -> Tuple[bool, str, Dict[str, int]]:
21
+ """
22
+ 将 TextGrid 对齐结果转换为音频库
23
+
24
+ 参数:
25
+ wav_dir: WAV文件目录
26
+ textgrid_dir: TextGrid文件目录
27
+ output_dir: 输出目录
28
+ progress_callback: 进度回调函数
29
+
30
+ 返回:
31
+ (成功标志, 消息, 词条统计)
32
+ """
33
+ import textgrid
34
+ import audiofile
35
+
36
+ def log(msg: str):
37
+ logger.info(msg)
38
+ if progress_callback:
39
+ progress_callback(msg)
40
+
41
+ try:
42
+ os.makedirs(output_dir, exist_ok=True)
43
+
44
+ tg_files = glob.glob(os.path.join(textgrid_dir, '*.TextGrid'))
45
+ if not tg_files:
46
+ return False, "未找到TextGrid文件", {}
47
+
48
+ log(f"处理 {len(tg_files)} 个TextGrid文件")
49
+
50
+ word_counts = {}
51
+
52
+ for idx, tg_path in enumerate(tg_files):
53
+ basename = os.path.basename(tg_path).replace('.TextGrid', '.wav')
54
+ wav_path = os.path.join(wav_dir, basename)
55
+
56
+ if not os.path.exists(wav_path):
57
+ log(f"警告: 找不到 {basename}")
58
+ continue
59
+
60
+ tg = textgrid.TextGrid.fromFile(tg_path)
61
+ audio, sr = audiofile.read(wav_path)
62
+
63
+ for word in tg[0]:
64
+ if word.mark in ['SP', 'AP', '']:
65
+ continue
66
+
67
+ word_text = word.mark.split(':')[0]
68
+ word_dir = os.path.join(output_dir, word_text)
69
+ os.makedirs(word_dir, exist_ok=True)
70
+
71
+ # 找到下一个可用编号
72
+ index = 1
73
+ while os.path.exists(os.path.join(word_dir, f'{index}.wav')):
74
+ index += 1
75
+
76
+ # 切出片段并保存
77
+ start_sample = int(word.minTime * sr)
78
+ end_sample = int(word.maxTime * sr)
79
+ segment = audio[start_sample:end_sample]
80
+
81
+ output_path = os.path.join(word_dir, f'{index}.wav')
82
+ audiofile.write(output_path, segment, sr)
83
+
84
+ word_counts[word_text] = word_counts.get(word_text, 0) + 1
85
+
86
+ log(f"进度: {idx+1}/{len(tg_files)} - {basename}")
87
+
88
+ total = sum(word_counts.values())
89
+ log(f"提取完成: {len(word_counts)} 个词条,共 {total} 个片段")
90
+
91
+ return True, f"提取完成: {len(word_counts)} 个词条", word_counts
92
+
93
+ except Exception as e:
94
+ logger.error(f"TextGrid转换失败: {e}", exc_info=True)
95
+ return False, str(e), {}
96
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # 保留原有脚本入口以兼容
99
+ if __name__ == "__main__":
100
+ import tqdm
101
+
102
+ wavDir = r'E:\Workspace\umamusume-voice-text-extractor\extracted'
103
+ tgDir = r'E:\SVS\DiffSinger\MakeDiffSinger\temp\revised'
104
+ saveDir = 'bank'
105
+
106
+ success, msg, stats = textgrid_to_bank(
107
+ wav_dir=wavDir,
108
+ textgrid_dir=tgDir,
109
+ output_dir=saveDir,
110
+ progress_callback=print
111
+ )
112
+ print(f"结果: {msg}")