File size: 6,410 Bytes
1d8403e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# ==================================================================================================
# DEEPFAKE AUDIO - encoder/inference.py (Neural Identity Distillation Interface)
# ==================================================================================================
# 
# πŸ“ DESCRIPTION
# This module provides the high-level API for using the Speaker Encoder in a 
# production environment. It encapsulates the complexities of model loading, 
# tensor orchestration, and d-vector derivation. It is the primary bridge 
# used by the web interface (app.py) to extract speaker identities from 
# uploaded reference audio samples.
#
# πŸ‘€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
# - Mega Satish (https://github.com/msatmod)
#
# 🀝🏻 CREDITS
# Original Real-Time Voice Cloning methodology by CorentinJ
# Repository: https://github.com/CorentinJ/Real-Time-Voice-Cloning
#
# πŸ”— PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO
# Video Demo: https://youtu.be/i3wnBcbHDbs
# Research: https://github.com/Amey-Thakur/DEEPFAKE-AUDIO/blob/main/DEEPFAKE-AUDIO.ipynb
#
# πŸ“œ LICENSE
# Released under the MIT License
# Release Date: 2021-02-06
# ==================================================================================================

from encoder.params_data import *
from encoder.model import SpeakerEncoder
from encoder.audio import preprocess_wav
from matplotlib import cm
from encoder import audio
from pathlib import Path
import numpy as np
import torch

# --- INTERNAL STATE (SINGLETON PATTERN) ---
_model = None # type: SpeakerEncoder
_device = None # type: torch.device

def load_model(weights_fpath: Path, device=None):
    """
    Initializes the Speaker Encoder neural network.
    Deserializes the PyTorch state dictionary and prepares the model for eval mode.
    """
    global _model, _device
    
    # Precise hardware targeting
    if device is None:
        _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif isinstance(device, str):
        _device = torch.device(device)
    
    # Constructing the architecture
    _model = SpeakerEncoder(_device, torch.device("cpu"))
    
    # Loading serialized weights
    checkpoint = torch.load(weights_fpath, map_location=_device, weights_only=False)
    _model.load_state_dict(checkpoint["model_state"])
    _model.eval()
    
    print("🀝🏻 Encoder Active: Loaded \"%s\" (Step %d)" % (weights_fpath.name, checkpoint["step"]))

def is_loaded():
    """Checks the initialization status of the neural engine."""
    return _model is not None

def embed_frames_batch(frames_batch):
    """
    Neural Forward Pass: Computes speaker embeddings for a batch of spectrograms.
    Returns l2-normalized d-vectors.
    """
    if _model is None:
        raise Exception("Fatal: Neural Encoder is not initialized. Invoke load_model().")

    frames = torch.from_numpy(frames_batch).to(_device)
    embed = _model.forward(frames).detach().cpu().numpy()
    return embed

def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
                           min_pad_coverage=0.75, overlap=0.5):
    """
    Spatio-Temporal Segmentation: Defines how a long utterance is sliced into 
    overlapping windows for stable embedding derivation.
    """
    assert 0 <= overlap < 1
    assert 0 < min_pad_coverage <= 1

    samples_per_frame = int((sampling_rate * mel_window_step / 1000))
    n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
    frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)

    # Window Orchestration
    wav_slices, mel_slices = [], []
    steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
    for i in range(0, steps, frame_step):
        mel_range = np.array([i, i + partial_utterance_n_frames])
        wav_range = mel_range * samples_per_frame
        mel_slices.append(slice(*mel_range))
        wav_slices.append(slice(*wav_range))

    # Defensive Padding Evaluation
    last_wav_range = wav_slices[-1]
    coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
    if coverage < min_pad_coverage and len(mel_slices) > 1:
        mel_slices = mel_slices[:-1]
        wav_slices = wav_slices[:-1]

    return wav_slices, mel_slices

def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
    """
    Core Identity Extraction: Distills a processed waveform into a single 
    256-dimensional identity vector (d-vector).
    """
    # 1. Full-Waveform Processing (Fallback for short utterances)
    if not using_partials:
        frames = audio.wav_to_mel_spectrogram(wav)
        embed = embed_frames_batch(frames[None, ...])[0]
        if return_partials:
            return embed, None, None
        return embed

    # 2. Windowed Distillation
    wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
    max_wave_length = wave_slices[-1].stop
    if max_wave_length >= len(wav):
        wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")

    # 3. Batch Inference on Windows
    frames = audio.wav_to_mel_spectrogram(wav)
    frames_batch = np.array([frames[s] for s in mel_slices])
    partial_embeds = embed_frames_batch(frames_batch)

    # 4. Statistical Averaging & Re-Normalization
    raw_embed = np.mean(partial_embeds, axis=0)
    embed = raw_embed / np.linalg.norm(raw_embed, 2)

    if return_partials:
        return embed, partial_embeds, wave_slices
    return embed

def embed_speaker(wavs, **kwargs):
    """Aggregate identity extraction for multiple utterances from the same speaker."""
    raise NotImplementedError("Collaborative development in progress for multi-wav aggregation.")

def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
    """Visualizes the high-dimensional latent vector as a spatial intensity map."""
    import matplotlib.pyplot as plt
    if ax is None:
        ax = plt.gca()

    if shape is None:
        height = int(np.sqrt(len(embed)))
        shape = (height, -1)
    embed = embed.reshape(shape)

    cmap = plt.get_cmap()
    mappable = ax.imshow(embed, cmap=cmap)
    cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
    sm = cm.ScalarMappable(cmap=cmap)
    sm.set_clim(*color_range)

    ax.set_xticks([]), ax.set_yticks([])
    ax.set_title(title)