Spaces:
Running
Running
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from torch.utils.data import Dataset, DataLoader | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from pathlib import Path | |
| import argparse | |
| from torch.nn.utils.rnn import pad_sequence | |
| # ---- RVC v2 Architecture (Hubert + Pitch + ContentVec) ---- | |
| class HubertEncoder(nn.Module): | |
| def __init__(self, input_dim=1024, hidden_dim=768): | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1) | |
| self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1) | |
| self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, 2, batch_first=True, bidirectional=True) | |
| self.proj = nn.Linear(hidden_dim, 256) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) # (B, T, F) -> (B, F, T) | |
| x = torch.relu(self.conv1(x)) | |
| x = torch.relu(self.conv2(x)) | |
| x = x.transpose(1, 2) # Back to (B, T, F) | |
| out, _ = self.lstm(x) | |
| return self.proj(out) # 256-dim features | |
| class PitchEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.f0_conv = nn.Sequential( | |
| nn.Conv1d(1, 64, 3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv1d(64, 128, 3, padding=1), | |
| nn.ReLU() | |
| ) | |
| self.pitch_proj = nn.Linear(128, 256) | |
| def forward(self, f0): | |
| f0 = f0.unsqueeze(1).transpose(1, 2) # (B, T) -> (B, 1, T) | |
| out = self.f0_conv(f0) | |
| out = out.mean(-1) # Global avg pool | |
| return self.pitch_proj(out) | |
| class RVCDecoder(nn.Module): | |
| def __init__(self, dim=256): | |
| super().__init__() | |
| self.content_lstm = nn.LSTM(dim, dim, 2, batch_first=True, bidirectional=True) | |
| self.pitch_lstm = nn.LSTM(dim, dim//2, 1, batch_first=True) | |
| self.fusion = nn.MultiheadAttention(dim*2, 8) | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(dim*2, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, 1024) # Mel output | |
| ) | |
| def forward(self, content, pitch): | |
| content_out, _ = self.content_lstm(content) | |
| pitch_out, _ = self.pitch_lstm(pitch) | |
| pitch_out = pitch_out.repeat(1, content_out.size(1), 1) | |
| fused, _ = self.fusion(content_out, pitch_out, content_out) | |
| return self.output_proj(fused) | |
| class RVCv2(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.hubert = HubertEncoder() | |
| self.pitch = PitchEncoder() | |
| self.decoder = RVCDecoder() | |
| def forward(self, mel, f0): | |
| content = self.hubert(mel) | |
| pitch_feat = self.pitch(f0) | |
| return self.decoder(content, pitch_feat) | |
| # ---- Advanced Audio Dataset ---- | |
| class RVCv2Dataset(Dataset): | |
| def __init__(self, dataset_dir, sample_rate=40000, duration=10): | |
| self.files = list(Path(dataset_dir).glob("*.wav")) | |
| self.sample_rate = sample_rate | |
| self.duration = duration | |
| self.n_samples = int(sample_rate * duration) | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| waveform, sr = torchaudio.load(self.files[idx]) | |
| # Resample | |
| if sr != self.sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.sample_rate) | |
| waveform = resampler(waveform) | |
| # Trim/pad | |
| if waveform.shape[1] > self.n_samples: | |
| waveform = waveform[:, :self.n_samples] | |
| else: | |
| waveform = torch.nn.functional.pad(waveform, (0, self.n_samples - waveform.shape[1])) | |
| # Mel spectrogram (target) | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=self.sample_rate, n_mels=128, n_fft=2048, hop_length=512 | |
| ) | |
| mel = mel_transform(waveform).squeeze(0) | |
| mel = torch.log(mel + 1e-9) | |
| # Dummy F0 (real impl needs crepe/dio) | |
| f0 = torch.ones(mel.shape[0]) * 200.0 # Placeholder | |
| f0 = torch.tensor(f0).float() | |
| return mel, f0, waveform | |
| def collate_fn(batch): | |
| mels, f0s, waves = zip(*batch) | |
| mels = pad_sequence(mels, batch_first=True, padding_value=0.0) | |
| f0s = pad_sequence(f0s.unsqueeze(1), batch_first=True, padding_value=0.0).squeeze(1) | |
| return mels, f0s, waves | |
| # ---- Training Loop ---- | |
| def train_rvc_v2(model_name, dataset_dir, sample_rate=40000, epochs=200, batch_size=8, lr=2e-4): | |
| print(f"๐ RVC v2 Training Started: {model_name}") | |
| print(f"๐ Dataset: {dataset_dir} ({len(os.listdir(dataset_dir))} files)") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"๐ ๏ธ Device: {device}") | |
| # Data | |
| dataset = RVCv2Dataset(dataset_dir, sample_rate) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) | |
| # Model | |
| model = RVCv2().to(device) | |
| optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) | |
| criterion = nn.MSELoss() | |
| os.makedirs("weights", exist_ok=True) | |
| best_loss = float('inf') | |
| for epoch in range(epochs): | |
| model.train() | |
| total_loss = 0 | |
| for batch_idx, (mel, f0, _) in enumerate(dataloader): | |
| mel, f0 = mel.to(device), f0.to(device) | |
| optimizer.zero_grad() | |
| output = model(mel, f0) | |
| loss = criterion(output, mel) # Reconstruction | |
| loss.backward() | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| scheduler.step() | |
| avg_loss = total_loss / len(dataloader) | |
| if avg_loss < best_loss: | |
| best_loss = avg_loss | |
| torch.save(model.state_dict(), f"weights/{model_name}.pth") | |
| if epoch % 10 == 0: | |
| print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}") | |
| print(f"โ RVC v2 Training Complete! Best model: weights/{model_name}.pth") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="RVC v2 Training") | |
| parser.add_argument("--model_name", required=True, help="Model name (e.g., zeynep_rvc)") | |
| parser.add_argument("--dataset", required=True, help="Path to dataset folder") | |
| parser.add_argument("--sample_rate", type=int, default=40000) | |
| parser.add_argument("--epochs", type=int, default=200) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| args = parser.parse_args() | |
| train_rvc_v2(args.model_name, args.dataset, args.sample_rate, args.epochs, args.batch_size) | |