| | import numpy as np |
| | import torch |
| | import torchaudio |
| | from coqpit import Coqpit |
| | from torch import nn |
| |
|
| | from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss |
| | from TTS.utils.generic_utils import set_init_dict |
| | from TTS.utils.io import load_fsspec |
| |
|
| |
|
| | class PreEmphasis(nn.Module): |
| | def __init__(self, coefficient=0.97): |
| | super().__init__() |
| | self.coefficient = coefficient |
| | self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) |
| |
|
| | def forward(self, x): |
| | assert len(x.size()) == 2 |
| |
|
| | x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") |
| | return torch.nn.functional.conv1d(x, self.filter).squeeze(1) |
| |
|
| |
|
| | class BaseEncoder(nn.Module): |
| | """Base `encoder` class. Every new `encoder` model must inherit this. |
| | |
| | It defines common `encoder` specific functions. |
| | """ |
| |
|
| | |
| | def __init__(self): |
| | super(BaseEncoder, self).__init__() |
| |
|
| | def get_torch_mel_spectrogram_class(self, audio_config): |
| | return torch.nn.Sequential( |
| | PreEmphasis(audio_config["preemphasis"]), |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | torchaudio.transforms.MelSpectrogram( |
| | sample_rate=audio_config["sample_rate"], |
| | n_fft=audio_config["fft_size"], |
| | win_length=audio_config["win_length"], |
| | hop_length=audio_config["hop_length"], |
| | window_fn=torch.hamming_window, |
| | n_mels=audio_config["num_mels"], |
| | ), |
| | ) |
| |
|
| | @torch.no_grad() |
| | def inference(self, x, l2_norm=True): |
| | return self.forward(x, l2_norm) |
| |
|
| | @torch.no_grad() |
| | def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): |
| | """ |
| | Generate embeddings for a batch of utterances |
| | x: 1xTxD |
| | """ |
| | |
| | if self.use_torch_spec: |
| | num_frames = num_frames * self.audio_config["hop_length"] |
| |
|
| | max_len = x.shape[1] |
| |
|
| | if max_len < num_frames: |
| | num_frames = max_len |
| |
|
| | offsets = np.linspace(0, max_len - num_frames, num=num_eval) |
| |
|
| | frames_batch = [] |
| | for offset in offsets: |
| | offset = int(offset) |
| | end_offset = int(offset + num_frames) |
| | frames = x[:, offset:end_offset] |
| | frames_batch.append(frames) |
| |
|
| | frames_batch = torch.cat(frames_batch, dim=0) |
| | embeddings = self.inference(frames_batch, l2_norm=l2_norm) |
| |
|
| | if return_mean: |
| | embeddings = torch.mean(embeddings, dim=0, keepdim=True) |
| | return embeddings |
| |
|
| | def get_criterion(self, c: Coqpit, num_classes=None): |
| | if c.loss == "ge2e": |
| | criterion = GE2ELoss(loss_method="softmax") |
| | elif c.loss == "angleproto": |
| | criterion = AngleProtoLoss() |
| | elif c.loss == "softmaxproto": |
| | criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) |
| | else: |
| | raise Exception("The %s not is a loss supported" % c.loss) |
| | return criterion |
| |
|
| | def load_checkpoint( |
| | self, |
| | config: Coqpit, |
| | checkpoint_path: str, |
| | eval: bool = False, |
| | use_cuda: bool = False, |
| | criterion=None, |
| | cache=False, |
| | ): |
| | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
| | try: |
| | self.load_state_dict(state["model"]) |
| | print(" > Model fully restored. ") |
| | except (KeyError, RuntimeError) as error: |
| | |
| | if eval: |
| | raise error |
| |
|
| | print(" > Partial model initialization.") |
| | model_dict = self.state_dict() |
| | model_dict = set_init_dict(model_dict, state["model"], c) |
| | self.load_state_dict(model_dict) |
| | del model_dict |
| |
|
| | |
| | if criterion is not None and "criterion" in state: |
| | try: |
| | criterion.load_state_dict(state["criterion"]) |
| | except (KeyError, RuntimeError) as error: |
| | print(" > Criterion load ignored because of:", error) |
| |
|
| | |
| | if ( |
| | eval |
| | and criterion is None |
| | and "criterion" in state |
| | and getattr(config, "map_classid_to_classname", None) is not None |
| | ): |
| | criterion = self.get_criterion(config, len(config.map_classid_to_classname)) |
| | criterion.load_state_dict(state["criterion"]) |
| |
|
| | if use_cuda: |
| | self.cuda() |
| | if criterion is not None: |
| | criterion = criterion.cuda() |
| |
|
| | if eval: |
| | self.eval() |
| | assert not self.training |
| |
|
| | if not eval: |
| | return criterion, state["step"] |
| | return criterion |
| |
|