import os import torch import torchaudio from torch.utils.data import Dataset, DataLoader import torch.nn as nn import torch.optim as optim from pathlib import Path import argparse from torch.nn.utils.rnn import pad_sequence # ---- Simplified RVC-like Architecture ---- class HubertEncoder(nn.Module): def __init__(self, input_dim=128, hidden_dim=256): super().__init__() self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1) self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1) self.lstm = nn.LSTM( hidden_dim, hidden_dim // 2, num_layers=2, batch_first=True, bidirectional=True, ) self.proj = nn.Linear(hidden_dim, 256) def forward(self, x): # x: (B, T, 128) x = x.transpose(1, 2) # (B, 128, T) x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = x.transpose(1, 2) # (B, T, hidden) out, _ = self.lstm(x) return self.proj(out) # (B, T, 256) class PitchEncoder(nn.Module): def __init__(self): super().__init__() self.f0_conv = nn.Sequential( nn.Conv1d(1, 64, 3, padding=1), nn.ReLU(), nn.Conv1d(64, 128, 3, padding=1), nn.ReLU(), ) self.pitch_proj = nn.Linear(128, 256) def forward(self, f0): # f0: (B, T) x = f0.unsqueeze(1) # (B, 1, T) x = self.f0_conv(x) # (B, 128, T) x = x.transpose(1, 2) # (B, T, 128) return self.pitch_proj(x) # (B, T, 256) class RVCDecoder(nn.Module): def __init__(self, dim=256, mel_dim=128): super().__init__() self.content_lstm = nn.LSTM( dim, dim, num_layers=2, batch_first=True, bidirectional=True, ) self.pitch_proj = nn.Linear(dim, dim * 2) self.fusion = nn.MultiheadAttention( embed_dim=dim * 2, num_heads=8, batch_first=True, ) self.output_proj = nn.Sequential( nn.Linear(dim * 2, dim), nn.ReLU(), nn.Linear(dim, mel_dim), ) def forward(self, content, pitch): # content: (B, T, 256) # pitch: (B, T, 256) content_out, _ = self.content_lstm(content) # (B, T, 512) pitch_out = self.pitch_proj(pitch) # (B, T, 512) fused, _ = self.fusion( query=content_out, key=pitch_out, value=content_out, ) return self.output_proj(fused) # (B, T, 128) class RVCv2(nn.Module): def __init__(self): super().__init__() self.hubert = HubertEncoder(input_dim=128) self.pitch = PitchEncoder() self.decoder = RVCDecoder(dim=256, mel_dim=128) def forward(self, mel, f0): # mel: (B, T, 128) # f0: (B, T) content = self.hubert(mel) pitch_feat = self.pitch(f0) return self.decoder(content, pitch_feat) # ---- Dataset ---- class RVCv2Dataset(Dataset): def __init__(self, dataset_dir, sample_rate=40000, duration=10): self.files = list(Path(dataset_dir).glob("*.wav")) if len(self.files) == 0: raise ValueError(f"No .wav files found in {dataset_dir}") self.sample_rate = sample_rate self.duration = duration self.n_samples = int(sample_rate * duration) self.mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.sample_rate, n_mels=128, n_fft=2048, hop_length=512, ) def __len__(self): return len(self.files) def __getitem__(self, idx): waveform, sr = torchaudio.load(self.files[idx]) # Convert stereo/multi-channel to mono if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) waveform = resampler(waveform) # Trim/pad audio if waveform.shape[1] > self.n_samples: waveform = waveform[:, :self.n_samples] else: pad_amount = self.n_samples - waveform.shape[1] waveform = torch.nn.functional.pad(waveform, (0, pad_amount)) # Mel spectrogram: (1, 128, T) -> (128, T) mel = self.mel_transform(waveform).squeeze(0) mel = torch.log(mel + 1e-9) # Convert to (T, 128) mel = mel.transpose(0, 1) # Dummy F0 placeholder, one value per time frame f0 = torch.ones(mel.shape[0], dtype=torch.float32) * 200.0 return mel, f0, waveform def collate_fn(batch): mels, f0s, waves = zip(*batch) # mels are list of tensors shaped (T, 128) mels = pad_sequence( mels, batch_first=True, padding_value=0.0, ) # f0s are list of tensors shaped (T,) f0s = pad_sequence( f0s, batch_first=True, padding_value=0.0, ) return mels, f0s, waves # ---- Training Loop ---- def train_rvc_v2( model_name, dataset_dir, sample_rate=40000, epochs=200, batch_size=8, lr=2e-4, ): print(f"🚀 RVC v2 Training Started: {model_name}") print(f"📂 Dataset: {dataset_dir}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🛠️ Device: {device}") dataset = RVCv2Dataset(dataset_dir, sample_rate) print(f"🎧 Files found: {len(dataset)}") dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, ) model = RVCv2().to(device) optimizer = optim.AdamW( model.parameters(), lr=lr, weight_decay=1e-5, ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=epochs, ) criterion = nn.MSELoss() os.makedirs("weights", exist_ok=True) best_loss = float("inf") for epoch in range(epochs): model.train() total_loss = 0.0 for batch_idx, (mel, f0, _) in enumerate(dataloader): mel = mel.to(device) f0 = f0.to(device) optimizer.zero_grad() output = model(mel, f0) loss = criterion(output, mel) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() scheduler.step() avg_loss = total_loss / len(dataloader) if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), f"weights/{model_name}.pth") if epoch % 10 == 0: print( f"Epoch {epoch}/{epochs} | " f"Loss: {avg_loss:.4f} | " f"Best: {best_loss:.4f} | " f"LR: {scheduler.get_last_lr()[0]:.2e}" ) print(f"✅ Training Complete! Best model: weights/{model_name}.pth") if __name__ == "__main__": parser = argparse.ArgumentParser(description="RVC v2 Training") parser.add_argument("--model_name", required=True) parser.add_argument("--dataset", required=True) parser.add_argument("--sample_rate", type=int, default=40000) parser.add_argument("--epochs", type=int, default=200) parser.add_argument("--batch_size", type=int, default=8) args = parser.parse_args() train_rvc_v2( model_name=args.model_name, dataset_dir=args.dataset, sample_rate=args.sample_rate, epochs=args.epochs, batch_size=args.batch_size, )