LongCat-Next-4bit / modular_longcat_next_audio.py
kernelpool's picture
Add files using upload-large-folder tool
74da6da verified
import math
import copy
from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, Optional
import numpy as np
import torch
import torchaudio
from einops import pack, rearrange, repeat
from flash_attn import flash_attn_varlen_func
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import functional as F
from diffusers.models.activations import get_activation
from diffusers.models.attention import (
GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
from transformers.activations import ACT2FN
from transformers.modeling_outputs import ModelOutput
from transformers.utils import logging
from .cosy24k_vocoder import Cosy24kVocoder
logger = logging.get_logger(__name__)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
def get_sequence_mask(inputs, inputs_length):
if inputs.dim() == 3:
bsz, tgt_len, _ = inputs.size()
else:
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
return sequence_mask, unpacking_index
def unpack_hidden_states(hidden_states, lengths):
bsz = lengths.shape[0]
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, torch.max(lengths), hidden_states.shape[-1]
)
hidden_states = torch.where(
sequence_mask, hidden_states, 0
) # 3d (bsz, max_input_len, d)
return hidden_states
def uniform_init(*shape):
t = torch.zeros(shape)
nn.init.kaiming_uniform_(t)
return t
def cdist(x, y):
x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
xy = torch.einsum('bd,cd->bc', x, y) * -2
return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
assert mask.dtype == torch.bool
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
mask = mask.to(dtype)
# attention mask bias
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * torch.finfo(dtype).min
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
# actually this is not needed after we have inference cache implemented, will remove it later
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
return ret
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
init_std=0.02,
):
super().__init__()
self.init_std = init_std
self.dim = dim
self.codebook_size = codebook_size
embed = uniform_init(codebook_size, dim).to(torch.float32)
self.cluster_size = nn.Parameter(torch.ones(codebook_size))
self.embed_avg = nn.Parameter(embed.clone())
self.embed = nn.Parameter(embed)
del embed
@autocast(enabled=True, dtype=torch.float32)
@torch.no_grad()
def forward(self, x):
assert(len(x.shape) == 2)
assert(x.dtype == torch.float32)
embed = self.embed.detach().to(x.device)
dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
embed_ind = dist.argmax(dim=-1)
quantize = embed[embed_ind] # (bs*sl, d)
return quantize, embed_ind, dist
class VectorQuantize(nn.Module):
def __init__(self, config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
def forward(self, x, input_length):
batch_size, seq_len, _ = x.shape
mask, unpacking_index = get_sequence_mask(x, input_length)
if x.dtype != torch.float32:
x = x.to(torch.float32)
x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
quantize, embed_ind, _ = self.codebook(x)
quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
quantize = torch.where(mask, quantize, 0)
embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
embed_ind = torch.where(mask, embed_ind, -1).squeeze()
return quantize, embed_ind
def get_output_from_indices(self, indices):
indices = indices.to(self.codebook.embed.device)
return self.codebook.embed[indices]
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self,
in_features,
out_features,
alpha=1.0,
alpha_trainable=True,
alpha_logscale=True,
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super().__init__()
self.in_features = (
out_features if isinstance(out_features, list) else [out_features]
)
self.proj = LoRACompatibleLinear(in_features, out_features)
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta ∶= x + 1/b * sin^2 (xa)
"""
x = self.proj(x)
if self.alpha_logscale:
alpha = torch.exp(self.alpha)
beta = torch.exp(self.beta)
else:
alpha = self.alpha
beta = self.beta
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
torch.sin(x * alpha), 2
)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
elif activation_fn == "snakebeta":
act_fn = SnakeBeta(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
use_omni_attn: bool = False,
):
super().__init__()
self.use_omni_attn = use_omni_attn
self.dim = dim
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if self.use_omni_attn:
if only_cross_attention:
raise NotImplementedError
print(
"Use OmniWhisperAttention with flash attention. Dropout is ignored."
)
self.attn1 = OmniWhisperAttention(
embed_dim=dim, num_heads=num_attention_heads, causal=False
)
else:
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=(
cross_attention_dim if only_cross_attention else None
),
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=(
cross_attention_dim if not double_self_attention else None
),
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
bsz, tgt_len, d_model = hidden_states.shape
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if self.use_omni_attn:
seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
var_len_attention_mask, unpacking_index = get_sequence_mask(
norm_hidden_states, seq_len
)
norm_hidden_states = torch.masked_select(
norm_hidden_states, var_len_attention_mask
)
norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
attn_output = self.attn1(norm_hidden_states, seq_len)
# unpacking
attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
bsz, tgt_len, d_model
)
attn_output = torch.where(var_len_attention_mask, attn_output, 0)
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=(
encoder_hidden_states if self.only_cross_attention else None
),
attention_mask=(
encoder_attention_mask
if self.only_cross_attention
else attention_mask
),
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(
num_chunks, dim=self._chunk_dim
)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class Block1D(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
torch.nn.GroupNorm(groups, dim_out),
nn.Mish(),
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__()
self.mlp = torch.nn.Sequential(
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
)
self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups)
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
n_feats,
cfm_params,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.n_feats = n_feats
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.solver = cfm_params.solver
if hasattr(cfm_params, "sigma_min"):
self.sigma_min = cfm_params.sigma_min
else:
self.sigma_min = 1e-4
self.estimator = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss, y
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
causal=False,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
gradient_checkpointing=False,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
self.gradient_checkpointing = gradient_checkpointing
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
) if self.causal else ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = t.to(x.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
mask = mask.to(x.dtype)
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(transformer_block),
x,
attn_mask,
t,
)
else:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(transformer_block),
x,
attn_mask,
t,
)
else:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(transformer_block),
x,
attn_mask,
t,
)
else:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
@torch.inference_mode()
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
cond=cond
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
org_dtype = x1.dtype
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
if spks is not None:
spks = spks * cfg_mask.view(-1, 1)
if cond is not None:
cond = cond * cfg_mask.view(-1, 1, 1)
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
pred = pred.float()
u = u.float()
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
loss = loss.to(org_dtype)
return loss, y
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1D(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(
self,
channels,
use_conv=False,
use_conv_transpose=True,
out_channels=None,
name="conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class OmniWhisperAttention(nn.Module):
def __init__(self, embed_dim, num_heads, causal=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.causal = causal
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
attn_output = attn_output.reshape(bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class OmniWhisperTransformerLayer(nn.Module):
def __init__(
self,
act,
d_model,
encoder_attention_heads,
encoder_ffn_dim,
causal,
ln_type="LayerNorm",
):
super().__init__()
self.embed_dim = d_model
self.self_attn = OmniWhisperAttention(
self.embed_dim, encoder_attention_heads, causal
)
if ln_type == "LayerNorm":
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
elif ln_type == "RMSNorm":
self.self_attn_layer_norm = RMSNorm(self.embed_dim)
else:
raise ValueError(f"Unknown ln_type: {ln_type}")
self.activation_fn = act
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
if ln_type == "LayerNorm":
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
elif ln_type == "RMSNorm":
self.final_layer_norm = RMSNorm(self.embed_dim)
else:
raise ValueError(f"Unknown ln_type: {ln_type}")
def forward(
self, hidden_states: torch.Tensor, seq_len: torch.Tensor
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states, seq_len)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if (
hidden_states.dtype == torch.float16
or hidden_states.dtype == torch.bfloat16
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
return hidden_states
class LongcatNextAudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
stride=config.stride_size, padding=1)
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
ACT2FN[config.activation_function],
config.d_model,
config.encoder_attention_heads,
config.encoder_ffn_dim,
False) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
def forward(
self,
input_features,
output_length,
):
input_features = input_features.to(self.conv1.weight.dtype)
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
bsz, tgt_len, _ = inputs_embeds.size()
if tgt_len < self.positional_embedding.shape[0]:
current_positional_embedding = self.positional_embedding[:tgt_len]
else:
current_positional_embedding = self.positional_embedding
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
# packing hidden states
attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
self.config.d_model)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(hidden_states, output_length)
hidden_states = self.layer_norm(hidden_states)
# unpacking
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
hidden_states = torch.where(attention_mask, hidden_states, 0)
return hidden_states
class CasualConvTranspose1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
self.norm = nn.GroupNorm(1, out_channels)
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, hidden_states, input_length, output_dim=None):
kernel_size = self.conv.kernel_size[0]
stride = self.conv.stride[0]
bsz = input_length.shape[0]
if output_dim is None:
output_dim = hidden_states.dim()
if hidden_states.dim() <= 2: # unpack sequence to 3d
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
self.in_channels)
hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
hidden_states = self.conv(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
casual_padding_right = max(0, kernel_size - stride)
hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
:]
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
if output_dim <= 2:
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
else:
hidden_states = torch.where(sequence_mask, hidden_states, 0)
hidden_states = hidden_states[:, :torch.max(output_length), :]
return hidden_states, output_length
class MelSpecRefineNet(nn.Module):
"""
# post net, coarse to refined mel-spectrogram frames
# ref1: Autoregressive Speech Synthesis without Vector Quantization
# ref2: CosyVoice length_regulator.py
# ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
"""
def __init__(self, encoder_config, vocoder_config):
super().__init__()
self.encoder_config = encoder_config
self.vocoder_config = vocoder_config
layers = nn.ModuleList([])
in_channels = self.vocoder_config.num_mel_bins
for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
in_channels = out_channels
norm = nn.GroupNorm(1, out_channels)
act = nn.Mish()
layers.extend([module, norm, act])
layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
self.layers = nn.Sequential(*layers)
def compute_output_length(self, input_length):
output_length = input_length.to(
torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
return output_length.to(torch.int64)
def forward(self, coarse_mel, input_length, output_length=None):
bsz, _, d = coarse_mel.shape
assert (d == self.vocoder_config.num_mel_bins)
if output_length is None or not self.training:
output_length = self.compute_output_length(input_length)
coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
mode='nearest').to(default_dtype)
refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
refined_mel += coarse_mel # residual conntection
sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
refined_mel = torch.where(sequence_mask, refined_mel, 0)
return refined_mel, coarse_mel, output_length
@dataclass
class OmniAudioDecoderOutput(ModelOutput):
refined_mel: Optional[torch.FloatTensor] = None
coarse_mel: Optional[torch.FloatTensor] = None
mel_length: Optional[torch.Tensor] = None
hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
output_length_before_dconv2: Optional[torch.Tensor] = None
class LongcatNextAudioDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.vocoder_config = config.vocoder_config
self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
self.dconv1 = CasualConvTranspose1d(
self.config.d_model,
self.config.d_model,
self.config.decoder_kernel_size,
self.config.avg_pooler,
)
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
# causal transformer layers
self.layers = nn.ModuleList(
[OmniWhisperTransformerLayer(
ACT2FN[self.config.activation_function],
self.config.d_model,
self.config.decoder_attention_heads,
self.config.decoder_ffn_dim,
True # causal
) for _ in range(self.config.decoder_layers)
])
self.layer_norm = nn.LayerNorm(self.config.d_model)
self.dconv2 = CasualConvTranspose1d(
self.config.d_model,
self.vocoder_config.num_mel_bins,
self.config.decoder_kernel_size,
self.config.decoder_stride_size
)
self.post_net = MelSpecRefineNet(self.config, self.vocoder_config)
self.gradient_checkpointing = False
def forward(self,
audio_embed,
input_length,
mel_labels=None,
mel_labels_length=None,
):
assert (audio_embed.shape[-1] == self.config.d_model)
audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
_, tgt_len, _ = audio_embed.size()
if tgt_len < self.positional_embedding.shape[0]:
current_positional_embedding = self.positional_embedding[:tgt_len]
else:
current_positional_embedding = self.positional_embedding
hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
# packing hidden states
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(hidden_states, output_length)
hidden_states = self.layer_norm(hidden_states)
hidden_states_before_dconv2 = hidden_states
output_length_before_dconv2 = output_length
coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
return OmniAudioDecoderOutput(
refined_mel=refined_mel,
coarse_mel=coarse_mel,
mel_length=mel_labels_length,
hidden_states_before_dconv2=hidden_states_before_dconv2,
output_length_before_dconv2=output_length_before_dconv2,
)
class LongcatNextAudioVQBridger(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gradient_checkpointing = False
self.intermediate_dim = self.config.d_model * self.config.avg_pooler
self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
self.act_fn = ACT2FN['silu']
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
self.vq_list = nn.ModuleList([])
for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
vq_config = copy.deepcopy(self.config.vq_config)
vq_config.dim = self.intermediate_dim
vq_config.codebook_size = codebook_size
self.vq_list.append(VectorQuantize(vq_config))
def rvq_op(self, inputs, output_length):
def rvq_layer_op(vq_layer, residual_encoding, output_length):
q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
residual_encoding = residual_encoding.float() - q_v_i.float()
residual_encoding = residual_encoding.to(inputs.dtype)
return residual_encoding, code_ids_i
cmt_loss, residual_encoding = 0, inputs
code_ids_list = []
for i, vq_layer in enumerate(self.vq_list):
residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
code_ids_list.append(code_ids_i)
return torch.stack(code_ids_list, -1)
def forward(self, x, output_length):
batch_size, _, _ = x.shape
output_length = output_length.to(x.device)
if x.shape[1] % self.config.avg_pooler != 0:
x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
xt = x.permute(0, 2, 1)
g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
u = self.up_proj(xt).permute(0, 2, 1)
x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
c = self.down_proj(self.act_fn(g) * u)
res = self.layer_norm(c + x)
valid_mask, _ = get_sequence_mask(res, output_length)
code_ids = self.rvq_op(res, output_length)
code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
return code_ids
@torch.no_grad()
def decode(self, code_ids):
vq_num = code_ids.shape[-1]
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
return decoder_emb
@torch.no_grad()
def recover(self, code_ids):
vq_num = code_ids.shape[-1]
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
return res
class FlowmatchingPrenet(nn.Module):
def __init__(
self,
input_feat_dim,
out_feat_dim,
d_model,
attention_heads,
ffn_dim,
nlayers,
activation_function,
max_source_positions,
target_mel_length_scale_ratio,
):
super().__init__()
self.d_model = d_model
self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
self.gradient_checkpointing = False
self.register_buffer(
"positional_embedding", sinusoids(max_source_positions, d_model)
)
self.in_mlp = nn.Sequential(
nn.Linear(input_feat_dim, d_model * 4),
nn.SiLU(),
nn.Linear(d_model * 4, d_model),
)
self.transformer_layers = nn.ModuleList(
[
OmniWhisperTransformerLayer(
act=ACT2FN[activation_function],
d_model=d_model,
encoder_attention_heads=attention_heads,
encoder_ffn_dim=ffn_dim,
causal=True, # causal
ln_type="RMSNorm",
)
for _ in range(nlayers)
]
)
self.final_norm = RMSNorm(self.d_model)
self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
def compute_output_length(self, input_length):
output_length = input_length.float() * self.target_mel_length_scale_ratio
return output_length.to(torch.int64)
def forward(self, input_feat, input_length, output_length=None):
"""
Args:
input_feat: [B, T, input_feat_dim]
input_length: [B]
output_length: [B]
"""
if output_length is None or not self.training:
output_length = self.compute_output_length(input_length)
input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
orig_dtype = input_feat.dtype
input_feat = F.interpolate(
input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
size=output_length.max(),
mode="nearest",
).to(orig_dtype)
input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
hidden_states = self.in_mlp(input_feat)
# packing hidden states
bsz, tgt_len, d_model = hidden_states.shape
attention_mask, unpacking_index = get_sequence_mask(
hidden_states, output_length
)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
torch.sum(output_length), self.d_model
)
for idx, encoder_layer in enumerate(self.transformer_layers):
hidden_states = encoder_layer(hidden_states, output_length)
# unpacking
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, tgt_len, d_model
)
hidden_states = torch.where(attention_mask, hidden_states, 0)
hidden_states = self.final_norm(hidden_states)
output = self.out_proj(hidden_states)
return output, output_length
@dataclass
class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
flow_matching_mel: Optional[torch.FloatTensor] = None
flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
class LongcatNextAudioFlowMatchingDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config.flow_matching_config
self.in_channels = self.config.in_channels
self.spk_emb_dim = self.config.spk_emb_dim
self.diffusion_steps = self.config.diffusion_steps
self.cal_mel_mae = self.config.cal_mel_mae
self.forward_step = -1
self.prenet = FlowmatchingPrenet(
input_feat_dim=self.config.prenet_in_dim,
out_feat_dim=self.config.prenet_out_dim,
d_model=self.config.prenet_d_model,
attention_heads=self.config.prenet_attention_heads,
ffn_dim=self.config.prenet_ffn_dim,
nlayers=self.config.prenet_nlayers,
activation_function=self.config.prenet_activation_function,
max_source_positions=self.config.prenet_max_source_positions,
target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
)
self.conditional_decoder = ConditionalDecoder(
in_channels=self.in_channels * 2 + self.spk_emb_dim,
out_channels=self.in_channels,
causal=True,
channels=self.config.channels,
dropout=self.config.dropout,
attention_head_dim=self.config.attention_head_dim,
n_blocks=self.config.n_blocks,
num_mid_blocks=self.config.num_mid_blocks,
num_heads=self.config.num_heads,
act_fn=self.config.act_fn,
)
self.cfm = ConditionalCFM(
in_channels=self.in_channels,
cfm_params=self.config.cfm_params,
n_spks=0,
spk_emb_dim=self.spk_emb_dim,
)
def unpack_hidden_states(self, hidden_states, output_length):
unpacked = unpack_hidden_states(hidden_states, output_length)
return unpacked, output_length
def forward(
self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
):
"""
:param refined_mel: [bs, max_input_len, mel_bin]
:param input_length: [batch_size]
:param refined_mel: [bs, mel_bin, max_input_len]
:return:
"""
self.forward_step += 1
orig_dtype = refined_mel.dtype
prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
if self.prenet is not None:
refined_mel = refined_mel[:, : torch.max(input_length), :]
if mel_labels_length is None:
mel_labels_length = self.prenet.compute_output_length(input_length)
refined_mel, input_length = self.prenet(
refined_mel, input_length, mel_labels_length
)
float_dtype = refined_mel.dtype
refined_mel = refined_mel.float()
input_length = input_length.long()
refined_mel = refined_mel[:, : torch.max(input_length), :]
sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
fm_mel = self.cfm.forward(
estimator=self.conditional_decoder,
mu=refined_mel.to(float_dtype),
mask=sequence_mask.float(),
n_timesteps=self.diffusion_steps,
)
return OmniAudioFlowMatchingDecoderOutput(
flow_matching_mel=fm_mel.transpose(1, 2),
flow_matching_mel_lengths=mel_labels_length,
)
@torch.no_grad()
def decode_wave_vocoder2(response, vocoder, audio_tokenizer):
response_len = (response[:,:,0] == audio_tokenizer.config.audio_config.vq_config.codebook_sizes[0]).long().argmax(dim=1)
valid_response_list = [response[i, :response_len[i], :] for i in range(response.shape[0]) if int(response_len[i])>0]
if len(valid_response_list)==0:
return []
flatten_response = torch.cat(valid_response_list, dim=0) if len(valid_response_list)>1 else valid_response_list[0]
valid_response_len = response_len[response_len>0]
ret = audio_tokenizer.decode(flatten_response.view(-1,response.shape[-1]),
bridge_length=valid_response_len)
batch_size = response.shape[0]
valid_start = 0
r = []
for i in range(batch_size):
if response_len[i]==0:
r.append(None)
continue
if isinstance(ret, torch.Tensor):
r.append(ret[valid_start:valid_start+1])
valid_start+=1
continue
decode_wave = vocoder.decode(ret.flow_matching_mel[valid_start ][:ret.flow_matching_mel_lengths[valid_start ], :].transpose(0, 1).to(torch.float32).unsqueeze(0))
r.append(decode_wave.cpu())
valid_start+=1
return r
@torch.no_grad()
def decode_save_concat2(response_list, vocoder, model, path, sampling_rate=16000, wave_concat_overlap=800):
wave_list = []
for response in response_list:
wave_list.extend([wave_i for wave_i in decode_wave_vocoder2(response, vocoder, model) if wave_i is not None])
new_wave_list = [wave_list[0]]
for w in wave_list[1:]:
if new_wave_list[-1].shape[1] > wave_concat_overlap and w.shape[1] > wave_concat_overlap:
new_wave_list.append((new_wave_list[-1][:, -wave_concat_overlap:] * torch.linspace(1.0, 0.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]
+ w[:, :wave_concat_overlap] * torch.linspace(0.0, 1.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]))
new_wave_list.append(w)
full_wave = torch.cat(new_wave_list, dim=1) if len(new_wave_list) > 1 else new_wave_list[0]
torchaudio.save(path, full_wave, sampling_rate)
class LongcatNextAudioTokenizer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.audio_model = LongcatNextAudioEncoder(config.audio_config)
self.audio_bridge_model = LongcatNextAudioVQBridger(config.audio_config)
self.audio_decoder = LongcatNextAudioDecoder(config.audio_config)
self.audio_flow_matching_decoder = LongcatNextAudioFlowMatchingDecoder(config.audio_config)
self.cosy24kvocoder = None
@torch.no_grad()
def encode(self, x, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
audio_emb = self.audio_model(x, encoder_length)
audio_tokens = self.audio_bridge_model(audio_emb, bridge_length)
return audio_tokens
@torch.no_grad()
def decode(self, audio_ids, bridge_length: Optional[torch.Tensor] = None):
audio_emb = self.audio_bridge_model.decode(audio_ids)
audio_dec = self.audio_decoder(
audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
)
if self.config.audio_config.flow_matching_config.use_hidden_states_before_dconv2:
hidden_states, hidden_states_length = (
self.audio_flow_matching_decoder.unpack_hidden_states(
audio_dec.hidden_states_before_dconv2,
audio_dec.output_length_before_dconv2,
)
)
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
hidden_states, hidden_states_length
)
else:
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
audio_dec.refined_mel, audio_dec.mel_length
)
return audio_flow_matching_decoder_ret
@torch.no_grad()
def lazy_decode_and_save(self, audio_ids, sampling_rate, wave_concat_overlap, save_path):
if self.cosy24kvocoder is None:
print("lazy load cosy24kvocoder ...")
device = next(self.parameters()).device
self.cosy24kvocoder = Cosy24kVocoder.from_pretrained(self.config.audio_config.cosy24kvocoder_config.weight_path).to(device)
if audio_ids[-1, 0] != self.config.audio_config.vq_config.codebook_sizes[0]: # exceed max_new_tokens
audio_ids = F.pad(audio_ids, (0, 0, 0, 1), value=self.config.audio_config.vq_config.codebook_sizes[0])
audio_end_pos = [-1] + (audio_ids[:, 0] == self.config.audio_config.vq_config.codebook_sizes[0]).nonzero().view(-1).tolist()
audio_ids_chunk = []
for i in range(len(audio_end_pos) - 1):
start = audio_end_pos[i] + 1
end = audio_end_pos[i+1] + 1
audio_ids_chunk.append(audio_ids[start:end].unsqueeze(0))
audio_ids = audio_ids_chunk
decode_save_concat2(
response_list=audio_ids,
vocoder=self.cosy24kvocoder,
model=self,
path=save_path,
sampling_rate=sampling_rate,
wave_concat_overlap=wave_concat_overlap,
)