MOSS-Audio-Tokenizer / modeling_moss_audio_tokenizer.py
Li-Ruixiao's picture
fix an import bug
b1dd2ae
raw
history blame
70.9 kB
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MossAudioTokenizer model."""
from __future__ import annotations
import copy
import math
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedAudioTokenizerBase
from transformers.utils import ModelOutput, auto_docstring, logging
from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
logger = logging.get_logger(__name__)
# =============================================================================
# Output Classes
# =============================================================================
@dataclass
@auto_docstring
class MossAudioTokenizerEncoderOutput(ModelOutput):
r"""
audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
Discrete audio codes computed using the encoder and quantizer.
audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio codes.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*):
Hidden states from the encoder before quantization.
"""
audio_codes: torch.Tensor | None = None
audio_codes_lengths: torch.Tensor | None = None
encoder_hidden_states: torch.Tensor | None = None
@dataclass
@auto_docstring
class MossAudioTokenizerDecoderOutput(ModelOutput):
r"""
audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
Decoded audio waveform.
audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio.
"""
audio: torch.Tensor | None = None
audio_lengths: torch.Tensor | None = None
@dataclass
@auto_docstring
class MossAudioTokenizerOutput(ModelOutput):
r"""
audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
Decoded audio waveform.
audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio.
audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
Discrete audio codes computed using the encoder and quantizer.
audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio codes.
"""
audio: torch.Tensor | None = None
audio_lengths: torch.Tensor | None = None
audio_codes: torch.Tensor | None = None
audio_codes_lengths: torch.Tensor | None = None
# =============================================================================
# Streaming Module Base Classes
# =============================================================================
@dataclass
class StreamingState:
"""Base state for streaming modules."""
batch_size: int
device: torch.device
def __post_init__(self):
self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device)
def set_exec_mask(self, exec_mask: torch.Tensor):
self.exec_mask[:] = exec_mask
def reset(self, reset_mask: torch.Tensor) -> None:
self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask)
def __enter__(self):
# ExitStack expects a context manager; returning self is conventional and useful for debugging.
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
class StreamingModule(nn.Module):
"""Base class for streaming components."""
def __init__(self) -> None:
super().__init__()
self._streaming_state: StreamingState | None = None
self._streaming_detached: bool = False
self._cached_children: list[tuple[str, StreamingModule]] | None = None
@property
def is_streaming(self):
return self._streaming_state is not None
def _apply_named_streaming(self, fn):
def _handle_module(prefix: str, module: nn.Module):
if isinstance(module, StreamingModule):
if module._streaming_detached and prefix != "":
return
if self._cached_children is None:
raise RuntimeError("Internal error: _cached_children should be initialized before traversal.")
self._cached_children.append((prefix, module))
for name, child in module.named_children():
new_prefix = f"{prefix}.{name}" if prefix else name
_handle_module(new_prefix, child)
if self._cached_children is None:
self._cached_children = []
_handle_module("", self)
for name, child in self._cached_children:
fn(name, child)
def _start_streaming(self, batch_size: int, exit_stack: ExitStack):
def _start_streaming_fn(name: str, module: StreamingModule):
if module._streaming_state is not None:
raise RuntimeError(f"{name} is already streaming!")
state = module._init_streaming_state(batch_size)
exit_stack.enter_context(state)
module._streaming_state = state
self._apply_named_streaming(_start_streaming_fn)
def _stop_streaming(self) -> None:
def _stop_streaming_fn(name: str, module: StreamingModule):
module._streaming_state = None
self._apply_named_streaming(_stop_streaming_fn)
def _init_streaming_state(self, batch_size: int) -> StreamingState:
device = next(iter(self.parameters())).device
return StreamingState(batch_size, device)
def streaming(self, batch_size: int) -> ExitStack:
"""Context manager to enter streaming mode."""
exit_stack = ExitStack()
self._start_streaming(batch_size, exit_stack)
exit_stack.callback(self._stop_streaming)
return exit_stack
class StreamingContainer(StreamingModule):
"""Container for streaming modules."""
pass
# =============================================================================
# Normalization Layers
# =============================================================================
class MossAudioTokenizerRMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(
self,
dim: int,
eps: float = 1e-5,
dtype: torch.dtype | None = None,
device=None,
):
super().__init__()
self.eps = eps
self.dtype = dtype
self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
x_dtype = x.dtype
if self.dtype is not None:
x = x.to(self.dtype)
var = self.eps + torch.mean(x**2, dim=2, keepdim=True)
y = (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
return y
class MossAudioTokenizerLayerScale(nn.Module):
"""Layer scale from Touvron et al. 2021."""
def __init__(
self,
channels: int,
init: float = 1e-4,
channel_last: bool = True,
device=None,
dtype=None,
):
super().__init__()
self.channel_last = channel_last
self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
if self.channel_last:
return self.scale * x
else:
return self.scale[:, None] * x
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
"""Create normalization module."""
if norm_type == "layer_norm":
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
elif norm_type in {"rms_norm"}:
return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs)
elif norm_type in {"rms_norm_f32"}:
kwargs.pop("dtype", None)
return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
# =============================================================================
# Rotary Position Embedding
# =============================================================================
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
offset: torch.Tensor,
max_period: float = 10_000,
time_before_heads: bool = False,
):
"""Apply rotary position embedding."""
if time_before_heads:
B, T, H, D = q.shape
else:
B, H, T, D = q.shape
if k.shape != q.shape:
raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}")
if D <= 0 or (D % 2) != 0:
raise ValueError(f"RoPE requires an even last dimension, got D={D}")
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32)
if time_before_heads:
ts = ts.view(B, -1, 1, 1)
else:
ts = ts.view(B, 1, -1, 1)
dims = q.shape[:-1]
q = q.view(*dims, D // 2, 2)
k = k.view(*dims, D // 2, 2)
qr, qi = q[..., 0].float(), q[..., 1].float()
kr, ki = k[..., 0].float(), k[..., 1].float()
rotr = torch.cos(freqs * ts)
roti = torch.sin(freqs * ts)
qor = qr * rotr - qi * roti
qoi = qr * roti + qi * rotr
kor = kr * rotr - ki * roti
koi = kr * roti + ki * rotr
dtype = q.dtype
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
return qo.view(*dims, D), ko.view(*dims, D)
class MossAudioTokenizerRotaryEmbedding(nn.Module):
"""Rotary positional embedding (RoPE)."""
def __init__(self, max_period: float = 10000.0):
super().__init__()
self.max_period = max_period
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
offset: torch.Tensor,
time_before_heads: bool = False,
):
return apply_rope(q, k, offset, self.max_period, time_before_heads)
# =============================================================================
# Gating Modules
# =============================================================================
class MossAudioTokenizerActivationGating(nn.Module):
"""Gating FFN layer with activation."""
def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
super().__init__()
if dim_feedforward == 4 * dim:
hidden = (21 * dim) // 8
else:
hidden = (2 * dim_feedforward) // 3
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
self.activation = activation
def forward(self, x: torch.Tensor):
x = self.linear_in(x)
B, T, _ = x.shape
x = x.view(B, T, 2, -1)
x = self.activation(x[..., 0, :]) * x[..., 1, :]
x = self.linear_out(x)
return x
def _get_activation(name: str):
if name in ["sigmoid", "tanh", "relu"]:
return getattr(torch, name)
elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
return getattr(F, name)
elif name == "identity":
return nn.Identity()
else:
raise ValueError(f"Unknown activation {name}")
def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module:
return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs)
# =============================================================================
# Positional Embeddings
# =============================================================================
def create_sin_embedding(
positions: torch.Tensor,
dim: int,
max_period: float = 10000,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Create sinusoidal positional embedding with shape [B, T, C]."""
if dim % 2 != 0:
raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
half_dim = dim // 2
if half_dim <= 1:
raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}")
positions = positions.to(dtype)
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
# =============================================================================
# KV Cache for Attention
# =============================================================================
class KVCacheResult:
"""Container for KV cache results that supports tuple unpacking."""
__slots__ = ("keys", "values", "positions")
def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor):
self.keys = keys
self.values = values
self.positions = positions
def __iter__(self):
"""Allow unpacking as (keys, values, positions)."""
return iter((self.keys, self.values, self.positions))
@staticmethod
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult:
B, H, T, D = keys.shape
positions = torch.arange(T, device=keys.device, dtype=torch.long)
return KVCacheResult(keys, values, positions.expand(B, -1))
class RingKVCache:
"""Efficient streaming KVCache compatible with CUDA Graph."""
def __init__(
self,
batch_size: int,
num_heads: int,
dim_per_head: int,
capacity: int,
respect_exec_mask: bool = True,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
):
self.capacity = capacity
self.cache = torch.zeros(
(2, batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
)
self.respect_exec_mask = respect_exec_mask
if self.respect_exec_mask:
self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long)
else:
self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
def reset(self, reset_mask: torch.Tensor) -> None:
self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset)
def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult:
B, H, T, D = k.shape
if T <= 0:
raise ValueError(f"Expected T > 0, got T={T}")
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype)
indexes = indexes + self.end_offset.view(-1, 1)
indexes = indexes % self.capacity
if self.respect_exec_mask:
this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D)
self.cache[0].scatter_(2, this_indexes, k)
self.cache[1].scatter_(2, this_indexes, v)
else:
self.cache[0].index_copy_(2, indexes[0], k)
self.cache[1].index_copy_(2, indexes[0], v)
keys = self.cache[0]
values = self.cache[1]
indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long)
last_offset = self.end_offset.view(-1, 1) + T - 1
end_index = last_offset % self.capacity
delta = indexes - end_index
positions = torch.where(
delta <= 0,
last_offset + delta,
last_offset + delta - self.capacity,
)
if self.respect_exec_mask:
self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset)
else:
self.end_offset.add_(T)
invalid = indexes >= self.end_offset.view(-1, 1)
positions = torch.where(invalid, torch.full_like(positions, -1), positions)
return KVCacheResult(keys, values, positions)
# =============================================================================
# Multi-Head Attention
# =============================================================================
@dataclass
class MHAState(StreamingState):
kv_cache: RingKVCache | None
offset: torch.Tensor
offset_cpu: int
def reset(self, reset_mask: torch.Tensor):
super().reset(reset_mask)
self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset)
if self.kv_cache is not None:
self.kv_cache.reset(reset_mask)
self.offset_cpu = 0
def apply_weights_per_step(
modules: nn.ModuleList,
schedule: list[int] | None,
x: torch.Tensor,
offset: int | None,
) -> torch.Tensor:
"""Apply different weights for each time step."""
if len(modules) == 1:
return modules[0](x)
if offset is None:
raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).")
ys = []
B, T, C = x.shape
for t in range(T):
module_index = t + offset
if schedule is not None:
if module_index >= len(schedule) or module_index < 0:
raise ValueError(
f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})."
)
module_index = schedule[module_index]
if module_index >= len(modules) or module_index < 0:
raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.")
y = modules[module_index](x[:, t : t + 1])
ys.append(y)
return torch.cat(ys, 1)
class MossAudioTokenizerMultiheadAttention(StreamingModule):
"""Multi-head attention with streaming support."""
def __init__(
self,
embed_dim: int,
num_heads: int,
causal: bool = False,
context: int | None = None,
rope: MossAudioTokenizerRotaryEmbedding | None = None,
weights_per_step: int = 0,
weights_per_step_schedule: list[int] | None = None,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.embed_dim = embed_dim
self.causal = causal
self.context = context
self.rope = rope
self.num_heads = num_heads
self.weights_per_step = weights_per_step
self.weights_per_step_schedule = weights_per_step_schedule
out_dim = 3 * embed_dim
mult = 1
if weights_per_step:
mult = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step
self.mult = mult
self.out_projs = nn.ModuleList(
[nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) for _ in range(mult)]
)
self.in_projs = nn.ModuleList(
[nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs) for _ in range(mult)]
)
self._register_load_state_dict_pre_hook(self._load_hook, with_module=True)
@staticmethod
def _load_hook(module, state_dict, prefix, *_):
mappings = {
"in_proj_weight": "in_projs.{i}.weight",
"in_proj.weight": "in_projs.{i}.weight",
"out_proj.weight": "out_projs.{i}.weight",
}
mult = module.mult
for suffix in ["", "_scb"]:
for source, target in mappings.items():
this_source = prefix + source + suffix
if this_source in state_dict:
weight = state_dict[this_source]
_, *OD = weight.shape
weight = weight.view(mult, -1, *OD)
for i in range(mult):
state_dict[prefix + target.format(i=i) + suffix] = weight[i]
state_dict.pop(this_source)
def _init_streaming_state(self, batch_size: int) -> MHAState:
in_proj = cast(nn.Linear, self.in_projs[0])
device = cast(torch.device, in_proj.weight.device)
dtype = cast(torch.dtype, in_proj.weight.dtype)
dim_per_head = self.embed_dim // self.num_heads
if self.context is None:
capacity = self.weights_per_step if self.weights_per_step else 1024
else:
capacity = self.context
kv_cache = RingKVCache(
batch_size,
self.num_heads,
dim_per_head,
capacity,
respect_exec_mask=not self.weights_per_step,
device=cast(torch.device, device),
dtype=cast(torch.dtype, dtype),
)
return MHAState(
batch_size,
cast(torch.device, device),
kv_cache,
offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long),
offset_cpu=0,
)
def _complete_kv(self, k, v) -> KVCacheResult:
state = cast(MHAState | None, self._streaming_state)
if state is None:
return KVCacheResult.from_kv(k, v)
if state.kv_cache is None:
return KVCacheResult.from_kv(k, v)
return state.kv_cache.complete(k, v, state.exec_mask)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
state = cast(MHAState | None, self._streaming_state)
B, T = query.shape[:2]
if state is None:
offset = torch.zeros(B, device=query.device, dtype=torch.long)
offset_cpu = 0
else:
offset = state.offset
offset_cpu = state.offset_cpu
projected = apply_weights_per_step(self.in_projs, self.weights_per_step_schedule, query, offset_cpu)
dim_per_head = self.embed_dim // self.num_heads
projected = projected.reshape(B, T, 3, self.num_heads, dim_per_head).permute(2, 0, 3, 1, 4)
q, k, v = projected[0], projected[1], projected[2]
if self.rope:
q, k = self.rope(q, k, offset, time_before_heads=False)
k, v, pos_k = self._complete_kv(k, v)
pos_k = pos_k[:, None]
if self.causal:
pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(-1, 1)
delta = pos_q - pos_k
attn_bias = (pos_k >= 0) & (delta >= 0)
if self.context is not None:
attn_bias = attn_bias & (delta < self.context)
attn_bias = attn_bias[:, None]
else:
attn_bias = None
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, T, self.embed_dim)
x = apply_weights_per_step(self.out_projs, self.weights_per_step_schedule, x, offset_cpu)
if state is not None:
state.offset[:] = torch.where(state.exec_mask, state.offset + T, state.offset)
state.offset_cpu += T
return x
# =============================================================================
# Transformer Layer
# =============================================================================
@dataclass
class LayerState(StreamingState):
offset_cpu: int = 0
def reset(self, reset_mask: torch.Tensor):
super().reset(reset_mask)
self.offset_cpu = 0
class MossAudioTokenizerTransformerLayer(StreamingModule):
"""Transformer layer with streaming support."""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int = 2048,
causal: bool = False,
context: int | None = None,
rope: MossAudioTokenizerRotaryEmbedding | None = None,
norm: str = "layer_norm",
layer_scale: float | None = None,
gating: str = "none",
weights_per_step: int = 0,
weights_per_step_schedule: list[int] | None = None,
activation=F.gelu,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.self_attn = MossAudioTokenizerMultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
causal=causal,
context=context,
rope=rope,
weights_per_step=weights_per_step,
weights_per_step_schedule=weights_per_step_schedule,
**factory_kwargs,
)
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
self.weights_per_step = weights_per_step
self.weights_per_step_schedule = weights_per_step_schedule
self.gating: nn.Module | nn.ModuleList | None = None
self.linear1: nn.Module | None = None
self.linear2: nn.Module | None = None
self.activation = activation
num_weights = 1
if weights_per_step:
num_weights = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step
if gating == "none":
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs)
else:
if weights_per_step:
dim_ff_list = [dim_feedforward] * num_weights if isinstance(dim_feedforward, int) else dim_feedforward
self.gating = nn.ModuleList(
[make_gating(gating, d_model, dim, **factory_kwargs) for dim in dim_ff_list]
)
else:
self.gating = make_gating(gating, d_model, dim_feedforward, **factory_kwargs)
if layer_scale is None:
self.layer_scale_1 = nn.Identity()
self.layer_scale_2 = nn.Identity()
else:
self.layer_scale_1 = MossAudioTokenizerLayerScale(
channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs)
)
self.layer_scale_2 = MossAudioTokenizerLayerScale(
channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs)
)
def _init_streaming_state(self, batch_size: int) -> LayerState:
device = next(iter(self.parameters())).device
return LayerState(batch_size, device, offset_cpu=0)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
state = self._streaming_state
offset = state.offset_cpu if isinstance(state, LayerState) else 0
x_orig = x
x = self.norm2(x)
if self.gating is None:
assert self.linear1 is not None
assert self.linear2 is not None
update = self.linear2(self.activation(self.linear1(x)))
else:
if self.weights_per_step:
assert isinstance(self.gating, nn.ModuleList)
update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset)
else:
update = self.gating(x)
return x_orig.to(update) + self.layer_scale_2(update)
def _sa_block(self, x: torch.Tensor):
x_orig = x
x = self.norm1(x)
update = self.self_attn(x, x, x)
return x_orig.to(update) + self.layer_scale_1(update)
def forward(self, x: torch.Tensor):
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
if state is not None:
assert isinstance(state, LayerState)
state.offset_cpu += x.shape[1]
return x
# =============================================================================
# Streaming Transformer
# =============================================================================
@dataclass
class TransformerState(StreamingState):
offsets: torch.Tensor
def reset(self, reset_mask: torch.Tensor):
super().reset(reset_mask)
self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets)
class MossAudioTokenizerTransformer(StreamingModule):
"""Transformer with streaming/causal support."""
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
dim_feedforward: int = 2048,
causal: bool = False,
context: int | None = None,
positional_embedding: str = "sin",
max_period: float = 10_000,
positional_scale: float = 1.0,
device=None,
dtype=None,
**kwargs,
):
super().__init__()
if d_model % num_heads != 0:
raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}")
self.positional_embedding = positional_embedding
self.max_period = max_period
self.positional_scale = positional_scale
self.rope: MossAudioTokenizerRotaryEmbedding | None = None
if positional_embedding in {"rope", "sin_rope"}:
self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
MossAudioTokenizerTransformerLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
causal=causal,
context=context,
rope=self.rope,
device=device,
dtype=dtype,
**kwargs,
)
)
def _init_streaming_state(self, batch_size: int) -> TransformerState:
device = next(self.parameters()).device
return TransformerState(
batch_size,
device,
offsets=torch.zeros(batch_size, device=device, dtype=torch.long),
)
def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
state = self._streaming_state
offsets = (
torch.zeros(1, dtype=torch.long, device=x.device)
if state is None
else (
state.offsets
if isinstance(state, TransformerState)
else torch.zeros(1, dtype=torch.long, device=x.device)
)
)
if self.positional_embedding in {"sin", "sin_rope"}:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
positions = positions + offsets.view(-1, 1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
x = x + self.positional_scale * pos_emb
for layer in self.layers:
x = layer(x, *args, **kwargs)
if state is not None:
assert isinstance(state, TransformerState)
state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets)
return x
class MossAudioTokenizerProjectedTransformer(StreamingContainer):
"""Transformer with input/output projections."""
def __init__(
self,
input_dimension: int,
output_dimension: int,
d_model: int,
*,
conv_layout: bool = False,
module_type: str,
**kwargs,
):
super().__init__()
self.module_type = module_type
self.downsample_ratio: int = 1
self.input_dimension = input_dimension
self.output_dimension = output_dimension
self.input_proj = (
nn.Linear(input_dimension, d_model, bias=False) if d_model != input_dimension else nn.Identity()
)
self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs)
self.conv_layout = conv_layout
self.output_proj = (
nn.Linear(d_model, output_dimension, bias=False) if d_model != output_dimension else nn.Identity()
)
def forward(self, x, input_lengths, *args, **kwargs):
x = self.input_proj(x.transpose(1, 2)) # (B, D, T) -> (B, T, D)
x = self.transformer(x, *args, **kwargs)
x = self.output_proj(x).transpose(1, 2) # (B, T, D) -> (B, D, T)
return x, input_lengths
# =============================================================================
# Patched Pretransform Module
# =============================================================================
class MossAudioTokenizerPatchedPretransform(nn.Module):
"""Patching module for downsampling/upsampling."""
def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs):
super().__init__()
self.patch_size = patch_size
self.downsample_ratio: int = patch_size
self.is_downsample = is_downsample
self.module_type = module_type
def encode(self, x, input_lengths):
b, d, _ = x.shape
h = self.patch_size
x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1)
# We pad the input waveform to a multiple of `downsample_rate` before applying the encoder.
# Use a ceil division to match that padding and avoid dropping the last (partially padded) frame.
output_lengths = (input_lengths + self.patch_size - 1) // self.patch_size
return x, output_lengths
def decode(self, x, input_lengths):
b, dh, l = x.shape
h = self.patch_size
d = dh // h
x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h)
output_lengths = input_lengths * self.patch_size
return x, output_lengths
def forward(self, x, input_lengths):
if self.is_downsample:
return self.encode(x, input_lengths)
else:
return self.decode(x, input_lengths)
# =============================================================================
# Vector Quantization
# =============================================================================
def WNConv1d(*args, **kwargs):
"""Weight-normalized Conv1d."""
return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs))
class MossAudioTokenizerVectorQuantize(nn.Module):
"""Single codebook vector quantization (inference only)."""
def __init__(
self,
input_dim: int,
codebook_size: int,
codebook_dim: int,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
if input_dim != codebook_dim:
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
else:
self.in_proj = nn.Identity()
self.out_proj = nn.Identity()
self.codebook = nn.Embedding(codebook_size, codebook_dim)
@torch.no_grad()
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
z: Input tensor of shape (B, D, T)
Returns:
z_q: Quantized tensor of shape (B, D, T)
indices: Code indices of shape (B, T)
z_e: Encoded tensor before quantization
"""
z = z.float()
z_e = self.in_proj(z).float()
encodings = z_e.transpose(1, 2).reshape(-1, z_e.shape[1])
codebook_weight = self.codebook.weight
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook_weight.float().t()
+ codebook_weight.float().pow(2).sum(1, keepdim=True).t()
)
indices = (-dist).max(1)[1]
indices = indices.reshape(z.size(0), -1)
z_q = self.decode_code(indices)
z_q = self.out_proj(z_q).float()
return z_q, indices, z_e
def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor:
"""Decode code indices to embeddings."""
return self.codebook(embed_id).transpose(1, 2).float()
class MossAudioTokenizerLFQ(nn.Module):
"""LFQ (inference-only) used by ResidualLFQ."""
def __init__(
self,
input_dim: int,
codebook_size: int,
codebook_dim: int,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
if self.input_dim != self.codebook_dim:
self.in_proj = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1)
else:
self.in_proj = nn.Identity()
self.out_proj = nn.Identity()
self.codebook = nn.Embedding(codebook_size, codebook_dim)
@torch.no_grad()
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Quantize z into codebook vectors."""
z = z.float()
z_e = self.in_proj(z).float()
z_q, indices = self.decode_latents(z_e)
z_q = (z_e + (z_q - z_e).detach()).float()
z_q = self.out_proj(z_q).float()
return z_q, indices, z_e
def embed_code(self, embed_id: torch.Tensor) -> torch.Tensor:
return F.embedding(embed_id, self.codebook.weight)
def decode_code_wo_out_proj(self, embed_id: torch.Tensor) -> torch.Tensor:
return self.embed_code(embed_id).transpose(1, 2)
def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor:
z_q = self.decode_code_wo_out_proj(embed_id).float()
z_q = self.out_proj(z_q).float()
return z_q
def decode_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Match training LFQ: L2-normalize then argmin squared distance."""
encodings = latents.transpose(1, 2).reshape(-1, latents.shape[1]).float()
codebook = self.codebook.weight.float()
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = (-dist).max(1)[1]
indices = indices.reshape(latents.size(0), -1)
z_q = self.decode_code_wo_out_proj(indices).float()
return z_q, indices
class MossAudioTokenizerResidualVQ(nn.Module):
"""Residual Vector Quantization (inference only)."""
def __init__(
self,
input_dim: int = 1024,
rvq_dim: int | None = None,
output_dim: int | None = None,
num_quantizers: int = 32,
codebook_size: int = 1024,
codebook_dim: int = 8,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.rvq_dim = rvq_dim or input_dim
self.output_dim = output_dim or input_dim
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.input_proj = (
WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity()
)
self.output_proj = (
WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1)
if self.rvq_dim != self.output_dim
else nn.Identity()
)
self.quantizers = nn.ModuleList(
[
MossAudioTokenizerVectorQuantize(
input_dim=self.rvq_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
**kwargs,
)
for _ in range(num_quantizers)
]
)
@torch.no_grad()
def forward(
self,
z: torch.Tensor,
input_length: torch.Tensor,
n_quantizers: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
z: Input tensor of shape (B, D, T)
input_length: Valid lengths for each sample (B,)
n_quantizers: Number of quantizers to use
Returns:
quantized_out: Quantized output (B, D, T)
all_indices: All code indices (N, B, T)
output_length: Output lengths (B,)
"""
z = self.input_proj(z)
batch_size, _, max_time = z.shape
mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1)
quantized_out = torch.zeros_like(z, dtype=torch.float32)
residual = z.clone().float()
all_indices = []
n_quantizers = n_quantizers or self.num_quantizers
for i, quantizer in enumerate(self.quantizers):
if i >= n_quantizers:
break
masked_residual = residual * mask.unsqueeze(1)
z_q_i, indices_i, _ = quantizer(masked_residual)
update_mask = mask.unsqueeze(1)
quantized_out = quantized_out + z_q_i * update_mask
residual = residual - z_q_i * update_mask
all_indices.append(indices_i)
all_indices = torch.stack(all_indices) # (N, B, T)
quantized_out = self.output_proj(quantized_out)
return quantized_out, all_indices, input_length
def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode codes from multiple quantizers to embeddings."""
nq, B, T = codes.shape
emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32)
for i, quantizer in enumerate(self.quantizers[:nq]):
quantizer = cast(MossAudioTokenizerVectorQuantize, quantizer)
quantized_i = quantizer.decode_code(codes[i])
emb += quantized_i
emb = self.output_proj(emb)
return emb
class MossAudioTokenizerResidualLFQ(nn.Module):
"""Residual LFQ (inference only)."""
def __init__(
self,
input_dim: int = 1024,
rvq_dim: int | None = None,
output_dim: int | None = None,
num_quantizers: int = 32,
codebook_size: int = 1024,
codebook_dim: int = 8,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.rvq_dim = rvq_dim or input_dim
self.output_dim = output_dim or input_dim
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.input_proj = (
WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity()
)
self.output_proj = (
WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1)
if self.rvq_dim != self.output_dim
else nn.Identity()
)
self.quantizers = nn.ModuleList(
[
MossAudioTokenizerLFQ(
input_dim=self.rvq_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
**kwargs,
)
for _ in range(num_quantizers)
]
)
@torch.no_grad()
def forward(
self,
z: torch.Tensor,
input_length: torch.Tensor,
n_quantizers: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Inference quantization."""
z = self.input_proj(z).float()
batch_size, _, max_time = z.shape
mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1)
quantized_out = torch.zeros_like(z, dtype=torch.float32)
residual = z.clone().float()
all_indices = []
n_quantizers = n_quantizers or self.num_quantizers
for i, quantizer in enumerate(self.quantizers):
if i >= n_quantizers:
break
masked_residual = residual * mask.unsqueeze(1)
z_q_i, indices_i, _ = quantizer(masked_residual)
update_mask = mask.unsqueeze(1)
quantized_out = quantized_out + z_q_i * update_mask
residual = residual - z_q_i * update_mask
all_indices.append(indices_i)
all_indices = (
torch.stack(all_indices)
if all_indices
else torch.empty(0, batch_size, max_time, device=z.device, dtype=torch.long)
)
quantized_out = self.output_proj(quantized_out)
return quantized_out, all_indices, input_length
def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
nq, B, T = codes.shape
emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32)
for i, quantizer in enumerate(self.quantizers[:nq]):
quantizer = cast(MossAudioTokenizerLFQ, quantizer)
emb += quantizer.decode_code(codes[i]).float()
emb = self.output_proj(emb)
return emb
# =============================================================================
# Main Model Classes
# =============================================================================
@auto_docstring
class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase):
"""Base class for MossAudioTokenizer models."""
config_class = MossAudioTokenizerConfig
base_model_prefix = ""
main_input_name = "input_values"
input_modalities = "audio"
supports_gradient_checkpointing = False
_no_split_modules = [
"MossAudioTokenizerTransformerLayer",
"MossAudioTokenizerResidualVQ",
"MossAudioTokenizerResidualLFQ",
]
@auto_docstring(
custom_intro="""
The MossAudioTokenizer neural audio codec model for audio tokenization and synthesis.
"""
)
class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
"""
MossAudioTokenizer model for audio tokenization and synthesis.
This model can encode audio waveforms into discrete tokens and decode
tokens back into audio waveforms.
"""
def __init__(self, config: MossAudioTokenizerConfig):
super().__init__(config)
self.config = config
_ = config.version
self.sampling_rate = config.sampling_rate
self.downsample_rate = config.downsample_rate
self.causal_transformer_context_duration = config.causal_transformer_context_duration
# Build encoder
current_frame_rate: float = float(self.sampling_rate)
self.encoder = nn.ModuleList()
for encoder_kwargs_i in config.encoder_kwargs:
encoder_kwargs_i = dict(encoder_kwargs_i) # Make a copy
if encoder_kwargs_i["module_type"] == "PatchedPretransform":
self.encoder.append(MossAudioTokenizerPatchedPretransform(**encoder_kwargs_i, is_downsample=True))
elif encoder_kwargs_i["module_type"] == "Transformer":
self.encoder.append(
MossAudioTokenizerProjectedTransformer(
**encoder_kwargs_i,
context=int(current_frame_rate * self.causal_transformer_context_duration),
)
)
current_frame_rate /= self.encoder[-1].downsample_ratio
# Build quantizer
quantizer_kwargs = dict(config.quantizer_kwargs)
quantizer_type = quantizer_kwargs.get("quantizer_type", getattr(config, "quantizer_type", "rvq"))
if quantizer_type in {"rvq", "spec_rvq"}:
self.quantizer = MossAudioTokenizerResidualVQ(**quantizer_kwargs)
elif quantizer_type in {"rlfq", "random_prefix_rlfq"}:
self.quantizer = MossAudioTokenizerResidualLFQ(**quantizer_kwargs)
else:
raise ValueError(f"Unsupported quantizer_type: {quantizer_type}")
# Build decoder
decoder_kwargs_list = copy.deepcopy(config.decoder_kwargs)
self.decoder = nn.ModuleList()
for decoder_kwargs_i in decoder_kwargs_list:
decoder_kwargs_i = dict(decoder_kwargs_i)
if decoder_kwargs_i["module_type"] == "PatchedPretransform":
self.decoder.append(MossAudioTokenizerPatchedPretransform(**decoder_kwargs_i, is_downsample=False))
elif decoder_kwargs_i["module_type"] == "Transformer":
self.decoder.append(
MossAudioTokenizerProjectedTransformer(
**decoder_kwargs_i,
context=int(current_frame_rate * self.causal_transformer_context_duration),
)
)
current_frame_rate *= self.decoder[-1].downsample_ratio
self.post_init()
def _start_streaming(self, batch_size: int):
"""Start streaming mode for all modules."""
def _start(module):
if isinstance(module, StreamingModule):
module._streaming_state = module._init_streaming_state(batch_size)
self.apply(_start)
def _stop_streaming(self):
"""Stop streaming mode for all modules."""
def _stop(module):
if isinstance(module, StreamingModule):
module._streaming_state = None
self.apply(_stop)
@contextmanager
def streaming(self, batch_size: int = 1):
"""Context manager for streaming mode."""
self._start_streaming(batch_size)
try:
yield
finally:
self._stop_streaming()
@torch.no_grad()
def batch_encode(
self, wav_list: list[torch.Tensor], num_quantizers: int | None = None
) -> MossAudioTokenizerEncoderOutput:
"""Batch encode a list of audio waveforms.
Args:
wav_list: List of audio tensors, each of shape `(num_samples,)`.
num_quantizers: Number of quantizers to use. By default, all quantizers are used.
Returns:
[`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`.
"""
if len(wav_list) == 0:
raise ValueError("`wav_list` must contain at least one waveform.")
device = wav_list[0].device
batch_size = len(wav_list)
max_length = max(wav.shape[-1] for wav in wav_list)
input_values = torch.zeros(batch_size, 1, max_length, device=device)
input_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
for i, wav in enumerate(wav_list):
input_values[i, 0, : wav.shape[-1]] = wav
input_lengths[i] = wav.shape[-1]
return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers)
@torch.no_grad()
def batch_decode(
self, codes_list: list[torch.Tensor], num_quantizers: int | None = None
) -> MossAudioTokenizerDecoderOutput:
"""Batch decode a list of audio codes.
Args:
codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`.
num_quantizers: If provided, decode only the first `num_quantizers` quantizers from each element in
`codes_list`. If omitted, all elements in `codes_list` must have the same number of quantizers.
Returns:
[`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`.
"""
if len(codes_list) == 0:
raise ValueError("`codes_list` must contain at least one code tensor.")
batch_size = len(codes_list)
device = codes_list[0].device
nqs = [codes.shape[0] for codes in codes_list]
if num_quantizers is None:
num_quantizers = nqs[0]
if any(nq != num_quantizers for nq in nqs):
raise ValueError(
"All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. "
"Pass `num_quantizers=...` to decode a common prefix."
)
else:
min_nq = min(nqs)
if min_nq < num_quantizers:
raise ValueError(
"`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. "
f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min_nq}."
)
max_length = max(codes.shape[-1] for codes in codes_list)
audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long)
audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
for i, codes in enumerate(codes_list):
codes = codes[:num_quantizers]
audio_codes[:, i, : codes.shape[-1]] = codes
audio_codes_lengths[i] = codes.shape[-1]
return self._decode_frame(audio_codes, audio_codes_lengths)
@torch.no_grad()
def _encode_frame(
self,
input_values: torch.Tensor,
input_lengths: torch.Tensor | None = None,
n_quantizers: int | None = None,
) -> MossAudioTokenizerEncoderOutput:
"""Tokenize audio waveform into discrete tokens."""
# Handle input shape
if input_values.dim() == 2:
input_values = input_values.unsqueeze(1)
B, _, T = input_values.shape
device = input_values.device
if input_lengths is None:
input_lengths = torch.full((B,), T, device=device, dtype=torch.long)
# Pad to multiple of downsample_rate
if T % self.downsample_rate != 0:
pad_length = self.downsample_rate - (T % self.downsample_rate)
input_values = F.pad(input_values, (0, pad_length))
# Encode
e, e_lengths = input_values, input_lengths
for encoder_module in self.encoder:
e, e_lengths = encoder_module(e, e_lengths)
# Quantize
quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer)
zq, audio_codes, audio_codes_lengths = quantizer(e, e_lengths, n_quantizers)
return MossAudioTokenizerEncoderOutput(
audio_codes=audio_codes, audio_codes_lengths=audio_codes_lengths, encoder_hidden_states=e
)
@torch.no_grad()
def _decode_frame(
self,
codes: torch.Tensor,
codes_lengths: torch.Tensor | None = None,
) -> MossAudioTokenizerDecoderOutput:
"""Detokenize discrete tokens into audio waveform."""
nq, B, T = codes.shape
device = codes.device
if codes_lengths is None:
codes_lengths = torch.full((B,), T, device=device, dtype=torch.long)
# Decode from codes
quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer)
zq = quantizer.decode_codes(codes)
d, d_lengths = zq, codes_lengths
for decoder_module in self.decoder:
d, d_lengths = decoder_module(d, d_lengths)
return MossAudioTokenizerDecoderOutput(audio=d, audio_lengths=d_lengths)
def encode( # type: ignore[override]
self,
input_values: torch.Tensor,
padding_mask: torch.Tensor | None = None,
num_quantizers: int | None = None,
return_dict: bool | None = None,
chunk_duration: float | None = None,
):
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Float values of the input audio waveform.
padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to indicate valid audio samples.
num_quantizers (`int`, *optional*):
Number of quantizers to use. By default, all quantizers are used.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
chunk_duration (`float`, *optional*):
If provided, encode the input waveform in successive chunks of `chunk_duration` seconds while keeping a
streaming KV cache for the causal transformers.
`chunk_duration` must be <= `config.causal_transformer_context_duration`, and
`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
Returns:
`MossAudioTokenizerEncoderOutput` or tuple containing audio codes and lengths.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Handle input shape
if input_values.dim() == 2:
input_values = input_values.unsqueeze(1)
B, _, T = input_values.shape
device = input_values.device
if padding_mask is not None:
input_lengths = padding_mask.sum(dim=-1).long()
else:
input_lengths = torch.full((B,), T, device=device, dtype=torch.long)
if chunk_duration is None:
encoder_output = self._encode_frame(input_values, input_lengths, num_quantizers)
else:
if chunk_duration <= 0:
raise ValueError("`chunk_duration` must be > 0 when provided.")
if chunk_duration > self.causal_transformer_context_duration:
raise ValueError(
"`chunk_duration` must be <= `config.causal_transformer_context_duration` "
f"({self.causal_transformer_context_duration}), got {chunk_duration}."
)
if B != 1:
raise ValueError("Streaming encode via `chunk_duration` currently only supports batch_size=1.")
chunk_length = int(round(chunk_duration * self.sampling_rate))
if chunk_length <= 0:
raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.")
if chunk_length % self.downsample_rate != 0:
raise ValueError(
"`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. "
f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}."
)
input_length = int(input_lengths[0].item())
if input_length <= chunk_length:
encoder_output = self._encode_frame(input_values[..., :input_length], input_lengths, num_quantizers)
else:
codes_chunks: list[torch.Tensor] = []
hidden_chunks: list[torch.Tensor] = []
with ExitStack() as exit_stack:
for encoder_module in self.encoder:
if isinstance(encoder_module, StreamingModule):
exit_stack.enter_context(encoder_module.streaming(batch_size=B))
for start_idx in range(0, input_length, chunk_length):
input_length_i = min(chunk_length, input_length - start_idx)
if input_length_i <= 0:
break
input_lengths_i = torch.tensor([input_length_i], device=device, dtype=torch.long)
input_values_i = input_values[..., start_idx : start_idx + input_length_i]
result_i = self._encode_frame(input_values_i, input_lengths_i, num_quantizers)
if result_i.audio_codes is None or result_i.audio_codes_lengths is None:
raise RuntimeError("Internal error: `_encode_frame` returned empty audio codes.")
if result_i.encoder_hidden_states is None:
raise RuntimeError("Internal error: `_encode_frame` returned empty encoder hidden states.")
codes_length_i = result_i.audio_codes_lengths
codes_chunks.append(result_i.audio_codes[:, :, : codes_length_i[0]])
hidden_chunks.append(result_i.encoder_hidden_states[:, :, : codes_length_i[0]])
audio_codes = torch.cat(codes_chunks, dim=-1)
encoder_hidden_states = torch.cat(hidden_chunks, dim=-1)
audio_codes_lengths = torch.tensor([audio_codes.shape[-1]], device=device, dtype=torch.long)
encoder_output = MossAudioTokenizerEncoderOutput(
audio_codes=audio_codes,
audio_codes_lengths=audio_codes_lengths,
encoder_hidden_states=encoder_hidden_states,
)
if not return_dict:
assert encoder_output.audio_codes is not None
assert encoder_output.audio_codes_lengths is not None
return (
cast(torch.Tensor, encoder_output.audio_codes),
cast(torch.Tensor, encoder_output.audio_codes_lengths),
)
return encoder_output
def decode( # type: ignore[override]
self,
audio_codes: torch.Tensor,
padding_mask: torch.Tensor | None = None,
return_dict: bool | None = None,
chunk_duration: float | None = None,
num_quantizers: int | None = None,
):
"""
Decodes the given codes into an output audio waveform.
Args:
audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`):
Discrete code embeddings computed using `model.encode`.
padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to indicate valid code positions.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
chunk_duration (`float`, *optional*):
If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a
streaming KV cache for the causal transformers.
num_quantizers (`int`, *optional*):
Number of quantizers to use. By default, all quantizers in `audio_codes` are used.
`chunk_duration` must be <= `config.causal_transformer_context_duration`, and
`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
Returns:
`MossAudioTokenizerDecoderOutput` or tuple containing decoded audio.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
if audio_codes.dim() == 2:
audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T
if num_quantizers is not None:
if num_quantizers > audio_codes.shape[0]:
raise ValueError(
f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})."
)
audio_codes = audio_codes[:num_quantizers]
_, B, T = audio_codes.shape
device = audio_codes.device
if padding_mask is not None:
codes_lengths = padding_mask.sum(dim=-1).long()
else:
codes_lengths = torch.full((B,), T, device=device, dtype=torch.long)
if chunk_duration is None:
decoder_output = self._decode_frame(audio_codes, codes_lengths)
else:
if chunk_duration <= 0:
raise ValueError("`chunk_duration` must be > 0 when provided.")
if chunk_duration > self.causal_transformer_context_duration:
raise ValueError(
"`chunk_duration` must be <= `config.causal_transformer_context_duration` "
f"({self.causal_transformer_context_duration}), got {chunk_duration}."
)
if B != 1:
raise ValueError("Streaming decode via `chunk_duration` currently only supports batch_size=1.")
chunk_length = int(round(chunk_duration * self.sampling_rate))
if chunk_length <= 0:
raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.")
if chunk_length % self.downsample_rate != 0:
raise ValueError(
"`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. "
f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}."
)
chunk_frame_length = chunk_length // self.downsample_rate
codes_length = int(codes_lengths[0].item())
if codes_length <= chunk_frame_length:
decoder_output = self._decode_frame(audio_codes[..., :codes_length], codes_lengths)
else:
wav_chunks: list[torch.Tensor] = []
with ExitStack() as exit_stack:
for decoder_module in self.decoder:
if isinstance(decoder_module, StreamingModule):
exit_stack.enter_context(decoder_module.streaming(batch_size=B))
for start_idx in range(0, codes_length, chunk_frame_length):
codes_length_i = min(chunk_frame_length, codes_length - start_idx)
if codes_length_i <= 0:
break
codes_lengths_i = torch.tensor([codes_length_i], device=device, dtype=torch.long)
codes_i = audio_codes[:, :, start_idx : start_idx + codes_length_i]
result_i = self._decode_frame(codes_i, codes_lengths_i)
if result_i.audio is None or result_i.audio_lengths is None:
raise RuntimeError("Internal error: `_decode_frame` returned empty audio.")
wav_chunks.append(result_i.audio[:, :, : result_i.audio_lengths[0]])
wav = torch.cat(wav_chunks, dim=-1)
audio_lengths = torch.tensor([wav.shape[-1]], device=device, dtype=torch.long)
decoder_output = MossAudioTokenizerDecoderOutput(audio=wav, audio_lengths=audio_lengths)
if not return_dict:
assert decoder_output.audio is not None
return (cast(torch.Tensor, decoder_output.audio),)
return decoder_output
@auto_docstring
def forward(
self,
input_values: torch.FloatTensor | None = None,
padding_mask: torch.BoolTensor | None = None,
audio_codes: torch.Tensor | None = None,
num_quantizers: int | None = None,
return_dict: bool | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | MossAudioTokenizerOutput: # type: ignore[override]
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
Raw audio input converted to Float.
padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid computing on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
Discrete code embeddings computed using `model.encode`.
num_quantizers (`int`, *optional*):
Number of quantizers (codebooks) to use. By default, all quantizers are used.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Examples:
```python
>>> import torch
>>> from transformers import MossAudioTokenizerModel
>>> model = MossAudioTokenizerModel.from_pretrained("moss_audio_tokenizer-model")
>>> # Create dummy audio input
>>> audio = torch.randn(1, 1, 24000) # 1 second of audio at 24kHz
>>> outputs = model(input_values=audio)
>>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio
```
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
output_audio_codes: torch.Tensor | None = None
output_audio_codes_lengths: torch.Tensor | None = None
output_audio: torch.Tensor | None = None
output_audio_lengths: torch.Tensor | None = None
decoded_from_encoded_codes = False
# Encode if input_values provided
if input_values is not None:
encoder_output = self.encode(input_values, padding_mask, num_quantizers, return_dict=True)
encoder_output = cast(MossAudioTokenizerEncoderOutput, encoder_output)
output_audio_codes = encoder_output.audio_codes
output_audio_codes_lengths = encoder_output.audio_codes_lengths
# If codes not provided separately, use encoded codes for decoding
if audio_codes is None:
audio_codes = output_audio_codes
decoded_from_encoded_codes = True
# Decode if codes available
if audio_codes is not None:
# If we're decoding the codes we just produced, use the computed lengths so we don't decode padded garbage.
if decoded_from_encoded_codes and output_audio_codes_lengths is not None:
decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths)
else:
decoder_output = self.decode(
audio_codes,
padding_mask=padding_mask,
return_dict=True,
num_quantizers=num_quantizers,
)
decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output)
output_audio = decoder_output.audio
output_audio_lengths = decoder_output.audio_lengths
if not return_dict:
return (output_audio_codes, output_audio, output_audio_lengths)
return MossAudioTokenizerOutput(
audio=output_audio,
audio_lengths=output_audio_lengths,
audio_codes=output_audio_codes,
audio_codes_lengths=output_audio_codes_lengths,
)
__all__ = ["MossAudioTokenizerModel", "MossAudioTokenizerPreTrainedModel"]