| import math |
| import typing as tp |
| from dataclasses import dataclass |
| from typing import List, Optional, Union |
|
|
| import hydra |
| import librosa |
| import numpy as np |
| import soundfile as sf |
| import torch |
| from audiotools import AudioSignal |
| from audiotools.ml import BaseModel |
| from dac.model.base import CodecMixin |
| from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d |
| from omegaconf import OmegaConf |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
| from torch.nn.utils.parametrizations import weight_norm |
| from torch.nn.utils.parametrize import remove_parametrizations |
|
|
|
|
| @dataclass |
| class VQResult: |
| z: torch.Tensor |
| codes: torch.Tensor |
| latents: torch.Tensor |
| codebook_loss: torch.Tensor |
| commitment_loss: torch.Tensor |
| semantic_distill_z: torch.Tensor | None = None |
|
|
|
|
| def find_multiple(n: int, k: int) -> int: |
| if n % k == 0: |
| return n |
| return n + k - (n % k) |
|
|
|
|
| @dataclass |
| class ModelArgs: |
| block_size: int = 2048 |
| n_layer: int = 8 |
| n_head: int = 8 |
| dim: int = 512 |
| intermediate_size: int = 1536 |
| n_local_heads: int = -1 |
| head_dim: int = 64 |
| rope_base: float = 10000 |
| norm_eps: float = 1e-5 |
| dropout_rate: float = 0.1 |
| attn_dropout_rate: float = 0.1 |
| channels_first: bool = True |
| pos_embed_type: str = "rope" |
| max_relative_position: int = 128 |
|
|
| def __post_init__(self): |
| if self.n_local_heads == -1: |
| self.n_local_heads = self.n_head |
| if self.intermediate_size is None: |
| hidden_dim = 4 * self.dim |
| n_hidden = int(2 * hidden_dim / 3) |
| self.intermediate_size = find_multiple(n_hidden, 256) |
| assert self.pos_embed_type in [ |
| "rope", |
| "conformer", |
| ], "pos_embed_type must be either 'rope' or 'conformer'" |
|
|
|
|
| class KVCache(nn.Module): |
| def __init__( |
| self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 |
| ): |
| super().__init__() |
| cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) |
| self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) |
|
|
| def update(self, input_pos, k_val, v_val): |
| |
| assert input_pos.shape[0] == k_val.shape[2] |
|
|
| k_out = self.k_cache |
| v_out = self.v_cache |
| k_out[:, :, input_pos] = k_val |
| v_out[:, :, input_pos] = v_val |
|
|
| return ( |
| k_out[:, :, : input_pos.max() + 1, :], |
| v_out[:, :, : input_pos.max() + 1, :], |
| ) |
|
|
| def clear_cache(self, prompt_len): |
| self.k_cache[:, :, prompt_len:, :].fill_(0) |
| self.v_cache[:, :, prompt_len:, :].fill_(0) |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, config: ModelArgs) -> None: |
| super().__init__() |
| self.config = config |
|
|
| self.layers = nn.ModuleList( |
| TransformerBlock(config) for _ in range(config.n_layer) |
| ) |
| self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
|
| |
| if config.pos_embed_type == "rope": |
| freqs_cis = precompute_freqs_cis( |
| self.config.block_size, self.config.head_dim, self.config.rope_base |
| ) |
| self.register_buffer("freqs_cis", freqs_cis) |
| else: |
| self.register_buffer("freqs_cis", None) |
|
|
| causal_mask = torch.tril( |
| torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool) |
| ) |
| self.register_buffer("causal_mask", causal_mask) |
|
|
| self.max_batch_size = -1 |
| self.max_seq_length = -1 |
| self.use_kv_cache = False |
|
|
| def setup_caches(self, max_batch_size, max_seq_length): |
| """ |
| This method will only be called during inference when using KV cache. |
| """ |
| head_dim = self.config.dim // self.config.n_head |
| max_seq_length = find_multiple(max_seq_length, 8) |
| self.max_seq_length = max_seq_length |
| self.max_batch_size = max_batch_size |
| dtype = self.norm.weight.dtype |
| device = self.norm.weight.device |
|
|
| for b in self.layers: |
| b.attention.kv_cache = KVCache( |
| max_batch_size, |
| max_seq_length, |
| self.config.n_local_heads, |
| head_dim, |
| dtype, |
| ).to(device) |
|
|
| self.use_kv_cache = True |
|
|
| def forward( |
| self, |
| x: Tensor, |
| input_pos: Optional[Tensor] = None, |
| mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| if self.config.pos_embed_type == "rope": |
| assert ( |
| self.freqs_cis is not None |
| ), "RoPE frequencies must be initialized for RoPE positional embedding" |
| freqs_cis = self.freqs_cis[input_pos] |
| else: |
| freqs_cis = None |
|
|
| if mask is None: |
| if not self.training and self.use_kv_cache: |
| mask = self.causal_mask[None, None, input_pos] |
| mask = mask[..., : input_pos.max() + 1] |
| else: |
| mask = self.causal_mask[None, None, input_pos] |
| mask = mask[..., input_pos] |
|
|
| for i, layer in enumerate(self.layers): |
| x = layer(x, input_pos, freqs_cis, mask) |
| x = self.norm(x) |
| return x |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: ModelArgs) -> None: |
| super().__init__() |
| self.attention = Attention(config) |
| self.feed_forward = FeedForward(config) |
| self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
| self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) |
| self.attention_layer_scale = LayerScale(config.dim, inplace=True) |
| self.ffn_layer_scale = LayerScale(config.dim, inplace=True) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| input_pos: Tensor, |
| freqs_cis: Tensor, |
| mask: Tensor, |
| ) -> Tensor: |
| h = x + self.attention_layer_scale( |
| self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) |
| ) |
| out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h))) |
| return out |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config: ModelArgs): |
| super().__init__() |
| assert config.dim % config.n_head == 0 |
|
|
| total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim |
| |
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) |
| self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) |
| self.kv_cache = None |
|
|
| self.n_head = config.n_head |
| self.head_dim = config.head_dim |
| self.n_local_heads = config.n_local_heads |
| self.dim = config.dim |
| self.attn_dropout_rate = config.attn_dropout_rate |
| self.pos_embed_type = config.pos_embed_type |
|
|
| |
| if self.pos_embed_type == "conformer": |
| self.max_relative_position = config.max_relative_position |
| num_pos_embeddings = 2 * config.max_relative_position + 1 |
| self.rel_pos_embeddings = nn.Parameter( |
| torch.zeros(num_pos_embeddings, self.head_dim) |
| ) |
| nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02) |
|
|
| def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor: |
| |
| |
| positions = torch.arange(seqlen, device=q.device) |
| relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) |
| relative_positions = torch.clamp( |
| relative_positions + self.max_relative_position, |
| 0, |
| 2 * self.max_relative_position, |
| ) |
| rel_embeddings = self.rel_pos_embeddings[relative_positions] |
|
|
| |
| q = q.transpose(1, 2) |
| rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) |
| rel_logits = rel_logits.transpose(1, 2) |
| return rel_logits |
|
|
| def forward( |
| self, |
| x: Tensor, |
| freqs_cis: Tensor, |
| mask: Tensor, |
| input_pos: Optional[Tensor] = None, |
| ) -> Tensor: |
| bsz, seqlen, _ = x.shape |
|
|
| kv_size = self.n_local_heads * self.head_dim |
| q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) |
| context_seqlen = seqlen |
|
|
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) |
| k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) |
| v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) |
|
|
| if self.pos_embed_type == "rope": |
| q = apply_rotary_emb(q, freqs_cis) |
| k = apply_rotary_emb(k, freqs_cis) |
|
|
| q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
|
|
| if self.kv_cache is not None: |
| k, v = self.kv_cache.update(input_pos, k, v) |
|
|
| k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
| v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
|
|
| if self.pos_embed_type == "conformer": |
| |
| scale = 1.0 / math.sqrt(self.head_dim) |
| scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
| |
| rel_scores = self._compute_conformer_pos_scores(q, seqlen) |
| scores = scores + rel_scores |
|
|
| |
| if mask is not None: |
| scores = scores.masked_fill(~mask, float("-inf")) |
|
|
| attn = F.softmax(scores, dim=-1) |
| if self.attn_dropout_rate > 0 and self.training: |
| attn = F.dropout(attn, p=self.attn_dropout_rate) |
|
|
| y = torch.matmul(attn, v) |
| else: |
| y = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.attn_dropout_rate if self.training else 0.0, |
| attn_mask=mask, |
| ) |
| |
| y = ( |
| y.transpose(1, 2) |
| .contiguous() |
| .view(bsz, seqlen, self.head_dim * self.n_head) |
| ) |
| y = self.wo(y) |
| return y |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, config: ModelArgs) -> None: |
| super().__init__() |
| self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) |
| self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) |
| self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
|
|
| class LayerScale(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| init_values: Union[float, Tensor] = 1e-2, |
| inplace: bool = False, |
| ) -> None: |
| super().__init__() |
| self.inplace = inplace |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
| class WindowLimitedTransformer(Transformer): |
| """ |
| Transformer with window limited attention, causal. |
| """ |
|
|
| def __init__( |
| self, |
| config: ModelArgs, |
| input_dim: int = 512, |
| window_size: Optional[int] = None, |
| causal: bool = True, |
| look_ahead_conv: nn.Module = None, |
| ): |
| super().__init__(config) |
| self.window_size = window_size |
| self.causal = causal |
| self.channels_first = config.channels_first |
| self.look_ahead_conv = ( |
| look_ahead_conv if look_ahead_conv is not None else nn.Identity() |
| ) |
| self.input_proj = ( |
| nn.Linear(input_dim, config.dim) |
| if input_dim != config.dim |
| else nn.Identity() |
| ) |
| self.output_proj = ( |
| nn.Linear(config.dim, input_dim) |
| if input_dim != config.dim |
| else nn.Identity() |
| ) |
|
|
| def make_window_limited_mask( |
| self, |
| max_length: int, |
| x_lens: Optional[Tensor] = None, |
| ) -> Tensor: |
| """ |
| Make mask to form window limited attention. |
| """ |
| if self.causal: |
| mask = torch.tril(torch.ones(max_length, max_length)) |
| row_indices = torch.arange(max_length).view(-1, 1) |
| window_size = self.window_size or max_length |
| valid_range = (row_indices - window_size + 1).clamp(min=0) |
| column_indices = torch.arange(max_length) |
| mask = (column_indices >= valid_range) & mask.bool() |
| else: |
| raise NotImplementedError |
| mask = mask.bool()[None, None] |
| return mask |
|
|
| def make_mask( |
| self, |
| max_length: int, |
| x_lens: Optional[Tensor] = None, |
| ) -> Tensor: |
| """ |
| Make ordinary mask if window size is not specified. |
| """ |
| if self.causal: |
| mask = torch.tril(torch.ones(max_length, max_length)) |
| else: |
| mask = torch.ones(max_length, max_length) |
| mask = mask.bool()[None, None] |
| for i, x_len in enumerate(x_lens): |
| mask[:x_len, i] = 0 |
| mask = mask.bool()[None, None] |
| return mask |
|
|
| def forward( |
| self, |
| x: Tensor, |
| x_lens: Optional[Tensor] = None, |
| ) -> Tensor: |
| if self.channels_first: |
| x = x.transpose(1, 2) |
| x = self.input_proj(x) |
| x = self.look_ahead_conv(x) |
| input_pos = torch.arange(x.shape[1], device=x.device) |
| |
| max_length = x.shape[1] |
| if self.window_size is not None: |
| mask = self.make_window_limited_mask(max_length, x_lens) |
| else: |
| mask = self.make_mask(max_length, x_lens) |
| mask = mask.to(x.device) |
| x = super().forward(x, input_pos, mask) |
| x = self.output_proj(x) |
| if self.channels_first: |
| x = x.transpose(1, 2) |
| return x |
|
|
|
|
| def precompute_freqs_cis( |
| seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 |
| ) -> Tensor: |
| freqs = 1.0 / ( |
| base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) |
| ) |
| t = torch.arange(seq_len, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) |
| return cache.to(dtype=dtype) |
|
|
|
|
| def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: |
| xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) |
| x_out2 = torch.stack( |
| [ |
| xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], |
| xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], |
| ], |
| -1, |
| ) |
|
|
| x_out2 = x_out2.flatten(3) |
| return x_out2.type_as(x) |
|
|
|
|
| def init_weights(m): |
| if isinstance(m, nn.Conv1d): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): |
| """Remove padding from x, handling properly zero padding. Only for 1d!""" |
| padding_left, padding_right = paddings |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
| assert (padding_left + padding_right) <= x.shape[-1] |
| end = x.shape[-1] - padding_right |
| return x[..., padding_left:end] |
|
|
|
|
| def get_extra_padding_for_conv1d( |
| x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 |
| ) -> int: |
| """See `pad_for_conv1d`.""" |
| length = x.shape[-1] |
| n_frames = (length - kernel_size + padding_total) / stride + 1 |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) |
| return ideal_length - length |
|
|
|
|
| def pad1d( |
| x: torch.Tensor, |
| paddings: tp.Tuple[int, int], |
| mode: str = "zeros", |
| value: float = 0.0, |
| ): |
| """Tiny wrapper around F.pad, just to allow for reflect padding on small input. |
| If this is the case, we insert extra 0 padding to the right |
| before the reflection happen. |
| """ |
| length = x.shape[-1] |
| padding_left, padding_right = paddings |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
| if mode == "reflect": |
| max_pad = max(padding_left, padding_right) |
| extra_pad = 0 |
| if length <= max_pad: |
| extra_pad = max_pad - length + 1 |
| x = F.pad(x, (0, extra_pad)) |
| padded = F.pad(x, paddings, mode, value) |
| end = padded.shape[-1] - extra_pad |
| return padded[..., :end] |
| else: |
| return F.pad(x, paddings, mode, value) |
|
|
|
|
| class CausalConvNet(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| dilation=1, |
| stride=1, |
| groups=1, |
| padding=None, |
| ): |
| super(CausalConvNet, self).__init__() |
| self.conv = nn.Conv1d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| dilation=dilation, |
| groups=groups, |
| ) |
| self.stride = stride |
| self.kernel_size = (kernel_size - 1) * dilation + 1 |
| self.dilation = dilation |
| self.padding = self.kernel_size - self.stride |
|
|
| def forward(self, x): |
| pad = self.padding |
| extra_padding = get_extra_padding_for_conv1d( |
| x, self.kernel_size, self.stride, pad |
| ) |
| x = pad1d(x, (pad, extra_padding), mode="constant", value=0) |
| return self.conv(x).contiguous() |
|
|
| def weight_norm(self, name="weight", dim=0): |
| self.conv = weight_norm(self.conv, name=name, dim=dim) |
| return self |
|
|
| def remove_weight_norm(self): |
| self.conv = remove_parametrizations(self.conv) |
| return self |
|
|
|
|
| class CausalTransConvNet(nn.Module): |
| def __init__( |
| self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None |
| ): |
| super(CausalTransConvNet, self).__init__() |
| self.conv = nn.ConvTranspose1d( |
| in_channels, out_channels, kernel_size, stride=stride, dilation=dilation |
| ) |
| self.stride = stride |
| self.kernel_size = kernel_size |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| pad = self.kernel_size - self.stride |
| padding_right = math.ceil(pad) |
| padding_left = pad - padding_right |
| x = unpad1d(x, (padding_left, padding_right)) |
| return x.contiguous() |
|
|
| def weight_norm(self, name="weight", dim=0): |
| self.conv = weight_norm(self.conv, name=name, dim=dim) |
| return self |
|
|
| def remove_weight_norm(self): |
| self.conv = remove_parametrizations(self.conv) |
| return self |
|
|
|
|
| def CausalWNConv1d(*args, **kwargs): |
| return CausalConvNet(*args, **kwargs).weight_norm() |
|
|
|
|
| def CausalWNConvTranspose1d(*args, **kwargs): |
| return CausalTransConvNet(*args, **kwargs).weight_norm() |
|
|
|
|
| class ResidualUnit(nn.Module): |
| def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): |
| super().__init__() |
| conv_class = CausalWNConv1d if causal else WNConv1d |
| pad = ((7 - 1) * dilation) // 2 |
| self.block = nn.Sequential( |
| Snake1d(dim), |
| conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
| Snake1d(dim), |
| conv_class(dim, dim, kernel_size=1), |
| ) |
| self.causal = causal |
|
|
| def forward(self, x): |
| y = self.block(x) |
| pad = x.shape[-1] - y.shape[-1] |
| if pad > 0: |
| if self.causal: |
| x = x[..., :-pad] |
| else: |
| x = x[..., pad // 2 : -pad // 2] |
| return x + y |
|
|
|
|
| class EncoderBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int = 16, |
| stride: int = 1, |
| causal: bool = False, |
| n_t_layer: int = 0, |
| transformer_general_config=None, |
| ): |
| super().__init__() |
| conv_class = CausalWNConv1d if causal else WNConv1d |
| transformer_module = ( |
| nn.Identity() |
| if n_t_layer == 0 |
| else ( |
| WindowLimitedTransformer( |
| causal=causal, |
| input_dim=dim, |
| window_size=512, |
| config=transformer_general_config( |
| n_layer=n_t_layer, |
| n_head=dim // 64, |
| dim=dim, |
| intermediate_size=dim * 3, |
| ), |
| ) |
| ) |
| ) |
| self.block = nn.Sequential( |
| ResidualUnit(dim // 2, dilation=1, causal=causal), |
| ResidualUnit(dim // 2, dilation=3, causal=causal), |
| ResidualUnit(dim // 2, dilation=9, causal=causal), |
| Snake1d(dim // 2), |
| conv_class( |
| dim // 2, |
| dim, |
| kernel_size=2 * stride, |
| stride=stride, |
| padding=math.ceil(stride / 2), |
| ), |
| transformer_module, |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__( |
| self, |
| d_model: int = 64, |
| strides: list = [2, 4, 8, 8], |
| d_latent: int = 64, |
| n_transformer_layers: list = [0, 0, 4, 4], |
| transformer_general_config: ModelArgs = None, |
| causal: bool = False, |
| ): |
| super().__init__() |
| conv_class = CausalWNConv1d if causal else WNConv1d |
| |
| self.block = [conv_class(1, d_model, kernel_size=7, padding=3)] |
|
|
| |
| for stride, n_t_layer in zip(strides, n_transformer_layers): |
| d_model *= 2 |
| self.block += [ |
| EncoderBlock( |
| d_model, |
| stride=stride, |
| causal=causal, |
| n_t_layer=n_t_layer, |
| transformer_general_config=transformer_general_config, |
| ) |
| ] |
|
|
| |
| self.block += [ |
| Snake1d(d_model), |
| conv_class(d_model, d_latent, kernel_size=3, padding=1), |
| ] |
|
|
| |
| self.block = nn.Sequential(*self.block) |
| self.enc_dim = d_model |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class DecoderBlock(nn.Module): |
| def __init__( |
| self, |
| input_dim: int = 16, |
| output_dim: int = 8, |
| stride: int = 1, |
| causal: bool = False, |
| n_t_layer: int = 0, |
| transformer_general_config=None, |
| ): |
| super().__init__() |
| conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d |
| transformer_module = ( |
| nn.Identity() |
| if n_t_layer == 0 |
| else ( |
| WindowLimitedTransformer( |
| causal=causal, |
| input_dim=input_dim, |
| window_size=None, |
| config=transformer_general_config( |
| n_layer=n_t_layer, |
| n_head=input_dim // 64, |
| dim=input_dim, |
| intermediate_size=input_dim * 3, |
| ), |
| ) |
| ) |
| ) |
| self.block = nn.Sequential( |
| |
| Snake1d(input_dim), |
| conv_trans_class( |
| input_dim, |
| output_dim, |
| kernel_size=2 * stride, |
| stride=stride, |
| padding=math.ceil(stride / 2), |
| ), |
| ResidualUnit(output_dim, dilation=1, causal=causal), |
| ResidualUnit(output_dim, dilation=3, causal=causal), |
| ResidualUnit(output_dim, dilation=9, causal=causal), |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__( |
| self, |
| input_channel, |
| channels, |
| rates, |
| d_out: int = 1, |
| causal: bool = False, |
| n_transformer_layers: list = [0, 0, 0, 0], |
| transformer_general_config=None, |
| ): |
| super().__init__() |
| conv_class = CausalWNConv1d if causal else WNConv1d |
| |
| layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)] |
|
|
| |
| for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)): |
| input_dim = channels // 2**i |
| output_dim = channels // 2 ** (i + 1) |
| layers += [ |
| DecoderBlock( |
| input_dim, |
| output_dim, |
| stride, |
| causal=causal, |
| n_t_layer=n_t_layer, |
| transformer_general_config=transformer_general_config, |
| ) |
| ] |
|
|
| |
| layers += [ |
| Snake1d(output_dim), |
| conv_class(output_dim, d_out, kernel_size=7, padding=3), |
| nn.Tanh(), |
| ] |
|
|
| self.model = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class DAC(BaseModel, CodecMixin): |
| def __init__( |
| self, |
| encoder_dim: int = 64, |
| encoder_rates: List[int] = [2, 4, 8, 8], |
| latent_dim: int = None, |
| decoder_dim: int = 1536, |
| decoder_rates: List[int] = [8, 8, 4, 2], |
| quantizer: torch.nn.Module = None, |
| sample_rate: int = 44100, |
| causal: bool = True, |
| encoder_transformer_layers: List[int] = [0, 0, 0, 0], |
| decoder_transformer_layers: List[int] = [0, 0, 0, 0], |
| transformer_general_config=None, |
| ): |
| super().__init__() |
|
|
| self.encoder_dim = encoder_dim |
| self.encoder_rates = encoder_rates |
| self.decoder_dim = decoder_dim |
| self.decoder_rates = decoder_rates |
| self.sample_rate = sample_rate |
|
|
| if latent_dim is None: |
| latent_dim = encoder_dim * (2 ** len(encoder_rates)) |
|
|
| self.latent_dim = latent_dim |
|
|
| self.hop_length = np.prod(encoder_rates) |
| self.encoder = Encoder( |
| encoder_dim, |
| encoder_rates, |
| latent_dim, |
| causal=causal, |
| n_transformer_layers=encoder_transformer_layers, |
| transformer_general_config=transformer_general_config, |
| ) |
|
|
| self.quantizer = quantizer |
|
|
| self.decoder = Decoder( |
| latent_dim, |
| decoder_dim, |
| decoder_rates, |
| causal=causal, |
| n_transformer_layers=decoder_transformer_layers, |
| transformer_general_config=transformer_general_config, |
| ) |
| self.sample_rate = sample_rate |
| self.apply(init_weights) |
|
|
| self.delay = self.get_delay() |
|
|
| self.frame_length = self.hop_length * 4 |
|
|
| def preprocess(self, audio_data, sample_rate): |
| if sample_rate is None: |
| sample_rate = self.sample_rate |
| assert sample_rate == self.sample_rate |
|
|
| length = audio_data.shape[-1] |
| right_pad = math.ceil(length / self.hop_length) * self.hop_length - length |
| audio_data = nn.functional.pad(audio_data, (0, right_pad)) |
|
|
| return audio_data |
|
|
| def encode( |
| self, |
| audio_data: torch.Tensor, |
| audio_lengths: torch.Tensor = None, |
| n_quantizers: int = None, |
| **kwargs, |
| ): |
| """Encode given audio data and return quantized latent codes |
| |
| Parameters |
| ---------- |
| audio_data : Tensor[B x T] |
| Audio data to encode |
| n_quantizers : int, optional |
| Number of quantizers to use, by default None |
| If None, all quantizers are used. |
| |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| "z" : Tensor[B x D x T] |
| Quantized continuous representation of input |
| "codes" : Tensor[B x N x T] |
| Codebook indices for each codebook |
| (quantized discrete representation of input) |
| "latents" : Tensor[B x N*D x T] |
| Projected latents (continuous representation of input before quantization) |
| "vq/commitment_loss" : Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| "vq/codebook_loss" : Tensor[1] |
| Codebook loss to update the codebook |
| "length" : int |
| Number of samples in input audio |
| """ |
| |
| if audio_data.ndim == 2: |
| audio_data = audio_data.unsqueeze(1) |
| |
| length = audio_data.shape[-1] |
| right_pad = math.ceil(length / self.frame_length) * self.frame_length - length |
| audio_data = nn.functional.pad(audio_data, (0, right_pad)) |
| if audio_lengths is None: |
| audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device) |
|
|
| z = self.encoder(audio_data) |
| vq_results = self.quantizer(z, n_quantizers, **kwargs) |
| indices = vq_results.codes |
| indices_lens = torch.ceil(audio_lengths / self.frame_length).long() |
| return indices, indices_lens |
|
|
| def decode(self, indices: torch.Tensor, feature_lengths): |
| if indices.ndim == 2: |
| indices = indices[None] |
|
|
| z = self.quantizer.decode(indices) |
| audio_lengths = feature_lengths * self.frame_length |
| return self.decoder(z), audio_lengths |
|
|
| def forward( |
| self, |
| audio_data: torch.Tensor, |
| template: torch.Tensor = None, |
| mask: torch.Tensor = None, |
| sample_rate: int = None, |
| n_quantizers: int = None, |
| **kwargs, |
| ): |
| """Model forward pass |
| |
| Parameters |
| ---------- |
| audio_data : Tensor[B x 1 x T] |
| Audio data to encode |
| sample_rate : int, optional |
| Sample rate of audio data in Hz, by default None |
| If None, defaults to `self.sample_rate` |
| n_quantizers : int, optional |
| Number of quantizers to use, by default None. |
| If None, all quantizers are used. |
| |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| "z" : Tensor[B x D x T] |
| Quantized continuous representation of input |
| "codes" : Tensor[B x N x T] |
| Codebook indices for each codebook |
| (quantized discrete representation of input) |
| "latents" : Tensor[B x N*D x T] |
| Projected latents (continuous representation of input before quantization) |
| "vq/commitment_loss" : Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| "vq/codebook_loss" : Tensor[1] |
| Codebook loss to update the codebook |
| "length" : int |
| Number of samples in input audio |
| "audio" : Tensor[B x 1 x length] |
| Decoded audio data. |
| """ |
| length = audio_data.shape[-1] |
| audio_data = self.preprocess(audio_data, sample_rate) |
| vq_results = self.encode(audio_data, n_quantizers, **kwargs) |
| z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z |
| x = self.decode(z) |
| return x[..., :length], vq_results |
|
|