| | import torch |
| | from torch import nn |
| |
|
| | from TTS.encoder.models.base_encoder import BaseEncoder |
| |
|
| |
|
| | class LSTMWithProjection(nn.Module): |
| | def __init__(self, input_size, hidden_size, proj_size): |
| | super().__init__() |
| | self.input_size = input_size |
| | self.hidden_size = hidden_size |
| | self.proj_size = proj_size |
| | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) |
| | self.linear = nn.Linear(hidden_size, proj_size, bias=False) |
| |
|
| | def forward(self, x): |
| | self.lstm.flatten_parameters() |
| | o, (_, _) = self.lstm(x) |
| | return self.linear(o) |
| |
|
| |
|
| | class LSTMWithoutProjection(nn.Module): |
| | def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): |
| | super().__init__() |
| | self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) |
| | self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) |
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, x): |
| | _, (hidden, _) = self.lstm(x) |
| | return self.relu(self.linear(hidden[-1])) |
| |
|
| |
|
| | class LSTMSpeakerEncoder(BaseEncoder): |
| | def __init__( |
| | self, |
| | input_dim, |
| | proj_dim=256, |
| | lstm_dim=768, |
| | num_lstm_layers=3, |
| | use_lstm_with_projection=True, |
| | use_torch_spec=False, |
| | audio_config=None, |
| | ): |
| | super().__init__() |
| | self.use_lstm_with_projection = use_lstm_with_projection |
| | self.use_torch_spec = use_torch_spec |
| | self.audio_config = audio_config |
| | self.proj_dim = proj_dim |
| |
|
| | layers = [] |
| | |
| | if use_lstm_with_projection: |
| | layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) |
| | for _ in range(num_lstm_layers - 1): |
| | layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) |
| | self.layers = nn.Sequential(*layers) |
| | else: |
| | self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) |
| |
|
| | self.instancenorm = nn.InstanceNorm1d(input_dim) |
| |
|
| | if self.use_torch_spec: |
| | self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) |
| | else: |
| | self.torch_spec = None |
| |
|
| | self._init_layers() |
| |
|
| | def _init_layers(self): |
| | for name, param in self.layers.named_parameters(): |
| | if "bias" in name: |
| | nn.init.constant_(param, 0.0) |
| | elif "weight" in name: |
| | nn.init.xavier_normal_(param) |
| |
|
| | def forward(self, x, l2_norm=True): |
| | """Forward pass of the model. |
| | |
| | Args: |
| | x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` |
| | to compute the spectrogram on-the-fly. |
| | l2_norm (bool): Whether to L2-normalize the outputs. |
| | |
| | Shapes: |
| | - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` |
| | """ |
| | with torch.no_grad(): |
| | with torch.cuda.amp.autocast(enabled=False): |
| | if self.use_torch_spec: |
| | x.squeeze_(1) |
| | x = self.torch_spec(x) |
| | x = self.instancenorm(x).transpose(1, 2) |
| | d = self.layers(x) |
| | if self.use_lstm_with_projection: |
| | d = d[:, -1] |
| | if l2_norm: |
| | d = torch.nn.functional.normalize(d, p=2, dim=1) |
| | return d |
| |
|