Train_RVC / train.py
aTrapDeer's picture
Update train.py
4936b4d verified
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import argparse
from torch.nn.utils.rnn import pad_sequence
# ---- Simplified RVC-like Architecture ----
class HubertEncoder(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super().__init__()
self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1)
self.lstm = nn.LSTM(
hidden_dim,
hidden_dim // 2,
num_layers=2,
batch_first=True,
bidirectional=True,
)
self.proj = nn.Linear(hidden_dim, 256)
def forward(self, x):
# x: (B, T, 128)
x = x.transpose(1, 2) # (B, 128, T)
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.transpose(1, 2) # (B, T, hidden)
out, _ = self.lstm(x)
return self.proj(out) # (B, T, 256)
class PitchEncoder(nn.Module):
def __init__(self):
super().__init__()
self.f0_conv = nn.Sequential(
nn.Conv1d(1, 64, 3, padding=1),
nn.ReLU(),
nn.Conv1d(64, 128, 3, padding=1),
nn.ReLU(),
)
self.pitch_proj = nn.Linear(128, 256)
def forward(self, f0):
# f0: (B, T)
x = f0.unsqueeze(1) # (B, 1, T)
x = self.f0_conv(x) # (B, 128, T)
x = x.transpose(1, 2) # (B, T, 128)
return self.pitch_proj(x) # (B, T, 256)
class RVCDecoder(nn.Module):
def __init__(self, dim=256, mel_dim=128):
super().__init__()
self.content_lstm = nn.LSTM(
dim,
dim,
num_layers=2,
batch_first=True,
bidirectional=True,
)
self.pitch_proj = nn.Linear(dim, dim * 2)
self.fusion = nn.MultiheadAttention(
embed_dim=dim * 2,
num_heads=8,
batch_first=True,
)
self.output_proj = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Linear(dim, mel_dim),
)
def forward(self, content, pitch):
# content: (B, T, 256)
# pitch: (B, T, 256)
content_out, _ = self.content_lstm(content) # (B, T, 512)
pitch_out = self.pitch_proj(pitch) # (B, T, 512)
fused, _ = self.fusion(
query=content_out,
key=pitch_out,
value=content_out,
)
return self.output_proj(fused) # (B, T, 128)
class RVCv2(nn.Module):
def __init__(self):
super().__init__()
self.hubert = HubertEncoder(input_dim=128)
self.pitch = PitchEncoder()
self.decoder = RVCDecoder(dim=256, mel_dim=128)
def forward(self, mel, f0):
# mel: (B, T, 128)
# f0: (B, T)
content = self.hubert(mel)
pitch_feat = self.pitch(f0)
return self.decoder(content, pitch_feat)
# ---- Dataset ----
class RVCv2Dataset(Dataset):
def __init__(self, dataset_dir, sample_rate=40000, duration=10):
self.files = list(Path(dataset_dir).glob("*.wav"))
if len(self.files) == 0:
raise ValueError(f"No .wav files found in {dataset_dir}")
self.sample_rate = sample_rate
self.duration = duration
self.n_samples = int(sample_rate * duration)
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_mels=128,
n_fft=2048,
hop_length=512,
)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
waveform, sr = torchaudio.load(self.files[idx])
# Convert stereo/multi-channel to mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
waveform = resampler(waveform)
# Trim/pad audio
if waveform.shape[1] > self.n_samples:
waveform = waveform[:, :self.n_samples]
else:
pad_amount = self.n_samples - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
# Mel spectrogram: (1, 128, T) -> (128, T)
mel = self.mel_transform(waveform).squeeze(0)
mel = torch.log(mel + 1e-9)
# Convert to (T, 128)
mel = mel.transpose(0, 1)
# Dummy F0 placeholder, one value per time frame
f0 = torch.ones(mel.shape[0], dtype=torch.float32) * 200.0
return mel, f0, waveform
def collate_fn(batch):
mels, f0s, waves = zip(*batch)
# mels are list of tensors shaped (T, 128)
mels = pad_sequence(
mels,
batch_first=True,
padding_value=0.0,
)
# f0s are list of tensors shaped (T,)
f0s = pad_sequence(
f0s,
batch_first=True,
padding_value=0.0,
)
return mels, f0s, waves
# ---- Training Loop ----
def train_rvc_v2(
model_name,
dataset_dir,
sample_rate=40000,
epochs=200,
batch_size=8,
lr=2e-4,
):
print(f"🚀 RVC v2 Training Started: {model_name}")
print(f"📂 Dataset: {dataset_dir}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🛠️ Device: {device}")
dataset = RVCv2Dataset(dataset_dir, sample_rate)
print(f"🎧 Files found: {len(dataset)}")
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=0,
)
model = RVCv2().to(device)
optimizer = optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=1e-5,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=epochs,
)
criterion = nn.MSELoss()
os.makedirs("weights", exist_ok=True)
best_loss = float("inf")
for epoch in range(epochs):
model.train()
total_loss = 0.0
for batch_idx, (mel, f0, _) in enumerate(dataloader):
mel = mel.to(device)
f0 = f0.to(device)
optimizer.zero_grad()
output = model(mel, f0)
loss = criterion(output, mel)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(model.state_dict(), f"weights/{model_name}.pth")
if epoch % 10 == 0:
print(
f"Epoch {epoch}/{epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Best: {best_loss:.4f} | "
f"LR: {scheduler.get_last_lr()[0]:.2e}"
)
print(f"✅ Training Complete! Best model: weights/{model_name}.pth")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RVC v2 Training")
parser.add_argument("--model_name", required=True)
parser.add_argument("--dataset", required=True)
parser.add_argument("--sample_rate", type=int, default=40000)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=8)
args = parser.parse_args()
train_rvc_v2(
model_name=args.model_name,
dataset_dir=args.dataset,
sample_rate=args.sample_rate,
epochs=args.epochs,
batch_size=args.batch_size,
)