| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | from typing import Dict, Iterator, List, Tuple, Union |
| | import gc |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | |
| | |
| |
|
| | from modules.embedding import SinePositionalEmbedding, TokenEmbedding |
| | from modules.transformer import ( |
| | AdaptiveLayerNorm, |
| | LayerNorm, |
| | TransformerDecoderLayer, |
| | TransformerEncoder, |
| | TransformerEncoderLayer, |
| | ) |
| |
|
| | from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS |
| |
|
| | import psutil |
| | def get_memory_usage(): |
| | process = psutil.Process() |
| | memory_info = process.memory_info() |
| |
|
| | memory_used = memory_info.rss |
| | memory_used_mb = memory_used / (1024 * 1024) |
| |
|
| | return memory_used_mb |
| |
|
| | class Transpose(nn.Identity): |
| | """(N, T, D) -> (N, D, T)""" |
| |
|
| | def forward(self, input: torch.Tensor) -> torch.Tensor: |
| | return input.transpose(1, 2) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class VALLF(nn.Module): |
| | """It implements https://arxiv.org/abs/2301.02111 |
| | "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | nhead: int, |
| | num_layers: int, |
| | norm_first: bool = True, |
| | add_prenet: bool = False, |
| | decoder_cls: Union[ |
| | nn.TransformerDecoder, nn.TransformerEncoder |
| | ] = nn.TransformerDecoder, |
| | decoder_layer_cls: Union[ |
| | TransformerDecoderLayer, TransformerEncoderLayer |
| | ] = TransformerDecoderLayer, |
| | prefix_mode: int = 0, |
| | share_embedding: bool = True, |
| | nar_scale_factor: float = 1.0, |
| | prepend_bos: bool = True, |
| | num_quantizers: int = 8, |
| | ): |
| | """ |
| | Args: |
| | d_model: |
| | The number of expected features in the input (required). |
| | nhead: |
| | The number of heads in the multiheadattention models (required). |
| | num_layers: |
| | The number of sub-decoder-layers in the decoder (required). |
| | """ |
| | super().__init__() |
| | nar_d_model = int(d_model * nar_scale_factor) |
| |
|
| | self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) |
| | self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) |
| |
|
| | |
| | |
| | self.ar_audio_prepend_bos = prepend_bos |
| | self.ar_audio_embedding = TokenEmbedding( |
| | d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) |
| | ) |
| |
|
| | |
| | if add_prenet: |
| | self.ar_text_prenet = nn.Sequential( |
| | Transpose(), |
| | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), |
| | nn.BatchNorm1d(d_model), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), |
| | nn.BatchNorm1d(d_model), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), |
| | nn.BatchNorm1d(d_model), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | Transpose(), |
| | nn.Linear(d_model, d_model), |
| | ) |
| |
|
| | self.ar_audio_prenet = nn.Sequential( |
| | nn.Linear(d_model, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.25), |
| | nn.Linear(256, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.25), |
| | nn.Linear(256, d_model), |
| | ) |
| | else: |
| | self.ar_text_prenet = nn.Identity() |
| | self.ar_audio_prenet = nn.Identity() |
| |
|
| | self.ar_text_position = SinePositionalEmbedding( |
| | d_model, |
| | dropout=0.1, |
| | scale=False, |
| | alpha=True, |
| | ) |
| | self.ar_audio_position = SinePositionalEmbedding( |
| | d_model, |
| | dropout=0.1, |
| | scale=False, |
| | alpha=True, |
| | ) |
| |
|
| | self.ar_decoder = decoder_cls( |
| | decoder_layer_cls( |
| | d_model, |
| | nhead, |
| | dim_feedforward=d_model * 4, |
| | dropout=0.1, |
| | batch_first=True, |
| | norm_first=norm_first, |
| | ), |
| | num_layers=num_layers, |
| | norm=LayerNorm(d_model) if norm_first else None, |
| | ) |
| | self.ar_predict_layer = nn.Linear( |
| | d_model, NUM_AUDIO_TOKENS + 1, bias=False |
| | ) |
| |
|
| | self.rng = random.Random(0) |
| | self.num_heads = nhead |
| | self.prefix_mode = prefix_mode |
| | self.num_quantizers = num_quantizers |
| |
|
| | assert num_quantizers >= 1 |
| | if num_quantizers > 1: |
| | self.nar_audio_embeddings = nn.ModuleList( |
| | [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] |
| | + [ |
| | TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) |
| | for i in range(num_quantizers - 1) |
| | ] |
| | ) |
| |
|
| | |
| | if add_prenet: |
| | self.nar_text_prenet = nn.Sequential( |
| | Transpose(), |
| | nn.Conv1d( |
| | nar_d_model, nar_d_model, kernel_size=5, padding="same" |
| | ), |
| | nn.BatchNorm1d(nar_d_model), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | nn.Conv1d( |
| | nar_d_model, nar_d_model, kernel_size=5, padding="same" |
| | ), |
| | nn.BatchNorm1d(nar_d_model), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | nn.Conv1d( |
| | nar_d_model, nar_d_model, kernel_size=5, padding="same" |
| | ), |
| | nn.BatchNorm1d(nar_d_model), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | Transpose(), |
| | nn.Linear(nar_d_model, nar_d_model), |
| | ) |
| | self.nar_audio_prenet = nn.Sequential( |
| | nn.Linear(nar_d_model, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.25), |
| | nn.Linear(256, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.25), |
| | nn.Linear(256, nar_d_model), |
| | ) |
| | else: |
| | self.nar_text_prenet = nn.Identity() |
| | self.nar_audio_prenet = nn.Identity() |
| |
|
| | self.nar_text_position = SinePositionalEmbedding( |
| | nar_d_model, |
| | dropout=0.0, |
| | scale=False, |
| | alpha=False, |
| | ) |
| | self.nar_audio_position = SinePositionalEmbedding( |
| | nar_d_model, |
| | dropout=0.1, |
| | scale=False, |
| | alpha=False, |
| | ) |
| |
|
| | self.nar_decoder = decoder_cls( |
| | decoder_layer_cls( |
| | nar_d_model, |
| | int(nhead * nar_scale_factor), |
| | dim_feedforward=nar_d_model * 4, |
| | dropout=0.1, |
| | batch_first=True, |
| | norm_first=norm_first, |
| | adaptive_layer_norm=True, |
| | ), |
| | num_layers=int(num_layers * nar_scale_factor), |
| | norm=AdaptiveLayerNorm( |
| | nar_d_model, norm=nn.LayerNorm(nar_d_model) |
| | ) |
| | if norm_first |
| | else None, |
| | ) |
| | self.nar_predict_layers = nn.ModuleList( |
| | [ |
| | nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) |
| | for i in range(num_quantizers - 1) |
| | ] |
| | ) |
| | self.nar_stage_embeddings = nn.ModuleList( |
| | [ |
| | TokenEmbedding(nar_d_model, 1) |
| | for i in range(num_quantizers - 1) |
| | ] |
| | ) |
| |
|
| | if share_embedding: |
| | |
| | |
| | |
| |
|
| | |
| | |
| | for j in range(0, num_quantizers - 2): |
| | self.nar_predict_layers[ |
| | j |
| | ].weight = self.nar_audio_embeddings[j + 2].weight |
| |
|
| | def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: |
| | assert stage > 0 |
| | if stage == 1: |
| | for name, param in self.named_parameters(): |
| | if name.startswith("ar_"): |
| | print(f" AR parameter: {name}") |
| | yield param |
| |
|
| | if stage == 2: |
| | for name, param in self.named_parameters(): |
| | if name.startswith("nar_"): |
| | print(f"NAR parameter: {name}") |
| | yield param |
| |
|
| | def stage_named_parameters( |
| | self, stage: int = 1 |
| | ) -> Iterator[Tuple[str, nn.Parameter]]: |
| | assert stage > 0 |
| | if stage == 1: |
| | for pair in self.named_parameters(): |
| | if pair[0].startswith("ar_"): |
| | yield pair |
| |
|
| | if stage == 2: |
| | for pair in self.named_parameters(): |
| | if pair[0].startswith("nar_"): |
| | yield pair |
| |
|
| | def pad_y_eos(self, y, y_mask_int, eos_id): |
| | targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( |
| | y_mask_int, (0, 1), value=1 |
| | ) |
| | |
| | if self.ar_audio_prepend_bos: |
| | return ( |
| | F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), |
| | targets, |
| | ) |
| |
|
| | return targets[:, :-1], targets[:, 1:] |
| |
|
| | def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode): |
| | |
| | |
| | |
| | if prefix_mode == 0: |
| | |
| | prefix_len = 0 |
| | y_emb = self.nar_audio_embeddings[0](y) |
| | for j in range(1, nar_stage): |
| | |
| | y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) |
| | elif prefix_mode == 1: |
| | |
| | int_low = (0.25 * y_lens.min()).type(torch.int64).item() |
| | prefix_len = torch.randint(0, int_low * 2, size=()).item() |
| | prefix_len = min(prefix_len, 225) |
| |
|
| | y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) |
| | y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) |
| | for j in range(1, self.num_quantizers): |
| | y_prompts += self.nar_audio_embeddings[j]( |
| | codes[:, :prefix_len, j] |
| | ) |
| | if j < nar_stage: |
| | y_emb += self.nar_audio_embeddings[j]( |
| | codes[:, prefix_len:, j] |
| | ) |
| | y_emb = torch.concat([y_prompts, y_emb], axis=1) |
| | elif prefix_mode in [2, 4]: |
| | if prefix_mode == 2: |
| | |
| | prefix_len = min(225, int(0.25 * y_lens.min().item())) |
| |
|
| | y_prompts_codes = [] |
| | for b in range(codes.shape[0]): |
| | start = self.rng.randint(0, y_lens[b].item() - prefix_len) |
| | y_prompts_codes.append( |
| | torch.clone(codes[b, start : start + prefix_len]) |
| | ) |
| | codes[ |
| | b, start : start + prefix_len, nar_stage |
| | ] = NUM_AUDIO_TOKENS |
| | y_prompts_codes = torch.stack(y_prompts_codes, dim=0) |
| | else: |
| | prefix_len = y_prompts_codes.shape[1] |
| |
|
| | y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) |
| | y_emb = self.nar_audio_embeddings[0](y) |
| | for j in range(1, self.num_quantizers): |
| | y_prompts += self.nar_audio_embeddings[j]( |
| | y_prompts_codes[..., j] |
| | ) |
| | if j < nar_stage: |
| | y_emb += self.nar_audio_embeddings[j](codes[..., j]) |
| | y_emb = torch.concat([y_prompts, y_emb], axis=1) |
| | else: |
| | raise ValueError |
| |
|
| | return y_emb, prefix_len |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | x_lens: torch.Tensor, |
| | y: Union[torch.Tensor], |
| | y_lens: Union[torch.Tensor], |
| | reduction: str = "sum", |
| | train_stage: int = 0, |
| | **kwargs, |
| | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: |
| | raise NotImplementedError |
| |
|
| | def inference( |
| | self, |
| | x: torch.Tensor, |
| | x_lens: torch.Tensor, |
| | y: torch.Tensor, |
| | enroll_x_lens: Union[torch.Tensor, None] = None, |
| | top_k: int = -100, |
| | temperature: float = 1.0, |
| | ) -> torch.Tensor: |
| | raise NotImplementedError |
| |
|
| | def visualize( |
| | self, |
| | predicts: Tuple[torch.Tensor], |
| | batch: Dict[str, Union[List, torch.Tensor]], |
| | output_dir: str, |
| | limit: int = 4, |
| | ) -> None: |
| | raise NotImplementedError |
| |
|
| |
|
| | class VALLE(VALLF): |
| | """It implements https://arxiv.org/abs/2301.02111 |
| | "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | nhead: int, |
| | num_layers: int, |
| | norm_first: bool = True, |
| | add_prenet: bool = False, |
| | prefix_mode: int = 0, |
| | share_embedding: bool = True, |
| | nar_scale_factor: float = 1.0, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | d_model: |
| | The number of expected features in the input (required). |
| | nhead: |
| | The number of heads in the multiheadattention models (required). |
| | num_layers: |
| | The number of sub-decoder-layers in the decoder (required). |
| | """ |
| | super(VALLE, self).__init__( |
| | d_model, |
| | nhead, |
| | num_layers, |
| | norm_first=norm_first, |
| | add_prenet=add_prenet, |
| | decoder_cls=TransformerEncoder, |
| | decoder_layer_cls=TransformerEncoderLayer, |
| | prefix_mode=prefix_mode, |
| | share_embedding=share_embedding, |
| | nar_scale_factor=nar_scale_factor, |
| | **kwargs, |
| | ) |
| | self.language_ID = { |
| | 'en': 0, |
| | 'zh': 1, |
| | 'ja': 2, |
| | } |
| | self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID)) |
| | self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID)) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | x_lens: torch.Tensor, |
| | y: Union[torch.Tensor], |
| | y_lens: Union[torch.Tensor], |
| | reduction: str = "sum", |
| | train_stage: int = 0, |
| | **kwargs, |
| | ): |
| | raise NotImplementedError |
| |
|
| | def inference( |
| | self, |
| | x: torch.Tensor, |
| | x_lens: torch.Tensor, |
| | y: torch.Tensor, |
| | enroll_x_lens: torch.Tensor, |
| | top_k: int = -100, |
| | temperature: float = 1.0, |
| | prompt_language: str = None, |
| | text_language: str = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: |
| | A 2-D tensor of shape (1, S). |
| | x_lens: |
| | A 1-D tensor of shape (1,). It contains the number of tokens in `x` |
| | before padding. |
| | y: |
| | A 3-D tensor of shape (1, T, 8). |
| | top_k: (`optional`) int |
| | The number of highest probability tokens to keep for top-k-filtering. Default to -100. |
| | temperature: (`optional`) float |
| | The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. |
| | Returns: |
| | Return the predicted audio code matrix. |
| | """ |
| | assert x.ndim == 2, x.shape |
| | assert x_lens.ndim == 1, x_lens.shape |
| | assert y.ndim == 3, y.shape |
| | assert y.shape[0] == 1, y.shape |
| |
|
| | assert torch.all(x_lens > 0) |
| |
|
| | |
| | text = x |
| | x = self.ar_text_embedding(text) |
| | |
| | prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device) |
| | if isinstance(text_language, str): |
| | text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device) |
| | elif isinstance(text_language, List): |
| | text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device) |
| | x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id) |
| | x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id) |
| | x = self.ar_text_prenet(x) |
| | x = self.ar_text_position(x) |
| |
|
| | text_len = x_lens.max() |
| | prompts = y |
| | prefix_len = y.shape[1] |
| |
|
| | |
| | |
| | y = prompts[..., 0] |
| | if self.ar_audio_prepend_bos: |
| | y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) |
| |
|
| | x_len = x_lens.max() |
| | x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) |
| |
|
| | kv_cache = None |
| | use_kv_caching = True |
| | while True: |
| | y_emb = self.ar_audio_embedding(y) |
| | y_emb = self.ar_audio_prenet(y_emb) |
| | y_pos = self.ar_audio_position(y_emb) |
| | xy_pos = torch.concat([x, y_pos], dim=1) |
| |
|
| | y_len = y.shape[1] |
| | x_attn_mask_pad = F.pad( |
| | x_attn_mask, |
| | (0, y_len), |
| | value=True, |
| | ) |
| | y_attn_mask = F.pad( |
| | torch.triu( |
| | torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 |
| | ), |
| | (x_len, 0), |
| | value=False, |
| | ) |
| | xy_attn_mask = torch.concat( |
| | [x_attn_mask_pad, y_attn_mask], dim=0 |
| | ).to(y.device) |
| |
|
| |
|
| | if use_kv_caching and kv_cache is not None: |
| | xy_pos = xy_pos[:, [-1]] |
| | else: |
| | pass |
| |
|
| | xy_dec, kv_cache = self.ar_decoder.infer( |
| | xy_pos, |
| | mask=xy_attn_mask, |
| | past_kv=kv_cache, |
| | use_cache=use_kv_caching, |
| | ) |
| | |
| | |
| | |
| | |
| |
|
| | logits = self.ar_predict_layer(xy_dec[:, -1]) |
| | samples = topk_sampling( |
| | logits, top_k=top_k, top_p=1, temperature=temperature |
| | ) |
| |
|
| | if ( |
| | torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS |
| | or samples[0, 0] == NUM_AUDIO_TOKENS |
| | or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 |
| | ): |
| | if prompts.shape[1] == y.shape[1]: |
| | raise SyntaxError( |
| | "well trained model shouldn't reach here." |
| | ) |
| |
|
| | print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") |
| |
|
| | memory_used = get_memory_usage() |
| | print(f"Current memory used: {memory_used:.2f} MB") |
| | break |
| |
|
| | |
| | if y.shape[1] > 2250: |
| | print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") |
| | break |
| |
|
| | y = torch.concat([y, samples], dim=1) |
| |
|
| | codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] |
| | if self.num_quantizers == 1: |
| | return torch.stack(codes, dim=-1) |
| |
|
| | |
| | y_emb = self.nar_audio_embeddings[0]( |
| | y[:, int(self.ar_audio_prepend_bos) :] |
| | ) |
| |
|
| | if self.prefix_mode in [2, 4]: |
| | enrolled_len = enroll_x_lens.max().item() |
| | |
| | text = torch.concat( |
| | [ |
| | text[:, :1], |
| | text[:, enrolled_len - 1 :], |
| | ], |
| | dim=1, |
| | ) |
| | text_len = text_len - (enrolled_len - 2) |
| | assert text.shape[0] == 1 |
| |
|
| | x = self.nar_text_embedding(text) |
| | |
| | prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device) |
| | if isinstance(text_language, str): |
| | text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device) |
| | elif isinstance(text_language, List): |
| | text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device) |
| | x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id) |
| | x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id) |
| | x = self.nar_text_prenet(x) |
| | x = self.nar_text_position(x) |
| |
|
| | if self.prefix_mode == 0: |
| | for i, (predict_layer, embedding_layer) in enumerate( |
| | zip( |
| | self.nar_predict_layers, |
| | self.nar_audio_embeddings[1:], |
| | ) |
| | ): |
| | y_pos = self.nar_audio_prenet(y_emb) |
| | y_pos = self.nar_audio_position(y_pos) |
| | xy_pos = torch.concat([x, y_pos], dim=1) |
| |
|
| | xy_dec, _ = self.nar_decoder( |
| | (xy_pos, self.nar_stage_embeddings[i].weight) |
| | ) |
| | logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
| |
|
| | samples = torch.argmax(logits, dim=-1) |
| | codes.append(samples) |
| |
|
| | if i < self.num_quantizers - 2: |
| | y_emb[:, :prefix_len] += embedding_layer( |
| | prompts[..., i + 1] |
| | ) |
| | y_emb[:, prefix_len:] += embedding_layer(samples) |
| | else: |
| | for j in range(1, self.num_quantizers): |
| | y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( |
| | prompts[..., j] |
| | ) |
| |
|
| | for i, (predict_layer, embedding_layer) in enumerate( |
| | zip( |
| | self.nar_predict_layers, |
| | self.nar_audio_embeddings[1:], |
| | ) |
| | ): |
| | y_pos = self.nar_audio_prenet(y_emb) |
| | y_pos = self.nar_audio_position(y_pos) |
| | xy_pos = torch.concat([x, y_pos], dim=1) |
| |
|
| | xy_dec, _ = self.nar_decoder( |
| | (xy_pos, self.nar_stage_embeddings[i].weight) |
| | ) |
| | logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
| |
|
| | samples = torch.argmax(logits, dim=-1) |
| | codes.append(samples) |
| |
|
| | if i < self.num_quantizers - 2: |
| | y_emb[:, prefix_len:] += embedding_layer(samples) |
| |
|
| | assert len(codes) == self.num_quantizers |
| | del text_language_id, prompt_language_id, y_emb, x, y_pos, xy_pos, xy_dec, logits, samples, kv_cache, x_attn_mask, y_attn_mask, xy_attn_mask |
| | gc.collect() |
| | return torch.stack(codes, dim=-1) |
| |
|
| | def continual( |
| | self, |
| | x: torch.Tensor, |
| | x_lens: torch.Tensor, |
| | y: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: |
| | A 2-D tensor of shape (1, S). |
| | x_lens: |
| | A 1-D tensor of shape (1,). It contains the number of tokens in `x` |
| | before padding. |
| | y: |
| | A 3-D tensor of shape (1, T, 8). |
| | Returns: |
| | Return the predicted audio code matrix. |
| | """ |
| | assert x.ndim == 2, x.shape |
| | assert x_lens.ndim == 1, x_lens.shape |
| | assert y.ndim == 3, y.shape |
| | assert y.shape[0] == 1, y.shape |
| |
|
| | assert torch.all(x_lens > 0) |
| | assert self.num_quantizers == 8 |
| |
|
| | |
| | text = x |
| | x = self.ar_text_embedding(text) |
| | x = self.ar_text_prenet(x) |
| | x = self.ar_text_position(x) |
| |
|
| | text_len = x_lens.max() |
| |
|
| | prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) |
| |
|
| | |
| | prompts = y[:, :prefix_len] |
| |
|
| | codes = [y[:, prefix_len:, 0]] |
| | |
| | x = self.nar_text_embedding(text) |
| | x = self.nar_text_prenet(x) |
| | x = self.nar_text_position(x) |
| |
|
| | y_emb = self.nar_audio_embeddings[0](y[..., 0]) |
| |
|
| | if self.prefix_mode == 0: |
| | for i, (predict_layer, embedding_layer) in enumerate( |
| | zip( |
| | self.nar_predict_layers, |
| | self.nar_audio_embeddings[1:], |
| | ) |
| | ): |
| | y_pos = self.nar_audio_position(y_emb) |
| | y_pos = self.nar_audio_prenet(y_pos) |
| | xy_pos = torch.concat([x, y_pos], dim=1) |
| |
|
| | xy_dec, _ = self.nar_decoder( |
| | (xy_pos, self.nar_stage_embeddings[i].weight) |
| | ) |
| | logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
| |
|
| | samples = torch.argmax(logits, dim=-1) |
| | codes.append(samples) |
| |
|
| | if i < 6: |
| | y_emb[:, :prefix_len] += embedding_layer( |
| | prompts[..., i + 1] |
| | ) |
| | y_emb[:, prefix_len:] += embedding_layer(samples) |
| | else: |
| | for j in range(1, 8): |
| | y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( |
| | prompts[..., j] |
| | ) |
| |
|
| | for i, (predict_layer, embedding_layer) in enumerate( |
| | zip( |
| | self.nar_predict_layers, |
| | self.nar_audio_embeddings[1:], |
| | ) |
| | ): |
| | y_pos = self.nar_audio_prenet(y_emb) |
| | y_pos = self.nar_audio_position(y_pos) |
| | xy_pos = torch.concat([x, y_pos], dim=1) |
| |
|
| | xy_dec, _ = self.nar_decoder( |
| | (xy_pos, self.nar_stage_embeddings[i].weight) |
| | ) |
| | logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
| |
|
| | samples = torch.argmax(logits, dim=-1) |
| | codes.append(samples) |
| |
|
| | if i < 6: |
| | y_emb[:, prefix_len:] += embedding_layer(samples) |
| |
|
| | assert len(codes) == 8 |
| | return torch.stack(codes, dim=-1) |
| |
|
| |
|
| | |
| | def top_k_top_p_filtering( |
| | logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 |
| | ): |
| | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
| | Args: |
| | logits: logits distribution shape (batch size, vocabulary size) |
| | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
| | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
| | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
| | Make sure we keep at least min_tokens_to_keep per batch example in the output |
| | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
| | """ |
| | if top_k > 0: |
| | top_k = min( |
| | max(top_k, min_tokens_to_keep), logits.size(-1) |
| | ) |
| | |
| | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| | logits[indices_to_remove] = filter_value |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum( |
| | F.softmax(sorted_logits, dim=-1), dim=-1 |
| | ) |
| |
|
| | |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | if min_tokens_to_keep > 1: |
| | |
| | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
| | |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ |
| | ..., :-1 |
| | ].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| |
|
| | |
| | indices_to_remove = sorted_indices_to_remove.scatter( |
| | 1, sorted_indices, sorted_indices_to_remove |
| | ) |
| | logits[indices_to_remove] = filter_value |
| | return logits |
| |
|
| |
|
| | def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if temperature != 1.0: |
| | logits = logits / temperature |
| | |
| | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
| | |
| | token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
| | return token |
| |
|