|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
class Transpose(nn.Identity): |
|
|
"""(N, T, D) -> (N, D, T)""" |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
return input.transpose(1, 2) |
|
|
|
|
|
|
|
|
class AdaptiveLayerNorm(nn.Module): |
|
|
r"""Adaptive Layer Normalization""" |
|
|
|
|
|
def __init__(self, d_model, norm) -> None: |
|
|
super(AdaptiveLayerNorm, self).__init__() |
|
|
self.project_layer = nn.Linear(d_model, 2 * d_model) |
|
|
self.norm = norm |
|
|
self.d_model = d_model |
|
|
self.eps = self.norm.eps |
|
|
|
|
|
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: |
|
|
if isinstance(input, tuple): |
|
|
input, embedding = input |
|
|
weight, bias = torch.split( |
|
|
self.project_layer(embedding), |
|
|
split_size_or_sections=self.d_model, |
|
|
dim=-1, |
|
|
) |
|
|
return (weight * self.norm(input) + bias, embedding) |
|
|
|
|
|
weight, bias = torch.split( |
|
|
self.project_layer(embedding), |
|
|
split_size_or_sections=self.d_model, |
|
|
dim=-1, |
|
|
) |
|
|
return weight * self.norm(input) + bias |
|
|
|
|
|
|
|
|
class TokenEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim_model: int, |
|
|
vocab_size: int, |
|
|
dropout: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.dim_model = dim_model |
|
|
|
|
|
self.dropout = torch.nn.Dropout(p=dropout) |
|
|
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) |
|
|
|
|
|
@property |
|
|
def weight(self) -> torch.Tensor: |
|
|
return self.word_embeddings.weight |
|
|
|
|
|
def embedding(self, index: int) -> torch.Tensor: |
|
|
return self.word_embeddings.weight[index : index + 1] |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
X = self.word_embeddings(x) |
|
|
X = self.dropout(X) |
|
|
|
|
|
return X |
|
|
|
|
|
|
|
|
class SinePositionalEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim_model: int, |
|
|
dropout: float = 0.0, |
|
|
scale: bool = False, |
|
|
alpha: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim_model = dim_model |
|
|
self.x_scale = math.sqrt(dim_model) if scale else 1.0 |
|
|
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) |
|
|
self.dropout = torch.nn.Dropout(p=dropout) |
|
|
|
|
|
self.reverse = False |
|
|
self.pe = None |
|
|
self.extend_pe(torch.tensor(0.0).expand(1, 4000)) |
|
|
|
|
|
def extend_pe(self, x): |
|
|
"""Reset the positional encodings.""" |
|
|
if self.pe is not None: |
|
|
if self.pe.size(1) >= x.size(1): |
|
|
if self.pe.dtype != x.dtype or self.pe.device != x.device: |
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
|
|
return |
|
|
pe = torch.zeros(x.size(1), self.dim_model) |
|
|
if self.reverse: |
|
|
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) |
|
|
else: |
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
|
|
div_term = torch.exp( |
|
|
torch.arange(0, self.dim_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.dim_model) |
|
|
) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0) |
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype).detach() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
self.extend_pe(x) |
|
|
output = x.unsqueeze(-1) if x.ndim == 2 else x |
|
|
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] |
|
|
return self.dropout(output) |
|
|
|
|
|
|
|
|
class PreNet(nn.Module): |
|
|
"""PreNet for NAR model""" |
|
|
|
|
|
def __init__(self, nar_d_model=1024) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.nar_audio_prenet = nn.Sequential( |
|
|
nn.Linear(nar_d_model, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.25), |
|
|
nn.Linear(256, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.25), |
|
|
nn.Linear(256, nar_d_model), |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.nar_audio_prenet(input) |
|
|
|
|
|
|
|
|
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): |
|
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
|
Args: |
|
|
logits: logits distribution shape (batch size, vocabulary size) |
|
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
|
Make sure we keep at least min_tokens_to_keep per batch example in the output |
|
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
|
""" |
|
|
if top_k > 0: |
|
|
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) |
|
|
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
if min_tokens_to_keep > 1: |
|
|
|
|
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits[indices_to_remove] = filter_value |
|
|
return logits |
|
|
|
|
|
|
|
|
def top_k_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
logits = logits / temperature |
|
|
|
|
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
|
|
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
|
|
return token |
|
|
|