ayf3 commited on
Commit
2e07aa1
·
verified ·
1 Parent(s): dc67603

Upload scripts/train_rvc_v2_fixed.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_rvc_v2_fixed.py +463 -0
scripts/train_rvc_v2_fixed.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RVC v2 训练脚本 (Fixed) - 使用 torchaudio 替代 librosa
4
+ NumberBlocks One 音色克隆
5
+
6
+ 修复内容:
7
+ - librosa.load → torchaudio.load (避免 numba 兼容问题)
8
+ - librosa.feature.melspectrogram → torchaudio.transforms.MelSpectrogram
9
+ - librosa.piptrack → torch-based pitch estimation
10
+ - 支持 soundfile / sox_backend 双后端
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import yaml
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.optim as optim
19
+ from torch.utils.data import Dataset, DataLoader
20
+ import torchaudio
21
+ import torchaudio.transforms as T
22
+ import numpy as np
23
+ from pathlib import Path
24
+ import json
25
+ import logging
26
+ import traceback
27
+
28
+ # 配置日志
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format='%(asctime)s - %(levelname)s - %(message)s'
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # 检查 torchaudio backend
36
+ logger.info(f"torchaudio version: {torchaudio.__version__}")
37
+ logger.info(f"torchaudio backends: {torchaudio.list_audio_backends()}")
38
+
39
+ class VoiceDataset(Dataset):
40
+ """语音数据集 - 使用 torchaudio 加载"""
41
+
42
+ def __init__(self, audio_dir, config, max_samples=None):
43
+ self.audio_dir = Path(audio_dir)
44
+ self.config = config
45
+ self.sample_rate = config['data']['sample_rate']
46
+ self.target_duration = config['data']['duration']
47
+ self.target_samples = int(self.sample_rate * self.target_duration)
48
+
49
+ # mel 频谱转换器
50
+ n_mels = config['model'].get('spec_n_mels', 128)
51
+ fmin = config['model'].get('spec_fmin', 0)
52
+ fmax = config['model'].get('spec_fmax', self.sample_rate // 2)
53
+ self.mel_transform = T.MelSpectrogram(
54
+ sample_rate=self.sample_rate,
55
+ n_mels=n_mels,
56
+ f_min=fmin,
57
+ f_max=fmax,
58
+ n_fft=1024,
59
+ hop_length=256,
60
+ )
61
+ self.amp_to_db = T.AmplitudeToDB(stype="power", top_db=80)
62
+
63
+ # 获取音频文件
64
+ extensions = ["*.wav", "*.mp3", "*.m4a", "*.flac", "*.ogg"]
65
+ audio_files = []
66
+ for ext in extensions:
67
+ audio_files.extend(self.audio_dir.glob(ext))
68
+
69
+ if max_samples:
70
+ audio_files = audio_files[:max_samples]
71
+
72
+ self.audio_files = sorted(audio_files)
73
+ logger.info(f"加载了 {len(self.audio_files)} 个音频文件")
74
+
75
+ def __len__(self):
76
+ return len(self.audio_files)
77
+
78
+ def _load_audio(self, audio_file):
79
+ """使用 torchaudio 加载音频,带 fallback"""
80
+ # 尝试 soundfile backend
81
+ try:
82
+ waveform, sr = torchaudio.load(str(audio_file), backend="soundfile")
83
+ except Exception:
84
+ pass
85
+
86
+ # 尝试默认 backend
87
+ try:
88
+ waveform, sr = torchaudio.load(str(audio_file))
89
+ except Exception as e:
90
+ # 最后尝试 ffmpeg 后端
91
+ try:
92
+ waveform, sr = torchaudio.load(str(audio_file), backend="ffmpeg")
93
+ except Exception:
94
+ logger.error(f"无法加载 {audio_file}: {e}")
95
+ return None, sr
96
+
97
+ return waveform, sr
98
+
99
+ def _load_audio_robust(self, audio_file):
100
+ """鲁棒的音频加载:torchaudio → ffmpeg subprocess → zeros"""
101
+ # Method 1: torchaudio 直接加载
102
+ try:
103
+ waveform, sr = torchaudio.load(str(audio_file))
104
+ if waveform.numel() > 0:
105
+ return waveform, sr
106
+ except Exception:
107
+ pass
108
+
109
+ # Method 2: ffmpeg subprocess 转 WAV 到临时文件再加载
110
+ try:
111
+ import tempfile
112
+ import subprocess as sp
113
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
114
+ tmp_path = tmp.name
115
+ sp.run(
116
+ ["ffmpeg", "-y", "-i", str(audio_file), "-ar", str(self.sample_rate),
117
+ "-ac", "1", "-f", "wav", tmp_path],
118
+ capture_output=True, timeout=30
119
+ )
120
+ waveform, sr = torchaudio.load(tmp_path)
121
+ os.unlink(tmp_path)
122
+ if waveform.numel() > 0:
123
+ return waveform, sr
124
+ except Exception:
125
+ pass
126
+
127
+ # Method 3: 返回静音
128
+ logger.warning(f"所有加载方式失败: {audio_file.name},返回静音")
129
+ return torch.zeros(1, self.target_samples), self.sample_rate
130
+
131
+ def __getitem__(self, idx):
132
+ audio_file = self.audio_files[idx]
133
+
134
+ try:
135
+ waveform, sr = self._load_audio_robust(audio_file)
136
+
137
+ # 单声道
138
+ if waveform.dim() > 1 and waveform.shape[0] > 1:
139
+ waveform = waveform.mean(dim=0, keepdim=True)
140
+ elif waveform.dim() == 1:
141
+ waveform = waveform.unsqueeze(0)
142
+
143
+ # 重采样
144
+ if sr != self.sample_rate:
145
+ resampler = T.Resample(orig_freq=sr, new_freq=self.sample_rate)
146
+ waveform = resampler(waveform)
147
+
148
+ # 裁剪或填充到目标长度
149
+ if waveform.shape[1] > self.target_samples:
150
+ start = torch.randint(0, waveform.shape[1] - self.target_samples, (1,)).item()
151
+ waveform = waveform[:, start:start + self.target_samples]
152
+ elif waveform.shape[1] < self.target_samples:
153
+ padding = self.target_samples - waveform.shape[1]
154
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
155
+
156
+ # 提取 mel 频谱
157
+ mel_spec = self.mel_transform(waveform)
158
+ mel_spec = self.amp_to_db(mel_spec)
159
+
160
+ # 简单 pitch 特征 (用 energy 作为 proxy)
161
+ frame_length = 256
162
+ hop_length = 256
163
+ energy = waveform.unfold(1, frame_length, hop_length).pow(2).mean(dim=2)
164
+ pitch_feat = energy.squeeze(0)
165
+
166
+ return {
167
+ 'audio': waveform.squeeze(0),
168
+ 'mel': mel_spec.squeeze(0),
169
+ 'pitch': pitch_feat,
170
+ 'filename': audio_file.name
171
+ }
172
+
173
+ except Exception as e:
174
+ logger.error(f"处理 {audio_file.name} 失败: {e}")
175
+ traceback.print_exc()
176
+ return {
177
+ 'audio': torch.zeros(self.target_samples),
178
+ 'mel': torch.zeros(self.config['model'].get('spec_n_mels', 128), 100),
179
+ 'pitch': torch.zeros(100),
180
+ 'filename': audio_file.name
181
+ }
182
+
183
+
184
+ class SimplifiedRVC(nn.Module):
185
+ """简化版RVC模型"""
186
+
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.config = config
190
+
191
+ # 特征提取器
192
+ self.feature_extractor = nn.Sequential(
193
+ nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3),
194
+ nn.ReLU(),
195
+ nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3),
196
+ nn.ReLU(),
197
+ nn.Conv1d(128, 256, kernel_size=7, stride=2, padding=3),
198
+ nn.ReLU()
199
+ )
200
+
201
+ # 编码器
202
+ self.encoder = nn.Sequential(
203
+ nn.Conv1d(256, 128, kernel_size=3, padding=1),
204
+ nn.ReLU(),
205
+ nn.Conv1d(128, 64, kernel_size=3, padding=1),
206
+ nn.ReLU()
207
+ )
208
+
209
+ # 解码器
210
+ self.decoder = nn.Sequential(
211
+ nn.Conv1d(64, 128, kernel_size=3, padding=1),
212
+ nn.ReLU(),
213
+ nn.Conv1d(128, 256, kernel_size=3, padding=1),
214
+ nn.ReLU(),
215
+ nn.ConvTranspose1d(256, 1, kernel_size=7, stride=8, padding=3, output_padding=1)
216
+ )
217
+
218
+ def forward(self, x):
219
+ # x: (batch, time)
220
+ x = x.unsqueeze(1) # (batch, 1, time)
221
+
222
+ # 特征提取
223
+ features = self.feature_extractor(x)
224
+
225
+ # 编码
226
+ encoded = self.encoder(features)
227
+
228
+ # 解码
229
+ decoded = self.decoder(encoded)
230
+
231
+ # 输出 - 裁剪到和输入一致
232
+ output = decoded.squeeze(1)
233
+ if output.shape[1] > x.shape[1]:
234
+ output = output[:, :x.shape[1]]
235
+ elif output.shape[1] < x.shape[1]:
236
+ output = torch.nn.functional.pad(output, (0, x.shape[1] - output.shape[1]))
237
+
238
+ return output
239
+
240
+
241
+ def train_model(config):
242
+ """训练模型"""
243
+ logger.info("=" * 60)
244
+ logger.info("🎤 开始RVC v2训练 (torchaudio版)")
245
+ logger.info("=" * 60)
246
+
247
+ # 设备
248
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
249
+ logger.info(f"📊 使用设备: {device}")
250
+
251
+ # 创建数据集
252
+ train_dir = config['data']['train_dir']
253
+ logger.info(f"📂 加载数据集: {train_dir}")
254
+
255
+ # 先测试能否加载至少一个音频
256
+ test_dir = Path(train_dir)
257
+ test_files = list(test_dir.glob("*.wav")) + list(test_dir.glob("*.mp3"))
258
+ if test_files:
259
+ logger.info(f"🔍 测试音频加载: {test_files[0].name}")
260
+ try:
261
+ wav, sr = torchaudio.load(str(test_files[0]))
262
+ logger.info(f" ✅ 成功! shape={wav.shape}, sr={sr}")
263
+ except Exception as e:
264
+ logger.warning(f" ⚠️ torchaudio 直接加载失败: {e}")
265
+ logger.info(" 尝试 ffmpeg fallback...")
266
+ import subprocess as sp
267
+ import tempfile
268
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
269
+ tmp_path = tmp.name
270
+ sp.run(
271
+ ["ffmpeg", "-y", "-i", str(test_files[0]), "-ar", "40000",
272
+ "-ac", "1", "-f", "wav", tmp_path],
273
+ capture_output=True, timeout=30
274
+ )
275
+ wav, sr = torchaudio.load(tmp_path)
276
+ os.unlink(tmp_path)
277
+ logger.info(f" ✅ ffmpeg fallback 成功! shape={wav.shape}, sr={sr}")
278
+
279
+ full_dataset = VoiceDataset(train_dir, config)
280
+
281
+ if len(full_dataset) == 0:
282
+ logger.error("❌ 没有找到任何音频文件!请检查数据目录。")
283
+ return None, float('inf')
284
+
285
+ # 分割训练集和验证集
286
+ val_split = config['data'].get('val_split', 0.1)
287
+ val_size = int(len(full_dataset) * val_split)
288
+ train_size = len(full_dataset) - val_size
289
+
290
+ train_dataset, val_dataset = torch.utils.data.random_split(
291
+ full_dataset,
292
+ [train_size, max(val_size, 1)]
293
+ )
294
+
295
+ logger.info(f" 训练集: {len(train_dataset)} 个样本")
296
+ logger.info(f" 验证集: {len(val_dataset)} 个样本")
297
+
298
+ # 创建数据加载器
299
+ train_loader = DataLoader(
300
+ train_dataset,
301
+ batch_size=config['training']['batch_size'],
302
+ shuffle=True,
303
+ num_workers=0,
304
+ drop_last=True
305
+ )
306
+
307
+ val_loader = DataLoader(
308
+ val_dataset,
309
+ batch_size=config['training']['batch_size'],
310
+ shuffle=False,
311
+ num_workers=0
312
+ )
313
+
314
+ # 创建模型
315
+ logger.info(f"🏗️ 创建模型: {config['model']['name']}")
316
+ model = SimplifiedRVC(config).to(device)
317
+
318
+ total_params = sum(p.numel() for p in model.parameters())
319
+ logger.info(f" 参数量: {total_params:,}")
320
+
321
+ # 损失函数
322
+ criterion = nn.MSELoss()
323
+
324
+ # 优化器
325
+ optimizer = optim.AdamW(
326
+ model.parameters(),
327
+ lr=config['training']['learning_rate'],
328
+ weight_decay=config['training'].get('weight_decay', 1e-5)
329
+ )
330
+
331
+ # 学习率调度器
332
+ scheduler = optim.lr_scheduler.StepLR(
333
+ optimizer,
334
+ step_size=config['training'].get('step_size', 100),
335
+ gamma=config['training'].get('gamma', 0.5)
336
+ )
337
+
338
+ # 创建输出目录
339
+ save_dir = Path(config['output']['save_dir'])
340
+ save_dir.mkdir(parents=True, exist_ok=True)
341
+
342
+ # 训练循环
343
+ epochs = config['training']['epochs']
344
+ best_val_loss = float('inf')
345
+
346
+ logger.info(f"🚀 开始训练: {epochs} 个epoch")
347
+ logger.info("=" * 60)
348
+
349
+ for epoch in range(epochs):
350
+ # 训练阶段
351
+ model.train()
352
+ train_loss = 0.0
353
+ num_batches = 0
354
+
355
+ for batch_idx, batch in enumerate(train_loader):
356
+ audio = batch['audio'].to(device)
357
+
358
+ # 前向传播
359
+ optimizer.zero_grad()
360
+ output = model(audio)
361
+
362
+ # 确保输出和目标长度一致
363
+ min_len = min(output.shape[1], audio.shape[1])
364
+ loss = criterion(output[:, :min_len], audio[:, :min_len])
365
+
366
+ # 反向传播
367
+ loss.backward()
368
+ optimizer.step()
369
+
370
+ train_loss += loss.item()
371
+ num_batches += 1
372
+
373
+ if (batch_idx + 1) % 10 == 0:
374
+ logger.info(f"Epoch {epoch+1}/{epochs} Batch {batch_idx+1}/{len(train_loader)} loss={loss.item():.6f}")
375
+
376
+ train_loss /= max(num_batches, 1)
377
+
378
+ # 验证阶段
379
+ val_every = config['training'].get('val_every_n_epochs', 10)
380
+ if (epoch + 1) % val_every == 0:
381
+ model.eval()
382
+ val_loss = 0.0
383
+ val_batches = 0
384
+
385
+ with torch.no_grad():
386
+ for batch in val_loader:
387
+ audio = batch['audio'].to(device)
388
+ output = model(audio)
389
+ min_len = min(output.shape[1], audio.shape[1])
390
+ loss = criterion(output[:, :min_len], audio[:, :min_len])
391
+ val_loss += loss.item()
392
+ val_batches += 1
393
+
394
+ val_loss /= max(val_batches, 1)
395
+
396
+ logger.info(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
397
+
398
+ # 保存最佳模型
399
+ if val_loss < best_val_loss:
400
+ best_val_loss = val_loss
401
+ save_path = save_dir / "best_model.pth"
402
+ torch.save({
403
+ 'epoch': epoch,
404
+ 'model_state_dict': model.state_dict(),
405
+ 'optimizer_state_dict': optimizer.state_dict(),
406
+ 'val_loss': val_loss,
407
+ 'config': config,
408
+ 'model_class': 'SimplifiedRVC',
409
+ 'torchaudio_version': torchaudio.__version__,
410
+ }, save_path)
411
+ logger.info(f" ✅ 保存最佳模型: {save_path} (Val Loss = {val_loss:.6f})")
412
+ else:
413
+ logger.info(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.6f}")
414
+
415
+ # 更新学习率
416
+ scheduler.step()
417
+
418
+ # 保存最终模型
419
+ final_path = save_dir / "final_model.pth"
420
+ torch.save({
421
+ 'epoch': epochs,
422
+ 'model_state_dict': model.state_dict(),
423
+ 'optimizer_state_dict': optimizer.state_dict(),
424
+ 'train_loss': train_loss,
425
+ 'config': config,
426
+ 'model_class': 'SimplifiedRVC',
427
+ 'torchaudio_version': torchaudio.__version__,
428
+ }, final_path)
429
+
430
+ logger.info("=" * 60)
431
+ logger.info("✅ 训练完成!")
432
+ logger.info(f"📊 最佳验证损失: {best_val_loss:.6f}")
433
+ logger.info(f"📦 最终模型: {final_path}")
434
+ logger.info("=" * 60)
435
+
436
+ return model, best_val_loss
437
+
438
+
439
+ def main():
440
+ """主函数"""
441
+ # 加载配置
442
+ config_file = "config_rvc_v2.yaml"
443
+ if not Path(config_file).exists():
444
+ logger.error(f"配置文件不存在: {config_file}")
445
+ sys.exit(1)
446
+
447
+ with open(config_file, 'r', encoding='utf-8') as f:
448
+ config = yaml.safe_load(f)
449
+
450
+ logger.info(f"📋 加载配置: {config_file}")
451
+
452
+ # 训练模型
453
+ model, best_val_loss = train_model(config)
454
+
455
+ if model is not None:
456
+ logger.info("🎉 训练成功完成!")
457
+ else:
458
+ logger.error("❌ 训练失败")
459
+ sys.exit(1)
460
+
461
+
462
+ if __name__ == "__main__":
463
+ main()