rvc-cpu-trainer / scripts /train_rvc_v2_fixed.py
ayf3's picture
Upload scripts/train_rvc_v2_fixed.py with huggingface_hub
2e07aa1 verified
#!/usr/bin/env python3
"""
RVC v2 训练脚本 (Fixed) - 使用 torchaudio 替代 librosa
NumberBlocks One 音色克隆
修复内容:
- librosa.load → torchaudio.load (避免 numba 兼容问题)
- librosa.feature.melspectrogram → torchaudio.transforms.MelSpectrogram
- librosa.piptrack → torch-based pitch estimation
- 支持 soundfile / sox_backend 双后端
"""
import os
import sys
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import numpy as np
from pathlib import Path
import json
import logging
import traceback
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 检查 torchaudio backend
logger.info(f"torchaudio version: {torchaudio.__version__}")
logger.info(f"torchaudio backends: {torchaudio.list_audio_backends()}")
class VoiceDataset(Dataset):
"""语音数据集 - 使用 torchaudio 加载"""
def __init__(self, audio_dir, config, max_samples=None):
self.audio_dir = Path(audio_dir)
self.config = config
self.sample_rate = config['data']['sample_rate']
self.target_duration = config['data']['duration']
self.target_samples = int(self.sample_rate * self.target_duration)
# mel 频谱转换器
n_mels = config['model'].get('spec_n_mels', 128)
fmin = config['model'].get('spec_fmin', 0)
fmax = config['model'].get('spec_fmax', self.sample_rate // 2)
self.mel_transform = T.MelSpectrogram(
sample_rate=self.sample_rate,
n_mels=n_mels,
f_min=fmin,
f_max=fmax,
n_fft=1024,
hop_length=256,
)
self.amp_to_db = T.AmplitudeToDB(stype="power", top_db=80)
# 获取音频文件
extensions = ["*.wav", "*.mp3", "*.m4a", "*.flac", "*.ogg"]
audio_files = []
for ext in extensions:
audio_files.extend(self.audio_dir.glob(ext))
if max_samples:
audio_files = audio_files[:max_samples]
self.audio_files = sorted(audio_files)
logger.info(f"加载了 {len(self.audio_files)} 个音频文件")
def __len__(self):
return len(self.audio_files)
def _load_audio(self, audio_file):
"""使用 torchaudio 加载音频,带 fallback"""
# 尝试 soundfile backend
try:
waveform, sr = torchaudio.load(str(audio_file), backend="soundfile")
except Exception:
pass
# 尝试默认 backend
try:
waveform, sr = torchaudio.load(str(audio_file))
except Exception as e:
# 最后尝试 ffmpeg 后端
try:
waveform, sr = torchaudio.load(str(audio_file), backend="ffmpeg")
except Exception:
logger.error(f"无法加载 {audio_file}: {e}")
return None, sr
return waveform, sr
def _load_audio_robust(self, audio_file):
"""鲁棒的音频加载:torchaudio → ffmpeg subprocess → zeros"""
# Method 1: torchaudio 直接加载
try:
waveform, sr = torchaudio.load(str(audio_file))
if waveform.numel() > 0:
return waveform, sr
except Exception:
pass
# Method 2: ffmpeg subprocess 转 WAV 到临时文件再加载
try:
import tempfile
import subprocess as sp
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
sp.run(
["ffmpeg", "-y", "-i", str(audio_file), "-ar", str(self.sample_rate),
"-ac", "1", "-f", "wav", tmp_path],
capture_output=True, timeout=30
)
waveform, sr = torchaudio.load(tmp_path)
os.unlink(tmp_path)
if waveform.numel() > 0:
return waveform, sr
except Exception:
pass
# Method 3: 返回静音
logger.warning(f"所有加载方式失败: {audio_file.name},返回静音")
return torch.zeros(1, self.target_samples), self.sample_rate
def __getitem__(self, idx):
audio_file = self.audio_files[idx]
try:
waveform, sr = self._load_audio_robust(audio_file)
# 单声道
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
elif waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
# 重采样
if sr != self.sample_rate:
resampler = T.Resample(orig_freq=sr, new_freq=self.sample_rate)
waveform = resampler(waveform)
# 裁剪或填充到目标长度
if waveform.shape[1] > self.target_samples:
start = torch.randint(0, waveform.shape[1] - self.target_samples, (1,)).item()
waveform = waveform[:, start:start + self.target_samples]
elif waveform.shape[1] < self.target_samples:
padding = self.target_samples - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (0, padding))
# 提取 mel 频谱
mel_spec = self.mel_transform(waveform)
mel_spec = self.amp_to_db(mel_spec)
# 简单 pitch 特征 (用 energy 作为 proxy)
frame_length = 256
hop_length = 256
energy = waveform.unfold(1, frame_length, hop_length).pow(2).mean(dim=2)
pitch_feat = energy.squeeze(0)
return {
'audio': waveform.squeeze(0),
'mel': mel_spec.squeeze(0),
'pitch': pitch_feat,
'filename': audio_file.name
}
except Exception as e:
logger.error(f"处理 {audio_file.name} 失败: {e}")
traceback.print_exc()
return {
'audio': torch.zeros(self.target_samples),
'mel': torch.zeros(self.config['model'].get('spec_n_mels', 128), 100),
'pitch': torch.zeros(100),
'filename': audio_file.name
}
class SimplifiedRVC(nn.Module):
"""简化版RVC模型"""
def __init__(self, config):
super().__init__()
self.config = config
# 特征提取器
self.feature_extractor = nn.Sequential(
nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3),
nn.ReLU(),
nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3),
nn.ReLU(),
nn.Conv1d(128, 256, kernel_size=7, stride=2, padding=3),
nn.ReLU()
)
# 编码器
self.encoder = nn.Sequential(
nn.Conv1d(256, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
nn.ReLU()
)
# 解码器
self.decoder = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose1d(256, 1, kernel_size=7, stride=8, padding=3, output_padding=1)
)
def forward(self, x):
# x: (batch, time)
x = x.unsqueeze(1) # (batch, 1, time)
# 特征提取
features = self.feature_extractor(x)
# 编码
encoded = self.encoder(features)
# 解码
decoded = self.decoder(encoded)
# 输出 - 裁剪到和输入一致
output = decoded.squeeze(1)
if output.shape[1] > x.shape[1]:
output = output[:, :x.shape[1]]
elif output.shape[1] < x.shape[1]:
output = torch.nn.functional.pad(output, (0, x.shape[1] - output.shape[1]))
return output
def train_model(config):
"""训练模型"""
logger.info("=" * 60)
logger.info("🎤 开始RVC v2训练 (torchaudio版)")
logger.info("=" * 60)
# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"📊 使用设备: {device}")
# 创建数据集
train_dir = config['data']['train_dir']
logger.info(f"📂 加载数据集: {train_dir}")
# 先测试能否加载至少一个音频
test_dir = Path(train_dir)
test_files = list(test_dir.glob("*.wav")) + list(test_dir.glob("*.mp3"))
if test_files:
logger.info(f"🔍 测试音频加载: {test_files[0].name}")
try:
wav, sr = torchaudio.load(str(test_files[0]))
logger.info(f" ✅ 成功! shape={wav.shape}, sr={sr}")
except Exception as e:
logger.warning(f" ⚠️ torchaudio 直接加载失败: {e}")
logger.info(" 尝试 ffmpeg fallback...")
import subprocess as sp
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
sp.run(
["ffmpeg", "-y", "-i", str(test_files[0]), "-ar", "40000",
"-ac", "1", "-f", "wav", tmp_path],
capture_output=True, timeout=30
)
wav, sr = torchaudio.load(tmp_path)
os.unlink(tmp_path)
logger.info(f" ✅ ffmpeg fallback 成功! shape={wav.shape}, sr={sr}")
full_dataset = VoiceDataset(train_dir, config)
if len(full_dataset) == 0:
logger.error("❌ 没有找到任何音频文件!请检查数据目录。")
return None, float('inf')
# 分割训练集和验证集
val_split = config['data'].get('val_split', 0.1)
val_size = int(len(full_dataset) * val_split)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
full_dataset,
[train_size, max(val_size, 1)]
)
logger.info(f" 训练集: {len(train_dataset)} 个样本")
logger.info(f" 验证集: {len(val_dataset)} 个样本")
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=config['training']['batch_size'],
shuffle=True,
num_workers=0,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config['training']['batch_size'],
shuffle=False,
num_workers=0
)
# 创建模型
logger.info(f"🏗️ 创建模型: {config['model']['name']}")
model = SimplifiedRVC(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f" 参数量: {total_params:,}")
# 损失函数
criterion = nn.MSELoss()
# 优化器
optimizer = optim.AdamW(
model.parameters(),
lr=config['training']['learning_rate'],
weight_decay=config['training'].get('weight_decay', 1e-5)
)
# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=config['training'].get('step_size', 100),
gamma=config['training'].get('gamma', 0.5)
)
# 创建输出目录
save_dir = Path(config['output']['save_dir'])
save_dir.mkdir(parents=True, exist_ok=True)
# 训练循环
epochs = config['training']['epochs']
best_val_loss = float('inf')
logger.info(f"🚀 开始训练: {epochs} 个epoch")
logger.info("=" * 60)
for epoch in range(epochs):
# 训练阶段
model.train()
train_loss = 0.0
num_batches = 0
for batch_idx, batch in enumerate(train_loader):
audio = batch['audio'].to(device)
# 前向传播
optimizer.zero_grad()
output = model(audio)
# 确保输出和目标长度一致
min_len = min(output.shape[1], audio.shape[1])
loss = criterion(output[:, :min_len], audio[:, :min_len])
# 反向传播
loss.backward()
optimizer.step()
train_loss += loss.item()
num_batches += 1
if (batch_idx + 1) % 10 == 0:
logger.info(f"Epoch {epoch+1}/{epochs} Batch {batch_idx+1}/{len(train_loader)} loss={loss.item():.6f}")
train_loss /= max(num_batches, 1)
# 验证阶段
val_every = config['training'].get('val_every_n_epochs', 10)
if (epoch + 1) % val_every == 0:
model.eval()
val_loss = 0.0
val_batches = 0
with torch.no_grad():
for batch in val_loader:
audio = batch['audio'].to(device)
output = model(audio)
min_len = min(output.shape[1], audio.shape[1])
loss = criterion(output[:, :min_len], audio[:, :min_len])
val_loss += loss.item()
val_batches += 1
val_loss /= max(val_batches, 1)
logger.info(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
save_path = save_dir / "best_model.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
'config': config,
'model_class': 'SimplifiedRVC',
'torchaudio_version': torchaudio.__version__,
}, save_path)
logger.info(f" ✅ 保存最佳模型: {save_path} (Val Loss = {val_loss:.6f})")
else:
logger.info(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.6f}")
# 更新学习率
scheduler.step()
# 保存最终模型
final_path = save_dir / "final_model.pth"
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'config': config,
'model_class': 'SimplifiedRVC',
'torchaudio_version': torchaudio.__version__,
}, final_path)
logger.info("=" * 60)
logger.info("✅ 训练完成!")
logger.info(f"📊 最佳验证损失: {best_val_loss:.6f}")
logger.info(f"📦 最终模型: {final_path}")
logger.info("=" * 60)
return model, best_val_loss
def main():
"""主函数"""
# 加载配置
config_file = "config_rvc_v2.yaml"
if not Path(config_file).exists():
logger.error(f"配置文件不存在: {config_file}")
sys.exit(1)
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
logger.info(f"📋 加载配置: {config_file}")
# 训练模型
model, best_val_loss = train_model(config)
if model is not None:
logger.info("🎉 训练成功完成!")
else:
logger.error("❌ 训练失败")
sys.exit(1)
if __name__ == "__main__":
main()