| import os |
| import torch |
| import torchaudio |
| from torch.utils.data import Dataset, DataLoader |
| import torch.nn as nn |
| import torch.optim as optim |
| from pathlib import Path |
| import argparse |
| from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
| |
| class HubertEncoder(nn.Module): |
| def __init__(self, input_dim=128, hidden_dim=256): |
| super().__init__() |
|
|
| self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1) |
| self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1) |
| self.lstm = nn.LSTM( |
| hidden_dim, |
| hidden_dim // 2, |
| num_layers=2, |
| batch_first=True, |
| bidirectional=True, |
| ) |
| self.proj = nn.Linear(hidden_dim, 256) |
|
|
| def forward(self, x): |
| |
| x = x.transpose(1, 2) |
| x = torch.relu(self.conv1(x)) |
| x = torch.relu(self.conv2(x)) |
| x = x.transpose(1, 2) |
| out, _ = self.lstm(x) |
| return self.proj(out) |
|
|
|
|
| class PitchEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| self.f0_conv = nn.Sequential( |
| nn.Conv1d(1, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.Conv1d(64, 128, 3, padding=1), |
| nn.ReLU(), |
| ) |
|
|
| self.pitch_proj = nn.Linear(128, 256) |
|
|
| def forward(self, f0): |
| |
| x = f0.unsqueeze(1) |
| x = self.f0_conv(x) |
| x = x.transpose(1, 2) |
| return self.pitch_proj(x) |
|
|
|
|
| class RVCDecoder(nn.Module): |
| def __init__(self, dim=256, mel_dim=128): |
| super().__init__() |
|
|
| self.content_lstm = nn.LSTM( |
| dim, |
| dim, |
| num_layers=2, |
| batch_first=True, |
| bidirectional=True, |
| ) |
|
|
| self.pitch_proj = nn.Linear(dim, dim * 2) |
|
|
| self.fusion = nn.MultiheadAttention( |
| embed_dim=dim * 2, |
| num_heads=8, |
| batch_first=True, |
| ) |
|
|
| self.output_proj = nn.Sequential( |
| nn.Linear(dim * 2, dim), |
| nn.ReLU(), |
| nn.Linear(dim, mel_dim), |
| ) |
|
|
| def forward(self, content, pitch): |
| |
| |
|
|
| content_out, _ = self.content_lstm(content) |
| pitch_out = self.pitch_proj(pitch) |
|
|
| fused, _ = self.fusion( |
| query=content_out, |
| key=pitch_out, |
| value=content_out, |
| ) |
|
|
| return self.output_proj(fused) |
|
|
|
|
| class RVCv2(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.hubert = HubertEncoder(input_dim=128) |
| self.pitch = PitchEncoder() |
| self.decoder = RVCDecoder(dim=256, mel_dim=128) |
|
|
| def forward(self, mel, f0): |
| |
| |
| content = self.hubert(mel) |
| pitch_feat = self.pitch(f0) |
| return self.decoder(content, pitch_feat) |
|
|
|
|
| |
| class RVCv2Dataset(Dataset): |
| def __init__(self, dataset_dir, sample_rate=40000, duration=10): |
| self.files = list(Path(dataset_dir).glob("*.wav")) |
|
|
| if len(self.files) == 0: |
| raise ValueError(f"No .wav files found in {dataset_dir}") |
|
|
| self.sample_rate = sample_rate |
| self.duration = duration |
| self.n_samples = int(sample_rate * duration) |
|
|
| self.mel_transform = torchaudio.transforms.MelSpectrogram( |
| sample_rate=self.sample_rate, |
| n_mels=128, |
| n_fft=2048, |
| hop_length=512, |
| ) |
|
|
| def __len__(self): |
| return len(self.files) |
|
|
| def __getitem__(self, idx): |
| waveform, sr = torchaudio.load(self.files[idx]) |
|
|
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| |
| if sr != self.sample_rate: |
| resampler = torchaudio.transforms.Resample(sr, self.sample_rate) |
| waveform = resampler(waveform) |
|
|
| |
| if waveform.shape[1] > self.n_samples: |
| waveform = waveform[:, :self.n_samples] |
| else: |
| pad_amount = self.n_samples - waveform.shape[1] |
| waveform = torch.nn.functional.pad(waveform, (0, pad_amount)) |
|
|
| |
| mel = self.mel_transform(waveform).squeeze(0) |
| mel = torch.log(mel + 1e-9) |
|
|
| |
| mel = mel.transpose(0, 1) |
|
|
| |
| f0 = torch.ones(mel.shape[0], dtype=torch.float32) * 200.0 |
|
|
| return mel, f0, waveform |
|
|
|
|
| def collate_fn(batch): |
| mels, f0s, waves = zip(*batch) |
|
|
| |
| mels = pad_sequence( |
| mels, |
| batch_first=True, |
| padding_value=0.0, |
| ) |
|
|
| |
| f0s = pad_sequence( |
| f0s, |
| batch_first=True, |
| padding_value=0.0, |
| ) |
|
|
| return mels, f0s, waves |
|
|
|
|
| |
| def train_rvc_v2( |
| model_name, |
| dataset_dir, |
| sample_rate=40000, |
| epochs=200, |
| batch_size=8, |
| lr=2e-4, |
| ): |
| print(f"🚀 RVC v2 Training Started: {model_name}") |
| print(f"📂 Dataset: {dataset_dir}") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"🛠️ Device: {device}") |
|
|
| dataset = RVCv2Dataset(dataset_dir, sample_rate) |
|
|
| print(f"🎧 Files found: {len(dataset)}") |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=0, |
| ) |
|
|
| model = RVCv2().to(device) |
|
|
| optimizer = optim.AdamW( |
| model.parameters(), |
| lr=lr, |
| weight_decay=1e-5, |
| ) |
|
|
| scheduler = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=epochs, |
| ) |
|
|
| criterion = nn.MSELoss() |
|
|
| os.makedirs("weights", exist_ok=True) |
|
|
| best_loss = float("inf") |
|
|
| for epoch in range(epochs): |
| model.train() |
| total_loss = 0.0 |
|
|
| for batch_idx, (mel, f0, _) in enumerate(dataloader): |
| mel = mel.to(device) |
| f0 = f0.to(device) |
|
|
| optimizer.zero_grad() |
|
|
| output = model(mel, f0) |
|
|
| loss = criterion(output, mel) |
| loss.backward() |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| scheduler.step() |
|
|
| avg_loss = total_loss / len(dataloader) |
|
|
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save(model.state_dict(), f"weights/{model_name}.pth") |
|
|
| if epoch % 10 == 0: |
| print( |
| f"Epoch {epoch}/{epochs} | " |
| f"Loss: {avg_loss:.4f} | " |
| f"Best: {best_loss:.4f} | " |
| f"LR: {scheduler.get_last_lr()[0]:.2e}" |
| ) |
|
|
| print(f"✅ Training Complete! Best model: weights/{model_name}.pth") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="RVC v2 Training") |
|
|
| parser.add_argument("--model_name", required=True) |
| parser.add_argument("--dataset", required=True) |
| parser.add_argument("--sample_rate", type=int, default=40000) |
| parser.add_argument("--epochs", type=int, default=200) |
| parser.add_argument("--batch_size", type=int, default=8) |
|
|
| args = parser.parse_args() |
|
|
| train_rvc_v2( |
| model_name=args.model_name, |
| dataset_dir=args.dataset, |
| sample_rate=args.sample_rate, |
| epochs=args.epochs, |
| batch_size=args.batch_size, |
| ) |