Spaces:
Running
Running
File size: 6,410 Bytes
1d8403e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | # ==================================================================================================
# DEEPFAKE AUDIO - encoder/inference.py (Neural Identity Distillation Interface)
# ==================================================================================================
#
# π DESCRIPTION
# This module provides the high-level API for using the Speaker Encoder in a
# production environment. It encapsulates the complexities of model loading,
# tensor orchestration, and d-vector derivation. It is the primary bridge
# used by the web interface (app.py) to extract speaker identities from
# uploaded reference audio samples.
#
# π€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
# - Mega Satish (https://github.com/msatmod)
#
# π€π» CREDITS
# Original Real-Time Voice Cloning methodology by CorentinJ
# Repository: https://github.com/CorentinJ/Real-Time-Voice-Cloning
#
# π PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO
# Video Demo: https://youtu.be/i3wnBcbHDbs
# Research: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO/blob/main/DEEPFAKE-AUDIO.ipynb
#
# π LICENSE
# Released under the MIT License
# Release Date: 2021-02-06
# ==================================================================================================
from encoder.params_data import *
from encoder.model import SpeakerEncoder
from encoder.audio import preprocess_wav
from matplotlib import cm
from encoder import audio
from pathlib import Path
import numpy as np
import torch
# --- INTERNAL STATE (SINGLETON PATTERN) ---
_model = None # type: SpeakerEncoder
_device = None # type: torch.device
def load_model(weights_fpath: Path, device=None):
"""
Initializes the Speaker Encoder neural network.
Deserializes the PyTorch state dictionary and prepares the model for eval mode.
"""
global _model, _device
# Precise hardware targeting
if device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
_device = torch.device(device)
# Constructing the architecture
_model = SpeakerEncoder(_device, torch.device("cpu"))
# Loading serialized weights
checkpoint = torch.load(weights_fpath, map_location=_device, weights_only=False)
_model.load_state_dict(checkpoint["model_state"])
_model.eval()
print("π€π» Encoder Active: Loaded \"%s\" (Step %d)" % (weights_fpath.name, checkpoint["step"]))
def is_loaded():
"""Checks the initialization status of the neural engine."""
return _model is not None
def embed_frames_batch(frames_batch):
"""
Neural Forward Pass: Computes speaker embeddings for a batch of spectrograms.
Returns l2-normalized d-vectors.
"""
if _model is None:
raise Exception("Fatal: Neural Encoder is not initialized. Invoke load_model().")
frames = torch.from_numpy(frames_batch).to(_device)
embed = _model.forward(frames).detach().cpu().numpy()
return embed
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
min_pad_coverage=0.75, overlap=0.5):
"""
Spatio-Temporal Segmentation: Defines how a long utterance is sliced into
overlapping windows for stable embedding derivation.
"""
assert 0 <= overlap < 1
assert 0 < min_pad_coverage <= 1
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
# Window Orchestration
wav_slices, mel_slices = [], []
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
for i in range(0, steps, frame_step):
mel_range = np.array([i, i + partial_utterance_n_frames])
wav_range = mel_range * samples_per_frame
mel_slices.append(slice(*mel_range))
wav_slices.append(slice(*wav_range))
# Defensive Padding Evaluation
last_wav_range = wav_slices[-1]
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
if coverage < min_pad_coverage and len(mel_slices) > 1:
mel_slices = mel_slices[:-1]
wav_slices = wav_slices[:-1]
return wav_slices, mel_slices
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
"""
Core Identity Extraction: Distills a processed waveform into a single
256-dimensional identity vector (d-vector).
"""
# 1. Full-Waveform Processing (Fallback for short utterances)
if not using_partials:
frames = audio.wav_to_mel_spectrogram(wav)
embed = embed_frames_batch(frames[None, ...])[0]
if return_partials:
return embed, None, None
return embed
# 2. Windowed Distillation
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
max_wave_length = wave_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# 3. Batch Inference on Windows
frames = audio.wav_to_mel_spectrogram(wav)
frames_batch = np.array([frames[s] for s in mel_slices])
partial_embeds = embed_frames_batch(frames_batch)
# 4. Statistical Averaging & Re-Normalization
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
if return_partials:
return embed, partial_embeds, wave_slices
return embed
def embed_speaker(wavs, **kwargs):
"""Aggregate identity extraction for multiple utterances from the same speaker."""
raise NotImplementedError("Collaborative development in progress for multi-wav aggregation.")
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
"""Visualizes the high-dimensional latent vector as a spatial intensity map."""
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
if shape is None:
height = int(np.sqrt(len(embed)))
shape = (height, -1)
embed = embed.reshape(shape)
cmap = plt.get_cmap()
mappable = ax.imshow(embed, cmap=cmap)
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
sm = cm.ScalarMappable(cmap=cmap)
sm.set_clim(*color_range)
ax.set_xticks([]), ax.set_yticks([])
ax.set_title(title)
|