Spaces:
Running
Running
File size: 6,463 Bytes
1d8403e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | # ==================================================================================================
# DEEPFAKE AUDIO - encoder/train.py (Neural Identity Optimization Cycle)
# ==================================================================================================
#
# π DESCRIPTION
# This module orchestrates the complete training cycle for the Speaker Encoder.
# It manages the GE2E (Generalized End-to-End) loss computation, stochastic
# gradient descent via Adam, and provides rich diagnostic telemetry through
# Visdom and UMAP projections. It ensures that the model learns a robust
# identity manifold for zero-shot speaker adaptation.
#
# π€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
# - Mega Satish (https://github.com/msatmod)
#
# π€π» CREDITS
# Original Real-Time Voice Cloning methodology by CorentinJ
# Repository: https://github.com/CorentinJ/Real-Time-Voice-Cloning
#
# π PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO
# Video Demo: https://youtu.be/i3wnBcbHDbs
# Research: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO/blob/main/DEEPFAKE-AUDIO.ipynb
#
# π LICENSE
# Released under the MIT License
# Release Date: 2021-02-06
# ==================================================================================================
from pathlib import Path
import torch
# --- PROJECT CORE MODULES ---
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
from encoder.model import SpeakerEncoder
from encoder.params_model import *
from encoder.visualizations import Visualizations
from utils.profiler import Profiler
def sync(device: torch.device):
"""Ensures GPU operations are completed before profiling ticks."""
if device.type == "cuda":
torch.cuda.synchronize(device)
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
no_visdom: bool):
"""
Main Orchestrator:
1. Dataset & DataLoader Initialization (Categorical Batching)
2. Architecture Construction (LSTM Backbone)
3. Checkpoint Resumption (Resilient Training)
4. Optimization Loop (GE2E Loss + UMAP Telemetry)
"""
# Categorical Data Pipeline
dataset = SpeakerVerificationDataset(clean_data_root)
loader = SpeakerVerificationDataLoader(
dataset,
speakers_per_batch,
utterances_per_speaker,
num_workers=4,
)
# Hardware Orchestration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# GE2E Loss Calculation is often mathematically stable on CPU
loss_device = torch.device("cpu")
# Neural & Optimization Setup
model = SpeakerEncoder(device, loss_device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
init_step = 1
# Storage Architecture
model_dir = models_dir / run_id
model_dir.mkdir(exist_ok=True, parents=True)
state_fpath = model_dir / "encoder.pt"
# Checkpoint Management
if not force_restart:
if state_fpath.exists():
print("π€π» Resuming Training Session: Found existing model \"%s\"" % run_id)
checkpoint = torch.load(state_fpath)
init_step = checkpoint["step"]
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
optimizer.param_groups[0]["lr"] = learning_rate_init
else:
print("π Initiating New Session: Model \"%s\" not found." % run_id)
else:
print("π Force Restart: Re-initializing weights from scratch.")
model.train()
# Telemetry System (Visdom)
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
vis.log_dataset(dataset)
vis.log_params()
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
vis.log_implementation({"Device": device_name})
# High-Performance Training Cycle
profiler = Profiler(summarize_every=10, disabled=False)
for step, speaker_batch in enumerate(loader, init_step):
profiler.tick("Blocking - Queue Ingestion")
# 1. Forward Pass
inputs = torch.from_numpy(speaker_batch.data).to(device)
sync(device)
profiler.tick("H2D Transfer")
embeds = model(inputs)
sync(device)
profiler.tick("LSTM Backbone Inference")
# 2. Geometric Similarity & Loss
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
loss, eer = model.loss(embeds_loss)
sync(loss_device)
profiler.tick("GE2E Loss Computation")
# 3. Stochastic Gradient Optimization
model.zero_grad()
loss.backward()
profiler.tick("Backpropagation")
model.do_gradient_ops() # Gradient Clipping & Scaling
optimizer.step()
profiler.tick("Parameter Update")
# 4. Telemetry Update (Smoothing Curve)
vis.update(loss.item(), eer, step)
# 5. UMAP Projections (Manifold Visualization)
if umap_every != 0 and step % umap_every == 0:
print("\nπ Generating Identity Manifold Projection (step %d)" % step)
projection_fpath = model_dir / f"umap_{step:06d}.png"
embeds_npy = embeds.detach().cpu().numpy()
vis.draw_projections(embeds_npy, utterances_per_speaker, step, projection_fpath)
vis.save()
# 6. Weight Persistence (Checkpointing)
if save_every != 0 and step % save_every == 0:
print("\nπΎ Persisting Latest Weights (step %d)" % step)
torch.save({
"step": step + 1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, state_fpath)
# 7. Rollng Backup (Immutable Snapshots)
if backup_every != 0 and step % backup_every == 0:
print("\nπ Creating Immutable Snapshot (step %d)" % step)
backup_fpath = model_dir / f"encoder_{step:06d}.bak"
torch.save({
"step": step + 1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, backup_fpath)
profiler.tick("Housekeeping (Telemetry & Storage)")
|