Commit
·
a09ce64
1
Parent(s):
ca3e491
Create modeling_svector.py
Browse files- modeling_svector.py +548 -0
modeling_svector.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import typing as tp
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers.utils import ModelOutput
|
| 7 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 8 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 9 |
+
|
| 10 |
+
from .helpers_svector import Fbank
|
| 11 |
+
from .configuration_svector import SvectorConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class InputNormalization(nn.Module):
|
| 15 |
+
|
| 16 |
+
spk_dict_mean: tp.Dict[int, torch.Tensor]
|
| 17 |
+
spk_dict_std: tp.Dict[int, torch.Tensor]
|
| 18 |
+
spk_dict_count: tp.Dict[int, int]
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
mean_norm=True,
|
| 23 |
+
std_norm=True,
|
| 24 |
+
norm_type="global",
|
| 25 |
+
avg_factor=None,
|
| 26 |
+
requires_grad=False,
|
| 27 |
+
update_until_epoch=3,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.mean_norm = mean_norm
|
| 31 |
+
self.std_norm = std_norm
|
| 32 |
+
self.norm_type = norm_type
|
| 33 |
+
self.avg_factor = avg_factor
|
| 34 |
+
self.requires_grad = requires_grad
|
| 35 |
+
self.glob_mean = torch.tensor([0])
|
| 36 |
+
self.glob_std = torch.tensor([0])
|
| 37 |
+
self.spk_dict_mean = {}
|
| 38 |
+
self.spk_dict_std = {}
|
| 39 |
+
self.spk_dict_count = {}
|
| 40 |
+
self.weight = 1.0
|
| 41 |
+
self.count = 0
|
| 42 |
+
self.eps = 1e-10
|
| 43 |
+
self.update_until_epoch = update_until_epoch
|
| 44 |
+
|
| 45 |
+
def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
|
| 46 |
+
"""Returns the tensor with the surrounding context.
|
| 47 |
+
|
| 48 |
+
Arguments
|
| 49 |
+
---------
|
| 50 |
+
x : tensor
|
| 51 |
+
A batch of tensors.
|
| 52 |
+
lengths : tensor
|
| 53 |
+
A batch of tensors containing the relative length of each
|
| 54 |
+
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
|
| 55 |
+
computing stats on zero-padded steps.
|
| 56 |
+
spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
|
| 57 |
+
It is used to perform per-speaker normalization when
|
| 58 |
+
norm_type='speaker'.
|
| 59 |
+
"""
|
| 60 |
+
x = input_values
|
| 61 |
+
N_batches = x.shape[0]
|
| 62 |
+
|
| 63 |
+
current_means = []
|
| 64 |
+
current_stds = []
|
| 65 |
+
|
| 66 |
+
for snt_id in range(N_batches):
|
| 67 |
+
# Avoiding padded time steps
|
| 68 |
+
# lengths = torch.sum(attention_mask, dim=1)
|
| 69 |
+
# relative_lengths = lengths / torch.max(lengths)
|
| 70 |
+
# actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
|
| 71 |
+
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
|
| 72 |
+
|
| 73 |
+
# computing statistics
|
| 74 |
+
current_mean, current_std = self._compute_current_stats(
|
| 75 |
+
x[snt_id, 0:actual_size, ...]
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
current_means.append(current_mean)
|
| 79 |
+
current_stds.append(current_std)
|
| 80 |
+
|
| 81 |
+
if self.norm_type == "sentence":
|
| 82 |
+
x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
|
| 83 |
+
|
| 84 |
+
if self.norm_type == "speaker":
|
| 85 |
+
spk_id = int(spk_ids[snt_id][0])
|
| 86 |
+
|
| 87 |
+
if self.training:
|
| 88 |
+
if spk_id not in self.spk_dict_mean:
|
| 89 |
+
# Initialization of the dictionary
|
| 90 |
+
self.spk_dict_mean[spk_id] = current_mean
|
| 91 |
+
self.spk_dict_std[spk_id] = current_std
|
| 92 |
+
self.spk_dict_count[spk_id] = 1
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
self.spk_dict_count[spk_id] = (
|
| 96 |
+
self.spk_dict_count[spk_id] + 1
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if self.avg_factor is None:
|
| 100 |
+
self.weight = 1 / self.spk_dict_count[spk_id]
|
| 101 |
+
else:
|
| 102 |
+
self.weight = self.avg_factor
|
| 103 |
+
|
| 104 |
+
self.spk_dict_mean[spk_id] = (
|
| 105 |
+
(1 - self.weight) * self.spk_dict_mean[spk_id]
|
| 106 |
+
+ self.weight * current_mean
|
| 107 |
+
)
|
| 108 |
+
self.spk_dict_std[spk_id] = (
|
| 109 |
+
(1 - self.weight) * self.spk_dict_std[spk_id]
|
| 110 |
+
+ self.weight * current_std
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.spk_dict_mean[spk_id].detach()
|
| 114 |
+
self.spk_dict_std[spk_id].detach()
|
| 115 |
+
|
| 116 |
+
speaker_mean = self.spk_dict_mean[spk_id].data
|
| 117 |
+
speaker_std = self.spk_dict_std[spk_id].data
|
| 118 |
+
else:
|
| 119 |
+
if spk_id in self.spk_dict_mean:
|
| 120 |
+
speaker_mean = self.spk_dict_mean[spk_id].data
|
| 121 |
+
speaker_std = self.spk_dict_std[spk_id].data
|
| 122 |
+
else:
|
| 123 |
+
speaker_mean = current_mean.data
|
| 124 |
+
speaker_std = current_std.data
|
| 125 |
+
|
| 126 |
+
x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
|
| 127 |
+
|
| 128 |
+
if self.norm_type == "batch" or self.norm_type == "global":
|
| 129 |
+
current_mean = torch.mean(torch.stack(current_means), dim=0)
|
| 130 |
+
current_std = torch.mean(torch.stack(current_stds), dim=0)
|
| 131 |
+
|
| 132 |
+
if self.norm_type == "batch":
|
| 133 |
+
x = (x - current_mean.data) / (current_std.data)
|
| 134 |
+
|
| 135 |
+
if self.norm_type == "global":
|
| 136 |
+
if self.training:
|
| 137 |
+
if self.count == 0:
|
| 138 |
+
self.glob_mean = current_mean
|
| 139 |
+
self.glob_std = current_std
|
| 140 |
+
|
| 141 |
+
elif epoch < self.update_until_epoch:
|
| 142 |
+
if self.avg_factor is None:
|
| 143 |
+
self.weight = 1 / (self.count + 1)
|
| 144 |
+
else:
|
| 145 |
+
self.weight = self.avg_factor
|
| 146 |
+
|
| 147 |
+
self.glob_mean = (
|
| 148 |
+
1 - self.weight
|
| 149 |
+
) * self.glob_mean + self.weight * current_mean
|
| 150 |
+
|
| 151 |
+
self.glob_std = (
|
| 152 |
+
1 - self.weight
|
| 153 |
+
) * self.glob_std + self.weight * current_std
|
| 154 |
+
|
| 155 |
+
self.glob_mean.detach()
|
| 156 |
+
self.glob_std.detach()
|
| 157 |
+
|
| 158 |
+
self.count = self.count + 1
|
| 159 |
+
|
| 160 |
+
x = (x - self.glob_mean.data) / (self.glob_std.data)
|
| 161 |
+
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
def _compute_current_stats(self, x):
|
| 165 |
+
"""Returns the tensor with the surrounding context.
|
| 166 |
+
|
| 167 |
+
Arguments
|
| 168 |
+
---------
|
| 169 |
+
x : tensor
|
| 170 |
+
A batch of tensors.
|
| 171 |
+
"""
|
| 172 |
+
# Compute current mean
|
| 173 |
+
if self.mean_norm:
|
| 174 |
+
current_mean = torch.mean(x, dim=0).detach().data
|
| 175 |
+
else:
|
| 176 |
+
current_mean = torch.tensor([0.0], device=x.device)
|
| 177 |
+
|
| 178 |
+
# Compute current std
|
| 179 |
+
if self.std_norm:
|
| 180 |
+
current_std = torch.std(x, dim=0).detach().data
|
| 181 |
+
else:
|
| 182 |
+
current_std = torch.tensor([1.0], device=x.device)
|
| 183 |
+
|
| 184 |
+
# Improving numerical stability of std
|
| 185 |
+
current_std = torch.max(
|
| 186 |
+
current_std, self.eps * torch.ones_like(current_std)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return current_mean, current_std
|
| 190 |
+
|
| 191 |
+
def _statistics_dict(self):
|
| 192 |
+
"""Fills the dictionary containing the normalization statistics."""
|
| 193 |
+
state = {}
|
| 194 |
+
state["count"] = self.count
|
| 195 |
+
state["glob_mean"] = self.glob_mean
|
| 196 |
+
state["glob_std"] = self.glob_std
|
| 197 |
+
state["spk_dict_mean"] = self.spk_dict_mean
|
| 198 |
+
state["spk_dict_std"] = self.spk_dict_std
|
| 199 |
+
state["spk_dict_count"] = self.spk_dict_count
|
| 200 |
+
|
| 201 |
+
return state
|
| 202 |
+
|
| 203 |
+
def _load_statistics_dict(self, state):
|
| 204 |
+
"""Loads the dictionary containing the statistics.
|
| 205 |
+
|
| 206 |
+
Arguments
|
| 207 |
+
---------
|
| 208 |
+
state : dict
|
| 209 |
+
A dictionary containing the normalization statistics.
|
| 210 |
+
"""
|
| 211 |
+
self.count = state["count"]
|
| 212 |
+
if isinstance(state["glob_mean"], int):
|
| 213 |
+
self.glob_mean = state["glob_mean"]
|
| 214 |
+
self.glob_std = state["glob_std"]
|
| 215 |
+
else:
|
| 216 |
+
self.glob_mean = state["glob_mean"] # .to(self.device_inp)
|
| 217 |
+
self.glob_std = state["glob_std"] # .to(self.device_inp)
|
| 218 |
+
|
| 219 |
+
# Loading the spk_dict_mean in the right device
|
| 220 |
+
self.spk_dict_mean = {}
|
| 221 |
+
for spk in state["spk_dict_mean"]:
|
| 222 |
+
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
|
| 223 |
+
self.device_inp
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Loading the spk_dict_std in the right device
|
| 227 |
+
self.spk_dict_std = {}
|
| 228 |
+
for spk in state["spk_dict_std"]:
|
| 229 |
+
self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
|
| 230 |
+
self.device_inp
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.spk_dict_count = state["spk_dict_count"]
|
| 234 |
+
|
| 235 |
+
return state
|
| 236 |
+
|
| 237 |
+
def to(self, device):
|
| 238 |
+
"""Puts the needed tensors in the right device."""
|
| 239 |
+
self = super(InputNormalization, self).to(device)
|
| 240 |
+
self.glob_mean = self.glob_mean.to(device)
|
| 241 |
+
self.glob_std = self.glob_std.to(device)
|
| 242 |
+
for spk in self.spk_dict_mean:
|
| 243 |
+
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
|
| 244 |
+
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
|
| 245 |
+
return self
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class TdnnLayer(nn.Module):
|
| 249 |
+
|
| 250 |
+
def __init__(
|
| 251 |
+
self,
|
| 252 |
+
in_channels,
|
| 253 |
+
out_channels,
|
| 254 |
+
kernel_size,
|
| 255 |
+
dilation=1,
|
| 256 |
+
stride=1,
|
| 257 |
+
padding=0,
|
| 258 |
+
padding_mode="reflect",
|
| 259 |
+
activation=torch.nn.LeakyReLU,
|
| 260 |
+
):
|
| 261 |
+
super(TdnnLayer, self).__init__()
|
| 262 |
+
self.in_channels = in_channels
|
| 263 |
+
self.out_channels = out_channels
|
| 264 |
+
self.kernel_size = kernel_size
|
| 265 |
+
self.dilation = dilation
|
| 266 |
+
self.stride = stride
|
| 267 |
+
self.padding = padding
|
| 268 |
+
self.padding_mode = padding_mode
|
| 269 |
+
self.activation = activation
|
| 270 |
+
|
| 271 |
+
self.conv = nn.Conv1d(
|
| 272 |
+
self.in_channels,
|
| 273 |
+
self.out_channels,
|
| 274 |
+
self.kernel_size,
|
| 275 |
+
dilation=self.dilation,
|
| 276 |
+
padding=self.padding
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Set Affine=false to be compatible with the original kaldi version
|
| 280 |
+
# self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
|
| 281 |
+
self.norm = nn.BatchNorm1d(out_channels, affine=False)
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
|
| 285 |
+
out = self.conv(x)
|
| 286 |
+
out = self.activation()(out)
|
| 287 |
+
out = self.norm(out)
|
| 288 |
+
return out
|
| 289 |
+
|
| 290 |
+
def _manage_padding(
|
| 291 |
+
self, x, kernel_size: int, dilation: int, stride: int,
|
| 292 |
+
):
|
| 293 |
+
# Detecting input shape
|
| 294 |
+
L_in = self.in_channels
|
| 295 |
+
|
| 296 |
+
# Time padding
|
| 297 |
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
| 298 |
+
|
| 299 |
+
# Applying padding
|
| 300 |
+
x = F.pad(x, padding, mode=self.padding_mode)
|
| 301 |
+
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
| 306 |
+
"""This function computes the number of elements to add for zero-padding.
|
| 307 |
+
|
| 308 |
+
Arguments
|
| 309 |
+
---------
|
| 310 |
+
L_in : int
|
| 311 |
+
stride: int
|
| 312 |
+
kernel_size : int
|
| 313 |
+
dilation : int
|
| 314 |
+
"""
|
| 315 |
+
if stride > 1:
|
| 316 |
+
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
|
| 317 |
+
|
| 318 |
+
else:
|
| 319 |
+
L_out = (
|
| 320 |
+
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
|
| 321 |
+
)
|
| 322 |
+
padding = [
|
| 323 |
+
math.floor((L_in - L_out) / 2),
|
| 324 |
+
math.floor((L_in - L_out) / 2),
|
| 325 |
+
]
|
| 326 |
+
return padding
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class StatisticsPooling(nn.Module):
|
| 330 |
+
|
| 331 |
+
def __init__(self, return_mean=True, return_std=True):
|
| 332 |
+
super().__init__()
|
| 333 |
+
|
| 334 |
+
# Small value for GaussNoise
|
| 335 |
+
self.eps = 1e-5
|
| 336 |
+
self.return_mean = return_mean
|
| 337 |
+
self.return_std = return_std
|
| 338 |
+
if not (self.return_mean or self.return_std):
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"both of statistics are equal to False \n"
|
| 341 |
+
"consider enabling mean and/or std statistic pooling"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def forward(self, input_values, lengths=None):
|
| 345 |
+
"""Calculates mean and std for a batch (input tensor).
|
| 346 |
+
|
| 347 |
+
Arguments
|
| 348 |
+
---------
|
| 349 |
+
x : torch.Tensor
|
| 350 |
+
It represents a tensor for a mini-batch.
|
| 351 |
+
"""
|
| 352 |
+
x = input_values
|
| 353 |
+
if lengths is None:
|
| 354 |
+
if self.return_mean:
|
| 355 |
+
mean = x.mean(dim=1)
|
| 356 |
+
if self.return_std:
|
| 357 |
+
std = x.std(dim=1)
|
| 358 |
+
else:
|
| 359 |
+
mean = []
|
| 360 |
+
std = []
|
| 361 |
+
for snt_id in range(x.shape[0]):
|
| 362 |
+
# Avoiding padded time steps
|
| 363 |
+
# lengths = torch.sum(attention_mask, dim=1)
|
| 364 |
+
# relative_lengths = lengths / torch.max(lengths)
|
| 365 |
+
# actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
|
| 366 |
+
actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
|
| 367 |
+
|
| 368 |
+
# computing statistics
|
| 369 |
+
if self.return_mean:
|
| 370 |
+
mean.append(
|
| 371 |
+
torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
|
| 372 |
+
)
|
| 373 |
+
if self.return_std:
|
| 374 |
+
std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
|
| 375 |
+
if self.return_mean:
|
| 376 |
+
mean = torch.stack(mean)
|
| 377 |
+
if self.return_std:
|
| 378 |
+
std = torch.stack(std)
|
| 379 |
+
|
| 380 |
+
if self.return_mean:
|
| 381 |
+
gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
|
| 382 |
+
gnoise = gnoise
|
| 383 |
+
mean += gnoise
|
| 384 |
+
if self.return_std:
|
| 385 |
+
std = std + self.eps
|
| 386 |
+
|
| 387 |
+
# Append mean and std of the batch
|
| 388 |
+
if self.return_mean and self.return_std:
|
| 389 |
+
pooled_stats = torch.cat((mean, std), dim=1)
|
| 390 |
+
pooled_stats = pooled_stats.unsqueeze(1)
|
| 391 |
+
elif self.return_mean:
|
| 392 |
+
pooled_stats = mean.unsqueeze(1)
|
| 393 |
+
elif self.return_std:
|
| 394 |
+
pooled_stats = std.unsqueeze(1)
|
| 395 |
+
|
| 396 |
+
return pooled_stats
|
| 397 |
+
|
| 398 |
+
def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
|
| 399 |
+
"""Returns a tensor of epsilon Gaussian noise.
|
| 400 |
+
|
| 401 |
+
Arguments
|
| 402 |
+
---------
|
| 403 |
+
shape_of_tensor : tensor
|
| 404 |
+
It represents the size of tensor for generating Gaussian noise.
|
| 405 |
+
"""
|
| 406 |
+
gnoise = torch.randn(shape_of_tensor, device=device)
|
| 407 |
+
gnoise -= torch.min(gnoise)
|
| 408 |
+
gnoise /= torch.max(gnoise)
|
| 409 |
+
gnoise = self.eps * ((1 - 9) * gnoise + 9)
|
| 410 |
+
|
| 411 |
+
return gnoise
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class SvectorEmbedder(nn.Module):
|
| 415 |
+
|
| 416 |
+
def __init__(
|
| 417 |
+
self,
|
| 418 |
+
in_channels=40,
|
| 419 |
+
num_heads=8,
|
| 420 |
+
num_layers=5,
|
| 421 |
+
activation=torch.nn.LeakyReLU,
|
| 422 |
+
hidden_size=512,
|
| 423 |
+
) -> None:
|
| 424 |
+
super(SvectorEmbedder, self).__init__()
|
| 425 |
+
self.tdnn = TdnnLayer(
|
| 426 |
+
in_channels=in_channels,
|
| 427 |
+
out_channels=hidden_size,
|
| 428 |
+
kernel_size=1,
|
| 429 |
+
dilation=1,
|
| 430 |
+
activation=activation,
|
| 431 |
+
)
|
| 432 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads)
|
| 433 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 434 |
+
self.pooler = StatisticsPooling()
|
| 435 |
+
self.fc = nn.Linear(2 * hidden_size, hidden_size)
|
| 436 |
+
|
| 437 |
+
def forward(self, input_values, lengths=None):
|
| 438 |
+
"""
|
| 439 |
+
x: [B, T, F]
|
| 440 |
+
"""
|
| 441 |
+
x = input_values
|
| 442 |
+
x = self.tdnn(x.transpose(1, 2))
|
| 443 |
+
last_hidden_state = self.transformer_encoder(x.transpose(1, 2))
|
| 444 |
+
pooler_output = self.pooler(last_hidden_state, lengths)
|
| 445 |
+
pooler_output = self.fc(pooler_output.squeeze(1))
|
| 446 |
+
return ModelOutput(
|
| 447 |
+
last_hidden_state=last_hidden_state,
|
| 448 |
+
pooler_output=pooler_output
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class CosineSimilarityHead(torch.nn.Module):
|
| 453 |
+
"""
|
| 454 |
+
This class implements the cosine similarity on the top of features.
|
| 455 |
+
"""
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
in_channels,
|
| 459 |
+
lin_blocks=0,
|
| 460 |
+
hidden_size=192,
|
| 461 |
+
num_classes=1211,
|
| 462 |
+
):
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.blocks = nn.ModuleList()
|
| 465 |
+
|
| 466 |
+
for block_index in range(lin_blocks):
|
| 467 |
+
self.blocks.extend(
|
| 468 |
+
[
|
| 469 |
+
nn.BatchNorm1d(num_features=in_channels),
|
| 470 |
+
nn.Linear(in_features=in_channels, out_features=hidden_size),
|
| 471 |
+
]
|
| 472 |
+
)
|
| 473 |
+
in_channels = hidden_size
|
| 474 |
+
|
| 475 |
+
# Final Layer
|
| 476 |
+
self.weight = nn.Parameter(
|
| 477 |
+
torch.FloatTensor(num_classes, in_channels)
|
| 478 |
+
)
|
| 479 |
+
nn.init.xavier_uniform_(self.weight)
|
| 480 |
+
|
| 481 |
+
def forward(self, x):
|
| 482 |
+
"""Returns the output probabilities over speakers.
|
| 483 |
+
|
| 484 |
+
Arguments
|
| 485 |
+
---------
|
| 486 |
+
x : torch.Tensor
|
| 487 |
+
Torch tensor.
|
| 488 |
+
"""
|
| 489 |
+
for layer in self.blocks:
|
| 490 |
+
x = layer(x)
|
| 491 |
+
|
| 492 |
+
# Need to be normalized
|
| 493 |
+
x = F.linear(F.normalize(x), F.normalize(self.weight))
|
| 494 |
+
return x
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class SvectorPreTrainedModel(PreTrainedModel):
|
| 498 |
+
|
| 499 |
+
config_class = SvectorConfig
|
| 500 |
+
base_model_prefix = "svector"
|
| 501 |
+
main_input_name = "input_values"
|
| 502 |
+
supports_gradient_checkpointing = True
|
| 503 |
+
|
| 504 |
+
def _init_weights(self, module):
|
| 505 |
+
"""Initialize the weights"""
|
| 506 |
+
if isinstance(module, nn.Linear):
|
| 507 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 508 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 509 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 510 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 511 |
+
module.bias.data.zero_()
|
| 512 |
+
module.weight.data.fill_(1.0)
|
| 513 |
+
elif isinstance(module, nn.Conv1d):
|
| 514 |
+
nn.init.kaiming_normal_(module.weight.data)
|
| 515 |
+
|
| 516 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
| 517 |
+
module.bias.data.zero_()
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class SvectorModel(SvectorPreTrainedModel):
|
| 521 |
+
|
| 522 |
+
def __init__(self, config):
|
| 523 |
+
super().__init__(config)
|
| 524 |
+
self.compute_features = Fbank(
|
| 525 |
+
n_mels=config.n_mels,
|
| 526 |
+
sample_rate=config.sample_rate,
|
| 527 |
+
win_length=config.win_length,
|
| 528 |
+
hop_length=config.hop_length,
|
| 529 |
+
)
|
| 530 |
+
self.mean_var_norm = InputNormalization(
|
| 531 |
+
mean_norm=config.mean_norm,
|
| 532 |
+
std_norm=config.std_norm,
|
| 533 |
+
norm_type=config.norm_type
|
| 534 |
+
)
|
| 535 |
+
self.embedding_model = SvectorEmbedder(
|
| 536 |
+
in_channels=config.n_mels,
|
| 537 |
+
activation=nn.LeakyReLU,
|
| 538 |
+
num_heads=config.num_heads,
|
| 539 |
+
num_layers=config.num_layers,
|
| 540 |
+
hidden_size=config.hidden_size,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def forward(self, input_values, lengths=None):
|
| 544 |
+
x = input_values
|
| 545 |
+
x = self.compute_features(x)
|
| 546 |
+
x = self.mean_var_norm(x, lengths)
|
| 547 |
+
output = self.embedding_model(x, lengths)
|
| 548 |
+
return output
|