import math import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2) class AdaptiveLayerNorm(nn.Module): r"""Adaptive Layer Normalization""" def __init__(self, d_model, norm) -> None: super(AdaptiveLayerNorm, self).__init__() self.project_layer = nn.Linear(d_model, 2 * d_model) self.norm = norm self.d_model = d_model self.eps = self.norm.eps def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: if isinstance(input, tuple): input, embedding = input weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return (weight * self.norm(input) + bias, embedding) weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return weight * self.norm(input) + bias class TokenEmbedding(nn.Module): def __init__( self, dim_model: int, vocab_size: int, dropout: float = 0.0, ): super().__init__() self.vocab_size = vocab_size self.dim_model = dim_model self.dropout = torch.nn.Dropout(p=dropout) self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) @property def weight(self) -> torch.Tensor: return self.word_embeddings.weight def embedding(self, index: int) -> torch.Tensor: return self.word_embeddings.weight[index : index + 1] def forward(self, x: torch.Tensor): X = self.word_embeddings(x) X = self.dropout(X) return X class SinePositionalEmbedding(nn.Module): def __init__( self, dim_model: int, dropout: float = 0.0, scale: bool = False, alpha: bool = False, ): super().__init__() self.dim_model = dim_model self.x_scale = math.sqrt(dim_model) if scale else 1.0 self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) self.dropout = torch.nn.Dropout(p=dropout) self.reverse = False self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, 4000)) def extend_pe(self, x): """Reset the positional encodings.""" if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.dim_model) if self.reverse: position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) else: position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.dim_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.dim_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype).detach() def forward(self, x: torch.Tensor) -> torch.Tensor: self.extend_pe(x) output = x.unsqueeze(-1) if x.ndim == 2 else x output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] return self.dropout(output) class PreNet(nn.Module): """PreNet for NAR model""" def __init__(self, nar_d_model=1024) -> None: super().__init__() # self.nar_text_prenet = nn.Sequential( # Transpose(), # nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), # nn.BatchNorm1d(nar_d_model), # nn.ReLU(), # nn.Dropout(0.5), # nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), # nn.BatchNorm1d(nar_d_model), # nn.ReLU(), # nn.Dropout(0.5), # nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), # nn.BatchNorm1d(nar_d_model), # nn.ReLU(), # nn.Dropout(0.5), # Transpose(), # nn.Linear(nar_d_model, nar_d_model), # ) self.nar_audio_prenet = nn.Sequential( nn.Linear(nar_d_model, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, nar_d_model), ) def forward(self, input): return self.nar_audio_prenet(input) # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits def top_k_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): # temperature: (`optional`) float # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. # top_k: (`optional`) int # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. # top_p: (`optional`) float # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: logits = logits / temperature # Top-p/top-k filtering logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) return token