Spaces:
Running
Running
File size: 5,987 Bytes
1d8403e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | # ==================================================================================================
# DEEPFAKE AUDIO - vocoder/train.py (Model Training Orchestrator)
# ==================================================================================================
#
# π DESCRIPTION
# This script manages the training lifecycle of the WaveRNN vocoder. It
# handles dataset loading, gradient optimization via Adam, periodic
# checkpointing, and real-time generation of validation samples to monitor
# convergence.
#
# π€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
# - Mega Satish (https://github.com/msatmod)
#
# π€π» CREDITS
# Original Real-Time Voice Cloning methodology by CorentinJ
# Repository: https://github.com/CorentinJ/Real-Time-Voice-Cloning
#
# π PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO
# Video Demo: https://youtu.be/i3wnBcbHDbs
# Research: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO/blob/main/DEEPFAKE-AUDIO.ipynb
#
# π LICENSE
# Released under the MIT License
# Release Date: 2021-02-06
# ==================================================================================================
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import vocoder.hparams as hp
from vocoder.display import stream, simple_table
from vocoder.distribution import discretized_mix_logistic_loss
from vocoder.gen_wavernn import gen_testset
from vocoder.models.fatchord_version import WaveRNN
from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int,
backup_every: int, force_restart: bool):
"""
Main Training Loop:
Executes the neural network training protocol for high-fidelity audio synthesis.
"""
# Integrity Check: Ensure hop length matches the upsampling spatial pyramid
assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
# Neural Architecture Initialization
print("Architectural Supervision: Initializing the WaveRNN model...")
model = WaveRNN(
rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode
)
if torch.cuda.is_available():
model = model.cuda()
# Optimization Setup: Continuous learning via Adam optimizer
optimizer = optim.Adam(model.parameters())
for p in optimizer.param_groups:
p["lr"] = hp.voc_lr
# Loss objective selection based on synthesis mode
loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss
# Persistence: Load or initialize weights
model_dir = models_dir / run_id
model_dir.mkdir(exist_ok=True)
weights_fpath = model_dir / "vocoder.pt"
if force_restart or not weights_fpath.exists():
print("\nClean Start: Training WaveRNN from scratch\n")
model.save(weights_fpath, optimizer)
else:
print("\nRestoration: Loading weights from %s" % weights_fpath)
model.load(weights_fpath, optimizer)
print("Status: WaveRNN weights loaded from global step %d" % model.step)
# Data Ingestion: Prepare training metadata and file pointers
metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
voc_dir.joinpath("synthesized.txt")
mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
wav_dir = syn_dir.joinpath("audio")
dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
test_loader = DataLoader(dataset, batch_size=1, shuffle=True)
# UI Feedback: Display training hyperparameters
simple_table([('Batch size', hp.voc_batch_size),
('LR', hp.voc_lr),
('Sequence Len', hp.voc_seq_len)])
# Epoch Supervision: Iterate through the dataset
for epoch in range(1, 350):
data_loader = DataLoader(dataset, hp.voc_batch_size, shuffle=True, num_workers=2, collate_fn=collate_vocoder)
start = time.time()
running_loss = 0.
for i, (x, y, m) in enumerate(data_loader, 1):
if torch.cuda.is_available():
x, m, y = x.cuda(), m.cuda(), y.cuda()
# Forward pass: Generate predictions
y_hat = model(x, m)
if model.mode == 'RAW':
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
elif model.mode == 'MOL':
y = y.float()
y = y.unsqueeze(-1)
# Backward pass: Weight adjustment
loss = loss_func(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Statistics management
running_loss += loss.item()
speed = i / (time.time() - start)
avg_loss = running_loss / i
step = model.get_step()
k = step // 1000
# Periodic persistence
if backup_every != 0 and step % backup_every == 0 :
model.checkpoint(model_dir, optimizer)
if save_every != 0 and step % save_every == 0 :
model.save(weights_fpath, optimizer)
msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
f"Loss: {avg_loss:.4f} | {speed:.1f} " \
f"steps/s | Step: {k}k | "
stream(msg)
# Validation: Generate qualitative results after each epoch
gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
hp.voc_target, hp.voc_overlap, model_dir)
print("")
|