Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| RVC v2 训练脚本 (Fixed) - 使用 torchaudio 替代 librosa | |
| NumberBlocks One 音色克隆 | |
| 修复内容: | |
| - librosa.load → torchaudio.load (避免 numba 兼容问题) | |
| - librosa.feature.melspectrogram → torchaudio.transforms.MelSpectrogram | |
| - librosa.piptrack → torch-based pitch estimation | |
| - 支持 soundfile / sox_backend 双后端 | |
| """ | |
| import os | |
| import sys | |
| import yaml | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| import logging | |
| import traceback | |
| # 配置日志 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # 检查 torchaudio backend | |
| logger.info(f"torchaudio version: {torchaudio.__version__}") | |
| logger.info(f"torchaudio backends: {torchaudio.list_audio_backends()}") | |
| class VoiceDataset(Dataset): | |
| """语音数据集 - 使用 torchaudio 加载""" | |
| def __init__(self, audio_dir, config, max_samples=None): | |
| self.audio_dir = Path(audio_dir) | |
| self.config = config | |
| self.sample_rate = config['data']['sample_rate'] | |
| self.target_duration = config['data']['duration'] | |
| self.target_samples = int(self.sample_rate * self.target_duration) | |
| # mel 频谱转换器 | |
| n_mels = config['model'].get('spec_n_mels', 128) | |
| fmin = config['model'].get('spec_fmin', 0) | |
| fmax = config['model'].get('spec_fmax', self.sample_rate // 2) | |
| self.mel_transform = T.MelSpectrogram( | |
| sample_rate=self.sample_rate, | |
| n_mels=n_mels, | |
| f_min=fmin, | |
| f_max=fmax, | |
| n_fft=1024, | |
| hop_length=256, | |
| ) | |
| self.amp_to_db = T.AmplitudeToDB(stype="power", top_db=80) | |
| # 获取音频文件 | |
| extensions = ["*.wav", "*.mp3", "*.m4a", "*.flac", "*.ogg"] | |
| audio_files = [] | |
| for ext in extensions: | |
| audio_files.extend(self.audio_dir.glob(ext)) | |
| if max_samples: | |
| audio_files = audio_files[:max_samples] | |
| self.audio_files = sorted(audio_files) | |
| logger.info(f"加载了 {len(self.audio_files)} 个音频文件") | |
| def __len__(self): | |
| return len(self.audio_files) | |
| def _load_audio(self, audio_file): | |
| """使用 torchaudio 加载音频,带 fallback""" | |
| # 尝试 soundfile backend | |
| try: | |
| waveform, sr = torchaudio.load(str(audio_file), backend="soundfile") | |
| except Exception: | |
| pass | |
| # 尝试默认 backend | |
| try: | |
| waveform, sr = torchaudio.load(str(audio_file)) | |
| except Exception as e: | |
| # 最后尝试 ffmpeg 后端 | |
| try: | |
| waveform, sr = torchaudio.load(str(audio_file), backend="ffmpeg") | |
| except Exception: | |
| logger.error(f"无法加载 {audio_file}: {e}") | |
| return None, sr | |
| return waveform, sr | |
| def _load_audio_robust(self, audio_file): | |
| """鲁棒的音频加载:torchaudio → ffmpeg subprocess → zeros""" | |
| # Method 1: torchaudio 直接加载 | |
| try: | |
| waveform, sr = torchaudio.load(str(audio_file)) | |
| if waveform.numel() > 0: | |
| return waveform, sr | |
| except Exception: | |
| pass | |
| # Method 2: ffmpeg subprocess 转 WAV 到临时文件再加载 | |
| try: | |
| import tempfile | |
| import subprocess as sp | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp_path = tmp.name | |
| sp.run( | |
| ["ffmpeg", "-y", "-i", str(audio_file), "-ar", str(self.sample_rate), | |
| "-ac", "1", "-f", "wav", tmp_path], | |
| capture_output=True, timeout=30 | |
| ) | |
| waveform, sr = torchaudio.load(tmp_path) | |
| os.unlink(tmp_path) | |
| if waveform.numel() > 0: | |
| return waveform, sr | |
| except Exception: | |
| pass | |
| # Method 3: 返回静音 | |
| logger.warning(f"所有加载方式失败: {audio_file.name},返回静音") | |
| return torch.zeros(1, self.target_samples), self.sample_rate | |
| def __getitem__(self, idx): | |
| audio_file = self.audio_files[idx] | |
| try: | |
| waveform, sr = self._load_audio_robust(audio_file) | |
| # 单声道 | |
| if waveform.dim() > 1 and waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| elif waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) | |
| # 重采样 | |
| if sr != self.sample_rate: | |
| resampler = T.Resample(orig_freq=sr, new_freq=self.sample_rate) | |
| waveform = resampler(waveform) | |
| # 裁剪或填充到目标长度 | |
| if waveform.shape[1] > self.target_samples: | |
| start = torch.randint(0, waveform.shape[1] - self.target_samples, (1,)).item() | |
| waveform = waveform[:, start:start + self.target_samples] | |
| elif waveform.shape[1] < self.target_samples: | |
| padding = self.target_samples - waveform.shape[1] | |
| waveform = torch.nn.functional.pad(waveform, (0, padding)) | |
| # 提取 mel 频谱 | |
| mel_spec = self.mel_transform(waveform) | |
| mel_spec = self.amp_to_db(mel_spec) | |
| # 简单 pitch 特征 (用 energy 作为 proxy) | |
| frame_length = 256 | |
| hop_length = 256 | |
| energy = waveform.unfold(1, frame_length, hop_length).pow(2).mean(dim=2) | |
| pitch_feat = energy.squeeze(0) | |
| return { | |
| 'audio': waveform.squeeze(0), | |
| 'mel': mel_spec.squeeze(0), | |
| 'pitch': pitch_feat, | |
| 'filename': audio_file.name | |
| } | |
| except Exception as e: | |
| logger.error(f"处理 {audio_file.name} 失败: {e}") | |
| traceback.print_exc() | |
| return { | |
| 'audio': torch.zeros(self.target_samples), | |
| 'mel': torch.zeros(self.config['model'].get('spec_n_mels', 128), 100), | |
| 'pitch': torch.zeros(100), | |
| 'filename': audio_file.name | |
| } | |
| class SimplifiedRVC(nn.Module): | |
| """简化版RVC模型""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| # 特征提取器 | |
| self.feature_extractor = nn.Sequential( | |
| nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3), | |
| nn.ReLU(), | |
| nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3), | |
| nn.ReLU(), | |
| nn.Conv1d(128, 256, kernel_size=7, stride=2, padding=3), | |
| nn.ReLU() | |
| ) | |
| # 编码器 | |
| self.encoder = nn.Sequential( | |
| nn.Conv1d(256, 128, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv1d(128, 64, kernel_size=3, padding=1), | |
| nn.ReLU() | |
| ) | |
| # 解码器 | |
| self.decoder = nn.Sequential( | |
| nn.Conv1d(64, 128, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv1d(128, 256, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.ConvTranspose1d(256, 1, kernel_size=7, stride=8, padding=3, output_padding=1) | |
| ) | |
| def forward(self, x): | |
| # x: (batch, time) | |
| x = x.unsqueeze(1) # (batch, 1, time) | |
| # 特征提取 | |
| features = self.feature_extractor(x) | |
| # 编码 | |
| encoded = self.encoder(features) | |
| # 解码 | |
| decoded = self.decoder(encoded) | |
| # 输出 - 裁剪到和输入一致 | |
| output = decoded.squeeze(1) | |
| if output.shape[1] > x.shape[1]: | |
| output = output[:, :x.shape[1]] | |
| elif output.shape[1] < x.shape[1]: | |
| output = torch.nn.functional.pad(output, (0, x.shape[1] - output.shape[1])) | |
| return output | |
| def train_model(config): | |
| """训练模型""" | |
| logger.info("=" * 60) | |
| logger.info("🎤 开始RVC v2训练 (torchaudio版)") | |
| logger.info("=" * 60) | |
| # 设备 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"📊 使用设备: {device}") | |
| # 创建数据集 | |
| train_dir = config['data']['train_dir'] | |
| logger.info(f"📂 加载数据集: {train_dir}") | |
| # 先测试能否加载至少一个音频 | |
| test_dir = Path(train_dir) | |
| test_files = list(test_dir.glob("*.wav")) + list(test_dir.glob("*.mp3")) | |
| if test_files: | |
| logger.info(f"🔍 测试音频加载: {test_files[0].name}") | |
| try: | |
| wav, sr = torchaudio.load(str(test_files[0])) | |
| logger.info(f" ✅ 成功! shape={wav.shape}, sr={sr}") | |
| except Exception as e: | |
| logger.warning(f" ⚠️ torchaudio 直接加载失败: {e}") | |
| logger.info(" 尝试 ffmpeg fallback...") | |
| import subprocess as sp | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp_path = tmp.name | |
| sp.run( | |
| ["ffmpeg", "-y", "-i", str(test_files[0]), "-ar", "40000", | |
| "-ac", "1", "-f", "wav", tmp_path], | |
| capture_output=True, timeout=30 | |
| ) | |
| wav, sr = torchaudio.load(tmp_path) | |
| os.unlink(tmp_path) | |
| logger.info(f" ✅ ffmpeg fallback 成功! shape={wav.shape}, sr={sr}") | |
| full_dataset = VoiceDataset(train_dir, config) | |
| if len(full_dataset) == 0: | |
| logger.error("❌ 没有找到任何音频文件!请检查数据目录。") | |
| return None, float('inf') | |
| # 分割训练集和验证集 | |
| val_split = config['data'].get('val_split', 0.1) | |
| val_size = int(len(full_dataset) * val_split) | |
| train_size = len(full_dataset) - val_size | |
| train_dataset, val_dataset = torch.utils.data.random_split( | |
| full_dataset, | |
| [train_size, max(val_size, 1)] | |
| ) | |
| logger.info(f" 训练集: {len(train_dataset)} 个样本") | |
| logger.info(f" 验证集: {len(val_dataset)} 个样本") | |
| # 创建数据加载器 | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config['training']['batch_size'], | |
| shuffle=True, | |
| num_workers=0, | |
| drop_last=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config['training']['batch_size'], | |
| shuffle=False, | |
| num_workers=0 | |
| ) | |
| # 创建模型 | |
| logger.info(f"🏗️ 创建模型: {config['model']['name']}") | |
| model = SimplifiedRVC(config).to(device) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| logger.info(f" 参数量: {total_params:,}") | |
| # 损失函数 | |
| criterion = nn.MSELoss() | |
| # 优化器 | |
| optimizer = optim.AdamW( | |
| model.parameters(), | |
| lr=config['training']['learning_rate'], | |
| weight_decay=config['training'].get('weight_decay', 1e-5) | |
| ) | |
| # 学习率调度器 | |
| scheduler = optim.lr_scheduler.StepLR( | |
| optimizer, | |
| step_size=config['training'].get('step_size', 100), | |
| gamma=config['training'].get('gamma', 0.5) | |
| ) | |
| # 创建输出目录 | |
| save_dir = Path(config['output']['save_dir']) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| # 训练循环 | |
| epochs = config['training']['epochs'] | |
| best_val_loss = float('inf') | |
| logger.info(f"🚀 开始训练: {epochs} 个epoch") | |
| logger.info("=" * 60) | |
| for epoch in range(epochs): | |
| # 训练阶段 | |
| model.train() | |
| train_loss = 0.0 | |
| num_batches = 0 | |
| for batch_idx, batch in enumerate(train_loader): | |
| audio = batch['audio'].to(device) | |
| # 前向传播 | |
| optimizer.zero_grad() | |
| output = model(audio) | |
| # 确保输出和目标长度一致 | |
| min_len = min(output.shape[1], audio.shape[1]) | |
| loss = criterion(output[:, :min_len], audio[:, :min_len]) | |
| # 反向传播 | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| num_batches += 1 | |
| if (batch_idx + 1) % 10 == 0: | |
| logger.info(f"Epoch {epoch+1}/{epochs} Batch {batch_idx+1}/{len(train_loader)} loss={loss.item():.6f}") | |
| train_loss /= max(num_batches, 1) | |
| # 验证阶段 | |
| val_every = config['training'].get('val_every_n_epochs', 10) | |
| if (epoch + 1) % val_every == 0: | |
| model.eval() | |
| val_loss = 0.0 | |
| val_batches = 0 | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| audio = batch['audio'].to(device) | |
| output = model(audio) | |
| min_len = min(output.shape[1], audio.shape[1]) | |
| loss = criterion(output[:, :min_len], audio[:, :min_len]) | |
| val_loss += loss.item() | |
| val_batches += 1 | |
| val_loss /= max(val_batches, 1) | |
| logger.info(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}") | |
| # 保存最佳模型 | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| save_path = save_dir / "best_model.pth" | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_loss': val_loss, | |
| 'config': config, | |
| 'model_class': 'SimplifiedRVC', | |
| 'torchaudio_version': torchaudio.__version__, | |
| }, save_path) | |
| logger.info(f" ✅ 保存最佳模型: {save_path} (Val Loss = {val_loss:.6f})") | |
| else: | |
| logger.info(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.6f}") | |
| # 更新学习率 | |
| scheduler.step() | |
| # 保存最终模型 | |
| final_path = save_dir / "final_model.pth" | |
| torch.save({ | |
| 'epoch': epochs, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'train_loss': train_loss, | |
| 'config': config, | |
| 'model_class': 'SimplifiedRVC', | |
| 'torchaudio_version': torchaudio.__version__, | |
| }, final_path) | |
| logger.info("=" * 60) | |
| logger.info("✅ 训练完成!") | |
| logger.info(f"📊 最佳验证损失: {best_val_loss:.6f}") | |
| logger.info(f"📦 最终模型: {final_path}") | |
| logger.info("=" * 60) | |
| return model, best_val_loss | |
| def main(): | |
| """主函数""" | |
| # 加载配置 | |
| config_file = "config_rvc_v2.yaml" | |
| if not Path(config_file).exists(): | |
| logger.error(f"配置文件不存在: {config_file}") | |
| sys.exit(1) | |
| with open(config_file, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| logger.info(f"📋 加载配置: {config_file}") | |
| # 训练模型 | |
| model, best_val_loss = train_model(config) | |
| if model is not None: | |
| logger.info("🎉 训练成功完成!") | |
| else: | |
| logger.error("❌ 训练失败") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |