Spaces:
Running
Running
| # ================================================================================================== | |
| # DEEPFAKE AUDIO - encoder/inference.py (Neural Identity Distillation Interface) | |
| # ================================================================================================== | |
| # | |
| # π DESCRIPTION | |
| # This module provides the high-level API for using the Speaker Encoder in a | |
| # production environment. It encapsulates the complexities of model loading, | |
| # tensor orchestration, and d-vector derivation. It is the primary bridge | |
| # used by the web interface (app.py) to extract speaker identities from | |
| # uploaded reference audio samples. | |
| # | |
| # π€ AUTHORS | |
| # - Amey Thakur (https://github.com/Amey-Thakur) | |
| # - Mega Satish (https://github.com/msatmod) | |
| # | |
| # π€π» CREDITS | |
| # Original Real-Time Voice Cloning methodology by CorentinJ | |
| # Repository: https://github.com/CorentinJ/Real-Time-Voice-Cloning | |
| # | |
| # π PROJECT LINKS | |
| # Repository: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO | |
| # Video Demo: https://youtu.be/i3wnBcbHDbs | |
| # Research: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO/blob/main/DEEPFAKE-AUDIO.ipynb | |
| # | |
| # π LICENSE | |
| # Released under the MIT License | |
| # Release Date: 2021-02-06 | |
| # ================================================================================================== | |
| from encoder.params_data import * | |
| from encoder.model import SpeakerEncoder | |
| from encoder.audio import preprocess_wav | |
| from matplotlib import cm | |
| from encoder import audio | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| # --- INTERNAL STATE (SINGLETON PATTERN) --- | |
| _model = None # type: SpeakerEncoder | |
| _device = None # type: torch.device | |
| def load_model(weights_fpath: Path, device=None): | |
| """ | |
| Initializes the Speaker Encoder neural network. | |
| Deserializes the PyTorch state dictionary and prepares the model for eval mode. | |
| """ | |
| global _model, _device | |
| # Precise hardware targeting | |
| if device is None: | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| elif isinstance(device, str): | |
| _device = torch.device(device) | |
| # Constructing the architecture | |
| _model = SpeakerEncoder(_device, torch.device("cpu")) | |
| # Loading serialized weights | |
| checkpoint = torch.load(weights_fpath, map_location=_device, weights_only=False) | |
| _model.load_state_dict(checkpoint["model_state"]) | |
| _model.eval() | |
| print("π€π» Encoder Active: Loaded \"%s\" (Step %d)" % (weights_fpath.name, checkpoint["step"])) | |
| def is_loaded(): | |
| """Checks the initialization status of the neural engine.""" | |
| return _model is not None | |
| def embed_frames_batch(frames_batch): | |
| """ | |
| Neural Forward Pass: Computes speaker embeddings for a batch of spectrograms. | |
| Returns l2-normalized d-vectors. | |
| """ | |
| if _model is None: | |
| raise Exception("Fatal: Neural Encoder is not initialized. Invoke load_model().") | |
| frames = torch.from_numpy(frames_batch).to(_device) | |
| embed = _model.forward(frames).detach().cpu().numpy() | |
| return embed | |
| def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, | |
| min_pad_coverage=0.75, overlap=0.5): | |
| """ | |
| Spatio-Temporal Segmentation: Defines how a long utterance is sliced into | |
| overlapping windows for stable embedding derivation. | |
| """ | |
| assert 0 <= overlap < 1 | |
| assert 0 < min_pad_coverage <= 1 | |
| samples_per_frame = int((sampling_rate * mel_window_step / 1000)) | |
| n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) | |
| frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) | |
| # Window Orchestration | |
| wav_slices, mel_slices = [], [] | |
| steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) | |
| for i in range(0, steps, frame_step): | |
| mel_range = np.array([i, i + partial_utterance_n_frames]) | |
| wav_range = mel_range * samples_per_frame | |
| mel_slices.append(slice(*mel_range)) | |
| wav_slices.append(slice(*wav_range)) | |
| # Defensive Padding Evaluation | |
| last_wav_range = wav_slices[-1] | |
| coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) | |
| if coverage < min_pad_coverage and len(mel_slices) > 1: | |
| mel_slices = mel_slices[:-1] | |
| wav_slices = wav_slices[:-1] | |
| return wav_slices, mel_slices | |
| def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): | |
| """ | |
| Core Identity Extraction: Distills a processed waveform into a single | |
| 256-dimensional identity vector (d-vector). | |
| """ | |
| # 1. Full-Waveform Processing (Fallback for short utterances) | |
| if not using_partials: | |
| frames = audio.wav_to_mel_spectrogram(wav) | |
| embed = embed_frames_batch(frames[None, ...])[0] | |
| if return_partials: | |
| return embed, None, None | |
| return embed | |
| # 2. Windowed Distillation | |
| wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) | |
| max_wave_length = wave_slices[-1].stop | |
| if max_wave_length >= len(wav): | |
| wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") | |
| # 3. Batch Inference on Windows | |
| frames = audio.wav_to_mel_spectrogram(wav) | |
| frames_batch = np.array([frames[s] for s in mel_slices]) | |
| partial_embeds = embed_frames_batch(frames_batch) | |
| # 4. Statistical Averaging & Re-Normalization | |
| raw_embed = np.mean(partial_embeds, axis=0) | |
| embed = raw_embed / np.linalg.norm(raw_embed, 2) | |
| if return_partials: | |
| return embed, partial_embeds, wave_slices | |
| return embed | |
| def embed_speaker(wavs, **kwargs): | |
| """Aggregate identity extraction for multiple utterances from the same speaker.""" | |
| raise NotImplementedError("Collaborative development in progress for multi-wav aggregation.") | |
| def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): | |
| """Visualizes the high-dimensional latent vector as a spatial intensity map.""" | |
| import matplotlib.pyplot as plt | |
| if ax is None: | |
| ax = plt.gca() | |
| if shape is None: | |
| height = int(np.sqrt(len(embed))) | |
| shape = (height, -1) | |
| embed = embed.reshape(shape) | |
| cmap = plt.get_cmap() | |
| mappable = ax.imshow(embed, cmap=cmap) | |
| cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) | |
| sm = cm.ScalarMappable(cmap=cmap) | |
| sm.set_clim(*color_range) | |
| ax.set_xticks([]), ax.set_yticks([]) | |
| ax.set_title(title) | |