| | import os |
| | import torch |
| | import torchaudio |
| | import numpy as np |
| | from torch.utils.data import Dataset, DataLoader |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from pathlib import Path |
| | import argparse |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | |
| | class HubertEncoder(nn.Module): |
| | def __init__(self, input_dim=1024, hidden_dim=768): |
| | super().__init__() |
| | self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1) |
| | self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1) |
| | self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, 2, batch_first=True, bidirectional=True) |
| | self.proj = nn.Linear(hidden_dim, 256) |
| | |
| | def forward(self, x): |
| | x = x.transpose(1, 2) |
| | x = torch.relu(self.conv1(x)) |
| | x = torch.relu(self.conv2(x)) |
| | x = x.transpose(1, 2) |
| | out, _ = self.lstm(x) |
| | return self.proj(out) |
| |
|
| | class PitchEncoder(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.f0_conv = nn.Sequential( |
| | nn.Conv1d(1, 64, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv1d(64, 128, 3, padding=1), |
| | nn.ReLU() |
| | ) |
| | self.pitch_proj = nn.Linear(128, 256) |
| | |
| | def forward(self, f0): |
| | f0 = f0.unsqueeze(1).transpose(1, 2) |
| | out = self.f0_conv(f0) |
| | out = out.mean(-1) |
| | return self.pitch_proj(out) |
| |
|
| | class RVCDecoder(nn.Module): |
| | def __init__(self, dim=256): |
| | super().__init__() |
| | self.content_lstm = nn.LSTM(dim, dim, 2, batch_first=True, bidirectional=True) |
| | self.pitch_lstm = nn.LSTM(dim, dim//2, 1, batch_first=True) |
| | self.fusion = nn.MultiheadAttention(dim*2, 8) |
| | self.output_proj = nn.Sequential( |
| | nn.Linear(dim*2, dim), |
| | nn.ReLU(), |
| | nn.Linear(dim, 1024) |
| | ) |
| | |
| | def forward(self, content, pitch): |
| | content_out, _ = self.content_lstm(content) |
| | pitch_out, _ = self.pitch_lstm(pitch) |
| | pitch_out = pitch_out.repeat(1, content_out.size(1), 1) |
| | |
| | fused, _ = self.fusion(content_out, pitch_out, content_out) |
| | return self.output_proj(fused) |
| |
|
| | class RVCv2(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.hubert = HubertEncoder() |
| | self.pitch = PitchEncoder() |
| | self.decoder = RVCDecoder() |
| | |
| | def forward(self, mel, f0): |
| | content = self.hubert(mel) |
| | pitch_feat = self.pitch(f0) |
| | return self.decoder(content, pitch_feat) |
| |
|
| | |
| | class RVCv2Dataset(Dataset): |
| | def __init__(self, dataset_dir, sample_rate=40000, duration=10): |
| | self.files = list(Path(dataset_dir).glob("*.wav")) |
| | self.sample_rate = sample_rate |
| | self.duration = duration |
| | self.n_samples = int(sample_rate * duration) |
| | |
| | def __len__(self): |
| | return len(self.files) |
| | |
| | def __getitem__(self, idx): |
| | waveform, sr = torchaudio.load(self.files[idx]) |
| | |
| | |
| | if sr != self.sample_rate: |
| | resampler = torchaudio.transforms.Resample(sr, self.sample_rate) |
| | waveform = resampler(waveform) |
| | |
| | |
| | if waveform.shape[1] > self.n_samples: |
| | waveform = waveform[:, :self.n_samples] |
| | else: |
| | waveform = torch.nn.functional.pad(waveform, (0, self.n_samples - waveform.shape[1])) |
| | |
| | |
| | mel_transform = torchaudio.transforms.MelSpectrogram( |
| | sample_rate=self.sample_rate, n_mels=128, n_fft=2048, hop_length=512 |
| | ) |
| | mel = mel_transform(waveform).squeeze(0) |
| | mel = torch.log(mel + 1e-9) |
| | |
| | |
| | f0 = torch.ones(mel.shape[0]) * 200.0 |
| | f0 = torch.tensor(f0).float() |
| | |
| | return mel, f0, waveform |
| |
|
| | def collate_fn(batch): |
| | mels, f0s, waves = zip(*batch) |
| | mels = pad_sequence(mels, batch_first=True, padding_value=0.0) |
| | f0s = pad_sequence(f0s.unsqueeze(1), batch_first=True, padding_value=0.0).squeeze(1) |
| | return mels, f0s, waves |
| |
|
| | |
| | def train_rvc_v2(model_name, dataset_dir, sample_rate=40000, epochs=200, batch_size=8, lr=2e-4): |
| | print(f"๐ RVC v2 Training Started: {model_name}") |
| | print(f"๐ Dataset: {dataset_dir} ({len(os.listdir(dataset_dir))} files)") |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"๐ ๏ธ Device: {device}") |
| | |
| | |
| | dataset = RVCv2Dataset(dataset_dir, sample_rate) |
| | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) |
| | |
| | |
| | model = RVCv2().to(device) |
| | optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| | criterion = nn.MSELoss() |
| | |
| | os.makedirs("weights", exist_ok=True) |
| | |
| | best_loss = float('inf') |
| | for epoch in range(epochs): |
| | model.train() |
| | total_loss = 0 |
| | |
| | for batch_idx, (mel, f0, _) in enumerate(dataloader): |
| | mel, f0 = mel.to(device), f0.to(device) |
| | |
| | optimizer.zero_grad() |
| | output = model(mel, f0) |
| | loss = criterion(output, mel) |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | |
| | total_loss += loss.item() |
| | |
| | scheduler.step() |
| | avg_loss = total_loss / len(dataloader) |
| | |
| | if avg_loss < best_loss: |
| | best_loss = avg_loss |
| | torch.save(model.state_dict(), f"weights/{model_name}.pth") |
| | |
| | if epoch % 10 == 0: |
| | print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}") |
| | |
| | print(f"โ
RVC v2 Training Complete! Best model: weights/{model_name}.pth") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="RVC v2 Training") |
| | parser.add_argument("--model_name", required=True, help="Model name (e.g., zeynep_rvc)") |
| | parser.add_argument("--dataset", required=True, help="Path to dataset folder") |
| | parser.add_argument("--sample_rate", type=int, default=40000) |
| | parser.add_argument("--epochs", type=int, default=200) |
| | parser.add_argument("--batch_size", type=int, default=8) |
| | |
| | args = parser.parse_args() |
| | train_rvc_v2(args.model_name, args.dataset, args.sample_rate, args.epochs, args.batch_size) |
| |
|