# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch MossAudioTokenizer model.""" from __future__ import annotations import copy import importlib import math import sys import types from contextlib import ExitStack, contextmanager from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import cast import torch import torch.nn as nn import torch.nn.functional as F if __name__ not in sys.modules: _module_proxy = types.ModuleType(__name__) sys.modules[__name__] = _module_proxy def _sync_module_proxy() -> None: sys.modules[__name__].__dict__.update(globals()) try: from transformers.modeling_utils import PreTrainedAudioTokenizerBase except ImportError: from transformers.modeling_utils import PreTrainedModel as PreTrainedAudioTokenizerBase from transformers.utils import ModelOutput, logging try: from transformers.utils import auto_docstring as _hf_auto_docstring except ImportError: _hf_auto_docstring = None def auto_docstring(*args, **kwargs): if _hf_auto_docstring is None: if len(args) == 1 and callable(args[0]) and not kwargs: return args[0] def decorator(obj): return obj return decorator if len(args) == 1 and callable(args[0]) and not kwargs: obj = args[0] try: return _hf_auto_docstring(obj) except Exception: return obj try: decorator = _hf_auto_docstring(*args, **kwargs) except Exception: def decorator(obj): return obj return decorator def safe_decorator(obj): try: return decorator(obj) except Exception: return obj return safe_decorator try: from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig except ImportError: _module_dir = str(Path(__file__).resolve().parent) if _module_dir not in sys.path: sys.path.insert(0, _module_dir) from configuration_moss_audio_tokenizer import MossAudioTokenizerConfig logger = logging.get_logger(__name__) @lru_cache(maxsize=1) def _get_flash_attn_module(): try: return importlib.import_module("flash_attn") except Exception: return None def _has_flash_attn() -> bool: return _get_flash_attn_module() is not None def _get_flash_attn_varlen_func(): flash_attn_module = _get_flash_attn_module() if flash_attn_module is None: return None return getattr(flash_attn_module, "flash_attn_varlen_func", None) def _get_flash_attn_with_kvcache(): flash_attn_module = _get_flash_attn_module() if flash_attn_module is None: return None return getattr(flash_attn_module, "flash_attn_with_kvcache", None) SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"} SUPPORTED_COMPUTE_DTYPES = {"fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16} _ACTIVE_DECODE_SESSION_ERROR_MESSAGE = "MossAudioTokenizerModel only supports one active decode session at a time." _CLOSED_DECODE_SESSION_ERROR_MESSAGE = "This decode session is closed." _MODEL_STREAMING_CONFLICT_ERROR_MESSAGE = "Model-level streaming helpers cannot be used while a decode session is active." _PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE = "Plain decode helpers cannot be used while a decode session is active." _DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session already contains request_id={request_id!r}." _UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session does not contain an active request_id={request_id!r}." _DECODE_SESSION_FULL_ERROR_TEMPLATE = "Decode session has no free slots remaining (max_batch_size={max_batch_size})." _INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE = ( "`request_ids` must exactly match the current active decode request order." ) _BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE = "`finalize_indices` must not contain duplicates." _BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE = ( "`finalize_indices` index {index} is out of range for the pre-call logical batch of size {batch_size}." ) _BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE = ( "`batch_decode(streaming=True)` must include all pre-call active rows in the current call before applying `finalize_indices`." ) def resolve_compute_dtype(compute_dtype: str) -> torch.dtype | None: if compute_dtype not in SUPPORTED_COMPUTE_DTYPES: raise ValueError( f"Unsupported compute_dtype={compute_dtype!r}. Expected one of {sorted(SUPPORTED_COMPUTE_DTYPES)}." ) return SUPPORTED_COMPUTE_DTYPES[compute_dtype] @contextmanager def disable_cuda_autocast(): with torch.autocast(device_type="cuda", enabled=False): yield # ============================================================================= # Output Classes # ============================================================================= _sync_module_proxy() @dataclass @auto_docstring class MossAudioTokenizerEncoderOutput(ModelOutput): r""" audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): Discrete audio codes computed using the encoder and quantizer. audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Valid lengths for each sample's audio codes. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*): Hidden states from the encoder before quantization. """ audio_codes: torch.Tensor | None = None audio_codes_lengths: torch.Tensor | None = None encoder_hidden_states: torch.Tensor | None = None _sync_module_proxy() @dataclass @auto_docstring class MossAudioTokenizerDecoderOutput(ModelOutput): r""" audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): Decoded audio waveform. audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Valid lengths for each sample's audio. """ audio: torch.Tensor | None = None audio_lengths: torch.Tensor | None = None _sync_module_proxy() @dataclass @auto_docstring class MossAudioTokenizerOutput(ModelOutput): r""" audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): Decoded audio waveform. audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Valid lengths for each sample's audio. audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): Discrete audio codes computed using the encoder and quantizer. audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Valid lengths for each sample's audio codes. """ audio: torch.Tensor | None = None audio_lengths: torch.Tensor | None = None audio_codes: torch.Tensor | None = None audio_codes_lengths: torch.Tensor | None = None # ============================================================================= # Streaming Module Base Classes # ============================================================================= _sync_module_proxy() @dataclass class StreamingState: """Base state for streaming modules.""" batch_size: int device: torch.device def __post_init__(self): self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device) def set_exec_mask(self, exec_mask: torch.Tensor): self.exec_mask[:] = exec_mask def reset(self, reset_mask: torch.Tensor) -> None: self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask) def __enter__(self): # ExitStack expects a context manager; returning self is conventional and useful for debugging. return self def __exit__(self, exc_type, exc_value, traceback) -> None: pass class StreamingModule(nn.Module): """Base class for streaming components.""" def __init__(self) -> None: super().__init__() self._streaming_state: StreamingState | None = None self._streaming_detached: bool = False self._cached_children: list[tuple[str, StreamingModule]] | None = None @property def is_streaming(self): return self._streaming_state is not None def _apply_named_streaming(self, fn): def _handle_module(prefix: str, module: nn.Module): if isinstance(module, StreamingModule): if module._streaming_detached and prefix != "": return if self._cached_children is None: raise RuntimeError("Internal error: _cached_children should be initialized before traversal.") self._cached_children.append((prefix, module)) for name, child in module.named_children(): new_prefix = f"{prefix}.{name}" if prefix else name _handle_module(new_prefix, child) if self._cached_children is None: self._cached_children = [] _handle_module("", self) for name, child in self._cached_children: fn(name, child) def _start_streaming(self, batch_size: int, exit_stack: ExitStack): def _start_streaming_fn(name: str, module: StreamingModule): if module._streaming_state is not None: raise RuntimeError(f"{name} is already streaming!") state = module._init_streaming_state(batch_size) exit_stack.enter_context(state) module._streaming_state = state self._apply_named_streaming(_start_streaming_fn) def _stop_streaming(self) -> None: def _stop_streaming_fn(name: str, module: StreamingModule): module._streaming_state = None self._apply_named_streaming(_stop_streaming_fn) def _init_streaming_state(self, batch_size: int) -> StreamingState: device = next(iter(self.parameters())).device return StreamingState(batch_size, device) def streaming(self, batch_size: int) -> ExitStack: """Context manager to enter streaming mode.""" exit_stack = ExitStack() self._start_streaming(batch_size, exit_stack) exit_stack.callback(self._stop_streaming) return exit_stack class StreamingContainer(StreamingModule): """Container for streaming modules.""" pass class MossAudioTokenizerDecodeSession: model: MossAudioTokenizerModel max_batch_size: int _use_cuda_graph: bool active_request_ids: list[str | int] request_id_to_slot_index: dict[str | int, int] slot_index_to_request_id: list[str | int | None] slot_is_free: list[bool] request_id_to_code_offset: dict[str | int, int] request_id_to_audio_offset: dict[str | int, int] _flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] _graph_num_quantizers_capacity: int | None _graph_input_codes: torch.Tensor | None _graph_input_code_lengths: torch.Tensor | None _graph_output_audio: torch.Tensor | None _graph_output_audio_lengths: torch.Tensor | None _cuda_graph: torch.cuda.CUDAGraph | None _cuda_graph_key: tuple[str, int, int, str] | None _decode_streaming_exit_stack: ExitStack | None _closed: bool def __init__(self, model: MossAudioTokenizerModel, max_batch_size: int, use_cuda_graph: bool = False): if max_batch_size <= 0: raise ValueError("`max_batch_size` must be > 0.") decoder_attention_modules: list[MossAudioTokenizerMultiheadAttention] = [] for decoder_module in model.decoder: for module in decoder_module.modules(): if isinstance(module, MossAudioTokenizerMultiheadAttention): if module.context is None: raise ValueError( "MossAudioTokenizerDecodeSession requires all decoder MHA modules to have a finite " "`context` (context=None is unsupported for continuous-batch streaming)." ) decoder_attention_modules.append(module) flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = [] if use_cuda_graph and _has_flash_attn(): for module in decoder_attention_modules: module._use_flash_kvcache = True flash_kvcache_attention_modules.append(module) decode_streaming_exit_stack = ExitStack() try: for decoder_module in model.decoder: if isinstance(decoder_module, StreamingModule): inner_stack = decoder_module.streaming(batch_size=max_batch_size) _ = decode_streaming_exit_stack.enter_context(inner_stack) except Exception: decode_streaming_exit_stack.close() for module in flash_kvcache_attention_modules: module._use_flash_kvcache = False raise self.model = model self.max_batch_size = max_batch_size self._use_cuda_graph = use_cuda_graph self.active_request_ids: list[str | int] = [] self.request_id_to_slot_index: dict[str | int, int] = {} self.slot_index_to_request_id: list[str | int | None] = [None] * max_batch_size self.slot_is_free: list[bool] = [True] * max_batch_size self.request_id_to_code_offset: dict[str | int, int] = {} self.request_id_to_audio_offset: dict[str | int, int] = {} self._flash_kvcache_attention_modules = flash_kvcache_attention_modules self._graph_num_quantizers_capacity = int(getattr(model.quantizer, "num_quantizers", 0)) if use_cuda_graph else None self._graph_input_codes = None self._graph_input_code_lengths = None self._graph_output_audio = None self._graph_output_audio_lengths = None self._cuda_graph = None self._cuda_graph_key = None self._decode_streaming_exit_stack: ExitStack | None = decode_streaming_exit_stack self._closed = False if use_cuda_graph: device = next(iter(model.parameters())).device if device.type == "cuda": self._ensure_cuda_graph_buffers(device) model._active_decode_session = self def _ensure_open(self) -> None: if self._closed: raise RuntimeError(_CLOSED_DECODE_SESSION_ERROR_MESSAGE) def append(self, request_id: str | int) -> None: self._ensure_open() if request_id in self.request_id_to_slot_index: raise RuntimeError(_DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id)) slot_index = next((index for index, is_free in enumerate(self.slot_is_free) if is_free), None) if slot_index is None: raise RuntimeError(_DECODE_SESSION_FULL_ERROR_TEMPLATE.format(max_batch_size=self.max_batch_size)) self.active_request_ids.append(request_id) self.request_id_to_slot_index[request_id] = slot_index self.slot_index_to_request_id[slot_index] = request_id self.slot_is_free[slot_index] = False self.request_id_to_code_offset[request_id] = 0 self.request_id_to_audio_offset[request_id] = 0 def _decoder_streaming_states(self) -> list[StreamingState]: decoder_streaming_states: list[StreamingState] = [] for decoder_module in self.model.decoder: for module in decoder_module.modules(): if isinstance(module, StreamingModule) and module._streaming_state is not None: decoder_streaming_states.append(module._streaming_state) return decoder_streaming_states def _ensure_cuda_graph_buffers(self, device: torch.device) -> None: if not self._use_cuda_graph or device.type != "cuda": return graph_num_quantizers_capacity = self._graph_num_quantizers_capacity if graph_num_quantizers_capacity is None: graph_num_quantizers_capacity = int(getattr(self.model.quantizer, "num_quantizers", 0)) self._graph_num_quantizers_capacity = graph_num_quantizers_capacity if graph_num_quantizers_capacity <= 0: raise RuntimeError("`use_cuda_graph=True` requires a quantizer with `num_quantizers > 0`.") if self._graph_input_codes is None or self._graph_input_codes.device != device: self._graph_input_codes = torch.zeros( (graph_num_quantizers_capacity, self.max_batch_size, 1), device=device, dtype=torch.long, ) self._graph_input_code_lengths = torch.zeros(self.max_batch_size, device=device, dtype=torch.long) self._graph_output_audio = None self._graph_output_audio_lengths = None self._cuda_graph = None self._cuda_graph_key = None def _snapshot_decoder_streaming_states(self) -> list[tuple[StreamingState, dict[str, torch.Tensor | None]]]: snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]] = [] for streaming_state in self._decoder_streaming_states(): state_snapshot: dict[str, torch.Tensor | None] = {"exec_mask": streaming_state.exec_mask.clone()} if isinstance(streaming_state, TransformerState): state_snapshot["offsets"] = streaming_state.offsets.clone() if isinstance(streaming_state, MHAState): state_snapshot["offset"] = streaming_state.offset.clone() state_snapshot["cached_keys"] = None if streaming_state.cached_keys is None else streaming_state.cached_keys.clone() state_snapshot["cached_values"] = None if streaming_state.cached_values is None else streaming_state.cached_values.clone() state_snapshot["cached_positions"] = ( None if streaming_state.cached_positions is None else streaming_state.cached_positions.clone() ) state_snapshot["flash_cached_keys"] = ( None if getattr(streaming_state, "_flash_cached_keys", None) is None else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_keys")).clone() ) state_snapshot["flash_cached_values"] = ( None if getattr(streaming_state, "_flash_cached_values", None) is None else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_values")).clone() ) snapshots.append((streaming_state, state_snapshot)) return snapshots def _restore_decoder_streaming_states( self, snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]], ) -> None: for streaming_state, state_snapshot in snapshots: exec_mask = state_snapshot["exec_mask"] assert exec_mask is not None streaming_state.exec_mask.copy_(exec_mask) if isinstance(streaming_state, TransformerState): offsets = state_snapshot.get("offsets") assert offsets is not None streaming_state.offsets.copy_(offsets) if isinstance(streaming_state, MHAState): offset = state_snapshot.get("offset") assert offset is not None streaming_state.offset.copy_(offset) cached_keys = state_snapshot.get("cached_keys") cached_values = state_snapshot.get("cached_values") cached_positions = state_snapshot.get("cached_positions") if cached_keys is None or cached_values is None or cached_positions is None: if streaming_state.cached_keys is not None: streaming_state.cached_keys.zero_() if streaming_state.cached_values is not None: streaming_state.cached_values.zero_() if streaming_state.cached_positions is not None: streaming_state.cached_positions.fill_(-1) else: if streaming_state.cached_keys is None or streaming_state.cached_keys.shape != cached_keys.shape: streaming_state.cached_keys = cached_keys.clone() else: streaming_state.cached_keys.copy_(cached_keys) if streaming_state.cached_values is None or streaming_state.cached_values.shape != cached_values.shape: streaming_state.cached_values = cached_values.clone() else: streaming_state.cached_values.copy_(cached_values) if streaming_state.cached_positions is None or streaming_state.cached_positions.shape != cached_positions.shape: streaming_state.cached_positions = cached_positions.clone() else: streaming_state.cached_positions.copy_(cached_positions) flash_cached_keys = state_snapshot.get("flash_cached_keys") flash_cached_values = state_snapshot.get("flash_cached_values") current_flash_cached_keys = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_keys", None)) current_flash_cached_values = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_values", None)) if flash_cached_keys is None or flash_cached_values is None: if current_flash_cached_keys is not None: current_flash_cached_keys.zero_() if current_flash_cached_values is not None: current_flash_cached_values.zero_() else: if current_flash_cached_keys is None or current_flash_cached_keys.shape != flash_cached_keys.shape: setattr(streaming_state, "_flash_cached_keys", flash_cached_keys.clone()) else: current_flash_cached_keys.copy_(flash_cached_keys) if current_flash_cached_values is None or current_flash_cached_values.shape != flash_cached_values.shape: setattr(streaming_state, "_flash_cached_values", flash_cached_values.clone()) else: current_flash_cached_values.copy_(flash_cached_values) def _graphed_decode_frame( self, codes: torch.Tensor, code_lengths: torch.Tensor, ) -> MossAudioTokenizerDecoderOutput: self._ensure_cuda_graph_buffers(codes.device) graph_input_codes = self._graph_input_codes graph_input_code_lengths = self._graph_input_code_lengths if graph_input_codes is None or graph_input_code_lengths is None: raise RuntimeError("CUDA graph buffers are unavailable.") num_quantizers = codes.shape[0] graph_input_codes_view = graph_input_codes[:num_quantizers] graph_input_codes_view.copy_(codes) graph_input_code_lengths.copy_(code_lengths) cuda_graph_key = (str(codes.device), self.max_batch_size, num_quantizers, self.model.compute_dtype_name) if self._cuda_graph is None or self._cuda_graph_key != cuda_graph_key: state_snapshots = self._snapshot_decoder_streaming_states() current_stream = torch.cuda.current_stream(device=codes.device) warmup_stream = torch.cuda.Stream(device=codes.device) warmup_stream.wait_stream(current_stream) with torch.cuda.stream(warmup_stream): _ = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths) current_stream.wait_stream(warmup_stream) self._restore_decoder_streaming_states(state_snapshots) cuda_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cuda_graph): decoder_output = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths) self._cuda_graph = cuda_graph self._cuda_graph_key = cuda_graph_key self._graph_output_audio = decoder_output.audio self._graph_output_audio_lengths = decoder_output.audio_lengths else: self._cuda_graph.replay() return MossAudioTokenizerDecoderOutput( audio=self._graph_output_audio, audio_lengths=self._graph_output_audio_lengths, ) def _reset_slot(self, slot_index: int) -> None: for streaming_state in self._decoder_streaming_states(): reset_mask = torch.zeros(streaming_state.batch_size, dtype=torch.bool, device=streaming_state.exec_mask.device) reset_mask[slot_index] = True streaming_state.reset(reset_mask) def _pack_logical_codes_to_physical_slots( self, request_ids: list[str | int], codes: torch.Tensor, code_lengths: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, list[int], torch.Tensor]: if request_ids != self.active_request_ids: raise ValueError(_INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE) if not request_ids: raise ValueError("`step()` requires at least one active request.") if codes.dim() == 2: codes = codes.unsqueeze(1) if codes.dim() != 3: raise ValueError(f"`codes` must be 3D with shape `(num_quantizers, batch_size, sequence_length)`, got {codes.shape}.") code_lengths = code_lengths.to(device=codes.device, dtype=torch.long) if code_lengths.dim() != 1: raise ValueError(f"`code_lengths` must be 1D with shape `(batch_size,)`, got {code_lengths.shape}.") num_quantizers, logical_batch_size, max_code_length = codes.shape if logical_batch_size != len(request_ids): raise ValueError( f"`codes.shape[1]` ({logical_batch_size}) must match len(`request_ids`) ({len(request_ids)})." ) if code_lengths.shape[0] != logical_batch_size: raise ValueError( f"`code_lengths.shape[0]` ({code_lengths.shape[0]}) must match len(`request_ids`) ({len(request_ids)})." ) if torch.any(code_lengths < 0): raise ValueError("`code_lengths` must be >= 0.") if torch.any(code_lengths > max_code_length): raise ValueError(f"`code_lengths` must be <= codes.shape[-1] ({max_code_length}).") packed_codes = codes.new_zeros((num_quantizers, self.max_batch_size, max_code_length)) packed_code_lengths = code_lengths.new_zeros((self.max_batch_size,)) logical_row_to_slot_index: list[int] = [] for logical_row_index, request_id in enumerate(request_ids): slot_index = self.request_id_to_slot_index[request_id] logical_row_to_slot_index.append(slot_index) row_length = int(code_lengths[logical_row_index].item()) if row_length > 0: packed_codes[:, slot_index, :row_length] = codes[:, logical_row_index, :row_length] packed_code_lengths[slot_index] = row_length return packed_codes, packed_code_lengths, logical_row_to_slot_index, code_lengths def _advance_request_progress( self, request_ids: list[str | int], code_lengths: torch.Tensor, audio_lengths: torch.Tensor, ) -> None: for logical_row_index, request_id in enumerate(request_ids): self.request_id_to_code_offset[request_id] += int(code_lengths[logical_row_index].item()) self.request_id_to_audio_offset[request_id] += int(audio_lengths[logical_row_index].item()) def step( self, request_ids: list[str | int], codes: torch.Tensor, code_lengths: torch.Tensor, ) -> tuple[list[str | int], torch.Tensor, torch.Tensor]: self._ensure_open() packed_codes, packed_code_lengths, logical_row_to_slot_index, logical_code_lengths = ( self._pack_logical_codes_to_physical_slots( request_ids=request_ids, codes=codes, code_lengths=code_lengths, ) ) max_step_length = int(packed_code_lengths.max().item()) if max_step_length <= 0: raise ValueError("`step()` requires at least one row with `code_length > 0`.") decoder_streaming_states = self._decoder_streaming_states() logical_audio_chunks: list[list[torch.Tensor]] = [[] for _ in request_ids] audio_device: torch.device | None = None audio_dtype: torch.dtype | None = None audio_num_channels: int | None = None try: for frame_index in range(max_step_length): frame_exec_mask = packed_code_lengths > frame_index for streaming_state in decoder_streaming_states: streaming_state.set_exec_mask(frame_exec_mask) frame_codes = packed_codes[:, :, frame_index : frame_index + 1] frame_code_lengths = frame_exec_mask.to(dtype=packed_code_lengths.dtype) if self._use_cuda_graph and frame_codes.is_cuda: decoder_output = self._graphed_decode_frame(frame_codes, frame_code_lengths) else: decoder_output = self.model._decode_frame(frame_codes, frame_code_lengths) if decoder_output.audio is None or decoder_output.audio_lengths is None: raise RuntimeError("Internal error: `_decode_frame` returned empty audio.") audio = decoder_output.audio audio_lengths = decoder_output.audio_lengths audio_device = audio.device audio_dtype = audio.dtype audio_num_channels = audio.shape[1] for logical_row_index, slot_index in enumerate(logical_row_to_slot_index): audio_length = int(audio_lengths[slot_index].item()) if audio_length <= 0: continue logical_audio_chunks[logical_row_index].append(audio[slot_index : slot_index + 1, :, :audio_length]) except Exception: self.close() raise finally: for streaming_state in decoder_streaming_states: streaming_state.set_exec_mask(torch.ones_like(streaming_state.exec_mask)) if audio_device is None or audio_dtype is None or audio_num_channels is None: raise RuntimeError("Internal error: `step()` produced no decoder outputs.") logical_audio_rows: list[torch.Tensor] = [] logical_audio_lengths: list[int] = [] for row_chunks in logical_audio_chunks: if row_chunks: row_audio = torch.cat(row_chunks, dim=-1) else: row_audio = torch.zeros((1, audio_num_channels, 0), device=audio_device, dtype=audio_dtype) logical_audio_rows.append(row_audio) logical_audio_lengths.append(row_audio.shape[-1]) audio_lengths = torch.tensor(logical_audio_lengths, device=audio_device, dtype=torch.long) max_audio_length = max(logical_audio_lengths) audio = torch.zeros( (len(request_ids), audio_num_channels, max_audio_length), device=audio_device, dtype=audio_dtype, ) for logical_row_index, row_audio in enumerate(logical_audio_rows): row_audio_length = row_audio.shape[-1] if row_audio_length > 0: audio[logical_row_index, :, :row_audio_length] = row_audio[0] logical_request_ids = list(request_ids) self._advance_request_progress( request_ids=logical_request_ids, code_lengths=logical_code_lengths, audio_lengths=audio_lengths, ) return logical_request_ids, audio, audio_lengths def remove(self, request_id: str | int) -> None: self._ensure_open() slot_index = self.request_id_to_slot_index.get(request_id) if slot_index is None or request_id not in self.active_request_ids: raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id)) if self.slot_is_free[slot_index] or self.slot_index_to_request_id[slot_index] != request_id: raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id)) self.active_request_ids.remove(request_id) self._reset_slot(slot_index) _ = self.request_id_to_slot_index.pop(request_id) self.slot_index_to_request_id[slot_index] = None self.slot_is_free[slot_index] = True _ = self.request_id_to_code_offset.pop(request_id, None) _ = self.request_id_to_audio_offset.pop(request_id, None) def close(self) -> None: if self._closed: return self._closed = True decode_streaming_exit_stack = self._decode_streaming_exit_stack self._decode_streaming_exit_stack = None try: if decode_streaming_exit_stack is not None: decode_streaming_exit_stack.close() finally: for module in self._flash_kvcache_attention_modules: module._use_flash_kvcache = False self._flash_kvcache_attention_modules = [] self._cuda_graph = None self._cuda_graph_key = None self._graph_input_codes = None self._graph_input_code_lengths = None self._graph_output_audio = None self._graph_output_audio_lengths = None if self.model._active_decode_session is self: self.model._active_decode_session = None # ============================================================================= # Normalization Layers # ============================================================================= class MossAudioTokenizerRMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__( self, dim: int, eps: float = 1e-5, dtype: torch.dtype | None = None, device=None, ): super().__init__() self.eps = eps self.dtype = dtype self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)) def forward(self, x: torch.Tensor): x_dtype = x.dtype if self.dtype is not None: x = x.to(self.dtype) var = self.eps + torch.mean(x**2, dim=-1, keepdim=True) alpha = self.alpha.to(var) if x.dim() == 2: alpha = alpha.view(1, -1) y = (x * (alpha * torch.rsqrt(var))).to(x_dtype) return y class MossAudioTokenizerLayerScale(nn.Module): """Layer scale from Touvron et al. 2021.""" def __init__( self, channels: int, init: float = 1e-4, channel_last: bool = True, device=None, dtype=None, ): super().__init__() self.channel_last = channel_last self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype)) def forward(self, x: torch.Tensor): if self.channel_last: return self.scale * x else: return self.scale[:, None] * x def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: """Create normalization module.""" if norm_type == "layer_norm": return nn.LayerNorm(dim, eps=1e-5, **kwargs) elif norm_type in {"rms_norm"}: return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs) elif norm_type in {"rms_norm_f32"}: kwargs.pop("dtype", None) return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs) else: raise ValueError(f"Unknown norm type: {norm_type}") # ============================================================================= # Rotary Position Embedding # ============================================================================= def apply_rope( q: torch.Tensor, k: torch.Tensor, offset: torch.Tensor, max_period: float = 10_000, time_before_heads: bool = False, ): """Apply rotary position embedding.""" if time_before_heads: B, T, H, D = q.shape else: B, H, T, D = q.shape if k.shape != q.shape: raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}") if D <= 0 or (D % 2) != 0: raise ValueError(f"RoPE requires an even last dimension, got D={D}") ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32) if time_before_heads: ts = ts.view(B, -1, 1, 1) else: ts = ts.view(B, 1, -1, 1) dims = q.shape[:-1] q = q.view(*dims, D // 2, 2) k = k.view(*dims, D // 2, 2) qr, qi = q[..., 0].float(), q[..., 1].float() kr, ki = k[..., 0].float(), k[..., 1].float() rotr = torch.cos(freqs * ts) roti = torch.sin(freqs * ts) qor = qr * rotr - qi * roti qoi = qr * roti + qi * rotr kor = kr * rotr - ki * roti koi = kr * roti + ki * rotr dtype = q.dtype qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) return qo.view(*dims, D), ko.view(*dims, D) def apply_rope_with_positions( q: torch.Tensor, k: torch.Tensor, positions: torch.Tensor, max_period: float = 10_000, ): """Apply rotary position embedding to packed `[N, H, D]` tensors.""" N, H, D = q.shape if k.shape != q.shape: raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}") if D <= 0 or (D % 2) != 0: raise ValueError(f"RoPE requires an even last dimension, got D={D}") ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) ts = positions.to(torch.float32).view(N, 1, 1) qr = q.float().view(N, H, D // 2, 2)[..., 0] qi = q.float().view(N, H, D // 2, 2)[..., 1] kr = k.float().view(N, H, D // 2, 2)[..., 0] ki = k.float().view(N, H, D // 2, 2)[..., 1] rotr = torch.cos(ts * freqs.view(1, 1, -1)) roti = torch.sin(ts * freqs.view(1, 1, -1)) qor = qr * rotr - qi * roti qoi = qr * roti + qi * rotr kor = kr * rotr - ki * roti koi = kr * roti + ki * rotr dtype = q.dtype qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) return qo.view(N, H, D), ko.view(N, H, D) class MossAudioTokenizerRotaryEmbedding(nn.Module): """Rotary positional embedding (RoPE).""" def __init__(self, max_period: float = 10000.0): super().__init__() self.max_period = max_period def forward( self, q: torch.Tensor, k: torch.Tensor, offset: torch.Tensor, time_before_heads: bool = False, ): return apply_rope(q, k, offset, self.max_period, time_before_heads) # ============================================================================= # Gating Modules # ============================================================================= class MossAudioTokenizerActivationGating(nn.Module): """Gating FFN layer with activation.""" def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): super().__init__() if dim_feedforward == 4 * dim: hidden = (21 * dim) // 8 else: hidden = (2 * dim_feedforward) // 3 self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) self.activation = activation def forward(self, x: torch.Tensor): x = self.linear_in(x) B, T, _ = x.shape x = x.view(B, T, 2, -1) x = self.activation(x[..., 0, :]) * x[..., 1, :] x = self.linear_out(x) return x def _get_activation(name: str): if name in ["sigmoid", "tanh", "relu"]: return getattr(torch, name) elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: return getattr(F, name) elif name == "identity": return nn.Identity() else: raise ValueError(f"Unknown activation {name}") def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module: return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs) # ============================================================================= # Positional Embeddings # ============================================================================= def create_sin_embedding( positions: torch.Tensor, dim: int, max_period: float = 10000, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Create sinusoidal positional embedding with shape [..., C].""" if dim % 2 != 0: raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}") half_dim = dim // 2 if half_dim <= 1: raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}") if positions.dim() == 0: positions = positions.view(1) positions = positions.to(dtype).unsqueeze(-1) adim = torch.arange(half_dim, device=positions.device, dtype=dtype) max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) def pack_padded_sequence( x: torch.Tensor, input_lengths: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Pack a padded `[B, T, D]` tensor into `[N, D]` plus metadata.""" batch_size, max_seqlen, _ = x.shape positions = torch.arange(max_seqlen, device=x.device, dtype=torch.long) valid_mask = positions.view(1, max_seqlen) < input_lengths.view(batch_size, 1) packed_x = x[valid_mask] cu_seqlens = torch.zeros(batch_size + 1, device=x.device, dtype=torch.int32) cu_seqlens[1:] = torch.cumsum(input_lengths.to(torch.int32), dim=0) position_ids = positions.view(1, max_seqlen).expand(batch_size, -1)[valid_mask] return packed_x, valid_mask, cu_seqlens, position_ids def unpack_packed_sequence( packed_x: torch.Tensor, valid_mask: torch.Tensor, batch_size: int, max_seqlen: int, ) -> torch.Tensor: """Unpack a packed `[N, D]` tensor back into `[B, T, D]`.""" output = packed_x.new_zeros((batch_size, max_seqlen, packed_x.shape[-1])) output[valid_mask] = packed_x return output # ============================================================================= # KV Cache for Attention # ============================================================================= class KVCacheResult: """Container for KV cache results that supports tuple unpacking.""" __slots__ = ("keys", "values", "positions") def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor): self.keys = keys self.values = values self.positions = positions def __iter__(self): """Allow unpacking as (keys, values, positions).""" return iter((self.keys, self.values, self.positions)) @staticmethod def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult: B, H, T, D = keys.shape positions = torch.arange(T, device=keys.device, dtype=torch.long) return KVCacheResult(keys, values, positions.expand(B, -1)) class RingKVCache: """Efficient streaming KVCache compatible with CUDA Graph.""" def __init__( self, batch_size: int, num_heads: int, dim_per_head: int, capacity: int, respect_exec_mask: bool = True, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.bfloat16, ): self.capacity = capacity self.cache = torch.zeros( (2, batch_size, num_heads, capacity, dim_per_head), device=device, dtype=dtype, ) self.respect_exec_mask = respect_exec_mask if self.respect_exec_mask: self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long) else: self.end_offset = torch.zeros(1, device=device, dtype=torch.long) def reset(self, reset_mask: torch.Tensor) -> None: self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset) def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult: B, H, T, D = k.shape if T <= 0: raise ValueError(f"Expected T > 0, got T={T}") indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) indexes = indexes + self.end_offset.view(-1, 1) indexes = indexes % self.capacity if self.respect_exec_mask: this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D) self.cache[0].scatter_(2, this_indexes, k) self.cache[1].scatter_(2, this_indexes, v) else: self.cache[0].index_copy_(2, indexes[0], k) self.cache[1].index_copy_(2, indexes[0], v) keys = self.cache[0] values = self.cache[1] indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long) last_offset = self.end_offset.view(-1, 1) + T - 1 end_index = last_offset % self.capacity delta = indexes - end_index positions = torch.where( delta <= 0, last_offset + delta, last_offset + delta - self.capacity, ) if self.respect_exec_mask: self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset) else: self.end_offset.add_(T) invalid = indexes >= self.end_offset.view(-1, 1) positions = torch.where(invalid, torch.full_like(positions, -1), positions) return KVCacheResult(keys, values, positions) # ============================================================================= # Multi-Head Attention # ============================================================================= _sync_module_proxy() @dataclass class MHAState(StreamingState): cached_keys: torch.Tensor | None cached_values: torch.Tensor | None cached_positions: torch.Tensor | None offset: torch.Tensor def reset(self, reset_mask: torch.Tensor): super().reset(reset_mask) self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset) if self.cached_positions is not None: self.cached_positions[reset_mask] = -1 if self.cached_keys is not None: self.cached_keys[reset_mask] = 0 if self.cached_values is not None: self.cached_values[reset_mask] = 0 def apply_weights_per_step( modules: nn.ModuleList, schedule: list[int] | None, x: torch.Tensor, offset: int | None, ) -> torch.Tensor: """Apply different weights for each time step.""" if len(modules) == 1: return modules[0](x) if offset is None: raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).") if x.dim() != 3: raise ValueError( f"Per-step weights require a dense `[B, T, C]` tensor when len(modules) > 1, got shape {tuple(x.shape)}." ) ys = [] B, T, C = x.shape for t in range(T): module_index = t + offset if schedule is not None: if module_index >= len(schedule) or module_index < 0: raise ValueError( f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})." ) module_index = schedule[module_index] if module_index >= len(modules) or module_index < 0: raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.") y = modules[module_index](x[:, t : t + 1]) ys.append(y) return torch.cat(ys, 1) class MossAudioTokenizerMultiheadAttention(StreamingModule): """Multi-head attention with streaming support.""" def __init__( self, embed_dim: int, num_heads: int, causal: bool = False, context: int | None = None, rope: MossAudioTokenizerRotaryEmbedding | None = None, attention_implementation: str = "sdpa", device=None, dtype=None, ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.embed_dim = embed_dim self.causal = causal self.context = context self.rope = rope self.num_heads = num_heads if attention_implementation not in SUPPORTED_ATTENTION_IMPLEMENTATIONS: raise ValueError( f"Unsupported attention_implementation={attention_implementation!r}. " f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}." ) self.attention_implementation = attention_implementation self._use_flash_kvcache = False self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) def set_attention_implementation(self, attention_implementation: str) -> None: if attention_implementation not in SUPPORTED_ATTENTION_IMPLEMENTATIONS: raise ValueError( f"Unsupported attention_implementation={attention_implementation!r}. " f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}." ) self.attention_implementation = attention_implementation @staticmethod def _load_hook(module, state_dict, prefix, *_): mappings = { "in_proj_weight": "in_proj.weight", "in_projs.0.weight": "in_proj.weight", "out_projs.0.weight": "out_proj.weight", } for suffix in ["", "_scb"]: for source, target in mappings.items(): this_source = prefix + source + suffix if this_source in state_dict: state_dict[prefix + target + suffix] = state_dict.pop(this_source) def _init_streaming_state(self, batch_size: int) -> MHAState: device = cast(torch.device, self.in_proj.weight.device) return MHAState( batch_size, device, cached_keys=None, cached_values=None, cached_positions=None, offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long), ) def _supports_flash_attention(self, device: torch.device, dtype: torch.dtype) -> bool: return _has_flash_attn() and device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} def _get_backend_check_dtype(self, x: torch.Tensor) -> torch.dtype: if x.device.type != "cuda": return x.dtype try: autocast_enabled = torch.is_autocast_enabled("cuda") except TypeError: autocast_enabled = torch.is_autocast_enabled() if not autocast_enabled: return x.dtype try: return torch.get_autocast_dtype("cuda") except TypeError: return torch.get_autocast_gpu_dtype() def resolve_attention_implementation(self, x: torch.Tensor, is_streaming: bool) -> str: if self.attention_implementation == "sdpa": return "sdpa" backend_dtype = self._get_backend_check_dtype(x) if self._supports_flash_attention(x.device, backend_dtype): return "flash_attention_2" if self.attention_implementation == "flash_attention_2": logger.warning_once( "Falling back to SDPA because flash_attention_2 is unavailable for device=%s dtype=%s " "(HAS_FLASH_ATTN=%s).", x.device, backend_dtype, _has_flash_attn(), ) return "sdpa" def _project_qkv( self, x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dim_per_head = self.embed_dim // self.num_heads if x.dim() == 3: projected = self.in_proj(x) projected = projected.reshape(x.shape[0], x.shape[1], 3, self.num_heads, dim_per_head).permute( 2, 0, 3, 1, 4 ) return projected[0], projected[1], projected[2] if x.dim() == 2: projected = self.in_proj(x) projected = projected.view(x.shape[0], 3, self.num_heads, dim_per_head) return projected[:, 0], projected[:, 1], projected[:, 2] raise ValueError(f"Expected a 2D or 3D tensor, got shape {tuple(x.shape)}") def _apply_dense_rope(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.rope is None: return q, k offset = torch.zeros(q.shape[0], device=q.device, dtype=torch.long) return self.rope(q, k, offset, time_before_heads=False) def _apply_packed_rope( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if self.rope is None: return q, k return apply_rope_with_positions(q, k, position_ids, max_period=self.rope.max_period) def _ensure_streaming_cache( self, state: MHAState, batch_size: int, device: torch.device, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: head_dim = self.embed_dim // self.num_heads cache_length = 0 if self.context is None else self.context if state.cached_keys is None or state.cached_values is None or state.cached_positions is None: state.cached_keys = torch.zeros( (batch_size, self.num_heads, cache_length, head_dim), device=device, dtype=dtype, ) state.cached_values = torch.zeros_like(state.cached_keys) state.cached_positions = torch.full( (batch_size, cache_length), -1, device=device, dtype=torch.long, ) else: if state.cached_keys.device != device or state.cached_keys.dtype != dtype: state.cached_keys = state.cached_keys.to(device=device, dtype=dtype) if state.cached_values.device != device or state.cached_values.dtype != dtype: state.cached_values = state.cached_values.to(device=device, dtype=dtype) if state.cached_positions.device != device: state.cached_positions = state.cached_positions.to(device=device) return state.cached_keys, state.cached_values, state.cached_positions def _ensure_flash_kvcache( self, state: MHAState, batch_size: int, device: torch.device, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: if self.context is None: raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.") head_dim = self.embed_dim // self.num_heads flash_cached_keys = cast(torch.Tensor | None, getattr(state, "_flash_cached_keys", None)) flash_cached_values = cast(torch.Tensor | None, getattr(state, "_flash_cached_values", None)) if flash_cached_keys is None or flash_cached_values is None: flash_cached_keys = torch.zeros( (batch_size, self.context, self.num_heads, head_dim), device=device, dtype=dtype, ) flash_cached_values = torch.zeros_like(flash_cached_keys) else: if flash_cached_keys.device != device or flash_cached_keys.dtype != dtype: flash_cached_keys = flash_cached_keys.to(device=device, dtype=dtype) if flash_cached_values.device != device or flash_cached_values.dtype != dtype: flash_cached_values = flash_cached_values.to(device=device, dtype=dtype) setattr(state, "_flash_cached_keys", flash_cached_keys) setattr(state, "_flash_cached_values", flash_cached_values) return flash_cached_keys, flash_cached_values def _build_streaming_kv( self, cached_k: torch.Tensor, cached_v: torch.Tensor, cached_pos: torch.Tensor, k_cur: torch.Tensor, v_cur: torch.Tensor, pos_q: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: k_all = torch.cat([cached_k, k_cur], dim=2) v_all = torch.cat([cached_v, v_cur], dim=2) pos_k = torch.cat([cached_pos, pos_q], dim=1) return k_all, v_all, pos_k def _update_streaming_cache( self, state: MHAState, cached_k: torch.Tensor, cached_v: torch.Tensor, cached_pos: torch.Tensor, k_all: torch.Tensor, v_all: torch.Tensor, pos_k: torch.Tensor, ) -> None: exec_mask = state.exec_mask.view(-1, 1, 1, 1) exec_mask_pos = state.exec_mask.view(-1, 1) if self.context is None: if not bool(state.exec_mask.all().item()): raise RuntimeError("Streaming exec_mask with context=None is not supported.") state.cached_keys = k_all.contiguous() state.cached_values = v_all.contiguous() state.cached_positions = pos_k.contiguous() return assert state.cached_keys is not None assert state.cached_values is not None assert state.cached_positions is not None new_cached_k = k_all[:, :, -self.context :, :].contiguous() new_cached_v = v_all[:, :, -self.context :, :].contiguous() new_cached_pos = pos_k[:, -self.context :].contiguous() state.cached_keys.copy_(torch.where(exec_mask, new_cached_k, cached_k)) state.cached_values.copy_(torch.where(exec_mask, new_cached_v, cached_v)) state.cached_positions.copy_(torch.where(exec_mask_pos, new_cached_pos, cached_pos)) def _build_streaming_sdpa_bias(self, pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor: delta = pos_q[:, :, None] - pos_k[:, None, :] attn_bias = (pos_k[:, None, :] >= 0) & (delta >= 0) if self.context is not None: attn_bias = attn_bias & (delta < self.context) return attn_bias[:, None, :, :] def _build_non_streaming_sdpa_bias( self, input_lengths: torch.Tensor, max_seqlen: int, device: torch.device, ) -> torch.Tensor: positions = torch.arange(max_seqlen, device=device, dtype=torch.long) valid_k = positions.view(1, 1, max_seqlen) < input_lengths.view(-1, 1, 1) if not self.causal and self.context is None: return valid_k[:, None, :, :].expand(-1, 1, max_seqlen, -1) delta = positions.view(1, max_seqlen, 1) - positions.view(1, 1, max_seqlen) attn_bias = torch.ones((1, max_seqlen, max_seqlen), device=device, dtype=torch.bool) if self.causal: attn_bias = attn_bias & (delta >= 0) if self.context is not None: attn_bias = attn_bias & (delta < self.context) return (attn_bias & valid_k)[:, None, :, :] def _run_flash_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, ) -> torch.Tensor: flash_attn_varlen_func = _get_flash_attn_varlen_func() if flash_attn_varlen_func is None: raise RuntimeError("flash-attn is not installed.") window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1) return cast( torch.Tensor, flash_attn_varlen_func( q.contiguous(), k.contiguous(), v.contiguous(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=self.causal, window_size=window_size, ), ) def _forward_streaming_sdpa(self, x: torch.Tensor, state: MHAState) -> torch.Tensor: batch_size, chunk_length, _ = x.shape q, k_cur, v_cur = self._project_qkv(x) if self.rope is not None: q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False) pos_q = state.offset.view(-1, 1) + torch.arange(chunk_length, device=x.device, dtype=torch.long).view(1, -1) cached_k, cached_v, cached_pos = self._ensure_streaming_cache(state, batch_size, k_cur.device, k_cur.dtype) k_all, v_all, pos_k = self._build_streaming_kv(cached_k, cached_v, cached_pos, k_cur, v_cur, pos_q) attn_bias = self._build_streaming_sdpa_bias(pos_q, pos_k) out = F.scaled_dot_product_attention(q, k_all, v_all, attn_bias, dropout_p=0.0) out = out.transpose(1, 2).reshape(batch_size, chunk_length, self.embed_dim) self._update_streaming_cache(state, cached_k, cached_v, cached_pos, k_all, v_all, pos_k) state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset) return out def _forward_streaming_flash(self, x: torch.Tensor, state: MHAState) -> torch.Tensor: batch_size, chunk_length, _ = x.shape q, k_cur, v_cur = self._project_qkv(x) if self.rope is not None: q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False) pos_q = state.offset.view(-1, 1) + torch.arange(chunk_length, device=x.device, dtype=torch.long).view(1, -1) cached_k, cached_v, cached_pos = self._ensure_streaming_cache(state, batch_size, k_cur.device, k_cur.dtype) k_all, v_all, pos_k = self._build_streaming_kv(cached_k, cached_v, cached_pos, k_cur, v_cur, pos_q) q_chunks = [] k_chunks = [] v_chunks = [] cu_q = [0] cu_k = [0] max_kv_len = 0 for batch_idx in range(batch_size): valid_k = pos_k[batch_idx] >= 0 q_i = q[batch_idx].transpose(0, 1).contiguous() k_i = k_all[batch_idx, :, valid_k, :].transpose(0, 1).contiguous() v_i = v_all[batch_idx, :, valid_k, :].transpose(0, 1).contiguous() q_chunks.append(q_i) k_chunks.append(k_i) v_chunks.append(v_i) cu_q.append(cu_q[-1] + q_i.shape[0]) cu_k.append(cu_k[-1] + k_i.shape[0]) max_kv_len = max(max_kv_len, int(k_i.shape[0])) out_flat = self._run_flash_attention( torch.cat(q_chunks, dim=0), torch.cat(k_chunks, dim=0), torch.cat(v_chunks, dim=0), torch.tensor(cu_q, device=x.device, dtype=torch.int32), torch.tensor(cu_k, device=x.device, dtype=torch.int32), max_seqlen_q=chunk_length, max_seqlen_k=max_kv_len, ) outputs = [] start = 0 for _ in range(batch_size): outputs.append(out_flat[start : start + chunk_length].transpose(0, 1).contiguous()) start += chunk_length out = torch.stack(outputs, dim=0) out = out.transpose(1, 2).reshape(batch_size, chunk_length, self.embed_dim) self._update_streaming_cache(state, cached_k, cached_v, cached_pos, k_all, v_all, pos_k) state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset) return out def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor: flash_attn_with_kvcache = _get_flash_attn_with_kvcache() if self.context is None: raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.") if flash_attn_with_kvcache is None: raise RuntimeError("flash-attn is not installed.") batch_size, chunk_length, _ = x.shape q, k_cur, v_cur = self._project_qkv(x) if self.rope is not None: q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False) q = q.transpose(1, 2).contiguous() k_cur = k_cur.transpose(1, 2).contiguous() v_cur = v_cur.transpose(1, 2).contiguous() exec_mask = state.exec_mask.view(batch_size, 1, 1, 1).to(dtype=k_cur.dtype) k_cur = k_cur * exec_mask v_cur = v_cur * exec_mask k_cache, v_cache = self._ensure_flash_kvcache(state, batch_size, k_cur.device, k_cur.dtype) cache_seqlens = state.offset.clamp(max=self.context).to(torch.int32) window_size = (self.context - 1, 0) out = cast( torch.Tensor, flash_attn_with_kvcache( q, k_cache, v_cache, k=k_cur, v=v_cur, cache_seqlens=cache_seqlens, causal=True, window_size=window_size, ), ) out = out.reshape(batch_size, chunk_length, self.embed_dim) state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset) return out def _forward_non_streaming_sdpa(self, x: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor: batch_size, max_seqlen, _ = x.shape q, k, v = self._project_qkv(x) q, k = self._apply_dense_rope(q, k) attn_bias = self._build_non_streaming_sdpa_bias(input_lengths, max_seqlen, x.device) out = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0) valid_q = (torch.arange(max_seqlen, device=x.device).view(1, max_seqlen) < input_lengths.view(-1, 1)).view( batch_size, 1, max_seqlen, 1 ) # Some SDPA backends return NaNs for fully-masked padded query rows in local-causal attention. # Multiplying by zero is not sufficient because NaN * 0 is still NaN; use torch.where so padded # rows are materialized as exact zeros before they can leak into later layers as masked K/V values. out = torch.where(valid_q, out, torch.zeros((), device=out.device, dtype=out.dtype)) return out.transpose(1, 2).reshape(batch_size, max_seqlen, self.embed_dim) def _forward_non_streaming_flash( self, x: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, position_ids: torch.Tensor, ) -> torch.Tensor: q, k, v = self._project_qkv(x) q, k = self._apply_packed_rope(q, k, position_ids) out = self._run_flash_attention(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) return out.reshape(x.shape[0], self.embed_dim) def forward( self, query: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, position_ids: torch.Tensor | None = None, input_lengths: torch.Tensor | None = None, ): state = cast(MHAState | None, self._streaming_state) backend = self.resolve_attention_implementation(query, is_streaming=state is not None) if state is not None: if query.dim() != 3: raise ValueError(f"Streaming attention expects a 3D tensor, got shape {tuple(query.shape)}") if backend == "flash_attention_2" and self._use_flash_kvcache: out = self._forward_streaming_flash_kvcache(query, state) elif backend == "flash_attention_2": out = self._forward_streaming_flash(query, state) else: out = self._forward_streaming_sdpa(query, state) return self.out_proj(out) if backend == "flash_attention_2": if query.dim() != 2: raise ValueError(f"Packed flash attention expects a 2D tensor, got shape {tuple(query.shape)}") if cu_seqlens is None or max_seqlen is None or position_ids is None: raise ValueError("Packed flash attention requires cu_seqlens, max_seqlen, and position_ids.") out = self._forward_non_streaming_flash(query, cu_seqlens, max_seqlen, position_ids) return self.out_proj(out) if query.dim() != 3: raise ValueError(f"Non-streaming SDPA expects a 3D tensor, got shape {tuple(query.shape)}") if input_lengths is None: raise ValueError("Non-streaming SDPA requires input_lengths.") out = self._forward_non_streaming_sdpa(query, input_lengths) return self.out_proj(out) # ============================================================================= # Transformer Layer # ============================================================================= _sync_module_proxy() @dataclass class LayerState(StreamingState): pass class MossAudioTokenizerTransformerLayer(StreamingModule): """Transformer layer with streaming support.""" def __init__( self, d_model: int, num_heads: int, dim_feedforward: int = 2048, causal: bool = False, context: int | None = None, rope: MossAudioTokenizerRotaryEmbedding | None = None, attention_implementation: str = "sdpa", norm: str = "layer_norm", layer_scale: float | None = None, gating: str = "none", device=None, dtype=None, ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.self_attn = MossAudioTokenizerMultiheadAttention( embed_dim=d_model, num_heads=num_heads, causal=causal, context=context, rope=rope, attention_implementation=attention_implementation, **factory_kwargs, ) self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) if gating == "none": self.ffn = nn.Sequential( nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs), nn.GELU(), nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs), ) else: self.ffn = make_gating(gating, d_model, dim_feedforward, **factory_kwargs) if layer_scale is None: self.layer_scale_1 = nn.Identity() self.layer_scale_2 = nn.Identity() else: self.layer_scale_1 = MossAudioTokenizerLayerScale( channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) ) self.layer_scale_2 = MossAudioTokenizerLayerScale( channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) ) self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) @staticmethod def _load_hook(module, state_dict, prefix, *_): mappings = { "linear1.weight": "ffn.0.weight", "linear2.weight": "ffn.2.weight", "linear1.bias": "ffn.0.bias", "linear2.bias": "ffn.2.bias", } for source, target in mappings.items(): this_source = prefix + source if this_source in state_dict: state_dict[prefix + target] = state_dict.pop(this_source) def _init_streaming_state(self, batch_size: int) -> LayerState: device = next(iter(self.parameters())).device return LayerState(batch_size, device) def forward(self, x: torch.Tensor, **kwargs): residual = x x = self.norm1(x) x = residual.to(x) + self.layer_scale_1(self.self_attn(x, **kwargs)) residual = x x = self.norm2(x) x = residual.to(x) + self.layer_scale_2(self.ffn(x)) return x # ============================================================================= # Streaming Transformer # ============================================================================= _sync_module_proxy() @dataclass class TransformerState(StreamingState): offsets: torch.Tensor def reset(self, reset_mask: torch.Tensor): super().reset(reset_mask) self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets) class MossAudioTokenizerTransformer(StreamingModule): """Transformer with streaming/causal support.""" def __init__( self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, causal: bool = False, context: int | None = None, positional_embedding: str = "sin", max_period: float = 10_000, positional_scale: float = 1.0, attention_implementation: str = "sdpa", device=None, dtype=None, **kwargs, ): super().__init__() if d_model % num_heads != 0: raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}") self.positional_embedding = positional_embedding self.max_period = max_period self.positional_scale = positional_scale self.rope: MossAudioTokenizerRotaryEmbedding | None = None if positional_embedding in {"rope", "sin_rope"}: self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period) self.layers = nn.ModuleList() for _ in range(num_layers): self.layers.append( MossAudioTokenizerTransformerLayer( d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, causal=causal, context=context, rope=self.rope, attention_implementation=attention_implementation, device=device, dtype=dtype, **kwargs, ) ) def _init_streaming_state(self, batch_size: int) -> TransformerState: device = next(self.parameters()).device return TransformerState( batch_size, device, offsets=torch.zeros(batch_size, device=device, dtype=torch.long), ) def resolve_attention_implementation(self, x: torch.Tensor) -> str: if len(self.layers) == 0: return "sdpa" first_layer = cast(MossAudioTokenizerTransformerLayer, self.layers[0]) return first_layer.self_attn.resolve_attention_implementation(x, is_streaming=self._streaming_state is not None) def set_attention_implementation(self, attention_implementation: str) -> None: for layer in self.layers: cast(MossAudioTokenizerTransformerLayer, layer).self_attn.set_attention_implementation(attention_implementation) def forward(self, x: torch.Tensor, **kwargs): C = x.shape[-1] state = self._streaming_state if x.dim() == 3: B, T, _ = x.shape offsets = ( torch.zeros(1, dtype=torch.long, device=x.device) if state is None else ( state.offsets if isinstance(state, TransformerState) else torch.zeros(1, dtype=torch.long, device=x.device) ) ) else: B = 0 T = 0 offsets = None if self.positional_embedding in {"sin", "sin_rope"}: if x.dim() == 3: positions = torch.arange(T, device=x.device).view(1, -1) + cast(torch.Tensor, offsets).view(-1, 1) else: position_ids = kwargs.get("position_ids") if position_ids is None: raise ValueError("Packed transformer inputs require position_ids when using sinusoidal embeddings.") positions = position_ids pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) x = x + self.positional_scale * pos_emb for layer in self.layers: x = layer(x, **kwargs) if state is not None and x.dim() == 3: assert isinstance(state, TransformerState) state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets) return x class MossAudioTokenizerProjectedTransformer(StreamingContainer): """Transformer with input/output projections.""" def __init__( self, input_dimension: int, output_dimension: int, d_model: int, *, conv_layout: bool = False, module_type: str, **kwargs, ): super().__init__() self.module_type = module_type self.downsample_ratio: int = 1 self.input_dimension = input_dimension self.output_dimension = output_dimension self.input_proj = nn.Linear(input_dimension, d_model, bias=False) self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs) self.conv_layout = conv_layout self.output_proj = nn.Linear(d_model, output_dimension, bias=False) def set_attention_implementation(self, attention_implementation: str) -> None: self.transformer.set_attention_implementation(attention_implementation) def forward(self, x, input_lengths, **kwargs): x = self.input_proj(x.transpose(1, 2)) # (B, D, T) -> (B, T, D) if not self.is_streaming and self.transformer.resolve_attention_implementation(x) == "flash_attention_2": batch_size, max_seqlen, _ = x.shape if max_seqlen > 0 and bool(input_lengths.any().item()): max_valid_seqlen = int(input_lengths.max().item()) packed_x, valid_mask, cu_seqlens, position_ids = pack_padded_sequence(x, input_lengths) packed_x = self.transformer( packed_x, cu_seqlens=cu_seqlens, max_seqlen=max_valid_seqlen, position_ids=position_ids, input_lengths=input_lengths, **kwargs, ) x = unpack_packed_sequence(packed_x, valid_mask, batch_size, max_seqlen) else: x = x.new_zeros(x.shape) else: x = self.transformer(x, input_lengths=input_lengths, **kwargs) x = self.output_proj(x).transpose(1, 2) # (B, T, D) -> (B, D, T) return x, input_lengths # ============================================================================= # Patched Pretransform Module # ============================================================================= class MossAudioTokenizerPatchedPretransform(nn.Module): """Patching module for downsampling/upsampling.""" def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs): super().__init__() self.patch_size = patch_size self.downsample_ratio: int = patch_size self.is_downsample = is_downsample self.module_type = module_type def encode(self, x, input_lengths): b, d, _ = x.shape h = self.patch_size x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1) # We pad the input waveform to a multiple of `downsample_rate` before applying the encoder. # Use a ceil division to match that padding and avoid dropping the last (partially padded) frame. output_lengths = input_lengths // self.patch_size return x, output_lengths def decode(self, x, input_lengths): b, dh, l = x.shape h = self.patch_size d = dh // h x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h) output_lengths = input_lengths * self.patch_size return x, output_lengths def forward(self, x, input_lengths): if self.is_downsample: return self.encode(x, input_lengths) else: return self.decode(x, input_lengths) # ============================================================================= # Vector Quantization # ============================================================================= def WNConv1d(*args, **kwargs): """Weight-normalized Conv1d.""" return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs)) def remap_weight_norm_state_dict_keys(state_dict: dict[str, torch.Tensor], prefix: str) -> None: replacements = ( (".weight_g", ".parametrizations.weight.original0"), (".weight_v", ".parametrizations.weight.original1"), ) for key in list(state_dict.keys()): if not key.startswith(prefix): continue new_key = key for source, target in replacements: new_key = new_key.replace(source, target) if new_key != key: state_dict[new_key] = state_dict.pop(key) class MossAudioTokenizerVectorQuantize(nn.Module): """Single codebook vector quantization (inference only).""" def __init__( self, input_dim: int, codebook_size: int, codebook_dim: int, **kwargs, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim if input_dim != codebook_dim: self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) else: self.in_proj = nn.Identity() self.out_proj = nn.Identity() self.codebook = nn.Embedding(codebook_size, codebook_dim) @torch.no_grad() def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: z: Input tensor of shape (B, D, T) Returns: z_q: Quantized tensor of shape (B, D, T) indices: Code indices of shape (B, T) z_e: Encoded tensor before quantization """ z = z.float() z_e = self.in_proj(z).float() encodings = z_e.transpose(1, 2).reshape(-1, z_e.shape[1]) codebook_weight = self.codebook.weight dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook_weight.float().t() + codebook_weight.float().pow(2).sum(1, keepdim=True).t() ) indices = (-dist).max(1)[1] indices = indices.reshape(z.size(0), -1) z_q = self.decode_code(indices) z_q = self.out_proj(z_q).float() return z_q, indices, z_e def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: """Decode code indices to embeddings.""" return self.codebook(embed_id).transpose(1, 2).float() class MossAudioTokenizerLFQ(nn.Module): """LFQ (inference-only) used by ResidualLFQ.""" def __init__( self, input_dim: int, codebook_size: int, codebook_dim: int, **kwargs, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim if self.input_dim != self.codebook_dim: self.in_proj = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) self.out_proj = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) else: self.in_proj = nn.Identity() self.out_proj = nn.Identity() self.codebook = nn.Embedding(codebook_size, codebook_dim) @torch.no_grad() def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Quantize z into codebook vectors.""" z = z.float() z_e = self.in_proj(z).float() z_q, indices = self.decode_latents(z_e) z_q = (z_e + (z_q - z_e).detach()).float() z_q = self.out_proj(z_q).float() return z_q, indices, z_e def embed_code(self, embed_id: torch.Tensor) -> torch.Tensor: return F.embedding(embed_id, self.codebook.weight) def decode_code_wo_out_proj(self, embed_id: torch.Tensor) -> torch.Tensor: return self.embed_code(embed_id).transpose(1, 2) def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: z_q = self.decode_code_wo_out_proj(embed_id).float() z_q = self.out_proj(z_q).float() return z_q def decode_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Match training LFQ: L2-normalize then argmin squared distance.""" encodings = latents.transpose(1, 2).reshape(-1, latents.shape[1]).float() codebook = self.codebook.weight.float() encodings = F.normalize(encodings) codebook = F.normalize(codebook) dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) indices = (-dist).max(1)[1] indices = indices.reshape(latents.size(0), -1) z_q = self.decode_code_wo_out_proj(indices).float() return z_q, indices class MossAudioTokenizerResidualVQ(nn.Module): """Residual Vector Quantization (inference only).""" def __init__( self, input_dim: int = 1024, rvq_dim: int | None = None, output_dim: int | None = None, num_quantizers: int = 32, codebook_size: int = 1024, codebook_dim: int = 8, **kwargs, ): super().__init__() self.input_dim = input_dim self.rvq_dim = rvq_dim or input_dim self.output_dim = output_dim or input_dim self.num_quantizers = num_quantizers self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.input_proj = ( WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() ) self.output_proj = ( WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) if self.rvq_dim != self.output_dim else nn.Identity() ) self.quantizers = nn.ModuleList( [ MossAudioTokenizerVectorQuantize( input_dim=self.rvq_dim, codebook_size=codebook_size, codebook_dim=codebook_dim, **kwargs, ) for _ in range(num_quantizers) ] ) self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) @staticmethod def _load_hook(module, state_dict, prefix, *_): remap_weight_norm_state_dict_keys(state_dict, prefix) @torch.no_grad() def forward( self, z: torch.Tensor, input_length: torch.Tensor, n_quantizers: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: z: Input tensor of shape (B, D, T) input_length: Valid lengths for each sample (B,) n_quantizers: Number of quantizers to use Returns: quantized_out: Quantized output (B, D, T) all_indices: All code indices (N, B, T) output_length: Output lengths (B,) """ with disable_cuda_autocast(): z = self.input_proj(z).float() batch_size, _, max_time = z.shape mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) quantized_out = torch.zeros_like(z, dtype=torch.float32) residual = z.clone().float() all_indices = [] n_quantizers = n_quantizers or self.num_quantizers for i, quantizer in enumerate(self.quantizers): if i >= n_quantizers: break masked_residual = residual * mask.unsqueeze(1) z_q_i, indices_i, _ = quantizer(masked_residual.float()) update_mask = mask.unsqueeze(1) quantized_out = quantized_out + z_q_i * update_mask residual = residual - z_q_i * update_mask all_indices.append(indices_i) all_indices = torch.stack(all_indices) # (N, B, T) quantized_out = self.output_proj(quantized_out.float()).float() return quantized_out, all_indices, input_length def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: """Decode codes from multiple quantizers to embeddings.""" with disable_cuda_autocast(): nq, B, T = codes.shape emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) for i, quantizer in enumerate(self.quantizers[:nq]): quantizer = cast(MossAudioTokenizerVectorQuantize, quantizer) quantized_i = quantizer.decode_code(codes[i]).float() emb += quantized_i emb = self.output_proj(emb.float()).float() return emb class MossAudioTokenizerResidualLFQ(nn.Module): """Residual LFQ (inference only).""" def __init__( self, input_dim: int = 1024, rvq_dim: int | None = None, output_dim: int | None = None, num_quantizers: int = 32, codebook_size: int = 1024, codebook_dim: int = 8, **kwargs, ): super().__init__() self.input_dim = input_dim self.rvq_dim = rvq_dim or input_dim self.output_dim = output_dim or input_dim self.num_quantizers = num_quantizers self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.input_proj = ( WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() ) self.output_proj = ( WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) if self.rvq_dim != self.output_dim else nn.Identity() ) self.quantizers = nn.ModuleList( [ MossAudioTokenizerLFQ( input_dim=self.rvq_dim, codebook_size=codebook_size, codebook_dim=codebook_dim, **kwargs, ) for _ in range(num_quantizers) ] ) self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) @staticmethod def _load_hook(module, state_dict, prefix, *_): remap_weight_norm_state_dict_keys(state_dict, prefix) @torch.no_grad() def forward( self, z: torch.Tensor, input_length: torch.Tensor, n_quantizers: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Inference quantization.""" with disable_cuda_autocast(): z = self.input_proj(z).float() batch_size, _, max_time = z.shape mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) quantized_out = torch.zeros_like(z, dtype=torch.float32) residual = z.clone().float() all_indices = [] n_quantizers = n_quantizers or self.num_quantizers for i, quantizer in enumerate(self.quantizers): if i >= n_quantizers: break masked_residual = residual * mask.unsqueeze(1) z_q_i, indices_i, _ = quantizer(masked_residual.float()) update_mask = mask.unsqueeze(1) quantized_out = quantized_out + z_q_i * update_mask residual = residual - z_q_i * update_mask all_indices.append(indices_i) all_indices = ( torch.stack(all_indices) if all_indices else torch.empty(0, batch_size, max_time, device=z.device, dtype=torch.long) ) quantized_out = self.output_proj(quantized_out.float()).float() return quantized_out, all_indices, input_length def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: with disable_cuda_autocast(): nq, B, T = codes.shape emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) for i, quantizer in enumerate(self.quantizers[:nq]): quantizer = cast(MossAudioTokenizerLFQ, quantizer) emb += quantizer.decode_code(codes[i]).float() emb = self.output_proj(emb.float()).float() return emb # ============================================================================= # Main Model Classes # ============================================================================= @auto_docstring class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase): """Base class for MossAudioTokenizer models.""" config_class = MossAudioTokenizerConfig base_model_prefix = "" main_input_name = "input_values" input_modalities = "audio" supports_gradient_checkpointing = False _no_split_modules = [ "MossAudioTokenizerTransformerLayer", "MossAudioTokenizerResidualVQ", "MossAudioTokenizerResidualLFQ", ] @auto_docstring( custom_intro=""" The MossAudioTokenizer neural audio codec model for audio tokenization and synthesis. """ ) class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel): """ MossAudioTokenizer model for audio tokenization and synthesis. This model can encode audio waveforms into discrete tokens and decode tokens back into audio waveforms. """ def __init__(self, config: MossAudioTokenizerConfig): super().__init__(config) self.config = config _ = config.version self.sampling_rate = config.sampling_rate self.downsample_rate = config.downsample_rate self.number_channels = config.number_channels self.enable_channel_interleave = getattr(config, "enable_channel_interleave", True) self.attention_implementation = config.attention_implementation self.compute_dtype_name = config.compute_dtype self.compute_dtype = resolve_compute_dtype(config.compute_dtype) encoder_context_durations = [ float(module_kwargs.get("context_duration", config.causal_transformer_context_duration)) for module_kwargs in config.encoder_kwargs if module_kwargs["module_type"] == "Transformer" ] self.causal_transformer_context_duration = ( min(encoder_context_durations) if encoder_context_durations else config.causal_transformer_context_duration ) # Build encoder channel_interleave_factor = ( self.number_channels if self.enable_channel_interleave and self.number_channels > 1 else 1 ) current_frame_rate: float = float(self.sampling_rate * channel_interleave_factor) self.encoder = nn.ModuleList() for encoder_kwargs_i in config.encoder_kwargs: encoder_kwargs_i = dict(encoder_kwargs_i) # Make a copy if encoder_kwargs_i["module_type"] == "PatchedPretransform": self.encoder.append(MossAudioTokenizerPatchedPretransform(**encoder_kwargs_i, is_downsample=True)) elif encoder_kwargs_i["module_type"] == "Transformer": context_duration = float(encoder_kwargs_i.pop("context_duration", self.causal_transformer_context_duration)) self.encoder.append( MossAudioTokenizerProjectedTransformer( **encoder_kwargs_i, context=int(round(current_frame_rate * context_duration)), attention_implementation=self.attention_implementation, ) ) current_frame_rate /= self.encoder[-1].downsample_ratio # Build quantizer quantizer_kwargs = dict(config.quantizer_kwargs) quantizer_type = quantizer_kwargs.get("quantizer_type", getattr(config, "quantizer_type", "rvq")) if quantizer_type in {"rvq", "spec_rvq"}: self.quantizer = MossAudioTokenizerResidualVQ(**quantizer_kwargs) elif quantizer_type in {"rlfq", "random_prefix_rlfq"}: self.quantizer = MossAudioTokenizerResidualLFQ(**quantizer_kwargs) else: raise ValueError(f"Unsupported quantizer_type: {quantizer_type}") # Build decoder decoder_kwargs_list = copy.deepcopy(config.decoder_kwargs) self.decoder = nn.ModuleList() for decoder_kwargs_i in decoder_kwargs_list: decoder_kwargs_i = dict(decoder_kwargs_i) if decoder_kwargs_i["module_type"] == "PatchedPretransform": self.decoder.append(MossAudioTokenizerPatchedPretransform(**decoder_kwargs_i, is_downsample=False)) elif decoder_kwargs_i["module_type"] == "Transformer": context_duration = float(decoder_kwargs_i.pop("context_duration", self.causal_transformer_context_duration)) self.decoder.append( MossAudioTokenizerProjectedTransformer( **decoder_kwargs_i, context=int(round(current_frame_rate * context_duration)), attention_implementation=self.attention_implementation, ) ) current_frame_rate *= self.decoder[-1].downsample_ratio expected_output_frame_rate = float(self.sampling_rate * channel_interleave_factor) if int(round(current_frame_rate)) != int(round(expected_output_frame_rate)): raise ValueError( "Decoder stack does not invert the encoder frame rate correctly: " f"got current_frame_rate={current_frame_rate}, expected={expected_output_frame_rate}." ) self.post_init() self._active_decode_session: "MossAudioTokenizerDecodeSession | None" = None self._batch_decode_streaming_max_batch_size: int | None = None self._batch_decode_streaming_batch_size: int | None = None self._batch_decode_streaming_session: "MossAudioTokenizerDecodeSession | None" = None self._batch_decode_streaming_next_request_id: int = 0 def create_decode_session( self, max_batch_size: int, use_cuda_graph: bool = False, ) -> MossAudioTokenizerDecodeSession: active_session = self._active_decode_session if active_session is not None and not active_session._closed: raise RuntimeError(_ACTIVE_DECODE_SESSION_ERROR_MESSAGE) for module in self.modules(): if isinstance(module, StreamingModule) and module._streaming_state is not None: raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE) session = MossAudioTokenizerDecodeSession(self, max_batch_size, use_cuda_graph=use_cuda_graph) return session def _reset_batch_decode_streaming_state(self) -> None: streaming_session = self._batch_decode_streaming_session self._batch_decode_streaming_session = None self._batch_decode_streaming_max_batch_size = None self._batch_decode_streaming_batch_size = None self._batch_decode_streaming_next_request_id = 0 if streaming_session is not None and not streaming_session._closed: streaming_session.close() def _prepare_batch_decode_streaming_state( self, batch_size: int, max_batch_size: int | None, reset_stream: bool, ) -> int: if reset_stream: self._reset_batch_decode_streaming_state() if max_batch_size is not None and max_batch_size <= 0: raise ValueError("`max_batch_size` must be > 0 when provided.") streaming_max_batch_size = self._batch_decode_streaming_max_batch_size if streaming_max_batch_size is None: streaming_max_batch_size = batch_size if max_batch_size is None else max_batch_size elif max_batch_size is not None and max_batch_size != streaming_max_batch_size: raise ValueError( "`max_batch_size` can only be set on the first streaming `batch_decode()` call for now. " f"Expected {streaming_max_batch_size}, got {max_batch_size}." ) if batch_size > streaming_max_batch_size: raise ValueError( "Streaming `batch_decode()` received a batch larger than the reserved `max_batch_size`. " f"Got batch_size={batch_size}, max_batch_size={streaming_max_batch_size}." ) return streaming_max_batch_size def _ensure_batch_decode_streaming_session( self, max_batch_size: int, use_cuda_graph: bool = False, ) -> MossAudioTokenizerDecodeSession: session = self._batch_decode_streaming_session if session is not None and not session._closed: if session._use_cuda_graph != use_cuda_graph: raise ValueError( "`use_cuda_graph` must match the existing streaming `batch_decode()` session configuration. " f"Expected {session._use_cuda_graph}, got {use_cuda_graph}." ) return session session = self.create_decode_session(max_batch_size=max_batch_size, use_cuda_graph=use_cuda_graph) self._batch_decode_streaming_session = session self._batch_decode_streaming_max_batch_size = max_batch_size self._batch_decode_streaming_next_request_id = 0 return session def _append_batch_decode_streaming_requests( self, session: MossAudioTokenizerDecodeSession, target_batch_size: int, ) -> None: requests_to_append = target_batch_size - len(session.active_request_ids) for _ in range(requests_to_append): request_id = self._batch_decode_streaming_next_request_id session.append(request_id) self._batch_decode_streaming_next_request_id += 1 def _resolve_batch_decode_streaming_finalize_request_ids( self, request_ids: list[str | int], finalize_indices: list[int] | tuple[int, ...] | None, ) -> list[str | int]: normalized_finalize_indices = tuple(finalize_indices) if finalize_indices is not None else () if len(set(normalized_finalize_indices)) != len(normalized_finalize_indices): raise ValueError(_BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE) batch_size = len(request_ids) finalize_request_ids: list[str | int] = [] for index in normalized_finalize_indices: if index < 0 or index >= batch_size: raise ValueError( _BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE.format( index=index, batch_size=batch_size ) ) finalize_request_ids.append(request_ids[index]) return finalize_request_ids def _raise_if_plain_decode_conflicts_with_active_session(self) -> None: active_session = self._active_decode_session if active_session is not None and not getattr(active_session, "_closed", False): raise RuntimeError(_PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE) def _start_streaming(self, batch_size: int): """Start streaming mode for all modules.""" active_session = self._active_decode_session if active_session is not None and not getattr(active_session, "_closed", False): raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE) def _start(module): if isinstance(module, StreamingModule): module._streaming_state = module._init_streaming_state(batch_size) self.apply(_start) def _stop_streaming(self): """Stop streaming mode for all modules.""" active_session = self._active_decode_session if active_session is not None and not getattr(active_session, "_closed", False): raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE) def _stop(module): if isinstance(module, StreamingModule): module._streaming_state = None self.apply(_stop) @contextmanager def streaming(self, batch_size: int = 1): """Context manager for streaming mode.""" self._start_streaming(batch_size) try: yield finally: self._stop_streaming() def _set_streaming_exec_mask(self, exec_mask: torch.Tensor) -> None: exec_mask = exec_mask.to(torch.bool) def _set_exec_mask(module: nn.Module): if isinstance(module, StreamingModule) and module._streaming_state is not None: module._streaming_state.set_exec_mask(exec_mask.to(module._streaming_state.device)) self.apply(_set_exec_mask) def _plan_batch_stream_step( self, remaining: torch.Tensor, max_step_length: int, alignment: int, ) -> tuple[int, torch.Tensor]: positive_mask = remaining > 0 if not bool(positive_mask.any().item()): raise RuntimeError("Cannot plan a streaming step when no samples remain.") if max_step_length > 0: full_step_mask = remaining >= max_step_length if bool(full_step_mask.any().item()): return max_step_length, full_step_mask positive_remaining = remaining[positive_mask] min_remaining = int(positive_remaining.min().item()) if alignment > 1: aligned_step = (min_remaining // alignment) * alignment if aligned_step > 0: return aligned_step, remaining >= aligned_step return min_remaining, remaining == min_remaining step_length = min_remaining if max_step_length > 0: step_length = min(step_length, max_step_length) return step_length, remaining >= step_length def _infer_num_quantizers(self, codes_chunks: list[list[torch.Tensor]], requested_num_quantizers: int | None) -> int: if requested_num_quantizers is not None: return requested_num_quantizers for chunks_i in codes_chunks: if chunks_i: return int(chunks_i[0].shape[0]) num_quantizers = getattr(self.quantizer, "num_quantizers", None) if num_quantizers is None: raise RuntimeError("Unable to infer the number of quantizers from empty streaming output.") return int(num_quantizers) def _infer_waveform_dtype(self, wav_chunks: list[list[torch.Tensor]]) -> torch.dtype: for chunks_i in wav_chunks: if chunks_i: return chunks_i[0].dtype return torch.float32 @contextmanager def _codec_inference_autocast(self): device = next(self.parameters()).device if device.type == "cuda" and self.compute_dtype is not None: with torch.autocast(device_type="cuda", dtype=self.compute_dtype): yield else: yield def set_attention_implementation(self, attention_implementation: str) -> None: self.attention_implementation = attention_implementation for module in self.modules(): if isinstance(module, MossAudioTokenizerProjectedTransformer): module.set_attention_implementation(attention_implementation) def set_compute_dtype(self, compute_dtype: str) -> None: self.compute_dtype_name = compute_dtype self.compute_dtype = resolve_compute_dtype(compute_dtype) def _prepare_waveform_batch( self, wav_list: list[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if len(wav_list) == 0: raise ValueError("`wav_list` must contain at least one waveform.") device = wav_list[0].device dtype = wav_list[0].dtype batch_size = len(wav_list) lengths = torch.zeros(batch_size, device=device, dtype=torch.long) normalized_wavs: list[torch.Tensor] = [] for i, wav in enumerate(wav_list): if self.number_channels == 1: if wav.dim() == 1: wav_i = wav.unsqueeze(0) elif wav.dim() == 2 and wav.shape[0] == 1: wav_i = wav else: raise ValueError( f"Expected wav_list[{i}] to have shape `(T,)` or `(1, T)` for a mono model, got {tuple(wav.shape)}." ) else: if wav.dim() != 2 or wav.shape[0] != self.number_channels: raise ValueError( f"Expected wav_list[{i}] to have shape `({self.number_channels}, T)`, got {tuple(wav.shape)}." ) wav_i = wav normalized_wavs.append(wav_i) lengths[i] = wav_i.shape[-1] max_length = int(lengths.max().item()) if batch_size > 0 else 0 input_values = torch.zeros(batch_size, self.number_channels, max_length, device=device, dtype=dtype) for i, wav_i in enumerate(normalized_wavs): input_values[i, :, : wav_i.shape[-1]] = wav_i return input_values, lengths def _prepare_codes_batch( self, codes_list: list[torch.Tensor], num_quantizers: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, int]: if len(codes_list) == 0: raise ValueError("`codes_list` must contain at least one code tensor.") batch_size = len(codes_list) device = codes_list[0].device nqs = [codes.shape[0] for codes in codes_list] if num_quantizers is None: num_quantizers = nqs[0] if any(nq != num_quantizers for nq in nqs): raise ValueError( "All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. " "Pass `num_quantizers=...` to decode a common prefix." ) elif min(nqs) < num_quantizers: raise ValueError( "`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. " f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min(nqs)}." ) lengths = torch.tensor([codes.shape[-1] for codes in codes_list], device=device, dtype=torch.long) max_length = int(lengths.max().item()) if batch_size > 0 else 0 audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long) for i, codes in enumerate(codes_list): codes_i = codes[:num_quantizers] audio_codes[:, i, : codes_i.shape[-1]] = codes_i return audio_codes, lengths, num_quantizers def _flatten_channels_for_codec( self, input_values: torch.Tensor, input_lengths: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if input_values.dim() != 3: raise ValueError(f"Expected `input_values` with shape `(B, C, T)`, got {tuple(input_values.shape)}.") if input_values.shape[1] != self.number_channels: raise ValueError( f"Expected `input_values.shape[1] == {self.number_channels}`, got {input_values.shape[1]}." ) if input_values.shape[-1] % self.downsample_rate != 0: pad_length = self.downsample_rate - (input_values.shape[-1] % self.downsample_rate) input_values = F.pad(input_values, (0, pad_length)) if self.number_channels > 1 and self.enable_channel_interleave: input_values = input_values.transpose(1, 2).contiguous().view(input_values.shape[0], 1, -1) input_lengths = input_lengths * self.number_channels return input_values, input_lengths def _restore_channels_from_codec( self, output_values: torch.Tensor, output_lengths: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if self.number_channels == 1 or not self.enable_channel_interleave: return output_values.float(), output_lengths output_values = ( output_values.squeeze(1) .contiguous() .view(output_values.shape[0], -1, self.number_channels) .transpose(1, 2) .contiguous() .float() ) output_lengths = torch.div(output_lengths, self.number_channels, rounding_mode="floor") return output_values, output_lengths def _stack_hidden_states( self, hidden_chunks: list[list[torch.Tensor]], lengths: torch.Tensor, ) -> torch.Tensor | None: hidden_dim = None for chunks_i in hidden_chunks: if chunks_i: hidden_dim = chunks_i[0].shape[0] break if hidden_dim is None: return None batch_size = len(hidden_chunks) max_length = int(lengths.max().item()) if batch_size > 0 else 0 device = lengths.device hidden_states = torch.zeros(batch_size, hidden_dim, max_length, device=device, dtype=torch.float32) for i, chunks_i in enumerate(hidden_chunks): if not chunks_i: continue hidden_i = torch.cat(chunks_i, dim=-1).float() hidden_states[i, :, : hidden_i.shape[-1]] = hidden_i return hidden_states @torch.no_grad() def _encode_frame( self, input_values: torch.Tensor, input_lengths: torch.Tensor | None = None, n_quantizers: int | None = None, ) -> MossAudioTokenizerEncoderOutput: if input_values.dim() == 1: input_values = input_values.view(1, 1, -1) elif input_values.dim() == 2: if self.number_channels == 1: input_values = input_values.unsqueeze(1) else: input_values = input_values.unsqueeze(0) batch_size, _, time = input_values.shape device = input_values.device if input_lengths is None: input_lengths = torch.full((batch_size,), time, device=device, dtype=torch.long) input_values, input_lengths = self._flatten_channels_for_codec(input_values, input_lengths) with self._codec_inference_autocast(): encoder_hidden_states, encoder_hidden_lengths = input_values, input_lengths for encoder_module in self.encoder: encoder_hidden_states, encoder_hidden_lengths = encoder_module( encoder_hidden_states, encoder_hidden_lengths, ) quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) _, audio_codes, audio_codes_lengths = quantizer(encoder_hidden_states.float(), encoder_hidden_lengths, n_quantizers) return MossAudioTokenizerEncoderOutput( audio_codes=audio_codes, audio_codes_lengths=audio_codes_lengths, encoder_hidden_states=encoder_hidden_states.float(), ) @torch.no_grad() def _decode_frame( self, codes: torch.Tensor, codes_lengths: torch.Tensor | None = None, ) -> MossAudioTokenizerDecoderOutput: _, batch_size, time = codes.shape device = codes.device if codes_lengths is None: codes_lengths = torch.full((batch_size,), time, device=device, dtype=torch.long) quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) decoder_hidden_states = quantizer.decode_codes(codes).float() with self._codec_inference_autocast(): audio, audio_lengths = decoder_hidden_states, codes_lengths for decoder_module in self.decoder: audio, audio_lengths = decoder_module(audio, audio_lengths) audio, audio_lengths = self._restore_channels_from_codec(audio, audio_lengths) return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths) @torch.no_grad() def batch_encode( self, wav_list: list[torch.Tensor], num_quantizers: int | None = None, chunk_duration: float | None = None, ) -> MossAudioTokenizerEncoderOutput: input_values, input_lengths = self._prepare_waveform_batch(wav_list) batch_size = len(wav_list) device = input_values.device if chunk_duration is None: return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers) if chunk_duration <= 0: raise ValueError("`chunk_duration` must be > 0 when provided.") chunk_length = int(round(chunk_duration * self.sampling_rate)) if chunk_length <= 0: raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") if chunk_length % self.downsample_rate != 0: raise ValueError( "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." ) cursors = torch.zeros_like(input_lengths) codes_chunks: list[list[torch.Tensor]] = [[] for _ in range(batch_size)] hidden_chunks: list[list[torch.Tensor]] = [[] for _ in range(batch_size)] with self.streaming(batch_size=batch_size): while bool((cursors < input_lengths).any().item()): remaining = input_lengths - cursors step_length, active_mask = self._plan_batch_stream_step( remaining=remaining, max_step_length=chunk_length, alignment=self.downsample_rate, ) x_step = torch.zeros( batch_size, self.number_channels, step_length, device=device, dtype=input_values.dtype, ) input_lengths_step = torch.zeros(batch_size, device=device, dtype=torch.long) active_indices = torch.nonzero(active_mask, as_tuple=False).flatten().tolist() for i in active_indices: start = int(cursors[i].item()) end = start + step_length x_step[i] = input_values[i, :, start:end] input_lengths_step[i] = step_length self._set_streaming_exec_mask(active_mask) result = self._encode_frame(x_step, input_lengths_step, n_quantizers=num_quantizers) assert result.audio_codes is not None assert result.audio_codes_lengths is not None for i in active_indices: codes_length_i = int(result.audio_codes_lengths[i].item()) if codes_length_i > 0: codes_chunks[i].append(result.audio_codes[:, i, :codes_length_i].clone()) if result.encoder_hidden_states is not None: hidden_chunks[i].append(result.encoder_hidden_states[i, :, :codes_length_i].clone()) cursors[i] += step_length num_quantizers_used = self._infer_num_quantizers(codes_chunks, num_quantizers) empty_codes = torch.empty((num_quantizers_used, 0), device=device, dtype=torch.long) codes_list = [torch.cat(chunks_i, dim=-1) if chunks_i else empty_codes.clone() for chunks_i in codes_chunks] audio_codes, audio_codes_lengths, _ = self._prepare_codes_batch(codes_list, num_quantizers=num_quantizers_used) encoder_hidden_states = self._stack_hidden_states(hidden_chunks, audio_codes_lengths) return MossAudioTokenizerEncoderOutput( audio_codes=audio_codes, audio_codes_lengths=audio_codes_lengths, encoder_hidden_states=encoder_hidden_states, ) @torch.no_grad() def batch_decode( self, codes_list: list[torch.Tensor], num_quantizers: int | None = None, chunk_duration: float | None = None, streaming: bool = False, max_batch_size: int | None = None, finalize_indices: list[int] | tuple[int, ...] | None = None, reset_stream: bool = False, use_cuda_graph: bool = False, ) -> MossAudioTokenizerDecoderOutput: if len(codes_list) == 0: raise ValueError("`codes_list` must contain at least one code tensor.") streaming_max_batch_size: int | None = None if streaming: streaming_max_batch_size = self._prepare_batch_decode_streaming_state( batch_size=len(codes_list), max_batch_size=max_batch_size, reset_stream=reset_stream, ) else: if reset_stream: self._reset_batch_decode_streaming_state() self._raise_if_plain_decode_conflicts_with_active_session() audio_codes, audio_codes_lengths, num_quantizers_used = self._prepare_codes_batch( codes_list, num_quantizers=num_quantizers, ) batch_size = len(codes_list) device = audio_codes.device if not streaming and chunk_duration is None: return self._decode_frame(audio_codes, audio_codes_lengths) if streaming: assert streaming_max_batch_size is not None existing_session = self._batch_decode_streaming_session reusing_streaming_session = existing_session is not None and not existing_session._closed session = self._ensure_batch_decode_streaming_session( max_batch_size=streaming_max_batch_size, use_cuda_graph=use_cuda_graph, ) pre_call_request_ids = list(session.active_request_ids) pre_call_batch_size = len(pre_call_request_ids) if batch_size < pre_call_batch_size: raise ValueError(_BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE) try: finalize_request_ids = self._resolve_batch_decode_streaming_finalize_request_ids( request_ids=pre_call_request_ids, finalize_indices=finalize_indices, ) except Exception: if not reusing_streaming_session and pre_call_batch_size == 0: self._reset_batch_decode_streaming_state() raise try: if batch_size > pre_call_batch_size: self._append_batch_decode_streaming_requests(session=session, target_batch_size=batch_size) request_ids = list(session.active_request_ids) _, audio, audio_lengths = session.step( request_ids=request_ids, codes=audio_codes, code_lengths=audio_codes_lengths, ) for request_id in finalize_request_ids: session.remove(request_id) except Exception: self._reset_batch_decode_streaming_state() raise self._batch_decode_streaming_max_batch_size = session.max_batch_size self._batch_decode_streaming_batch_size = len(session.active_request_ids) return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths) assert chunk_duration is not None if chunk_duration <= 0: raise ValueError("`chunk_duration` must be > 0 when provided.") chunk_length = int(round(chunk_duration * self.sampling_rate)) if chunk_length <= 0: raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") if chunk_length % self.downsample_rate != 0: raise ValueError( "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." ) chunk_frame_length = chunk_length // self.downsample_rate cursors = torch.zeros_like(audio_codes_lengths) wav_chunks: list[list[torch.Tensor]] = [[] for _ in range(batch_size)] with self.streaming(batch_size=batch_size): while bool((cursors < audio_codes_lengths).any().item()): remaining = audio_codes_lengths - cursors step_frames, active_mask = self._plan_batch_stream_step( remaining=remaining, max_step_length=chunk_frame_length, alignment=1, ) codes_step = torch.zeros( num_quantizers_used, batch_size, step_frames, device=device, dtype=torch.long, ) codes_lengths_step = torch.zeros(batch_size, device=device, dtype=torch.long) active_indices = torch.nonzero(active_mask, as_tuple=False).flatten().tolist() for i in active_indices: start = int(cursors[i].item()) end = start + step_frames codes_step[:, i, :] = audio_codes[:, i, start:end] codes_lengths_step[i] = step_frames self._set_streaming_exec_mask(active_mask) result = self._decode_frame(codes_step, codes_lengths_step) assert result.audio is not None assert result.audio_lengths is not None for i in active_indices: audio_length_i = int(result.audio_lengths[i].item()) if audio_length_i > 0: wav_chunks[i].append(result.audio[i, :, :audio_length_i].clone()) cursors[i] += step_frames wav_dtype = self._infer_waveform_dtype(wav_chunks) audio_lengths = torch.tensor( [sum(chunk.shape[-1] for chunk in chunks_i) for chunks_i in wav_chunks], device=device, dtype=torch.long, ) max_audio_length = int(audio_lengths.max().item()) if batch_size > 0 else 0 audio = torch.zeros(batch_size, self.number_channels, max_audio_length, device=device, dtype=wav_dtype) for i, chunks_i in enumerate(wav_chunks): if not chunks_i: continue wav_i = torch.cat(chunks_i, dim=-1) audio[i, :, : wav_i.shape[-1]] = wav_i return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths) def encode( # type: ignore[override] self, input_values: torch.Tensor, padding_mask: torch.Tensor | None = None, num_quantizers: int | None = None, return_dict: bool | None = None, chunk_duration: float | None = None, ): """ Encodes the input audio waveform into discrete codes. Args: input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): Float values of the input audio waveform. padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate valid audio samples. num_quantizers (`int`, *optional*): Number of quantizers to use. By default, all quantizers are used. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. chunk_duration (`float`, *optional*): If provided, encode the input waveform in successive chunks of `chunk_duration` seconds while keeping a streaming KV cache for the causal transformers. `chunk_duration` must be <= `config.causal_transformer_context_duration`, and `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. Returns: `MossAudioTokenizerEncoderOutput` or tuple containing audio codes and lengths. """ return_dict = return_dict if return_dict is not None else self.config.return_dict wav_list: list[torch.Tensor] if input_values.dim() == 1: wav_list = [input_values] elif input_values.dim() == 2: if self.number_channels == 1: lengths = ( padding_mask.sum(dim=-1).long() if padding_mask is not None and padding_mask.dim() == 2 else torch.full((input_values.shape[0],), input_values.shape[-1], device=input_values.device, dtype=torch.long) ) wav_list = [input_values[i, : int(lengths[i].item())] for i in range(input_values.shape[0])] else: length = ( int(padding_mask.sum().item()) if padding_mask is not None and padding_mask.dim() == 1 else int(input_values.shape[-1]) ) wav_list = [input_values[:, :length]] elif input_values.dim() == 3: if input_values.shape[1] != self.number_channels: raise ValueError( f"Expected `input_values.shape[1] == {self.number_channels}`, got {input_values.shape[1]}." ) lengths = ( padding_mask.sum(dim=-1).long() if padding_mask is not None else torch.full((input_values.shape[0],), input_values.shape[-1], device=input_values.device, dtype=torch.long) ) wav_list = [input_values[i, :, : int(lengths[i].item())] for i in range(input_values.shape[0])] else: raise ValueError(f"Unsupported `input_values` shape: {tuple(input_values.shape)}") encoder_output = self.batch_encode(wav_list, num_quantizers=num_quantizers, chunk_duration=chunk_duration) if not return_dict: assert encoder_output.audio_codes is not None assert encoder_output.audio_codes_lengths is not None return ( cast(torch.Tensor, encoder_output.audio_codes), cast(torch.Tensor, encoder_output.audio_codes_lengths), ) return encoder_output def decode( # type: ignore[override] self, audio_codes: torch.Tensor, padding_mask: torch.Tensor | None = None, return_dict: bool | None = None, chunk_duration: float | None = None, num_quantizers: int | None = None, ): """ Decodes the given codes into an output audio waveform. Args: audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`): Discrete code embeddings computed using `model.encode`. padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate valid code positions. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. chunk_duration (`float`, *optional*): If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a streaming KV cache for the causal transformers. num_quantizers (`int`, *optional*): Number of quantizers to use. By default, all quantizers in `audio_codes` are used. `chunk_duration` must be <= `config.causal_transformer_context_duration`, and `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. Returns: `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio. """ return_dict = return_dict if return_dict is not None else self.config.return_dict self._raise_if_plain_decode_conflicts_with_active_session() if audio_codes.dim() == 2: codes_list = [audio_codes[:num_quantizers] if num_quantizers is not None else audio_codes] elif audio_codes.dim() == 3: if num_quantizers is not None and num_quantizers > audio_codes.shape[0]: raise ValueError( f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})." ) codes_lengths = ( padding_mask.sum(dim=-1).long() if padding_mask is not None else torch.full((audio_codes.shape[1],), audio_codes.shape[-1], device=audio_codes.device, dtype=torch.long) ) codes_list = [ (audio_codes[:num_quantizers, i, : int(codes_lengths[i].item())] if num_quantizers is not None else audio_codes[:, i, : int(codes_lengths[i].item())]) for i in range(audio_codes.shape[1]) ] else: raise ValueError(f"Unsupported `audio_codes` shape: {tuple(audio_codes.shape)}") decoder_output = self.batch_decode(codes_list, num_quantizers=num_quantizers, chunk_duration=chunk_duration) if not return_dict: assert decoder_output.audio is not None return (cast(torch.Tensor, decoder_output.audio),) return decoder_output @auto_docstring def forward( self, input_values: torch.FloatTensor | None = None, padding_mask: torch.BoolTensor | None = None, audio_codes: torch.Tensor | None = None, num_quantizers: int | None = None, return_dict: bool | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | MossAudioTokenizerOutput: # type: ignore[override] r""" input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): Raw audio input converted to Float. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid computing on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): Discrete code embeddings computed using `model.encode`. num_quantizers (`int`, *optional*): Number of quantizers (codebooks) to use. By default, all quantizers are used. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Examples: ```python >>> import torch >>> from transformers import MossAudioTokenizerModel >>> model = MossAudioTokenizerModel.from_pretrained("moss_audio_tokenizer-model") >>> # Create dummy audio input >>> audio = torch.randn(1, 1, 24000) # 1 second of audio at 24kHz >>> outputs = model(input_values=audio) >>> audio_codes = outputs.audio_codes >>> audio_values = outputs.audio ``` """ return_dict = return_dict if return_dict is not None else self.config.return_dict output_audio_codes: torch.Tensor | None = None output_audio_codes_lengths: torch.Tensor | None = None output_audio: torch.Tensor | None = None output_audio_lengths: torch.Tensor | None = None decoded_from_encoded_codes = False # Encode if input_values provided if input_values is not None: encoder_output = self.encode(input_values, padding_mask, num_quantizers, return_dict=True) encoder_output = cast(MossAudioTokenizerEncoderOutput, encoder_output) output_audio_codes = encoder_output.audio_codes output_audio_codes_lengths = encoder_output.audio_codes_lengths # If codes not provided separately, use encoded codes for decoding if audio_codes is None: audio_codes = output_audio_codes decoded_from_encoded_codes = True # Decode if codes available if audio_codes is not None: # If we're decoding the codes we just produced, use the computed lengths so we don't decode padded garbage. if decoded_from_encoded_codes and output_audio_codes_lengths is not None: decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths) else: decoder_output = self.decode( audio_codes, padding_mask=padding_mask, return_dict=True, num_quantizers=num_quantizers, ) decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output) output_audio = decoder_output.audio output_audio_lengths = decoder_output.audio_lengths if not return_dict: return (output_audio_codes, output_audio, output_audio_lengths) return MossAudioTokenizerOutput( audio=output_audio, audio_lengths=output_audio_lengths, audio_codes=output_audio_codes, audio_codes_lengths=output_audio_codes_lengths, ) __all__ = ["MossAudioTokenizerModel", "MossAudioTokenizerPreTrainedModel"]