Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| 音源制作流水线 | |
| 将所有非GUI的业务逻辑集中管理 | |
| """ | |
| import os | |
| import glob | |
| import json | |
| import logging | |
| import shutil | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional, Callable, List, Tuple, Dict | |
| from dataclasses import dataclass, asdict | |
| logger = logging.getLogger(__name__) | |
| class PipelineConfig: | |
| """流水线配置""" | |
| source_name: str # 音源名称 | |
| input_path: str # 输入音频路径(文件或目录) | |
| output_base_dir: str # 输出基础目录 (bank) | |
| models_dir: str # 模型目录 | |
| # 模型配置 | |
| whisper_model: str = "openai/whisper-small" | |
| mfa_dict_path: Optional[str] = None | |
| mfa_model_path: Optional[str] = None | |
| # 处理参数 | |
| language: str = "chinese" | |
| single_speaker: bool = True | |
| clean_mfa_cache: bool = True | |
| max_samples_per_word: int = 100 | |
| cancel_checker: Optional[Callable[[], bool]] = None | |
| def source_dir(self) -> str: | |
| """音源目录: bank/[音源名称]""" | |
| return os.path.join(self.output_base_dir, self.source_name) | |
| def slices_dir(self) -> str: | |
| """切片目录: bank/[音源名称]/slices""" | |
| return os.path.join(self.source_dir, "slices") | |
| def textgrid_dir(self) -> str: | |
| """TextGrid目录: bank/[音源名称]/textgrid""" | |
| return os.path.join(self.source_dir, "textgrid") | |
| def segments_dir(self) -> str: | |
| """分字片段临时目录(处理完成后可删除)""" | |
| return os.path.join(self.output_base_dir, ".temp_segments", self.source_name) | |
| def export_dir(self) -> str: | |
| """导出目录: export/[音源名称]/simple_export""" | |
| # 导出到项目根目录的 export 文件夹 | |
| base = Path(self.output_base_dir).parent | |
| return os.path.join(base, "export", self.source_name, "simple_export") | |
| def meta_file(self) -> str: | |
| """元文件路径: bank/[音源名称]/meta.json""" | |
| return os.path.join(self.source_dir, "meta.json") | |
| class VoiceBankMeta: | |
| """ | |
| 音源元信息 | |
| 存储制作音源时的设置和模型信息 | |
| """ | |
| # 基本信息 | |
| source_name: str | |
| created_at: str # ISO格式时间戳 | |
| updated_at: str # ISO格式时间戳 | |
| # 模型信息 | |
| whisper_model: str # Whisper模型名称 | |
| mfa_dict: str # MFA字典文件名 | |
| mfa_acoustic: str # MFA声学模型文件名 | |
| # 处理参数 | |
| language: str # 转录语言 | |
| single_speaker: bool # 单说话人模式 | |
| # 统计信息 | |
| slice_count: int = 0 # 切片数量 | |
| textgrid_count: int = 0 # TextGrid文件数量 | |
| def from_config(cls, config: PipelineConfig) -> "VoiceBankMeta": | |
| """从流水线配置创建元信息""" | |
| now = datetime.now().isoformat() | |
| # 提取模型文件名(不含路径) | |
| mfa_dict = "" | |
| if config.mfa_dict_path: | |
| mfa_dict = os.path.basename(config.mfa_dict_path) | |
| mfa_acoustic = "" | |
| if config.mfa_model_path: | |
| mfa_acoustic = os.path.basename(config.mfa_model_path) | |
| return cls( | |
| source_name=config.source_name, | |
| created_at=now, | |
| updated_at=now, | |
| whisper_model=config.whisper_model, | |
| mfa_dict=mfa_dict, | |
| mfa_acoustic=mfa_acoustic, | |
| language=config.language, | |
| single_speaker=config.single_speaker | |
| ) | |
| def to_dict(self) -> Dict: | |
| """转换为字典""" | |
| return asdict(self) | |
| def from_dict(cls, data: Dict) -> "VoiceBankMeta": | |
| """从字典创建""" | |
| return cls(**data) | |
| def save(self, path: str): | |
| """保存到文件""" | |
| with open(path, 'w', encoding='utf-8') as f: | |
| json.dump(self.to_dict(), f, ensure_ascii=False, indent=2) | |
| def load(cls, path: str) -> Optional["VoiceBankMeta"]: | |
| """从文件加载""" | |
| if not os.path.exists(path): | |
| return None | |
| try: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| return cls.from_dict(data) | |
| except Exception as e: | |
| logger.warning(f"加载元文件失败: {e}") | |
| return None | |
| class VoiceBankPipeline: | |
| """音源制作流水线""" | |
| def __init__( | |
| self, | |
| config: PipelineConfig, | |
| progress_callback: Optional[Callable[[str], None]] = None | |
| ): | |
| self.config = config | |
| self.progress_callback = progress_callback | |
| self._vad_model = None | |
| self._get_speech_timestamps = None | |
| self._whisper_model = None | |
| self._whisper_processor = None | |
| self._device = None | |
| def _log(self, msg: str): | |
| """记录日志""" | |
| logger.info(msg) | |
| if self.progress_callback: | |
| self.progress_callback(msg) | |
| def _is_cancelled(self) -> bool: | |
| checker = self.config.cancel_checker | |
| return bool(checker and checker()) | |
| def _check_cancel(self): | |
| if self._is_cancelled(): | |
| raise RuntimeError("任务已取消") | |
| def _ensure_dirs(self): | |
| """确保目录存在""" | |
| os.makedirs(self.config.source_dir, exist_ok=True) | |
| os.makedirs(self.config.slices_dir, exist_ok=True) | |
| def _save_meta(self, slice_count: int = 0): | |
| """ | |
| 保存元文件 | |
| 参数: | |
| slice_count: 切片数量 | |
| """ | |
| # 统计TextGrid文件数量 | |
| textgrid_count = 0 | |
| if os.path.exists(self.config.textgrid_dir): | |
| textgrid_count = len([ | |
| f for f in os.listdir(self.config.textgrid_dir) | |
| if f.endswith('.TextGrid') | |
| ]) | |
| # 检查是否已有元文件(更新而非覆盖) | |
| existing_meta = VoiceBankMeta.load(self.config.meta_file) | |
| if existing_meta: | |
| # 更新现有元文件 | |
| existing_meta.updated_at = datetime.now().isoformat() | |
| existing_meta.whisper_model = self.config.whisper_model | |
| existing_meta.mfa_dict = os.path.basename(self.config.mfa_dict_path) if self.config.mfa_dict_path else "" | |
| existing_meta.mfa_acoustic = os.path.basename(self.config.mfa_model_path) if self.config.mfa_model_path else "" | |
| existing_meta.language = self.config.language | |
| existing_meta.single_speaker = self.config.single_speaker | |
| if slice_count > 0: | |
| existing_meta.slice_count = slice_count | |
| existing_meta.textgrid_count = textgrid_count | |
| meta = existing_meta | |
| else: | |
| # 创建新元文件 | |
| meta = VoiceBankMeta.from_config(self.config) | |
| meta.slice_count = slice_count | |
| meta.textgrid_count = textgrid_count | |
| meta.save(self.config.meta_file) | |
| self._log(f"元文件已保存: {self.config.meta_file}") | |
| # ==================== 模型加载 ==================== | |
| def _load_vad_model(self): | |
| """加载VAD模型""" | |
| if self._vad_model is not None: | |
| return | |
| self._log("正在加载 Silero VAD 模型...") | |
| from src.silero_vad_downloader import ensure_vad_model | |
| from silero_vad import load_silero_vad, get_speech_timestamps | |
| # 确保模型已下载 | |
| model_path = ensure_vad_model(self.config.models_dir, self.progress_callback) | |
| # 使用 silero_vad 包加载本地 ONNX 模型 | |
| self._vad_model = load_silero_vad(onnx=True) | |
| self._get_speech_timestamps = get_speech_timestamps | |
| self._log("VAD 模型加载完成") | |
| def _load_whisper_model(self): | |
| """加载Whisper模型""" | |
| if self._whisper_model is not None: | |
| return | |
| self._log(f"正在加载 Whisper 模型: {self.config.whisper_model}...") | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import torch | |
| cache_dir = os.path.join(self.config.models_dir, "whisper") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| # 云端环境设置 HuggingFace 镜像 | |
| if Path("/home/studio_service").exists() and not os.environ.get("HF_ENDPOINT"): | |
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" | |
| # 检查模型是否已缓存,决定是否使用离线模式 | |
| model_cache_name = self.config.whisper_model.replace("/", "--") | |
| model_cache_path = Path(cache_dir) / f"models--{model_cache_name}" | |
| local_files_only = model_cache_path.exists() | |
| if local_files_only: | |
| self._log("使用本地缓存模型(离线模式)") | |
| self._whisper_processor = WhisperProcessor.from_pretrained( | |
| self.config.whisper_model, | |
| cache_dir=cache_dir, | |
| local_files_only=local_files_only | |
| ) | |
| self._whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
| self.config.whisper_model, | |
| cache_dir=cache_dir, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| local_files_only=local_files_only | |
| ) | |
| # 移动到GPU(如果可用) | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._whisper_model.to(self._device) | |
| self._log("Whisper 模型加载完成") | |
| # ==================== 步骤0: VAD切片 + 转录 ==================== | |
| def step0_preprocess(self) -> Tuple[bool, str, List[str]]: | |
| """ | |
| 步骤0: VAD切片 + Whisper转录 | |
| 输入: 原始音频文件 | |
| 输出: bank/[音源名称]/slices/ 下的 .wav 和 .lab 文件 | |
| """ | |
| try: | |
| self._check_cancel() | |
| self._ensure_dirs() | |
| self._load_vad_model() | |
| self._load_whisper_model() | |
| # 收集输入文件 | |
| input_files = self._collect_audio_files(self.config.input_path) | |
| if not input_files: | |
| return False, "未找到音频文件", [] | |
| self._log(f"找到 {len(input_files)} 个音频文件") | |
| all_slices = [] | |
| for idx, audio_file in enumerate(input_files): | |
| self._check_cancel() | |
| basename = Path(audio_file).stem | |
| self._log(f"处理 [{idx+1}/{len(input_files)}]: {basename}") | |
| # VAD切片 | |
| slices = self._vad_split(audio_file, self.config.slices_dir, basename) | |
| # 转录每个切片 | |
| for slice_path in slices: | |
| self._check_cancel() | |
| text = self._transcribe(slice_path) | |
| if text: | |
| self._write_lab(slice_path, text) | |
| all_slices.append(slice_path) | |
| self._log(f" {Path(slice_path).name} -> {text[:30]}...") | |
| else: | |
| self._log(f" 跳过空转录: {Path(slice_path).name}") | |
| # 保存元文件 | |
| self._save_meta(slice_count=len(all_slices)) | |
| return True, f"预处理完成,共 {len(all_slices)} 个切片", all_slices | |
| except Exception as e: | |
| logger.error(f"预处理失败: {e}", exc_info=True) | |
| return False, str(e), [] | |
| def _collect_audio_files(self, path: str) -> List[str]: | |
| """收集音频文件""" | |
| extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a') | |
| if os.path.isfile(path): | |
| return [path] if path.lower().endswith(extensions) else [] | |
| elif os.path.isdir(path): | |
| return [ | |
| os.path.join(path, f) for f in os.listdir(path) | |
| if f.lower().endswith(extensions) | |
| ] | |
| return [] | |
| def _load_audio_ffmpeg(self, audio_path: str) -> tuple: | |
| """ | |
| 使用 ffmpeg 读取音频文件 | |
| 返回: | |
| (audio, sr): 单声道 float32 numpy 数组和采样率 | |
| """ | |
| import subprocess | |
| import numpy as np | |
| # 使用 ffprobe 获取采样率 | |
| probe_cmd = [ | |
| 'ffprobe', '-v', 'quiet', '-print_format', 'json', | |
| '-show_streams', audio_path | |
| ] | |
| probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) | |
| sr = 44100 # 默认采样率 | |
| if probe_result.returncode == 0: | |
| import json | |
| try: | |
| info = json.loads(probe_result.stdout) | |
| for stream in info.get('streams', []): | |
| if stream.get('codec_type') == 'audio': | |
| sr = int(stream.get('sample_rate', 44100)) | |
| break | |
| except (json.JSONDecodeError, KeyError): | |
| pass | |
| # 使用 ffmpeg 转换为单声道 PCM | |
| cmd = [ | |
| 'ffmpeg', '-i', audio_path, | |
| '-f', 's16le', # 16-bit signed little-endian | |
| '-acodec', 'pcm_s16le', | |
| '-ac', '1', # 单声道 | |
| '-ar', str(sr), # 保持原采样率 | |
| '-v', 'quiet', | |
| '-' | |
| ] | |
| result = subprocess.run(cmd, capture_output=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg 读取音频失败: {audio_path}") | |
| # 转换为 numpy 数组 | |
| audio = np.frombuffer(result.stdout, dtype=np.int16).astype(np.float32) / 32768.0 | |
| return audio, sr | |
| def _vad_split(self, audio_path: str, output_dir: str, prefix: str) -> List[str]: | |
| """ | |
| VAD切片 | |
| 输出格式统一为: 16bit 44.1kHz 单声道 WAV | |
| 支持格式: wav, mp3, flac, ogg, m4a 等 (通过 ffmpeg) | |
| """ | |
| import torch | |
| import torchaudio | |
| import soundfile as sf | |
| import numpy as np | |
| # 标准输出格式 | |
| TARGET_SR = 44100 | |
| # 使用 ffmpeg 读取音频(支持更多格式,包括 m4a) | |
| # 返回的 audio 已经是单声道 float32 numpy 数组 | |
| audio, sr = self._load_audio_ffmpeg(audio_path) | |
| # 重采样到 44.1kHz(标准格式) | |
| if sr != TARGET_SR: | |
| resampler = torchaudio.transforms.Resample(sr, TARGET_SR) | |
| audio = resampler(torch.from_numpy(audio).float()).numpy() | |
| # VAD 需要 16kHz,单独重采样用于检测 | |
| audio_tensor = torch.from_numpy(audio).float() | |
| resampler_16k = torchaudio.transforms.Resample(TARGET_SR, 16000) | |
| wav_16k = resampler_16k(audio_tensor) | |
| # 获取语音时间戳(基于16kHz) | |
| timestamps = self._get_speech_timestamps( | |
| wav_16k, self._vad_model, | |
| threshold=0.5, | |
| min_speech_duration_ms=250, | |
| min_silence_duration_ms=100, | |
| sampling_rate=16000 | |
| ) | |
| self._log(f" 检测到 {len(timestamps)} 个语音片段") | |
| output_files = [] | |
| for i, ts in enumerate(timestamps): | |
| self._check_cancel() | |
| # 将16kHz的时间戳转换为44.1kHz的采样点索引 | |
| start = int(ts['start'] * TARGET_SR / 16000) | |
| end = int(ts['end'] * TARGET_SR / 16000) | |
| segment = audio[start:end] | |
| output_path = os.path.join(output_dir, f"{prefix}_{i:04d}.wav") | |
| # 写入 16bit 44.1kHz 单声道 WAV | |
| sf.write(output_path, segment, TARGET_SR, subtype='PCM_16') | |
| output_files.append(output_path) | |
| return output_files | |
| def _transcribe(self, audio_path: str) -> str: | |
| """Whisper转录(输入已是44.1kHz,需转为16kHz)""" | |
| self._check_cancel() | |
| import soundfile as sf | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| # 读取音频(已是44.1kHz单声道) | |
| audio, sr = sf.read(audio_path, dtype='float32') | |
| # Whisper 需要 16kHz | |
| audio_tensor = torch.from_numpy(audio).float() | |
| resampler = torchaudio.transforms.Resample(sr, 16000) | |
| audio_16k = resampler(audio_tensor).numpy() | |
| # 处理输入 | |
| input_features = self._whisper_processor( | |
| audio_16k, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features.to(self._device) | |
| # 设置语言 | |
| forced_decoder_ids = self._whisper_processor.get_decoder_prompt_ids( | |
| language=self.config.language, | |
| task="transcribe" | |
| ) | |
| # 生成 | |
| with torch.no_grad(): | |
| predicted_ids = self._whisper_model.generate( | |
| input_features, | |
| forced_decoder_ids=forced_decoder_ids | |
| ) | |
| # 解码 | |
| transcription = self._whisper_processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription.strip() | |
| def _write_lab(self, audio_path: str, text: str): | |
| """写入.lab文件""" | |
| lab_path = os.path.splitext(audio_path)[0] + ".lab" | |
| with open(lab_path, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| # ==================== 步骤1: MFA对齐 ==================== | |
| def step1_mfa_align(self) -> Tuple[bool, str]: | |
| """ | |
| 步骤1: MFA语音对齐 | |
| 输入: bank/[音源名称]/slices/ 下的 .wav 和 .lab 文件 | |
| 输出: bank/[音源名称]/textgrid/ 下的 .TextGrid 文件 | |
| 注意: 直接使用中文文本,MFA字典为汉字到音素映射 | |
| """ | |
| try: | |
| self._check_cancel() | |
| os.makedirs(self.config.textgrid_dir, exist_ok=True) | |
| # 调用 MFA 对齐(直接使用中文文本,不转拼音) | |
| from src.mfa_runner import run_mfa_alignment | |
| success, message = run_mfa_alignment( | |
| corpus_dir=self.config.slices_dir, | |
| output_dir=self.config.textgrid_dir, | |
| dict_path=self.config.mfa_dict_path, | |
| model_path=self.config.mfa_model_path, | |
| single_speaker=self.config.single_speaker, | |
| clean=self.config.clean_mfa_cache, | |
| progress_callback=self.progress_callback, | |
| cancel_checker=self.config.cancel_checker | |
| ) | |
| # 更新元文件(更新TextGrid数量) | |
| if success: | |
| self._save_meta() | |
| return success, message | |
| except Exception as e: | |
| logger.error(f"MFA对齐失败: {e}", exc_info=True) | |
| return False, str(e) | |
| # ==================== 制作流程(步骤0+1) ==================== | |
| def run_make_pipeline(self) -> Tuple[bool, str]: | |
| """运行制作流水线(仅步骤0和步骤1)""" | |
| self._log("=" * 50) | |
| self._log(f"开始制作音源: {self.config.source_name}") | |
| self._log("=" * 50) | |
| # 步骤0 | |
| self._log("\n【步骤0】音频预处理 (VAD切片 + Whisper转录)") | |
| success, msg, _ = self.step0_preprocess() | |
| if not success: | |
| return False, f"步骤0失败: {msg}" | |
| # 步骤1 | |
| self._log("\n【步骤1】MFA语音对齐") | |
| success, msg = self.step1_mfa_align() | |
| if not success: | |
| return False, f"步骤1失败: {msg}" | |
| self._log("\n" + "=" * 50) | |
| self._log("✅ 音源制作完成!") | |
| self._log(f"输出目录: {self.config.source_dir}") | |
| self._log("提示: 请到「导出音源」页面进行导出") | |
| self._log("=" * 50) | |
| return True, "音源制作完成" | |
| # ==================== 模型扫描工具 ==================== | |
| def scan_mfa_models(models_dir: str) -> Dict[str, List[str]]: | |
| """ | |
| 扫描MFA模型目录 | |
| 返回: | |
| { | |
| "acoustic": ["mandarin_mfa.zip", ...], | |
| "dictionary": ["mandarin_china_mfa.dict", ...] | |
| } | |
| """ | |
| mfa_dir = os.path.join(models_dir, "mfa") | |
| result = {"acoustic": [], "dictionary": []} | |
| if not os.path.exists(mfa_dir): | |
| return result | |
| for f in os.listdir(mfa_dir): | |
| if f.endswith('.zip'): | |
| result["acoustic"].append(f) | |
| elif f.endswith('.dict') or f.endswith('.txt'): | |
| result["dictionary"].append(f) | |
| return result | |
| def scan_whisper_models(models_dir: str) -> List[str]: | |
| """ | |
| 扫描已下载的Whisper模型 | |
| 返回模型名称列表 | |
| """ | |
| whisper_dir = os.path.join(models_dir, "whisper") | |
| models = [] | |
| if not os.path.exists(whisper_dir): | |
| return models | |
| # 检查 HuggingFace 缓存目录结构 | |
| for item in os.listdir(whisper_dir): | |
| if item.startswith("models--"): | |
| # 格式: models--openai--whisper-small | |
| parts = item.replace("models--", "").split("--") | |
| if len(parts) >= 2: | |
| models.append("/".join(parts)) | |
| return models | |
| def load_voice_bank_meta(bank_dir: str, source_name: str) -> Optional[VoiceBankMeta]: | |
| """ | |
| 加载音源元信息 | |
| 参数: | |
| bank_dir: bank目录路径 | |
| source_name: 音源名称 | |
| 返回: | |
| VoiceBankMeta对象,如果不存在则返回None | |
| """ | |
| meta_path = os.path.join(bank_dir, source_name, "meta.json") | |
| return VoiceBankMeta.load(meta_path) | |
| def list_voice_banks_with_meta(bank_dir: str) -> List[Dict]: | |
| """ | |
| 列出所有音源及其元信息 | |
| 参数: | |
| bank_dir: bank目录路径 | |
| 返回: | |
| 包含音源信息的字典列表 | |
| """ | |
| result = [] | |
| if not os.path.exists(bank_dir): | |
| return result | |
| for name in os.listdir(bank_dir): | |
| source_dir = os.path.join(bank_dir, name) | |
| if not os.path.isdir(source_dir): | |
| continue | |
| # 检查是否为有效音源目录(包含slices子目录) | |
| slices_dir = os.path.join(source_dir, "slices") | |
| if not os.path.exists(slices_dir): | |
| continue | |
| info = {"name": name, "meta": None} | |
| # 尝试加载元信息 | |
| meta = load_voice_bank_meta(bank_dir, name) | |
| if meta: | |
| info["meta"] = meta.to_dict() | |
| result.append(info) | |
| return result | |