Feature Extraction
Transformers
Safetensors
moss-audio-tokenizer
audio
audio-tokenizer
neural-codec
moss-tts-family
MOSS Audio Tokenizer
speech-tokenizer
trust-remote-code
custom_code
Instructions to use koibor/MOSS-Audio-Tokenizer-Nano with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use koibor/MOSS-Audio-Tokenizer-Nano with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="koibor/MOSS-Audio-Tokenizer-Nano", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("koibor/MOSS-Audio-Tokenizer-Nano", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """PyTorch MossAudioTokenizer model.""" | |
| from __future__ import annotations | |
| import copy | |
| import importlib | |
| import math | |
| import sys | |
| import types | |
| from contextlib import ExitStack, contextmanager | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import cast | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| if __name__ not in sys.modules: | |
| _module_proxy = types.ModuleType(__name__) | |
| sys.modules[__name__] = _module_proxy | |
| def _sync_module_proxy() -> None: | |
| sys.modules[__name__].__dict__.update(globals()) | |
| try: | |
| from transformers.modeling_utils import PreTrainedAudioTokenizerBase | |
| except ImportError: | |
| from transformers.modeling_utils import PreTrainedModel as PreTrainedAudioTokenizerBase | |
| from transformers.utils import ModelOutput, logging | |
| try: | |
| from transformers.utils import auto_docstring as _hf_auto_docstring | |
| except ImportError: | |
| _hf_auto_docstring = None | |
| def auto_docstring(*args, **kwargs): | |
| if _hf_auto_docstring is None: | |
| if len(args) == 1 and callable(args[0]) and not kwargs: | |
| return args[0] | |
| def decorator(obj): | |
| return obj | |
| return decorator | |
| if len(args) == 1 and callable(args[0]) and not kwargs: | |
| obj = args[0] | |
| try: | |
| return _hf_auto_docstring(obj) | |
| except Exception: | |
| return obj | |
| try: | |
| decorator = _hf_auto_docstring(*args, **kwargs) | |
| except Exception: | |
| def decorator(obj): | |
| return obj | |
| return decorator | |
| def safe_decorator(obj): | |
| try: | |
| return decorator(obj) | |
| except Exception: | |
| return obj | |
| return safe_decorator | |
| try: | |
| from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig | |
| except ImportError: | |
| _module_dir = str(Path(__file__).resolve().parent) | |
| if _module_dir not in sys.path: | |
| sys.path.insert(0, _module_dir) | |
| from configuration_moss_audio_tokenizer import MossAudioTokenizerConfig | |
| logger = logging.get_logger(__name__) | |
| def _get_flash_attn_module(): | |
| try: | |
| return importlib.import_module("flash_attn") | |
| except Exception: | |
| return None | |
| def _has_flash_attn() -> bool: | |
| return _get_flash_attn_module() is not None | |
| def _get_flash_attn_varlen_func(): | |
| flash_attn_module = _get_flash_attn_module() | |
| if flash_attn_module is None: | |
| return None | |
| return getattr(flash_attn_module, "flash_attn_varlen_func", None) | |
| def _get_flash_attn_with_kvcache(): | |
| flash_attn_module = _get_flash_attn_module() | |
| if flash_attn_module is None: | |
| return None | |
| return getattr(flash_attn_module, "flash_attn_with_kvcache", None) | |
| SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"} | |
| SUPPORTED_COMPUTE_DTYPES = {"fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16} | |
| _ACTIVE_DECODE_SESSION_ERROR_MESSAGE = "MossAudioTokenizerModel only supports one active decode session at a time." | |
| _CLOSED_DECODE_SESSION_ERROR_MESSAGE = "This decode session is closed." | |
| _MODEL_STREAMING_CONFLICT_ERROR_MESSAGE = "Model-level streaming helpers cannot be used while a decode session is active." | |
| _PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE = "Plain decode helpers cannot be used while a decode session is active." | |
| _DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session already contains request_id={request_id!r}." | |
| _UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session does not contain an active request_id={request_id!r}." | |
| _DECODE_SESSION_FULL_ERROR_TEMPLATE = "Decode session has no free slots remaining (max_batch_size={max_batch_size})." | |
| _INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE = ( | |
| "`request_ids` must exactly match the current active decode request order." | |
| ) | |
| _BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE = "`finalize_indices` must not contain duplicates." | |
| _BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE = ( | |
| "`finalize_indices` index {index} is out of range for the pre-call logical batch of size {batch_size}." | |
| ) | |
| _BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE = ( | |
| "`batch_decode(streaming=True)` must include all pre-call active rows in the current call before applying `finalize_indices`." | |
| ) | |
| def resolve_compute_dtype(compute_dtype: str) -> torch.dtype | None: | |
| if compute_dtype not in SUPPORTED_COMPUTE_DTYPES: | |
| raise ValueError( | |
| f"Unsupported compute_dtype={compute_dtype!r}. Expected one of {sorted(SUPPORTED_COMPUTE_DTYPES)}." | |
| ) | |
| return SUPPORTED_COMPUTE_DTYPES[compute_dtype] | |
| def disable_cuda_autocast(): | |
| with torch.autocast(device_type="cuda", enabled=False): | |
| yield | |
| # ============================================================================= | |
| # Output Classes | |
| # ============================================================================= | |
| _sync_module_proxy() | |
| class MossAudioTokenizerEncoderOutput(ModelOutput): | |
| r""" | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): | |
| Discrete audio codes computed using the encoder and quantizer. | |
| audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio codes. | |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*): | |
| Hidden states from the encoder before quantization. | |
| """ | |
| audio_codes: torch.Tensor | None = None | |
| audio_codes_lengths: torch.Tensor | None = None | |
| encoder_hidden_states: torch.Tensor | None = None | |
| _sync_module_proxy() | |
| class MossAudioTokenizerDecoderOutput(ModelOutput): | |
| r""" | |
| audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): | |
| Decoded audio waveform. | |
| audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio. | |
| """ | |
| audio: torch.Tensor | None = None | |
| audio_lengths: torch.Tensor | None = None | |
| _sync_module_proxy() | |
| class MossAudioTokenizerOutput(ModelOutput): | |
| r""" | |
| audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): | |
| Decoded audio waveform. | |
| audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio. | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): | |
| Discrete audio codes computed using the encoder and quantizer. | |
| audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio codes. | |
| """ | |
| audio: torch.Tensor | None = None | |
| audio_lengths: torch.Tensor | None = None | |
| audio_codes: torch.Tensor | None = None | |
| audio_codes_lengths: torch.Tensor | None = None | |
| # ============================================================================= | |
| # Streaming Module Base Classes | |
| # ============================================================================= | |
| _sync_module_proxy() | |
| class StreamingState: | |
| """Base state for streaming modules.""" | |
| batch_size: int | |
| device: torch.device | |
| def __post_init__(self): | |
| self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device) | |
| def set_exec_mask(self, exec_mask: torch.Tensor): | |
| self.exec_mask[:] = exec_mask | |
| def reset(self, reset_mask: torch.Tensor) -> None: | |
| self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask) | |
| def __enter__(self): | |
| # ExitStack expects a context manager; returning self is conventional and useful for debugging. | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback) -> None: | |
| pass | |
| class StreamingModule(nn.Module): | |
| """Base class for streaming components.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._streaming_state: StreamingState | None = None | |
| self._streaming_detached: bool = False | |
| self._cached_children: list[tuple[str, StreamingModule]] | None = None | |
| def is_streaming(self): | |
| return self._streaming_state is not None | |
| def _apply_named_streaming(self, fn): | |
| def _handle_module(prefix: str, module: nn.Module): | |
| if isinstance(module, StreamingModule): | |
| if module._streaming_detached and prefix != "": | |
| return | |
| if self._cached_children is None: | |
| raise RuntimeError("Internal error: _cached_children should be initialized before traversal.") | |
| self._cached_children.append((prefix, module)) | |
| for name, child in module.named_children(): | |
| new_prefix = f"{prefix}.{name}" if prefix else name | |
| _handle_module(new_prefix, child) | |
| if self._cached_children is None: | |
| self._cached_children = [] | |
| _handle_module("", self) | |
| for name, child in self._cached_children: | |
| fn(name, child) | |
| def _start_streaming(self, batch_size: int, exit_stack: ExitStack): | |
| def _start_streaming_fn(name: str, module: StreamingModule): | |
| if module._streaming_state is not None: | |
| raise RuntimeError(f"{name} is already streaming!") | |
| state = module._init_streaming_state(batch_size) | |
| exit_stack.enter_context(state) | |
| module._streaming_state = state | |
| self._apply_named_streaming(_start_streaming_fn) | |
| def _stop_streaming(self) -> None: | |
| def _stop_streaming_fn(name: str, module: StreamingModule): | |
| module._streaming_state = None | |
| self._apply_named_streaming(_stop_streaming_fn) | |
| def _init_streaming_state(self, batch_size: int) -> StreamingState: | |
| device = next(iter(self.parameters())).device | |
| return StreamingState(batch_size, device) | |
| def streaming(self, batch_size: int) -> ExitStack: | |
| """Context manager to enter streaming mode.""" | |
| exit_stack = ExitStack() | |
| self._start_streaming(batch_size, exit_stack) | |
| exit_stack.callback(self._stop_streaming) | |
| return exit_stack | |
| class StreamingContainer(StreamingModule): | |
| """Container for streaming modules.""" | |
| pass | |
| class MossAudioTokenizerDecodeSession: | |
| model: MossAudioTokenizerModel | |
| max_batch_size: int | |
| _use_cuda_graph: bool | |
| active_request_ids: list[str | int] | |
| request_id_to_slot_index: dict[str | int, int] | |
| slot_index_to_request_id: list[str | int | None] | |
| slot_is_free: list[bool] | |
| request_id_to_code_offset: dict[str | int, int] | |
| request_id_to_audio_offset: dict[str | int, int] | |
| _flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] | |
| _graph_num_quantizers_capacity: int | None | |
| _graph_input_codes: torch.Tensor | None | |
| _graph_input_code_lengths: torch.Tensor | None | |
| _graph_output_audio: torch.Tensor | None | |
| _graph_output_audio_lengths: torch.Tensor | None | |
| _cuda_graph: torch.cuda.CUDAGraph | None | |
| _cuda_graph_key: tuple[str, int, int, str] | None | |
| _decode_streaming_exit_stack: ExitStack | None | |
| _closed: bool | |
| def __init__(self, model: MossAudioTokenizerModel, max_batch_size: int, use_cuda_graph: bool = False): | |
| if max_batch_size <= 0: | |
| raise ValueError("`max_batch_size` must be > 0.") | |
| decoder_attention_modules: list[MossAudioTokenizerMultiheadAttention] = [] | |
| for decoder_module in model.decoder: | |
| for module in decoder_module.modules(): | |
| if isinstance(module, MossAudioTokenizerMultiheadAttention): | |
| if module.context is None: | |
| raise ValueError( | |
| "MossAudioTokenizerDecodeSession requires all decoder MHA modules to have a finite " | |
| "`context` (context=None is unsupported for continuous-batch streaming)." | |
| ) | |
| decoder_attention_modules.append(module) | |
| flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = [] | |
| if use_cuda_graph and _has_flash_attn(): | |
| for module in decoder_attention_modules: | |
| module._use_flash_kvcache = True | |
| flash_kvcache_attention_modules.append(module) | |
| decode_streaming_exit_stack = ExitStack() | |
| try: | |
| for decoder_module in model.decoder: | |
| if isinstance(decoder_module, StreamingModule): | |
| inner_stack = decoder_module.streaming(batch_size=max_batch_size) | |
| _ = decode_streaming_exit_stack.enter_context(inner_stack) | |
| except Exception: | |
| decode_streaming_exit_stack.close() | |
| for module in flash_kvcache_attention_modules: | |
| module._use_flash_kvcache = False | |
| raise | |
| self.model = model | |
| self.max_batch_size = max_batch_size | |
| self._use_cuda_graph = use_cuda_graph | |
| self.active_request_ids: list[str | int] = [] | |
| self.request_id_to_slot_index: dict[str | int, int] = {} | |
| self.slot_index_to_request_id: list[str | int | None] = [None] * max_batch_size | |
| self.slot_is_free: list[bool] = [True] * max_batch_size | |
| self.request_id_to_code_offset: dict[str | int, int] = {} | |
| self.request_id_to_audio_offset: dict[str | int, int] = {} | |
| self._flash_kvcache_attention_modules = flash_kvcache_attention_modules | |
| self._graph_num_quantizers_capacity = int(getattr(model.quantizer, "num_quantizers", 0)) if use_cuda_graph else None | |
| self._graph_input_codes = None | |
| self._graph_input_code_lengths = None | |
| self._graph_output_audio = None | |
| self._graph_output_audio_lengths = None | |
| self._cuda_graph = None | |
| self._cuda_graph_key = None | |
| self._decode_streaming_exit_stack: ExitStack | None = decode_streaming_exit_stack | |
| self._closed = False | |
| if use_cuda_graph: | |
| device = next(iter(model.parameters())).device | |
| if device.type == "cuda": | |
| self._ensure_cuda_graph_buffers(device) | |
| model._active_decode_session = self | |
| def _ensure_open(self) -> None: | |
| if self._closed: | |
| raise RuntimeError(_CLOSED_DECODE_SESSION_ERROR_MESSAGE) | |
| def append(self, request_id: str | int) -> None: | |
| self._ensure_open() | |
| if request_id in self.request_id_to_slot_index: | |
| raise RuntimeError(_DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id)) | |
| slot_index = next((index for index, is_free in enumerate(self.slot_is_free) if is_free), None) | |
| if slot_index is None: | |
| raise RuntimeError(_DECODE_SESSION_FULL_ERROR_TEMPLATE.format(max_batch_size=self.max_batch_size)) | |
| self.active_request_ids.append(request_id) | |
| self.request_id_to_slot_index[request_id] = slot_index | |
| self.slot_index_to_request_id[slot_index] = request_id | |
| self.slot_is_free[slot_index] = False | |
| self.request_id_to_code_offset[request_id] = 0 | |
| self.request_id_to_audio_offset[request_id] = 0 | |
| def _decoder_streaming_states(self) -> list[StreamingState]: | |
| decoder_streaming_states: list[StreamingState] = [] | |
| for decoder_module in self.model.decoder: | |
| for module in decoder_module.modules(): | |
| if isinstance(module, StreamingModule) and module._streaming_state is not None: | |
| decoder_streaming_states.append(module._streaming_state) | |
| return decoder_streaming_states | |
| def _ensure_cuda_graph_buffers(self, device: torch.device) -> None: | |
| if not self._use_cuda_graph or device.type != "cuda": | |
| return | |
| graph_num_quantizers_capacity = self._graph_num_quantizers_capacity | |
| if graph_num_quantizers_capacity is None: | |
| graph_num_quantizers_capacity = int(getattr(self.model.quantizer, "num_quantizers", 0)) | |
| self._graph_num_quantizers_capacity = graph_num_quantizers_capacity | |
| if graph_num_quantizers_capacity <= 0: | |
| raise RuntimeError("`use_cuda_graph=True` requires a quantizer with `num_quantizers > 0`.") | |
| if self._graph_input_codes is None or self._graph_input_codes.device != device: | |
| self._graph_input_codes = torch.zeros( | |
| (graph_num_quantizers_capacity, self.max_batch_size, 1), | |
| device=device, | |
| dtype=torch.long, | |
| ) | |
| self._graph_input_code_lengths = torch.zeros(self.max_batch_size, device=device, dtype=torch.long) | |
| self._graph_output_audio = None | |
| self._graph_output_audio_lengths = None | |
| self._cuda_graph = None | |
| self._cuda_graph_key = None | |
| def _snapshot_decoder_streaming_states(self) -> list[tuple[StreamingState, dict[str, torch.Tensor | None]]]: | |
| snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]] = [] | |
| for streaming_state in self._decoder_streaming_states(): | |
| state_snapshot: dict[str, torch.Tensor | None] = {"exec_mask": streaming_state.exec_mask.clone()} | |
| if isinstance(streaming_state, TransformerState): | |
| state_snapshot["offsets"] = streaming_state.offsets.clone() | |
| if isinstance(streaming_state, MHAState): | |
| state_snapshot["offset"] = streaming_state.offset.clone() | |
| state_snapshot["cached_keys"] = None if streaming_state.cached_keys is None else streaming_state.cached_keys.clone() | |
| state_snapshot["cached_values"] = None if streaming_state.cached_values is None else streaming_state.cached_values.clone() | |
| state_snapshot["cached_positions"] = ( | |
| None if streaming_state.cached_positions is None else streaming_state.cached_positions.clone() | |
| ) | |
| state_snapshot["flash_cached_keys"] = ( | |
| None | |
| if getattr(streaming_state, "_flash_cached_keys", None) is None | |
| else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_keys")).clone() | |
| ) | |
| state_snapshot["flash_cached_values"] = ( | |
| None | |
| if getattr(streaming_state, "_flash_cached_values", None) is None | |
| else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_values")).clone() | |
| ) | |
| snapshots.append((streaming_state, state_snapshot)) | |
| return snapshots | |
| def _restore_decoder_streaming_states( | |
| self, | |
| snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]], | |
| ) -> None: | |
| for streaming_state, state_snapshot in snapshots: | |
| exec_mask = state_snapshot["exec_mask"] | |
| assert exec_mask is not None | |
| streaming_state.exec_mask.copy_(exec_mask) | |
| if isinstance(streaming_state, TransformerState): | |
| offsets = state_snapshot.get("offsets") | |
| assert offsets is not None | |
| streaming_state.offsets.copy_(offsets) | |
| if isinstance(streaming_state, MHAState): | |
| offset = state_snapshot.get("offset") | |
| assert offset is not None | |
| streaming_state.offset.copy_(offset) | |
| cached_keys = state_snapshot.get("cached_keys") | |
| cached_values = state_snapshot.get("cached_values") | |
| cached_positions = state_snapshot.get("cached_positions") | |
| if cached_keys is None or cached_values is None or cached_positions is None: | |
| if streaming_state.cached_keys is not None: | |
| streaming_state.cached_keys.zero_() | |
| if streaming_state.cached_values is not None: | |
| streaming_state.cached_values.zero_() | |
| if streaming_state.cached_positions is not None: | |
| streaming_state.cached_positions.fill_(-1) | |
| else: | |
| if streaming_state.cached_keys is None or streaming_state.cached_keys.shape != cached_keys.shape: | |
| streaming_state.cached_keys = cached_keys.clone() | |
| else: | |
| streaming_state.cached_keys.copy_(cached_keys) | |
| if streaming_state.cached_values is None or streaming_state.cached_values.shape != cached_values.shape: | |
| streaming_state.cached_values = cached_values.clone() | |
| else: | |
| streaming_state.cached_values.copy_(cached_values) | |
| if streaming_state.cached_positions is None or streaming_state.cached_positions.shape != cached_positions.shape: | |
| streaming_state.cached_positions = cached_positions.clone() | |
| else: | |
| streaming_state.cached_positions.copy_(cached_positions) | |
| flash_cached_keys = state_snapshot.get("flash_cached_keys") | |
| flash_cached_values = state_snapshot.get("flash_cached_values") | |
| current_flash_cached_keys = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_keys", None)) | |
| current_flash_cached_values = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_values", None)) | |
| if flash_cached_keys is None or flash_cached_values is None: | |
| if current_flash_cached_keys is not None: | |
| current_flash_cached_keys.zero_() | |
| if current_flash_cached_values is not None: | |
| current_flash_cached_values.zero_() | |
| else: | |
| if current_flash_cached_keys is None or current_flash_cached_keys.shape != flash_cached_keys.shape: | |
| setattr(streaming_state, "_flash_cached_keys", flash_cached_keys.clone()) | |
| else: | |
| current_flash_cached_keys.copy_(flash_cached_keys) | |
| if current_flash_cached_values is None or current_flash_cached_values.shape != flash_cached_values.shape: | |
| setattr(streaming_state, "_flash_cached_values", flash_cached_values.clone()) | |
| else: | |
| current_flash_cached_values.copy_(flash_cached_values) | |
| def _graphed_decode_frame( | |
| self, | |
| codes: torch.Tensor, | |
| code_lengths: torch.Tensor, | |
| ) -> MossAudioTokenizerDecoderOutput: | |
| self._ensure_cuda_graph_buffers(codes.device) | |
| graph_input_codes = self._graph_input_codes | |
| graph_input_code_lengths = self._graph_input_code_lengths | |
| if graph_input_codes is None or graph_input_code_lengths is None: | |
| raise RuntimeError("CUDA graph buffers are unavailable.") | |
| num_quantizers = codes.shape[0] | |
| graph_input_codes_view = graph_input_codes[:num_quantizers] | |
| graph_input_codes_view.copy_(codes) | |
| graph_input_code_lengths.copy_(code_lengths) | |
| cuda_graph_key = (str(codes.device), self.max_batch_size, num_quantizers, self.model.compute_dtype_name) | |
| if self._cuda_graph is None or self._cuda_graph_key != cuda_graph_key: | |
| state_snapshots = self._snapshot_decoder_streaming_states() | |
| current_stream = torch.cuda.current_stream(device=codes.device) | |
| warmup_stream = torch.cuda.Stream(device=codes.device) | |
| warmup_stream.wait_stream(current_stream) | |
| with torch.cuda.stream(warmup_stream): | |
| _ = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths) | |
| current_stream.wait_stream(warmup_stream) | |
| self._restore_decoder_streaming_states(state_snapshots) | |
| cuda_graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(cuda_graph): | |
| decoder_output = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths) | |
| self._cuda_graph = cuda_graph | |
| self._cuda_graph_key = cuda_graph_key | |
| self._graph_output_audio = decoder_output.audio | |
| self._graph_output_audio_lengths = decoder_output.audio_lengths | |
| else: | |
| self._cuda_graph.replay() | |
| return MossAudioTokenizerDecoderOutput( | |
| audio=self._graph_output_audio, | |
| audio_lengths=self._graph_output_audio_lengths, | |
| ) | |
| def _reset_slot(self, slot_index: int) -> None: | |
| for streaming_state in self._decoder_streaming_states(): | |
| reset_mask = torch.zeros(streaming_state.batch_size, dtype=torch.bool, device=streaming_state.exec_mask.device) | |
| reset_mask[slot_index] = True | |
| streaming_state.reset(reset_mask) | |
| def _pack_logical_codes_to_physical_slots( | |
| self, | |
| request_ids: list[str | int], | |
| codes: torch.Tensor, | |
| code_lengths: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, list[int], torch.Tensor]: | |
| if request_ids != self.active_request_ids: | |
| raise ValueError(_INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE) | |
| if not request_ids: | |
| raise ValueError("`step()` requires at least one active request.") | |
| if codes.dim() == 2: | |
| codes = codes.unsqueeze(1) | |
| if codes.dim() != 3: | |
| raise ValueError(f"`codes` must be 3D with shape `(num_quantizers, batch_size, sequence_length)`, got {codes.shape}.") | |
| code_lengths = code_lengths.to(device=codes.device, dtype=torch.long) | |
| if code_lengths.dim() != 1: | |
| raise ValueError(f"`code_lengths` must be 1D with shape `(batch_size,)`, got {code_lengths.shape}.") | |
| num_quantizers, logical_batch_size, max_code_length = codes.shape | |
| if logical_batch_size != len(request_ids): | |
| raise ValueError( | |
| f"`codes.shape[1]` ({logical_batch_size}) must match len(`request_ids`) ({len(request_ids)})." | |
| ) | |
| if code_lengths.shape[0] != logical_batch_size: | |
| raise ValueError( | |
| f"`code_lengths.shape[0]` ({code_lengths.shape[0]}) must match len(`request_ids`) ({len(request_ids)})." | |
| ) | |
| if torch.any(code_lengths < 0): | |
| raise ValueError("`code_lengths` must be >= 0.") | |
| if torch.any(code_lengths > max_code_length): | |
| raise ValueError(f"`code_lengths` must be <= codes.shape[-1] ({max_code_length}).") | |
| packed_codes = codes.new_zeros((num_quantizers, self.max_batch_size, max_code_length)) | |
| packed_code_lengths = code_lengths.new_zeros((self.max_batch_size,)) | |
| logical_row_to_slot_index: list[int] = [] | |
| for logical_row_index, request_id in enumerate(request_ids): | |
| slot_index = self.request_id_to_slot_index[request_id] | |
| logical_row_to_slot_index.append(slot_index) | |
| row_length = int(code_lengths[logical_row_index].item()) | |
| if row_length > 0: | |
| packed_codes[:, slot_index, :row_length] = codes[:, logical_row_index, :row_length] | |
| packed_code_lengths[slot_index] = row_length | |
| return packed_codes, packed_code_lengths, logical_row_to_slot_index, code_lengths | |
| def _advance_request_progress( | |
| self, | |
| request_ids: list[str | int], | |
| code_lengths: torch.Tensor, | |
| audio_lengths: torch.Tensor, | |
| ) -> None: | |
| for logical_row_index, request_id in enumerate(request_ids): | |
| self.request_id_to_code_offset[request_id] += int(code_lengths[logical_row_index].item()) | |
| self.request_id_to_audio_offset[request_id] += int(audio_lengths[logical_row_index].item()) | |
| def step( | |
| self, | |
| request_ids: list[str | int], | |
| codes: torch.Tensor, | |
| code_lengths: torch.Tensor, | |
| ) -> tuple[list[str | int], torch.Tensor, torch.Tensor]: | |
| self._ensure_open() | |
| packed_codes, packed_code_lengths, logical_row_to_slot_index, logical_code_lengths = ( | |
| self._pack_logical_codes_to_physical_slots( | |
| request_ids=request_ids, | |
| codes=codes, | |
| code_lengths=code_lengths, | |
| ) | |
| ) | |
| max_step_length = int(packed_code_lengths.max().item()) | |
| if max_step_length <= 0: | |
| raise ValueError("`step()` requires at least one row with `code_length > 0`.") | |
| decoder_streaming_states = self._decoder_streaming_states() | |
| logical_audio_chunks: list[list[torch.Tensor]] = [[] for _ in request_ids] | |
| audio_device: torch.device | None = None | |
| audio_dtype: torch.dtype | None = None | |
| audio_num_channels: int | None = None | |
| try: | |
| for frame_index in range(max_step_length): | |
| frame_exec_mask = packed_code_lengths > frame_index | |
| for streaming_state in decoder_streaming_states: | |
| streaming_state.set_exec_mask(frame_exec_mask) | |
| frame_codes = packed_codes[:, :, frame_index : frame_index + 1] | |
| frame_code_lengths = frame_exec_mask.to(dtype=packed_code_lengths.dtype) | |
| if self._use_cuda_graph and frame_codes.is_cuda: | |
| decoder_output = self._graphed_decode_frame(frame_codes, frame_code_lengths) | |
| else: | |
| decoder_output = self.model._decode_frame(frame_codes, frame_code_lengths) | |
| if decoder_output.audio is None or decoder_output.audio_lengths is None: | |
| raise RuntimeError("Internal error: `_decode_frame` returned empty audio.") | |
| audio = decoder_output.audio | |
| audio_lengths = decoder_output.audio_lengths | |
| audio_device = audio.device | |
| audio_dtype = audio.dtype | |
| audio_num_channels = audio.shape[1] | |
| for logical_row_index, slot_index in enumerate(logical_row_to_slot_index): | |
| audio_length = int(audio_lengths[slot_index].item()) | |
| if audio_length <= 0: | |
| continue | |
| logical_audio_chunks[logical_row_index].append(audio[slot_index : slot_index + 1, :, :audio_length]) | |
| except Exception: | |
| self.close() | |
| raise | |
| finally: | |
| for streaming_state in decoder_streaming_states: | |
| streaming_state.set_exec_mask(torch.ones_like(streaming_state.exec_mask)) | |
| if audio_device is None or audio_dtype is None or audio_num_channels is None: | |
| raise RuntimeError("Internal error: `step()` produced no decoder outputs.") | |
| logical_audio_rows: list[torch.Tensor] = [] | |
| logical_audio_lengths: list[int] = [] | |
| for row_chunks in logical_audio_chunks: | |
| if row_chunks: | |
| row_audio = torch.cat(row_chunks, dim=-1) | |
| else: | |
| row_audio = torch.zeros((1, audio_num_channels, 0), device=audio_device, dtype=audio_dtype) | |
| logical_audio_rows.append(row_audio) | |
| logical_audio_lengths.append(row_audio.shape[-1]) | |
| audio_lengths = torch.tensor(logical_audio_lengths, device=audio_device, dtype=torch.long) | |
| max_audio_length = max(logical_audio_lengths) | |
| audio = torch.zeros( | |
| (len(request_ids), audio_num_channels, max_audio_length), | |
| device=audio_device, | |
| dtype=audio_dtype, | |
| ) | |
| for logical_row_index, row_audio in enumerate(logical_audio_rows): | |
| row_audio_length = row_audio.shape[-1] | |
| if row_audio_length > 0: | |
| audio[logical_row_index, :, :row_audio_length] = row_audio[0] | |
| logical_request_ids = list(request_ids) | |
| self._advance_request_progress( | |
| request_ids=logical_request_ids, | |
| code_lengths=logical_code_lengths, | |
| audio_lengths=audio_lengths, | |
| ) | |
| return logical_request_ids, audio, audio_lengths | |
| def remove(self, request_id: str | int) -> None: | |
| self._ensure_open() | |
| slot_index = self.request_id_to_slot_index.get(request_id) | |
| if slot_index is None or request_id not in self.active_request_ids: | |
| raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id)) | |
| if self.slot_is_free[slot_index] or self.slot_index_to_request_id[slot_index] != request_id: | |
| raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id)) | |
| self.active_request_ids.remove(request_id) | |
| self._reset_slot(slot_index) | |
| _ = self.request_id_to_slot_index.pop(request_id) | |
| self.slot_index_to_request_id[slot_index] = None | |
| self.slot_is_free[slot_index] = True | |
| _ = self.request_id_to_code_offset.pop(request_id, None) | |
| _ = self.request_id_to_audio_offset.pop(request_id, None) | |
| def close(self) -> None: | |
| if self._closed: | |
| return | |
| self._closed = True | |
| decode_streaming_exit_stack = self._decode_streaming_exit_stack | |
| self._decode_streaming_exit_stack = None | |
| try: | |
| if decode_streaming_exit_stack is not None: | |
| decode_streaming_exit_stack.close() | |
| finally: | |
| for module in self._flash_kvcache_attention_modules: | |
| module._use_flash_kvcache = False | |
| self._flash_kvcache_attention_modules = [] | |
| self._cuda_graph = None | |
| self._cuda_graph_key = None | |
| self._graph_input_codes = None | |
| self._graph_input_code_lengths = None | |
| self._graph_output_audio = None | |
| self._graph_output_audio_lengths = None | |
| if self.model._active_decode_session is self: | |
| self.model._active_decode_session = None | |
| # ============================================================================= | |
| # Normalization Layers | |
| # ============================================================================= | |
| class MossAudioTokenizerRMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization.""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| eps: float = 1e-5, | |
| dtype: torch.dtype | None = None, | |
| device=None, | |
| ): | |
| super().__init__() | |
| self.eps = eps | |
| self.dtype = dtype | |
| self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)) | |
| def forward(self, x: torch.Tensor): | |
| x_dtype = x.dtype | |
| if self.dtype is not None: | |
| x = x.to(self.dtype) | |
| var = self.eps + torch.mean(x**2, dim=-1, keepdim=True) | |
| alpha = self.alpha.to(var) | |
| if x.dim() == 2: | |
| alpha = alpha.view(1, -1) | |
| y = (x * (alpha * torch.rsqrt(var))).to(x_dtype) | |
| return y | |
| class MossAudioTokenizerLayerScale(nn.Module): | |
| """Layer scale from Touvron et al. 2021.""" | |
| def __init__( | |
| self, | |
| channels: int, | |
| init: float = 1e-4, | |
| channel_last: bool = True, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| self.channel_last = channel_last | |
| self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype)) | |
| def forward(self, x: torch.Tensor): | |
| if self.channel_last: | |
| return self.scale * x | |
| else: | |
| return self.scale[:, None] * x | |
| def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: | |
| """Create normalization module.""" | |
| if norm_type == "layer_norm": | |
| return nn.LayerNorm(dim, eps=1e-5, **kwargs) | |
| elif norm_type in {"rms_norm"}: | |
| return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs) | |
| elif norm_type in {"rms_norm_f32"}: | |
| kwargs.pop("dtype", None) | |
| return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs) | |
| else: | |
| raise ValueError(f"Unknown norm type: {norm_type}") | |
| # ============================================================================= | |
| # Rotary Position Embedding | |
| # ============================================================================= | |
| def apply_rope( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| offset: torch.Tensor, | |
| max_period: float = 10_000, | |
| time_before_heads: bool = False, | |
| ): | |
| """Apply rotary position embedding.""" | |
| if time_before_heads: | |
| B, T, H, D = q.shape | |
| else: | |
| B, H, T, D = q.shape | |
| if k.shape != q.shape: | |
| raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}") | |
| if D <= 0 or (D % 2) != 0: | |
| raise ValueError(f"RoPE requires an even last dimension, got D={D}") | |
| ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) | |
| freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) | |
| ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32) | |
| if time_before_heads: | |
| ts = ts.view(B, -1, 1, 1) | |
| else: | |
| ts = ts.view(B, 1, -1, 1) | |
| dims = q.shape[:-1] | |
| q = q.view(*dims, D // 2, 2) | |
| k = k.view(*dims, D // 2, 2) | |
| qr, qi = q[..., 0].float(), q[..., 1].float() | |
| kr, ki = k[..., 0].float(), k[..., 1].float() | |
| rotr = torch.cos(freqs * ts) | |
| roti = torch.sin(freqs * ts) | |
| qor = qr * rotr - qi * roti | |
| qoi = qr * roti + qi * rotr | |
| kor = kr * rotr - ki * roti | |
| koi = kr * roti + ki * rotr | |
| dtype = q.dtype | |
| qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) | |
| ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) | |
| return qo.view(*dims, D), ko.view(*dims, D) | |
| def apply_rope_with_positions( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| positions: torch.Tensor, | |
| max_period: float = 10_000, | |
| ): | |
| """Apply rotary position embedding to packed `[N, H, D]` tensors.""" | |
| N, H, D = q.shape | |
| if k.shape != q.shape: | |
| raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}") | |
| if D <= 0 or (D % 2) != 0: | |
| raise ValueError(f"RoPE requires an even last dimension, got D={D}") | |
| ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) | |
| freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) | |
| ts = positions.to(torch.float32).view(N, 1, 1) | |
| qr = q.float().view(N, H, D // 2, 2)[..., 0] | |
| qi = q.float().view(N, H, D // 2, 2)[..., 1] | |
| kr = k.float().view(N, H, D // 2, 2)[..., 0] | |
| ki = k.float().view(N, H, D // 2, 2)[..., 1] | |
| rotr = torch.cos(ts * freqs.view(1, 1, -1)) | |
| roti = torch.sin(ts * freqs.view(1, 1, -1)) | |
| qor = qr * rotr - qi * roti | |
| qoi = qr * roti + qi * rotr | |
| kor = kr * rotr - ki * roti | |
| koi = kr * roti + ki * rotr | |
| dtype = q.dtype | |
| qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) | |
| ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) | |
| return qo.view(N, H, D), ko.view(N, H, D) | |
| class MossAudioTokenizerRotaryEmbedding(nn.Module): | |
| """Rotary positional embedding (RoPE).""" | |
| def __init__(self, max_period: float = 10000.0): | |
| super().__init__() | |
| self.max_period = max_period | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| offset: torch.Tensor, | |
| time_before_heads: bool = False, | |
| ): | |
| return apply_rope(q, k, offset, self.max_period, time_before_heads) | |
| # ============================================================================= | |
| # Gating Modules | |
| # ============================================================================= | |
| class MossAudioTokenizerActivationGating(nn.Module): | |
| """Gating FFN layer with activation.""" | |
| def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): | |
| super().__init__() | |
| if dim_feedforward == 4 * dim: | |
| hidden = (21 * dim) // 8 | |
| else: | |
| hidden = (2 * dim_feedforward) // 3 | |
| self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) | |
| self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) | |
| self.activation = activation | |
| def forward(self, x: torch.Tensor): | |
| x = self.linear_in(x) | |
| B, T, _ = x.shape | |
| x = x.view(B, T, 2, -1) | |
| x = self.activation(x[..., 0, :]) * x[..., 1, :] | |
| x = self.linear_out(x) | |
| return x | |
| def _get_activation(name: str): | |
| if name in ["sigmoid", "tanh", "relu"]: | |
| return getattr(torch, name) | |
| elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: | |
| return getattr(F, name) | |
| elif name == "identity": | |
| return nn.Identity() | |
| else: | |
| raise ValueError(f"Unknown activation {name}") | |
| def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module: | |
| return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs) | |
| # ============================================================================= | |
| # Positional Embeddings | |
| # ============================================================================= | |
| def create_sin_embedding( | |
| positions: torch.Tensor, | |
| dim: int, | |
| max_period: float = 10000, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| """Create sinusoidal positional embedding with shape [..., C].""" | |
| if dim % 2 != 0: | |
| raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}") | |
| half_dim = dim // 2 | |
| if half_dim <= 1: | |
| raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}") | |
| if positions.dim() == 0: | |
| positions = positions.view(1) | |
| positions = positions.to(dtype).unsqueeze(-1) | |
| adim = torch.arange(half_dim, device=positions.device, dtype=dtype) | |
| max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) | |
| phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) | |
| return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | |
| def pack_padded_sequence( | |
| x: torch.Tensor, | |
| input_lengths: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Pack a padded `[B, T, D]` tensor into `[N, D]` plus metadata.""" | |
| batch_size, max_seqlen, _ = x.shape | |
| positions = torch.arange(max_seqlen, device=x.device, dtype=torch.long) | |
| valid_mask = positions.view(1, max_seqlen) < input_lengths.view(batch_size, 1) | |
| packed_x = x[valid_mask] | |
| cu_seqlens = torch.zeros(batch_size + 1, device=x.device, dtype=torch.int32) | |
| cu_seqlens[1:] = torch.cumsum(input_lengths.to(torch.int32), dim=0) | |
| position_ids = positions.view(1, max_seqlen).expand(batch_size, -1)[valid_mask] | |
| return packed_x, valid_mask, cu_seqlens, position_ids | |
| def unpack_packed_sequence( | |
| packed_x: torch.Tensor, | |
| valid_mask: torch.Tensor, | |
| batch_size: int, | |
| max_seqlen: int, | |
| ) -> torch.Tensor: | |
| """Unpack a packed `[N, D]` tensor back into `[B, T, D]`.""" | |
| output = packed_x.new_zeros((batch_size, max_seqlen, packed_x.shape[-1])) | |
| output[valid_mask] = packed_x | |
| return output | |
| # ============================================================================= | |
| # KV Cache for Attention | |
| # ============================================================================= | |
| class KVCacheResult: | |
| """Container for KV cache results that supports tuple unpacking.""" | |
| __slots__ = ("keys", "values", "positions") | |
| def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor): | |
| self.keys = keys | |
| self.values = values | |
| self.positions = positions | |
| def __iter__(self): | |
| """Allow unpacking as (keys, values, positions).""" | |
| return iter((self.keys, self.values, self.positions)) | |
| def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult: | |
| B, H, T, D = keys.shape | |
| positions = torch.arange(T, device=keys.device, dtype=torch.long) | |
| return KVCacheResult(keys, values, positions.expand(B, -1)) | |
| class RingKVCache: | |
| """Efficient streaming KVCache compatible with CUDA Graph.""" | |
| def __init__( | |
| self, | |
| batch_size: int, | |
| num_heads: int, | |
| dim_per_head: int, | |
| capacity: int, | |
| respect_exec_mask: bool = True, | |
| device: torch.device = torch.device("cuda"), | |
| dtype: torch.dtype = torch.bfloat16, | |
| ): | |
| self.capacity = capacity | |
| self.cache = torch.zeros( | |
| (2, batch_size, num_heads, capacity, dim_per_head), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.respect_exec_mask = respect_exec_mask | |
| if self.respect_exec_mask: | |
| self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| else: | |
| self.end_offset = torch.zeros(1, device=device, dtype=torch.long) | |
| def reset(self, reset_mask: torch.Tensor) -> None: | |
| self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset) | |
| def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult: | |
| B, H, T, D = k.shape | |
| if T <= 0: | |
| raise ValueError(f"Expected T > 0, got T={T}") | |
| indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) | |
| indexes = indexes + self.end_offset.view(-1, 1) | |
| indexes = indexes % self.capacity | |
| if self.respect_exec_mask: | |
| this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D) | |
| self.cache[0].scatter_(2, this_indexes, k) | |
| self.cache[1].scatter_(2, this_indexes, v) | |
| else: | |
| self.cache[0].index_copy_(2, indexes[0], k) | |
| self.cache[1].index_copy_(2, indexes[0], v) | |
| keys = self.cache[0] | |
| values = self.cache[1] | |
| indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long) | |
| last_offset = self.end_offset.view(-1, 1) + T - 1 | |
| end_index = last_offset % self.capacity | |
| delta = indexes - end_index | |
| positions = torch.where( | |
| delta <= 0, | |
| last_offset + delta, | |
| last_offset + delta - self.capacity, | |
| ) | |
| if self.respect_exec_mask: | |
| self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset) | |
| else: | |
| self.end_offset.add_(T) | |
| invalid = indexes >= self.end_offset.view(-1, 1) | |
| positions = torch.where(invalid, torch.full_like(positions, -1), positions) | |
| return KVCacheResult(keys, values, positions) | |
| # ============================================================================= | |
| # Multi-Head Attention | |
| # ============================================================================= | |
| _sync_module_proxy() | |
| class MHAState(StreamingState): | |
| cached_keys: torch.Tensor | None | |
| cached_values: torch.Tensor | None | |
| cached_positions: torch.Tensor | None | |
| offset: torch.Tensor | |
| def reset(self, reset_mask: torch.Tensor): | |
| super().reset(reset_mask) | |
| self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset) | |
| if self.cached_positions is not None: | |
| self.cached_positions[reset_mask] = -1 | |
| if self.cached_keys is not None: | |
| self.cached_keys[reset_mask] = 0 | |
| if self.cached_values is not None: | |
| self.cached_values[reset_mask] = 0 | |
| def apply_weights_per_step( | |
| modules: nn.ModuleList, | |
| schedule: list[int] | None, | |
| x: torch.Tensor, | |
| offset: int | None, | |
| ) -> torch.Tensor: | |
| """Apply different weights for each time step.""" | |
| if len(modules) == 1: | |
| return modules[0](x) | |
| if offset is None: | |
| raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).") | |
| if x.dim() != 3: | |
| raise ValueError( | |
| f"Per-step weights require a dense `[B, T, C]` tensor when len(modules) > 1, got shape {tuple(x.shape)}." | |
| ) | |
| ys = [] | |
| B, T, C = x.shape | |
| for t in range(T): | |
| module_index = t + offset | |
| if schedule is not None: | |
| if module_index >= len(schedule) or module_index < 0: | |
| raise ValueError( | |
| f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})." | |
| ) | |
| module_index = schedule[module_index] | |
| if module_index >= len(modules) or module_index < 0: | |
| raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.") | |
| y = modules[module_index](x[:, t : t + 1]) | |
| ys.append(y) | |
| return torch.cat(ys, 1) | |
| class MossAudioTokenizerMultiheadAttention(StreamingModule): | |
| """Multi-head attention with streaming support.""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| causal: bool = False, | |
| context: int | None = None, | |
| rope: MossAudioTokenizerRotaryEmbedding | None = None, | |
| attention_implementation: str = "sdpa", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.embed_dim = embed_dim | |
| self.causal = causal | |
| self.context = context | |
| self.rope = rope | |
| self.num_heads = num_heads | |
| if attention_implementation not in SUPPORTED_ATTENTION_IMPLEMENTATIONS: | |
| raise ValueError( | |
| f"Unsupported attention_implementation={attention_implementation!r}. " | |
| f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}." | |
| ) | |
| self.attention_implementation = attention_implementation | |
| self._use_flash_kvcache = False | |
| self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False, **factory_kwargs) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) | |
| self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) | |
| def set_attention_implementation(self, attention_implementation: str) -> None: | |
| if attention_implementation not in SUPPORTED_ATTENTION_IMPLEMENTATIONS: | |
| raise ValueError( | |
| f"Unsupported attention_implementation={attention_implementation!r}. " | |
| f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}." | |
| ) | |
| self.attention_implementation = attention_implementation | |
| def _load_hook(module, state_dict, prefix, *_): | |
| mappings = { | |
| "in_proj_weight": "in_proj.weight", | |
| "in_projs.0.weight": "in_proj.weight", | |
| "out_projs.0.weight": "out_proj.weight", | |
| } | |
| for suffix in ["", "_scb"]: | |
| for source, target in mappings.items(): | |
| this_source = prefix + source + suffix | |
| if this_source in state_dict: | |
| state_dict[prefix + target + suffix] = state_dict.pop(this_source) | |
| def _init_streaming_state(self, batch_size: int) -> MHAState: | |
| device = cast(torch.device, self.in_proj.weight.device) | |
| return MHAState( | |
| batch_size, | |
| device, | |
| cached_keys=None, | |
| cached_values=None, | |
| cached_positions=None, | |
| offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long), | |
| ) | |
| def _supports_flash_attention(self, device: torch.device, dtype: torch.dtype) -> bool: | |
| return _has_flash_attn() and device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} | |
| def _get_backend_check_dtype(self, x: torch.Tensor) -> torch.dtype: | |
| if x.device.type != "cuda": | |
| return x.dtype | |
| try: | |
| autocast_enabled = torch.is_autocast_enabled("cuda") | |
| except TypeError: | |
| autocast_enabled = torch.is_autocast_enabled() | |
| if not autocast_enabled: | |
| return x.dtype | |
| try: | |
| return torch.get_autocast_dtype("cuda") | |
| except TypeError: | |
| return torch.get_autocast_gpu_dtype() | |
| def resolve_attention_implementation(self, x: torch.Tensor, is_streaming: bool) -> str: | |
| if self.attention_implementation == "sdpa": | |
| return "sdpa" | |
| backend_dtype = self._get_backend_check_dtype(x) | |
| if self._supports_flash_attention(x.device, backend_dtype): | |
| return "flash_attention_2" | |
| if self.attention_implementation == "flash_attention_2": | |
| logger.warning_once( | |
| "Falling back to SDPA because flash_attention_2 is unavailable for device=%s dtype=%s " | |
| "(HAS_FLASH_ATTN=%s).", | |
| x.device, | |
| backend_dtype, | |
| _has_flash_attn(), | |
| ) | |
| return "sdpa" | |
| def _project_qkv( | |
| self, | |
| x: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| dim_per_head = self.embed_dim // self.num_heads | |
| if x.dim() == 3: | |
| projected = self.in_proj(x) | |
| projected = projected.reshape(x.shape[0], x.shape[1], 3, self.num_heads, dim_per_head).permute( | |
| 2, 0, 3, 1, 4 | |
| ) | |
| return projected[0], projected[1], projected[2] | |
| if x.dim() == 2: | |
| projected = self.in_proj(x) | |
| projected = projected.view(x.shape[0], 3, self.num_heads, dim_per_head) | |
| return projected[:, 0], projected[:, 1], projected[:, 2] | |
| raise ValueError(f"Expected a 2D or 3D tensor, got shape {tuple(x.shape)}") | |
| def _apply_dense_rope(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| if self.rope is None: | |
| return q, k | |
| offset = torch.zeros(q.shape[0], device=q.device, dtype=torch.long) | |
| return self.rope(q, k, offset, time_before_heads=False) | |
| def _apply_packed_rope( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if self.rope is None: | |
| return q, k | |
| return apply_rope_with_positions(q, k, position_ids, max_period=self.rope.max_period) | |
| def _ensure_streaming_cache( | |
| self, | |
| state: MHAState, | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| head_dim = self.embed_dim // self.num_heads | |
| cache_length = 0 if self.context is None else self.context | |
| if state.cached_keys is None or state.cached_values is None or state.cached_positions is None: | |
| state.cached_keys = torch.zeros( | |
| (batch_size, self.num_heads, cache_length, head_dim), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| state.cached_values = torch.zeros_like(state.cached_keys) | |
| state.cached_positions = torch.full( | |
| (batch_size, cache_length), | |
| -1, | |
| device=device, | |
| dtype=torch.long, | |
| ) | |
| else: | |
| if state.cached_keys.device != device or state.cached_keys.dtype != dtype: | |
| state.cached_keys = state.cached_keys.to(device=device, dtype=dtype) | |
| if state.cached_values.device != device or state.cached_values.dtype != dtype: | |
| state.cached_values = state.cached_values.to(device=device, dtype=dtype) | |
| if state.cached_positions.device != device: | |
| state.cached_positions = state.cached_positions.to(device=device) | |
| return state.cached_keys, state.cached_values, state.cached_positions | |
| def _ensure_flash_kvcache( | |
| self, | |
| state: MHAState, | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if self.context is None: | |
| raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.") | |
| head_dim = self.embed_dim // self.num_heads | |
| flash_cached_keys = cast(torch.Tensor | None, getattr(state, "_flash_cached_keys", None)) | |
| flash_cached_values = cast(torch.Tensor | None, getattr(state, "_flash_cached_values", None)) | |
| if flash_cached_keys is None or flash_cached_values is None: | |
| flash_cached_keys = torch.zeros( | |
| (batch_size, self.context, self.num_heads, head_dim), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| flash_cached_values = torch.zeros_like(flash_cached_keys) | |
| else: | |
| if flash_cached_keys.device != device or flash_cached_keys.dtype != dtype: | |
| flash_cached_keys = flash_cached_keys.to(device=device, dtype=dtype) | |
| if flash_cached_values.device != device or flash_cached_values.dtype != dtype: | |
| flash_cached_values = flash_cached_values.to(device=device, dtype=dtype) | |
| setattr(state, "_flash_cached_keys", flash_cached_keys) | |
| setattr(state, "_flash_cached_values", flash_cached_values) | |
| return flash_cached_keys, flash_cached_values | |
| def _build_streaming_kv( | |
| self, | |
| cached_k: torch.Tensor, | |
| cached_v: torch.Tensor, | |
| cached_pos: torch.Tensor, | |
| k_cur: torch.Tensor, | |
| v_cur: torch.Tensor, | |
| pos_q: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| k_all = torch.cat([cached_k, k_cur], dim=2) | |
| v_all = torch.cat([cached_v, v_cur], dim=2) | |
| pos_k = torch.cat([cached_pos, pos_q], dim=1) | |
| return k_all, v_all, pos_k | |
| def _update_streaming_cache( | |
| self, | |
| state: MHAState, | |
| cached_k: torch.Tensor, | |
| cached_v: torch.Tensor, | |
| cached_pos: torch.Tensor, | |
| k_all: torch.Tensor, | |
| v_all: torch.Tensor, | |
| pos_k: torch.Tensor, | |
| ) -> None: | |
| exec_mask = state.exec_mask.view(-1, 1, 1, 1) | |
| exec_mask_pos = state.exec_mask.view(-1, 1) | |
| if self.context is None: | |
| if not bool(state.exec_mask.all().item()): | |
| raise RuntimeError("Streaming exec_mask with context=None is not supported.") | |
| state.cached_keys = k_all.contiguous() | |
| state.cached_values = v_all.contiguous() | |
| state.cached_positions = pos_k.contiguous() | |
| return | |
| assert state.cached_keys is not None | |
| assert state.cached_values is not None | |
| assert state.cached_positions is not None | |
| new_cached_k = k_all[:, :, -self.context :, :].contiguous() | |
| new_cached_v = v_all[:, :, -self.context :, :].contiguous() | |
| new_cached_pos = pos_k[:, -self.context :].contiguous() | |
| state.cached_keys.copy_(torch.where(exec_mask, new_cached_k, cached_k)) | |
| state.cached_values.copy_(torch.where(exec_mask, new_cached_v, cached_v)) | |
| state.cached_positions.copy_(torch.where(exec_mask_pos, new_cached_pos, cached_pos)) | |
| def _build_streaming_sdpa_bias(self, pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor: | |
| delta = pos_q[:, :, None] - pos_k[:, None, :] | |
| attn_bias = (pos_k[:, None, :] >= 0) & (delta >= 0) | |
| if self.context is not None: | |
| attn_bias = attn_bias & (delta < self.context) | |
| return attn_bias[:, None, :, :] | |
| def _build_non_streaming_sdpa_bias( | |
| self, | |
| input_lengths: torch.Tensor, | |
| max_seqlen: int, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| positions = torch.arange(max_seqlen, device=device, dtype=torch.long) | |
| valid_k = positions.view(1, 1, max_seqlen) < input_lengths.view(-1, 1, 1) | |
| if not self.causal and self.context is None: | |
| return valid_k[:, None, :, :].expand(-1, 1, max_seqlen, -1) | |
| delta = positions.view(1, max_seqlen, 1) - positions.view(1, 1, max_seqlen) | |
| attn_bias = torch.ones((1, max_seqlen, max_seqlen), device=device, dtype=torch.bool) | |
| if self.causal: | |
| attn_bias = attn_bias & (delta >= 0) | |
| if self.context is not None: | |
| attn_bias = attn_bias & (delta < self.context) | |
| return (attn_bias & valid_k)[:, None, :, :] | |
| def _run_flash_attention( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| ) -> torch.Tensor: | |
| flash_attn_varlen_func = _get_flash_attn_varlen_func() | |
| if flash_attn_varlen_func is None: | |
| raise RuntimeError("flash-attn is not installed.") | |
| window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1) | |
| return cast( | |
| torch.Tensor, | |
| flash_attn_varlen_func( | |
| q.contiguous(), | |
| k.contiguous(), | |
| v.contiguous(), | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| causal=self.causal, | |
| window_size=window_size, | |
| ), | |
| ) | |
| def _forward_streaming_sdpa(self, x: torch.Tensor, state: MHAState) -> torch.Tensor: | |
| batch_size, chunk_length, _ = x.shape | |
| q, k_cur, v_cur = self._project_qkv(x) | |
| if self.rope is not None: | |
| q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False) | |
| pos_q = state.offset.view(-1, 1) + torch.arange(chunk_length, device=x.device, dtype=torch.long).view(1, -1) | |
| cached_k, cached_v, cached_pos = self._ensure_streaming_cache(state, batch_size, k_cur.device, k_cur.dtype) | |
| k_all, v_all, pos_k = self._build_streaming_kv(cached_k, cached_v, cached_pos, k_cur, v_cur, pos_q) | |
| attn_bias = self._build_streaming_sdpa_bias(pos_q, pos_k) | |
| out = F.scaled_dot_product_attention(q, k_all, v_all, attn_bias, dropout_p=0.0) | |
| out = out.transpose(1, 2).reshape(batch_size, chunk_length, self.embed_dim) | |
| self._update_streaming_cache(state, cached_k, cached_v, cached_pos, k_all, v_all, pos_k) | |
| state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset) | |
| return out | |
| def _forward_streaming_flash(self, x: torch.Tensor, state: MHAState) -> torch.Tensor: | |
| batch_size, chunk_length, _ = x.shape | |
| q, k_cur, v_cur = self._project_qkv(x) | |
| if self.rope is not None: | |
| q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False) | |
| pos_q = state.offset.view(-1, 1) + torch.arange(chunk_length, device=x.device, dtype=torch.long).view(1, -1) | |
| cached_k, cached_v, cached_pos = self._ensure_streaming_cache(state, batch_size, k_cur.device, k_cur.dtype) | |
| k_all, v_all, pos_k = self._build_streaming_kv(cached_k, cached_v, cached_pos, k_cur, v_cur, pos_q) | |
| q_chunks = [] | |
| k_chunks = [] | |
| v_chunks = [] | |
| cu_q = [0] | |
| cu_k = [0] | |
| max_kv_len = 0 | |
| for batch_idx in range(batch_size): | |
| valid_k = pos_k[batch_idx] >= 0 | |
| q_i = q[batch_idx].transpose(0, 1).contiguous() | |
| k_i = k_all[batch_idx, :, valid_k, :].transpose(0, 1).contiguous() | |
| v_i = v_all[batch_idx, :, valid_k, :].transpose(0, 1).contiguous() | |
| q_chunks.append(q_i) | |
| k_chunks.append(k_i) | |
| v_chunks.append(v_i) | |
| cu_q.append(cu_q[-1] + q_i.shape[0]) | |
| cu_k.append(cu_k[-1] + k_i.shape[0]) | |
| max_kv_len = max(max_kv_len, int(k_i.shape[0])) | |
| out_flat = self._run_flash_attention( | |
| torch.cat(q_chunks, dim=0), | |
| torch.cat(k_chunks, dim=0), | |
| torch.cat(v_chunks, dim=0), | |
| torch.tensor(cu_q, device=x.device, dtype=torch.int32), | |
| torch.tensor(cu_k, device=x.device, dtype=torch.int32), | |
| max_seqlen_q=chunk_length, | |
| max_seqlen_k=max_kv_len, | |
| ) | |
| outputs = [] | |
| start = 0 | |
| for _ in range(batch_size): | |
| outputs.append(out_flat[start : start + chunk_length].transpose(0, 1).contiguous()) | |
| start += chunk_length | |
| out = torch.stack(outputs, dim=0) | |
| out = out.transpose(1, 2).reshape(batch_size, chunk_length, self.embed_dim) | |
| self._update_streaming_cache(state, cached_k, cached_v, cached_pos, k_all, v_all, pos_k) | |
| state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset) | |
| return out | |
| def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor: | |
| flash_attn_with_kvcache = _get_flash_attn_with_kvcache() | |
| if self.context is None: | |
| raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.") | |
| if flash_attn_with_kvcache is None: | |
| raise RuntimeError("flash-attn is not installed.") | |
| batch_size, chunk_length, _ = x.shape | |
| q, k_cur, v_cur = self._project_qkv(x) | |
| if self.rope is not None: | |
| q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False) | |
| q = q.transpose(1, 2).contiguous() | |
| k_cur = k_cur.transpose(1, 2).contiguous() | |
| v_cur = v_cur.transpose(1, 2).contiguous() | |
| exec_mask = state.exec_mask.view(batch_size, 1, 1, 1).to(dtype=k_cur.dtype) | |
| k_cur = k_cur * exec_mask | |
| v_cur = v_cur * exec_mask | |
| k_cache, v_cache = self._ensure_flash_kvcache(state, batch_size, k_cur.device, k_cur.dtype) | |
| cache_seqlens = state.offset.clamp(max=self.context).to(torch.int32) | |
| window_size = (self.context - 1, 0) | |
| out = cast( | |
| torch.Tensor, | |
| flash_attn_with_kvcache( | |
| q, | |
| k_cache, | |
| v_cache, | |
| k=k_cur, | |
| v=v_cur, | |
| cache_seqlens=cache_seqlens, | |
| causal=True, | |
| window_size=window_size, | |
| ), | |
| ) | |
| out = out.reshape(batch_size, chunk_length, self.embed_dim) | |
| state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset) | |
| return out | |
| def _forward_non_streaming_sdpa(self, x: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor: | |
| batch_size, max_seqlen, _ = x.shape | |
| q, k, v = self._project_qkv(x) | |
| q, k = self._apply_dense_rope(q, k) | |
| attn_bias = self._build_non_streaming_sdpa_bias(input_lengths, max_seqlen, x.device) | |
| out = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0) | |
| valid_q = (torch.arange(max_seqlen, device=x.device).view(1, max_seqlen) < input_lengths.view(-1, 1)).view( | |
| batch_size, 1, max_seqlen, 1 | |
| ) | |
| # Some SDPA backends return NaNs for fully-masked padded query rows in local-causal attention. | |
| # Multiplying by zero is not sufficient because NaN * 0 is still NaN; use torch.where so padded | |
| # rows are materialized as exact zeros before they can leak into later layers as masked K/V values. | |
| out = torch.where(valid_q, out, torch.zeros((), device=out.device, dtype=out.dtype)) | |
| return out.transpose(1, 2).reshape(batch_size, max_seqlen, self.embed_dim) | |
| def _forward_non_streaming_flash( | |
| self, | |
| x: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| max_seqlen: int, | |
| position_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| q, k, v = self._project_qkv(x) | |
| q, k = self._apply_packed_rope(q, k, position_ids) | |
| out = self._run_flash_attention(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) | |
| return out.reshape(x.shape[0], self.embed_dim) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| cu_seqlens: torch.Tensor | None = None, | |
| max_seqlen: int | None = None, | |
| position_ids: torch.Tensor | None = None, | |
| input_lengths: torch.Tensor | None = None, | |
| ): | |
| state = cast(MHAState | None, self._streaming_state) | |
| backend = self.resolve_attention_implementation(query, is_streaming=state is not None) | |
| if state is not None: | |
| if query.dim() != 3: | |
| raise ValueError(f"Streaming attention expects a 3D tensor, got shape {tuple(query.shape)}") | |
| if backend == "flash_attention_2" and self._use_flash_kvcache: | |
| out = self._forward_streaming_flash_kvcache(query, state) | |
| elif backend == "flash_attention_2": | |
| out = self._forward_streaming_flash(query, state) | |
| else: | |
| out = self._forward_streaming_sdpa(query, state) | |
| return self.out_proj(out) | |
| if backend == "flash_attention_2": | |
| if query.dim() != 2: | |
| raise ValueError(f"Packed flash attention expects a 2D tensor, got shape {tuple(query.shape)}") | |
| if cu_seqlens is None or max_seqlen is None or position_ids is None: | |
| raise ValueError("Packed flash attention requires cu_seqlens, max_seqlen, and position_ids.") | |
| out = self._forward_non_streaming_flash(query, cu_seqlens, max_seqlen, position_ids) | |
| return self.out_proj(out) | |
| if query.dim() != 3: | |
| raise ValueError(f"Non-streaming SDPA expects a 3D tensor, got shape {tuple(query.shape)}") | |
| if input_lengths is None: | |
| raise ValueError("Non-streaming SDPA requires input_lengths.") | |
| out = self._forward_non_streaming_sdpa(query, input_lengths) | |
| return self.out_proj(out) | |
| # ============================================================================= | |
| # Transformer Layer | |
| # ============================================================================= | |
| _sync_module_proxy() | |
| class LayerState(StreamingState): | |
| pass | |
| class MossAudioTokenizerTransformerLayer(StreamingModule): | |
| """Transformer layer with streaming support.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| dim_feedforward: int = 2048, | |
| causal: bool = False, | |
| context: int | None = None, | |
| rope: MossAudioTokenizerRotaryEmbedding | None = None, | |
| attention_implementation: str = "sdpa", | |
| norm: str = "layer_norm", | |
| layer_scale: float | None = None, | |
| gating: str = "none", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.self_attn = MossAudioTokenizerMultiheadAttention( | |
| embed_dim=d_model, | |
| num_heads=num_heads, | |
| causal=causal, | |
| context=context, | |
| rope=rope, | |
| attention_implementation=attention_implementation, | |
| **factory_kwargs, | |
| ) | |
| self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) | |
| self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) | |
| if gating == "none": | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs), | |
| nn.GELU(), | |
| nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs), | |
| ) | |
| else: | |
| self.ffn = make_gating(gating, d_model, dim_feedforward, **factory_kwargs) | |
| if layer_scale is None: | |
| self.layer_scale_1 = nn.Identity() | |
| self.layer_scale_2 = nn.Identity() | |
| else: | |
| self.layer_scale_1 = MossAudioTokenizerLayerScale( | |
| channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) | |
| ) | |
| self.layer_scale_2 = MossAudioTokenizerLayerScale( | |
| channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) | |
| ) | |
| self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) | |
| def _load_hook(module, state_dict, prefix, *_): | |
| mappings = { | |
| "linear1.weight": "ffn.0.weight", | |
| "linear2.weight": "ffn.2.weight", | |
| "linear1.bias": "ffn.0.bias", | |
| "linear2.bias": "ffn.2.bias", | |
| } | |
| for source, target in mappings.items(): | |
| this_source = prefix + source | |
| if this_source in state_dict: | |
| state_dict[prefix + target] = state_dict.pop(this_source) | |
| def _init_streaming_state(self, batch_size: int) -> LayerState: | |
| device = next(iter(self.parameters())).device | |
| return LayerState(batch_size, device) | |
| def forward(self, x: torch.Tensor, **kwargs): | |
| residual = x | |
| x = self.norm1(x) | |
| x = residual.to(x) + self.layer_scale_1(self.self_attn(x, **kwargs)) | |
| residual = x | |
| x = self.norm2(x) | |
| x = residual.to(x) + self.layer_scale_2(self.ffn(x)) | |
| return x | |
| # ============================================================================= | |
| # Streaming Transformer | |
| # ============================================================================= | |
| _sync_module_proxy() | |
| class TransformerState(StreamingState): | |
| offsets: torch.Tensor | |
| def reset(self, reset_mask: torch.Tensor): | |
| super().reset(reset_mask) | |
| self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets) | |
| class MossAudioTokenizerTransformer(StreamingModule): | |
| """Transformer with streaming/causal support.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| num_layers: int, | |
| dim_feedforward: int = 2048, | |
| causal: bool = False, | |
| context: int | None = None, | |
| positional_embedding: str = "sin", | |
| max_period: float = 10_000, | |
| positional_scale: float = 1.0, | |
| attention_implementation: str = "sdpa", | |
| device=None, | |
| dtype=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| if d_model % num_heads != 0: | |
| raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}") | |
| self.positional_embedding = positional_embedding | |
| self.max_period = max_period | |
| self.positional_scale = positional_scale | |
| self.rope: MossAudioTokenizerRotaryEmbedding | None = None | |
| if positional_embedding in {"rope", "sin_rope"}: | |
| self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period) | |
| self.layers = nn.ModuleList() | |
| for _ in range(num_layers): | |
| self.layers.append( | |
| MossAudioTokenizerTransformerLayer( | |
| d_model=d_model, | |
| num_heads=num_heads, | |
| dim_feedforward=dim_feedforward, | |
| causal=causal, | |
| context=context, | |
| rope=self.rope, | |
| attention_implementation=attention_implementation, | |
| device=device, | |
| dtype=dtype, | |
| **kwargs, | |
| ) | |
| ) | |
| def _init_streaming_state(self, batch_size: int) -> TransformerState: | |
| device = next(self.parameters()).device | |
| return TransformerState( | |
| batch_size, | |
| device, | |
| offsets=torch.zeros(batch_size, device=device, dtype=torch.long), | |
| ) | |
| def resolve_attention_implementation(self, x: torch.Tensor) -> str: | |
| if len(self.layers) == 0: | |
| return "sdpa" | |
| first_layer = cast(MossAudioTokenizerTransformerLayer, self.layers[0]) | |
| return first_layer.self_attn.resolve_attention_implementation(x, is_streaming=self._streaming_state is not None) | |
| def set_attention_implementation(self, attention_implementation: str) -> None: | |
| for layer in self.layers: | |
| cast(MossAudioTokenizerTransformerLayer, layer).self_attn.set_attention_implementation(attention_implementation) | |
| def forward(self, x: torch.Tensor, **kwargs): | |
| C = x.shape[-1] | |
| state = self._streaming_state | |
| if x.dim() == 3: | |
| B, T, _ = x.shape | |
| offsets = ( | |
| torch.zeros(1, dtype=torch.long, device=x.device) | |
| if state is None | |
| else ( | |
| state.offsets | |
| if isinstance(state, TransformerState) | |
| else torch.zeros(1, dtype=torch.long, device=x.device) | |
| ) | |
| ) | |
| else: | |
| B = 0 | |
| T = 0 | |
| offsets = None | |
| if self.positional_embedding in {"sin", "sin_rope"}: | |
| if x.dim() == 3: | |
| positions = torch.arange(T, device=x.device).view(1, -1) + cast(torch.Tensor, offsets).view(-1, 1) | |
| else: | |
| position_ids = kwargs.get("position_ids") | |
| if position_ids is None: | |
| raise ValueError("Packed transformer inputs require position_ids when using sinusoidal embeddings.") | |
| positions = position_ids | |
| pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) | |
| x = x + self.positional_scale * pos_emb | |
| for layer in self.layers: | |
| x = layer(x, **kwargs) | |
| if state is not None and x.dim() == 3: | |
| assert isinstance(state, TransformerState) | |
| state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets) | |
| return x | |
| class MossAudioTokenizerProjectedTransformer(StreamingContainer): | |
| """Transformer with input/output projections.""" | |
| def __init__( | |
| self, | |
| input_dimension: int, | |
| output_dimension: int, | |
| d_model: int, | |
| *, | |
| conv_layout: bool = False, | |
| module_type: str, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.module_type = module_type | |
| self.downsample_ratio: int = 1 | |
| self.input_dimension = input_dimension | |
| self.output_dimension = output_dimension | |
| self.input_proj = nn.Linear(input_dimension, d_model, bias=False) | |
| self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs) | |
| self.conv_layout = conv_layout | |
| self.output_proj = nn.Linear(d_model, output_dimension, bias=False) | |
| def set_attention_implementation(self, attention_implementation: str) -> None: | |
| self.transformer.set_attention_implementation(attention_implementation) | |
| def forward(self, x, input_lengths, **kwargs): | |
| x = self.input_proj(x.transpose(1, 2)) # (B, D, T) -> (B, T, D) | |
| if not self.is_streaming and self.transformer.resolve_attention_implementation(x) == "flash_attention_2": | |
| batch_size, max_seqlen, _ = x.shape | |
| if max_seqlen > 0 and bool(input_lengths.any().item()): | |
| max_valid_seqlen = int(input_lengths.max().item()) | |
| packed_x, valid_mask, cu_seqlens, position_ids = pack_padded_sequence(x, input_lengths) | |
| packed_x = self.transformer( | |
| packed_x, | |
| cu_seqlens=cu_seqlens, | |
| max_seqlen=max_valid_seqlen, | |
| position_ids=position_ids, | |
| input_lengths=input_lengths, | |
| **kwargs, | |
| ) | |
| x = unpack_packed_sequence(packed_x, valid_mask, batch_size, max_seqlen) | |
| else: | |
| x = x.new_zeros(x.shape) | |
| else: | |
| x = self.transformer(x, input_lengths=input_lengths, **kwargs) | |
| x = self.output_proj(x).transpose(1, 2) # (B, T, D) -> (B, D, T) | |
| return x, input_lengths | |
| # ============================================================================= | |
| # Patched Pretransform Module | |
| # ============================================================================= | |
| class MossAudioTokenizerPatchedPretransform(nn.Module): | |
| """Patching module for downsampling/upsampling.""" | |
| def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.downsample_ratio: int = patch_size | |
| self.is_downsample = is_downsample | |
| self.module_type = module_type | |
| def encode(self, x, input_lengths): | |
| b, d, _ = x.shape | |
| h = self.patch_size | |
| x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1) | |
| # We pad the input waveform to a multiple of `downsample_rate` before applying the encoder. | |
| # Use a ceil division to match that padding and avoid dropping the last (partially padded) frame. | |
| output_lengths = input_lengths // self.patch_size | |
| return x, output_lengths | |
| def decode(self, x, input_lengths): | |
| b, dh, l = x.shape | |
| h = self.patch_size | |
| d = dh // h | |
| x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h) | |
| output_lengths = input_lengths * self.patch_size | |
| return x, output_lengths | |
| def forward(self, x, input_lengths): | |
| if self.is_downsample: | |
| return self.encode(x, input_lengths) | |
| else: | |
| return self.decode(x, input_lengths) | |
| # ============================================================================= | |
| # Vector Quantization | |
| # ============================================================================= | |
| def WNConv1d(*args, **kwargs): | |
| """Weight-normalized Conv1d.""" | |
| return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs)) | |
| def remap_weight_norm_state_dict_keys(state_dict: dict[str, torch.Tensor], prefix: str) -> None: | |
| replacements = ( | |
| (".weight_g", ".parametrizations.weight.original0"), | |
| (".weight_v", ".parametrizations.weight.original1"), | |
| ) | |
| for key in list(state_dict.keys()): | |
| if not key.startswith(prefix): | |
| continue | |
| new_key = key | |
| for source, target in replacements: | |
| new_key = new_key.replace(source, target) | |
| if new_key != key: | |
| state_dict[new_key] = state_dict.pop(key) | |
| class MossAudioTokenizerVectorQuantize(nn.Module): | |
| """Single codebook vector quantization (inference only).""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| codebook_size: int, | |
| codebook_dim: int, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| if input_dim != codebook_dim: | |
| self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) | |
| self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) | |
| else: | |
| self.in_proj = nn.Identity() | |
| self.out_proj = nn.Identity() | |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) | |
| def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| z: Input tensor of shape (B, D, T) | |
| Returns: | |
| z_q: Quantized tensor of shape (B, D, T) | |
| indices: Code indices of shape (B, T) | |
| z_e: Encoded tensor before quantization | |
| """ | |
| z = z.float() | |
| z_e = self.in_proj(z).float() | |
| encodings = z_e.transpose(1, 2).reshape(-1, z_e.shape[1]) | |
| codebook_weight = self.codebook.weight | |
| dist = ( | |
| encodings.pow(2).sum(1, keepdim=True) | |
| - 2 * encodings @ codebook_weight.float().t() | |
| + codebook_weight.float().pow(2).sum(1, keepdim=True).t() | |
| ) | |
| indices = (-dist).max(1)[1] | |
| indices = indices.reshape(z.size(0), -1) | |
| z_q = self.decode_code(indices) | |
| z_q = self.out_proj(z_q).float() | |
| return z_q, indices, z_e | |
| def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| """Decode code indices to embeddings.""" | |
| return self.codebook(embed_id).transpose(1, 2).float() | |
| class MossAudioTokenizerLFQ(nn.Module): | |
| """LFQ (inference-only) used by ResidualLFQ.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| codebook_size: int, | |
| codebook_dim: int, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| if self.input_dim != self.codebook_dim: | |
| self.in_proj = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) | |
| self.out_proj = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) | |
| else: | |
| self.in_proj = nn.Identity() | |
| self.out_proj = nn.Identity() | |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) | |
| def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Quantize z into codebook vectors.""" | |
| z = z.float() | |
| z_e = self.in_proj(z).float() | |
| z_q, indices = self.decode_latents(z_e) | |
| z_q = (z_e + (z_q - z_e).detach()).float() | |
| z_q = self.out_proj(z_q).float() | |
| return z_q, indices, z_e | |
| def embed_code(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| return F.embedding(embed_id, self.codebook.weight) | |
| def decode_code_wo_out_proj(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| return self.embed_code(embed_id).transpose(1, 2) | |
| def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| z_q = self.decode_code_wo_out_proj(embed_id).float() | |
| z_q = self.out_proj(z_q).float() | |
| return z_q | |
| def decode_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Match training LFQ: L2-normalize then argmin squared distance.""" | |
| encodings = latents.transpose(1, 2).reshape(-1, latents.shape[1]).float() | |
| codebook = self.codebook.weight.float() | |
| encodings = F.normalize(encodings) | |
| codebook = F.normalize(codebook) | |
| dist = ( | |
| encodings.pow(2).sum(1, keepdim=True) | |
| - 2 * encodings @ codebook.t() | |
| + codebook.pow(2).sum(1, keepdim=True).t() | |
| ) | |
| indices = (-dist).max(1)[1] | |
| indices = indices.reshape(latents.size(0), -1) | |
| z_q = self.decode_code_wo_out_proj(indices).float() | |
| return z_q, indices | |
| class MossAudioTokenizerResidualVQ(nn.Module): | |
| """Residual Vector Quantization (inference only).""" | |
| def __init__( | |
| self, | |
| input_dim: int = 1024, | |
| rvq_dim: int | None = None, | |
| output_dim: int | None = None, | |
| num_quantizers: int = 32, | |
| codebook_size: int = 1024, | |
| codebook_dim: int = 8, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.rvq_dim = rvq_dim or input_dim | |
| self.output_dim = output_dim or input_dim | |
| self.num_quantizers = num_quantizers | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.input_proj = ( | |
| WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() | |
| ) | |
| self.output_proj = ( | |
| WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) | |
| if self.rvq_dim != self.output_dim | |
| else nn.Identity() | |
| ) | |
| self.quantizers = nn.ModuleList( | |
| [ | |
| MossAudioTokenizerVectorQuantize( | |
| input_dim=self.rvq_dim, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| **kwargs, | |
| ) | |
| for _ in range(num_quantizers) | |
| ] | |
| ) | |
| self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) | |
| def _load_hook(module, state_dict, prefix, *_): | |
| remap_weight_norm_state_dict_keys(state_dict, prefix) | |
| def forward( | |
| self, | |
| z: torch.Tensor, | |
| input_length: torch.Tensor, | |
| n_quantizers: int | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| z: Input tensor of shape (B, D, T) | |
| input_length: Valid lengths for each sample (B,) | |
| n_quantizers: Number of quantizers to use | |
| Returns: | |
| quantized_out: Quantized output (B, D, T) | |
| all_indices: All code indices (N, B, T) | |
| output_length: Output lengths (B,) | |
| """ | |
| with disable_cuda_autocast(): | |
| z = self.input_proj(z).float() | |
| batch_size, _, max_time = z.shape | |
| mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) | |
| quantized_out = torch.zeros_like(z, dtype=torch.float32) | |
| residual = z.clone().float() | |
| all_indices = [] | |
| n_quantizers = n_quantizers or self.num_quantizers | |
| for i, quantizer in enumerate(self.quantizers): | |
| if i >= n_quantizers: | |
| break | |
| masked_residual = residual * mask.unsqueeze(1) | |
| z_q_i, indices_i, _ = quantizer(masked_residual.float()) | |
| update_mask = mask.unsqueeze(1) | |
| quantized_out = quantized_out + z_q_i * update_mask | |
| residual = residual - z_q_i * update_mask | |
| all_indices.append(indices_i) | |
| all_indices = torch.stack(all_indices) # (N, B, T) | |
| quantized_out = self.output_proj(quantized_out.float()).float() | |
| return quantized_out, all_indices, input_length | |
| def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Decode codes from multiple quantizers to embeddings.""" | |
| with disable_cuda_autocast(): | |
| nq, B, T = codes.shape | |
| emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) | |
| for i, quantizer in enumerate(self.quantizers[:nq]): | |
| quantizer = cast(MossAudioTokenizerVectorQuantize, quantizer) | |
| quantized_i = quantizer.decode_code(codes[i]).float() | |
| emb += quantized_i | |
| emb = self.output_proj(emb.float()).float() | |
| return emb | |
| class MossAudioTokenizerResidualLFQ(nn.Module): | |
| """Residual LFQ (inference only).""" | |
| def __init__( | |
| self, | |
| input_dim: int = 1024, | |
| rvq_dim: int | None = None, | |
| output_dim: int | None = None, | |
| num_quantizers: int = 32, | |
| codebook_size: int = 1024, | |
| codebook_dim: int = 8, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.rvq_dim = rvq_dim or input_dim | |
| self.output_dim = output_dim or input_dim | |
| self.num_quantizers = num_quantizers | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.input_proj = ( | |
| WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() | |
| ) | |
| self.output_proj = ( | |
| WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) | |
| if self.rvq_dim != self.output_dim | |
| else nn.Identity() | |
| ) | |
| self.quantizers = nn.ModuleList( | |
| [ | |
| MossAudioTokenizerLFQ( | |
| input_dim=self.rvq_dim, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| **kwargs, | |
| ) | |
| for _ in range(num_quantizers) | |
| ] | |
| ) | |
| self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) | |
| def _load_hook(module, state_dict, prefix, *_): | |
| remap_weight_norm_state_dict_keys(state_dict, prefix) | |
| def forward( | |
| self, | |
| z: torch.Tensor, | |
| input_length: torch.Tensor, | |
| n_quantizers: int | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Inference quantization.""" | |
| with disable_cuda_autocast(): | |
| z = self.input_proj(z).float() | |
| batch_size, _, max_time = z.shape | |
| mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) | |
| quantized_out = torch.zeros_like(z, dtype=torch.float32) | |
| residual = z.clone().float() | |
| all_indices = [] | |
| n_quantizers = n_quantizers or self.num_quantizers | |
| for i, quantizer in enumerate(self.quantizers): | |
| if i >= n_quantizers: | |
| break | |
| masked_residual = residual * mask.unsqueeze(1) | |
| z_q_i, indices_i, _ = quantizer(masked_residual.float()) | |
| update_mask = mask.unsqueeze(1) | |
| quantized_out = quantized_out + z_q_i * update_mask | |
| residual = residual - z_q_i * update_mask | |
| all_indices.append(indices_i) | |
| all_indices = ( | |
| torch.stack(all_indices) | |
| if all_indices | |
| else torch.empty(0, batch_size, max_time, device=z.device, dtype=torch.long) | |
| ) | |
| quantized_out = self.output_proj(quantized_out.float()).float() | |
| return quantized_out, all_indices, input_length | |
| def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: | |
| with disable_cuda_autocast(): | |
| nq, B, T = codes.shape | |
| emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) | |
| for i, quantizer in enumerate(self.quantizers[:nq]): | |
| quantizer = cast(MossAudioTokenizerLFQ, quantizer) | |
| emb += quantizer.decode_code(codes[i]).float() | |
| emb = self.output_proj(emb.float()).float() | |
| return emb | |
| # ============================================================================= | |
| # Main Model Classes | |
| # ============================================================================= | |
| class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase): | |
| """Base class for MossAudioTokenizer models.""" | |
| config_class = MossAudioTokenizerConfig | |
| base_model_prefix = "" | |
| main_input_name = "input_values" | |
| input_modalities = "audio" | |
| supports_gradient_checkpointing = False | |
| _no_split_modules = [ | |
| "MossAudioTokenizerTransformerLayer", | |
| "MossAudioTokenizerResidualVQ", | |
| "MossAudioTokenizerResidualLFQ", | |
| ] | |
| class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel): | |
| """ | |
| MossAudioTokenizer model for audio tokenization and synthesis. | |
| This model can encode audio waveforms into discrete tokens and decode | |
| tokens back into audio waveforms. | |
| """ | |
| def __init__(self, config: MossAudioTokenizerConfig): | |
| super().__init__(config) | |
| self.config = config | |
| _ = config.version | |
| self.sampling_rate = config.sampling_rate | |
| self.downsample_rate = config.downsample_rate | |
| self.number_channels = config.number_channels | |
| self.enable_channel_interleave = getattr(config, "enable_channel_interleave", True) | |
| self.attention_implementation = config.attention_implementation | |
| self.compute_dtype_name = config.compute_dtype | |
| self.compute_dtype = resolve_compute_dtype(config.compute_dtype) | |
| encoder_context_durations = [ | |
| float(module_kwargs.get("context_duration", config.causal_transformer_context_duration)) | |
| for module_kwargs in config.encoder_kwargs | |
| if module_kwargs["module_type"] == "Transformer" | |
| ] | |
| self.causal_transformer_context_duration = ( | |
| min(encoder_context_durations) if encoder_context_durations else config.causal_transformer_context_duration | |
| ) | |
| # Build encoder | |
| channel_interleave_factor = ( | |
| self.number_channels if self.enable_channel_interleave and self.number_channels > 1 else 1 | |
| ) | |
| current_frame_rate: float = float(self.sampling_rate * channel_interleave_factor) | |
| self.encoder = nn.ModuleList() | |
| for encoder_kwargs_i in config.encoder_kwargs: | |
| encoder_kwargs_i = dict(encoder_kwargs_i) # Make a copy | |
| if encoder_kwargs_i["module_type"] == "PatchedPretransform": | |
| self.encoder.append(MossAudioTokenizerPatchedPretransform(**encoder_kwargs_i, is_downsample=True)) | |
| elif encoder_kwargs_i["module_type"] == "Transformer": | |
| context_duration = float(encoder_kwargs_i.pop("context_duration", self.causal_transformer_context_duration)) | |
| self.encoder.append( | |
| MossAudioTokenizerProjectedTransformer( | |
| **encoder_kwargs_i, | |
| context=int(round(current_frame_rate * context_duration)), | |
| attention_implementation=self.attention_implementation, | |
| ) | |
| ) | |
| current_frame_rate /= self.encoder[-1].downsample_ratio | |
| # Build quantizer | |
| quantizer_kwargs = dict(config.quantizer_kwargs) | |
| quantizer_type = quantizer_kwargs.get("quantizer_type", getattr(config, "quantizer_type", "rvq")) | |
| if quantizer_type in {"rvq", "spec_rvq"}: | |
| self.quantizer = MossAudioTokenizerResidualVQ(**quantizer_kwargs) | |
| elif quantizer_type in {"rlfq", "random_prefix_rlfq"}: | |
| self.quantizer = MossAudioTokenizerResidualLFQ(**quantizer_kwargs) | |
| else: | |
| raise ValueError(f"Unsupported quantizer_type: {quantizer_type}") | |
| # Build decoder | |
| decoder_kwargs_list = copy.deepcopy(config.decoder_kwargs) | |
| self.decoder = nn.ModuleList() | |
| for decoder_kwargs_i in decoder_kwargs_list: | |
| decoder_kwargs_i = dict(decoder_kwargs_i) | |
| if decoder_kwargs_i["module_type"] == "PatchedPretransform": | |
| self.decoder.append(MossAudioTokenizerPatchedPretransform(**decoder_kwargs_i, is_downsample=False)) | |
| elif decoder_kwargs_i["module_type"] == "Transformer": | |
| context_duration = float(decoder_kwargs_i.pop("context_duration", self.causal_transformer_context_duration)) | |
| self.decoder.append( | |
| MossAudioTokenizerProjectedTransformer( | |
| **decoder_kwargs_i, | |
| context=int(round(current_frame_rate * context_duration)), | |
| attention_implementation=self.attention_implementation, | |
| ) | |
| ) | |
| current_frame_rate *= self.decoder[-1].downsample_ratio | |
| expected_output_frame_rate = float(self.sampling_rate * channel_interleave_factor) | |
| if int(round(current_frame_rate)) != int(round(expected_output_frame_rate)): | |
| raise ValueError( | |
| "Decoder stack does not invert the encoder frame rate correctly: " | |
| f"got current_frame_rate={current_frame_rate}, expected={expected_output_frame_rate}." | |
| ) | |
| self.post_init() | |
| self._active_decode_session: "MossAudioTokenizerDecodeSession | None" = None | |
| self._batch_decode_streaming_max_batch_size: int | None = None | |
| self._batch_decode_streaming_batch_size: int | None = None | |
| self._batch_decode_streaming_session: "MossAudioTokenizerDecodeSession | None" = None | |
| self._batch_decode_streaming_next_request_id: int = 0 | |
| def create_decode_session( | |
| self, | |
| max_batch_size: int, | |
| use_cuda_graph: bool = False, | |
| ) -> MossAudioTokenizerDecodeSession: | |
| active_session = self._active_decode_session | |
| if active_session is not None and not active_session._closed: | |
| raise RuntimeError(_ACTIVE_DECODE_SESSION_ERROR_MESSAGE) | |
| for module in self.modules(): | |
| if isinstance(module, StreamingModule) and module._streaming_state is not None: | |
| raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE) | |
| session = MossAudioTokenizerDecodeSession(self, max_batch_size, use_cuda_graph=use_cuda_graph) | |
| return session | |
| def _reset_batch_decode_streaming_state(self) -> None: | |
| streaming_session = self._batch_decode_streaming_session | |
| self._batch_decode_streaming_session = None | |
| self._batch_decode_streaming_max_batch_size = None | |
| self._batch_decode_streaming_batch_size = None | |
| self._batch_decode_streaming_next_request_id = 0 | |
| if streaming_session is not None and not streaming_session._closed: | |
| streaming_session.close() | |
| def _prepare_batch_decode_streaming_state( | |
| self, | |
| batch_size: int, | |
| max_batch_size: int | None, | |
| reset_stream: bool, | |
| ) -> int: | |
| if reset_stream: | |
| self._reset_batch_decode_streaming_state() | |
| if max_batch_size is not None and max_batch_size <= 0: | |
| raise ValueError("`max_batch_size` must be > 0 when provided.") | |
| streaming_max_batch_size = self._batch_decode_streaming_max_batch_size | |
| if streaming_max_batch_size is None: | |
| streaming_max_batch_size = batch_size if max_batch_size is None else max_batch_size | |
| elif max_batch_size is not None and max_batch_size != streaming_max_batch_size: | |
| raise ValueError( | |
| "`max_batch_size` can only be set on the first streaming `batch_decode()` call for now. " | |
| f"Expected {streaming_max_batch_size}, got {max_batch_size}." | |
| ) | |
| if batch_size > streaming_max_batch_size: | |
| raise ValueError( | |
| "Streaming `batch_decode()` received a batch larger than the reserved `max_batch_size`. " | |
| f"Got batch_size={batch_size}, max_batch_size={streaming_max_batch_size}." | |
| ) | |
| return streaming_max_batch_size | |
| def _ensure_batch_decode_streaming_session( | |
| self, | |
| max_batch_size: int, | |
| use_cuda_graph: bool = False, | |
| ) -> MossAudioTokenizerDecodeSession: | |
| session = self._batch_decode_streaming_session | |
| if session is not None and not session._closed: | |
| if session._use_cuda_graph != use_cuda_graph: | |
| raise ValueError( | |
| "`use_cuda_graph` must match the existing streaming `batch_decode()` session configuration. " | |
| f"Expected {session._use_cuda_graph}, got {use_cuda_graph}." | |
| ) | |
| return session | |
| session = self.create_decode_session(max_batch_size=max_batch_size, use_cuda_graph=use_cuda_graph) | |
| self._batch_decode_streaming_session = session | |
| self._batch_decode_streaming_max_batch_size = max_batch_size | |
| self._batch_decode_streaming_next_request_id = 0 | |
| return session | |
| def _append_batch_decode_streaming_requests( | |
| self, | |
| session: MossAudioTokenizerDecodeSession, | |
| target_batch_size: int, | |
| ) -> None: | |
| requests_to_append = target_batch_size - len(session.active_request_ids) | |
| for _ in range(requests_to_append): | |
| request_id = self._batch_decode_streaming_next_request_id | |
| session.append(request_id) | |
| self._batch_decode_streaming_next_request_id += 1 | |
| def _resolve_batch_decode_streaming_finalize_request_ids( | |
| self, | |
| request_ids: list[str | int], | |
| finalize_indices: list[int] | tuple[int, ...] | None, | |
| ) -> list[str | int]: | |
| normalized_finalize_indices = tuple(finalize_indices) if finalize_indices is not None else () | |
| if len(set(normalized_finalize_indices)) != len(normalized_finalize_indices): | |
| raise ValueError(_BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE) | |
| batch_size = len(request_ids) | |
| finalize_request_ids: list[str | int] = [] | |
| for index in normalized_finalize_indices: | |
| if index < 0 or index >= batch_size: | |
| raise ValueError( | |
| _BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE.format( | |
| index=index, batch_size=batch_size | |
| ) | |
| ) | |
| finalize_request_ids.append(request_ids[index]) | |
| return finalize_request_ids | |
| def _raise_if_plain_decode_conflicts_with_active_session(self) -> None: | |
| active_session = self._active_decode_session | |
| if active_session is not None and not getattr(active_session, "_closed", False): | |
| raise RuntimeError(_PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE) | |
| def _start_streaming(self, batch_size: int): | |
| """Start streaming mode for all modules.""" | |
| active_session = self._active_decode_session | |
| if active_session is not None and not getattr(active_session, "_closed", False): | |
| raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE) | |
| def _start(module): | |
| if isinstance(module, StreamingModule): | |
| module._streaming_state = module._init_streaming_state(batch_size) | |
| self.apply(_start) | |
| def _stop_streaming(self): | |
| """Stop streaming mode for all modules.""" | |
| active_session = self._active_decode_session | |
| if active_session is not None and not getattr(active_session, "_closed", False): | |
| raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE) | |
| def _stop(module): | |
| if isinstance(module, StreamingModule): | |
| module._streaming_state = None | |
| self.apply(_stop) | |
| def streaming(self, batch_size: int = 1): | |
| """Context manager for streaming mode.""" | |
| self._start_streaming(batch_size) | |
| try: | |
| yield | |
| finally: | |
| self._stop_streaming() | |
| def _set_streaming_exec_mask(self, exec_mask: torch.Tensor) -> None: | |
| exec_mask = exec_mask.to(torch.bool) | |
| def _set_exec_mask(module: nn.Module): | |
| if isinstance(module, StreamingModule) and module._streaming_state is not None: | |
| module._streaming_state.set_exec_mask(exec_mask.to(module._streaming_state.device)) | |
| self.apply(_set_exec_mask) | |
| def _plan_batch_stream_step( | |
| self, | |
| remaining: torch.Tensor, | |
| max_step_length: int, | |
| alignment: int, | |
| ) -> tuple[int, torch.Tensor]: | |
| positive_mask = remaining > 0 | |
| if not bool(positive_mask.any().item()): | |
| raise RuntimeError("Cannot plan a streaming step when no samples remain.") | |
| if max_step_length > 0: | |
| full_step_mask = remaining >= max_step_length | |
| if bool(full_step_mask.any().item()): | |
| return max_step_length, full_step_mask | |
| positive_remaining = remaining[positive_mask] | |
| min_remaining = int(positive_remaining.min().item()) | |
| if alignment > 1: | |
| aligned_step = (min_remaining // alignment) * alignment | |
| if aligned_step > 0: | |
| return aligned_step, remaining >= aligned_step | |
| return min_remaining, remaining == min_remaining | |
| step_length = min_remaining | |
| if max_step_length > 0: | |
| step_length = min(step_length, max_step_length) | |
| return step_length, remaining >= step_length | |
| def _infer_num_quantizers(self, codes_chunks: list[list[torch.Tensor]], requested_num_quantizers: int | None) -> int: | |
| if requested_num_quantizers is not None: | |
| return requested_num_quantizers | |
| for chunks_i in codes_chunks: | |
| if chunks_i: | |
| return int(chunks_i[0].shape[0]) | |
| num_quantizers = getattr(self.quantizer, "num_quantizers", None) | |
| if num_quantizers is None: | |
| raise RuntimeError("Unable to infer the number of quantizers from empty streaming output.") | |
| return int(num_quantizers) | |
| def _infer_waveform_dtype(self, wav_chunks: list[list[torch.Tensor]]) -> torch.dtype: | |
| for chunks_i in wav_chunks: | |
| if chunks_i: | |
| return chunks_i[0].dtype | |
| return torch.float32 | |
| def _codec_inference_autocast(self): | |
| device = next(self.parameters()).device | |
| if device.type == "cuda" and self.compute_dtype is not None: | |
| with torch.autocast(device_type="cuda", dtype=self.compute_dtype): | |
| yield | |
| else: | |
| yield | |
| def set_attention_implementation(self, attention_implementation: str) -> None: | |
| self.attention_implementation = attention_implementation | |
| for module in self.modules(): | |
| if isinstance(module, MossAudioTokenizerProjectedTransformer): | |
| module.set_attention_implementation(attention_implementation) | |
| def set_compute_dtype(self, compute_dtype: str) -> None: | |
| self.compute_dtype_name = compute_dtype | |
| self.compute_dtype = resolve_compute_dtype(compute_dtype) | |
| def _prepare_waveform_batch( | |
| self, | |
| wav_list: list[torch.Tensor], | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if len(wav_list) == 0: | |
| raise ValueError("`wav_list` must contain at least one waveform.") | |
| device = wav_list[0].device | |
| dtype = wav_list[0].dtype | |
| batch_size = len(wav_list) | |
| lengths = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| normalized_wavs: list[torch.Tensor] = [] | |
| for i, wav in enumerate(wav_list): | |
| if self.number_channels == 1: | |
| if wav.dim() == 1: | |
| wav_i = wav.unsqueeze(0) | |
| elif wav.dim() == 2 and wav.shape[0] == 1: | |
| wav_i = wav | |
| else: | |
| raise ValueError( | |
| f"Expected wav_list[{i}] to have shape `(T,)` or `(1, T)` for a mono model, got {tuple(wav.shape)}." | |
| ) | |
| else: | |
| if wav.dim() != 2 or wav.shape[0] != self.number_channels: | |
| raise ValueError( | |
| f"Expected wav_list[{i}] to have shape `({self.number_channels}, T)`, got {tuple(wav.shape)}." | |
| ) | |
| wav_i = wav | |
| normalized_wavs.append(wav_i) | |
| lengths[i] = wav_i.shape[-1] | |
| max_length = int(lengths.max().item()) if batch_size > 0 else 0 | |
| input_values = torch.zeros(batch_size, self.number_channels, max_length, device=device, dtype=dtype) | |
| for i, wav_i in enumerate(normalized_wavs): | |
| input_values[i, :, : wav_i.shape[-1]] = wav_i | |
| return input_values, lengths | |
| def _prepare_codes_batch( | |
| self, | |
| codes_list: list[torch.Tensor], | |
| num_quantizers: int | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, int]: | |
| if len(codes_list) == 0: | |
| raise ValueError("`codes_list` must contain at least one code tensor.") | |
| batch_size = len(codes_list) | |
| device = codes_list[0].device | |
| nqs = [codes.shape[0] for codes in codes_list] | |
| if num_quantizers is None: | |
| num_quantizers = nqs[0] | |
| if any(nq != num_quantizers for nq in nqs): | |
| raise ValueError( | |
| "All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. " | |
| "Pass `num_quantizers=...` to decode a common prefix." | |
| ) | |
| elif min(nqs) < num_quantizers: | |
| raise ValueError( | |
| "`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. " | |
| f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min(nqs)}." | |
| ) | |
| lengths = torch.tensor([codes.shape[-1] for codes in codes_list], device=device, dtype=torch.long) | |
| max_length = int(lengths.max().item()) if batch_size > 0 else 0 | |
| audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long) | |
| for i, codes in enumerate(codes_list): | |
| codes_i = codes[:num_quantizers] | |
| audio_codes[:, i, : codes_i.shape[-1]] = codes_i | |
| return audio_codes, lengths, num_quantizers | |
| def _flatten_channels_for_codec( | |
| self, | |
| input_values: torch.Tensor, | |
| input_lengths: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if input_values.dim() != 3: | |
| raise ValueError(f"Expected `input_values` with shape `(B, C, T)`, got {tuple(input_values.shape)}.") | |
| if input_values.shape[1] != self.number_channels: | |
| raise ValueError( | |
| f"Expected `input_values.shape[1] == {self.number_channels}`, got {input_values.shape[1]}." | |
| ) | |
| if input_values.shape[-1] % self.downsample_rate != 0: | |
| pad_length = self.downsample_rate - (input_values.shape[-1] % self.downsample_rate) | |
| input_values = F.pad(input_values, (0, pad_length)) | |
| if self.number_channels > 1 and self.enable_channel_interleave: | |
| input_values = input_values.transpose(1, 2).contiguous().view(input_values.shape[0], 1, -1) | |
| input_lengths = input_lengths * self.number_channels | |
| return input_values, input_lengths | |
| def _restore_channels_from_codec( | |
| self, | |
| output_values: torch.Tensor, | |
| output_lengths: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if self.number_channels == 1 or not self.enable_channel_interleave: | |
| return output_values.float(), output_lengths | |
| output_values = ( | |
| output_values.squeeze(1) | |
| .contiguous() | |
| .view(output_values.shape[0], -1, self.number_channels) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .float() | |
| ) | |
| output_lengths = torch.div(output_lengths, self.number_channels, rounding_mode="floor") | |
| return output_values, output_lengths | |
| def _stack_hidden_states( | |
| self, | |
| hidden_chunks: list[list[torch.Tensor]], | |
| lengths: torch.Tensor, | |
| ) -> torch.Tensor | None: | |
| hidden_dim = None | |
| for chunks_i in hidden_chunks: | |
| if chunks_i: | |
| hidden_dim = chunks_i[0].shape[0] | |
| break | |
| if hidden_dim is None: | |
| return None | |
| batch_size = len(hidden_chunks) | |
| max_length = int(lengths.max().item()) if batch_size > 0 else 0 | |
| device = lengths.device | |
| hidden_states = torch.zeros(batch_size, hidden_dim, max_length, device=device, dtype=torch.float32) | |
| for i, chunks_i in enumerate(hidden_chunks): | |
| if not chunks_i: | |
| continue | |
| hidden_i = torch.cat(chunks_i, dim=-1).float() | |
| hidden_states[i, :, : hidden_i.shape[-1]] = hidden_i | |
| return hidden_states | |
| def _encode_frame( | |
| self, | |
| input_values: torch.Tensor, | |
| input_lengths: torch.Tensor | None = None, | |
| n_quantizers: int | None = None, | |
| ) -> MossAudioTokenizerEncoderOutput: | |
| if input_values.dim() == 1: | |
| input_values = input_values.view(1, 1, -1) | |
| elif input_values.dim() == 2: | |
| if self.number_channels == 1: | |
| input_values = input_values.unsqueeze(1) | |
| else: | |
| input_values = input_values.unsqueeze(0) | |
| batch_size, _, time = input_values.shape | |
| device = input_values.device | |
| if input_lengths is None: | |
| input_lengths = torch.full((batch_size,), time, device=device, dtype=torch.long) | |
| input_values, input_lengths = self._flatten_channels_for_codec(input_values, input_lengths) | |
| with self._codec_inference_autocast(): | |
| encoder_hidden_states, encoder_hidden_lengths = input_values, input_lengths | |
| for encoder_module in self.encoder: | |
| encoder_hidden_states, encoder_hidden_lengths = encoder_module( | |
| encoder_hidden_states, | |
| encoder_hidden_lengths, | |
| ) | |
| quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) | |
| _, audio_codes, audio_codes_lengths = quantizer(encoder_hidden_states.float(), encoder_hidden_lengths, n_quantizers) | |
| return MossAudioTokenizerEncoderOutput( | |
| audio_codes=audio_codes, | |
| audio_codes_lengths=audio_codes_lengths, | |
| encoder_hidden_states=encoder_hidden_states.float(), | |
| ) | |
| def _decode_frame( | |
| self, | |
| codes: torch.Tensor, | |
| codes_lengths: torch.Tensor | None = None, | |
| ) -> MossAudioTokenizerDecoderOutput: | |
| _, batch_size, time = codes.shape | |
| device = codes.device | |
| if codes_lengths is None: | |
| codes_lengths = torch.full((batch_size,), time, device=device, dtype=torch.long) | |
| quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) | |
| decoder_hidden_states = quantizer.decode_codes(codes).float() | |
| with self._codec_inference_autocast(): | |
| audio, audio_lengths = decoder_hidden_states, codes_lengths | |
| for decoder_module in self.decoder: | |
| audio, audio_lengths = decoder_module(audio, audio_lengths) | |
| audio, audio_lengths = self._restore_channels_from_codec(audio, audio_lengths) | |
| return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths) | |
| def batch_encode( | |
| self, | |
| wav_list: list[torch.Tensor], | |
| num_quantizers: int | None = None, | |
| chunk_duration: float | None = None, | |
| ) -> MossAudioTokenizerEncoderOutput: | |
| input_values, input_lengths = self._prepare_waveform_batch(wav_list) | |
| batch_size = len(wav_list) | |
| device = input_values.device | |
| if chunk_duration is None: | |
| return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers) | |
| if chunk_duration <= 0: | |
| raise ValueError("`chunk_duration` must be > 0 when provided.") | |
| chunk_length = int(round(chunk_duration * self.sampling_rate)) | |
| if chunk_length <= 0: | |
| raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") | |
| if chunk_length % self.downsample_rate != 0: | |
| raise ValueError( | |
| "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " | |
| f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." | |
| ) | |
| cursors = torch.zeros_like(input_lengths) | |
| codes_chunks: list[list[torch.Tensor]] = [[] for _ in range(batch_size)] | |
| hidden_chunks: list[list[torch.Tensor]] = [[] for _ in range(batch_size)] | |
| with self.streaming(batch_size=batch_size): | |
| while bool((cursors < input_lengths).any().item()): | |
| remaining = input_lengths - cursors | |
| step_length, active_mask = self._plan_batch_stream_step( | |
| remaining=remaining, | |
| max_step_length=chunk_length, | |
| alignment=self.downsample_rate, | |
| ) | |
| x_step = torch.zeros( | |
| batch_size, | |
| self.number_channels, | |
| step_length, | |
| device=device, | |
| dtype=input_values.dtype, | |
| ) | |
| input_lengths_step = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| active_indices = torch.nonzero(active_mask, as_tuple=False).flatten().tolist() | |
| for i in active_indices: | |
| start = int(cursors[i].item()) | |
| end = start + step_length | |
| x_step[i] = input_values[i, :, start:end] | |
| input_lengths_step[i] = step_length | |
| self._set_streaming_exec_mask(active_mask) | |
| result = self._encode_frame(x_step, input_lengths_step, n_quantizers=num_quantizers) | |
| assert result.audio_codes is not None | |
| assert result.audio_codes_lengths is not None | |
| for i in active_indices: | |
| codes_length_i = int(result.audio_codes_lengths[i].item()) | |
| if codes_length_i > 0: | |
| codes_chunks[i].append(result.audio_codes[:, i, :codes_length_i].clone()) | |
| if result.encoder_hidden_states is not None: | |
| hidden_chunks[i].append(result.encoder_hidden_states[i, :, :codes_length_i].clone()) | |
| cursors[i] += step_length | |
| num_quantizers_used = self._infer_num_quantizers(codes_chunks, num_quantizers) | |
| empty_codes = torch.empty((num_quantizers_used, 0), device=device, dtype=torch.long) | |
| codes_list = [torch.cat(chunks_i, dim=-1) if chunks_i else empty_codes.clone() for chunks_i in codes_chunks] | |
| audio_codes, audio_codes_lengths, _ = self._prepare_codes_batch(codes_list, num_quantizers=num_quantizers_used) | |
| encoder_hidden_states = self._stack_hidden_states(hidden_chunks, audio_codes_lengths) | |
| return MossAudioTokenizerEncoderOutput( | |
| audio_codes=audio_codes, | |
| audio_codes_lengths=audio_codes_lengths, | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| def batch_decode( | |
| self, | |
| codes_list: list[torch.Tensor], | |
| num_quantizers: int | None = None, | |
| chunk_duration: float | None = None, | |
| streaming: bool = False, | |
| max_batch_size: int | None = None, | |
| finalize_indices: list[int] | tuple[int, ...] | None = None, | |
| reset_stream: bool = False, | |
| use_cuda_graph: bool = False, | |
| ) -> MossAudioTokenizerDecoderOutput: | |
| if len(codes_list) == 0: | |
| raise ValueError("`codes_list` must contain at least one code tensor.") | |
| streaming_max_batch_size: int | None = None | |
| if streaming: | |
| streaming_max_batch_size = self._prepare_batch_decode_streaming_state( | |
| batch_size=len(codes_list), | |
| max_batch_size=max_batch_size, | |
| reset_stream=reset_stream, | |
| ) | |
| else: | |
| if reset_stream: | |
| self._reset_batch_decode_streaming_state() | |
| self._raise_if_plain_decode_conflicts_with_active_session() | |
| audio_codes, audio_codes_lengths, num_quantizers_used = self._prepare_codes_batch( | |
| codes_list, | |
| num_quantizers=num_quantizers, | |
| ) | |
| batch_size = len(codes_list) | |
| device = audio_codes.device | |
| if not streaming and chunk_duration is None: | |
| return self._decode_frame(audio_codes, audio_codes_lengths) | |
| if streaming: | |
| assert streaming_max_batch_size is not None | |
| existing_session = self._batch_decode_streaming_session | |
| reusing_streaming_session = existing_session is not None and not existing_session._closed | |
| session = self._ensure_batch_decode_streaming_session( | |
| max_batch_size=streaming_max_batch_size, | |
| use_cuda_graph=use_cuda_graph, | |
| ) | |
| pre_call_request_ids = list(session.active_request_ids) | |
| pre_call_batch_size = len(pre_call_request_ids) | |
| if batch_size < pre_call_batch_size: | |
| raise ValueError(_BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE) | |
| try: | |
| finalize_request_ids = self._resolve_batch_decode_streaming_finalize_request_ids( | |
| request_ids=pre_call_request_ids, | |
| finalize_indices=finalize_indices, | |
| ) | |
| except Exception: | |
| if not reusing_streaming_session and pre_call_batch_size == 0: | |
| self._reset_batch_decode_streaming_state() | |
| raise | |
| try: | |
| if batch_size > pre_call_batch_size: | |
| self._append_batch_decode_streaming_requests(session=session, target_batch_size=batch_size) | |
| request_ids = list(session.active_request_ids) | |
| _, audio, audio_lengths = session.step( | |
| request_ids=request_ids, | |
| codes=audio_codes, | |
| code_lengths=audio_codes_lengths, | |
| ) | |
| for request_id in finalize_request_ids: | |
| session.remove(request_id) | |
| except Exception: | |
| self._reset_batch_decode_streaming_state() | |
| raise | |
| self._batch_decode_streaming_max_batch_size = session.max_batch_size | |
| self._batch_decode_streaming_batch_size = len(session.active_request_ids) | |
| return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths) | |
| assert chunk_duration is not None | |
| if chunk_duration <= 0: | |
| raise ValueError("`chunk_duration` must be > 0 when provided.") | |
| chunk_length = int(round(chunk_duration * self.sampling_rate)) | |
| if chunk_length <= 0: | |
| raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") | |
| if chunk_length % self.downsample_rate != 0: | |
| raise ValueError( | |
| "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " | |
| f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." | |
| ) | |
| chunk_frame_length = chunk_length // self.downsample_rate | |
| cursors = torch.zeros_like(audio_codes_lengths) | |
| wav_chunks: list[list[torch.Tensor]] = [[] for _ in range(batch_size)] | |
| with self.streaming(batch_size=batch_size): | |
| while bool((cursors < audio_codes_lengths).any().item()): | |
| remaining = audio_codes_lengths - cursors | |
| step_frames, active_mask = self._plan_batch_stream_step( | |
| remaining=remaining, | |
| max_step_length=chunk_frame_length, | |
| alignment=1, | |
| ) | |
| codes_step = torch.zeros( | |
| num_quantizers_used, | |
| batch_size, | |
| step_frames, | |
| device=device, | |
| dtype=torch.long, | |
| ) | |
| codes_lengths_step = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| active_indices = torch.nonzero(active_mask, as_tuple=False).flatten().tolist() | |
| for i in active_indices: | |
| start = int(cursors[i].item()) | |
| end = start + step_frames | |
| codes_step[:, i, :] = audio_codes[:, i, start:end] | |
| codes_lengths_step[i] = step_frames | |
| self._set_streaming_exec_mask(active_mask) | |
| result = self._decode_frame(codes_step, codes_lengths_step) | |
| assert result.audio is not None | |
| assert result.audio_lengths is not None | |
| for i in active_indices: | |
| audio_length_i = int(result.audio_lengths[i].item()) | |
| if audio_length_i > 0: | |
| wav_chunks[i].append(result.audio[i, :, :audio_length_i].clone()) | |
| cursors[i] += step_frames | |
| wav_dtype = self._infer_waveform_dtype(wav_chunks) | |
| audio_lengths = torch.tensor( | |
| [sum(chunk.shape[-1] for chunk in chunks_i) for chunks_i in wav_chunks], | |
| device=device, | |
| dtype=torch.long, | |
| ) | |
| max_audio_length = int(audio_lengths.max().item()) if batch_size > 0 else 0 | |
| audio = torch.zeros(batch_size, self.number_channels, max_audio_length, device=device, dtype=wav_dtype) | |
| for i, chunks_i in enumerate(wav_chunks): | |
| if not chunks_i: | |
| continue | |
| wav_i = torch.cat(chunks_i, dim=-1) | |
| audio[i, :, : wav_i.shape[-1]] = wav_i | |
| return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths) | |
| def encode( # type: ignore[override] | |
| self, | |
| input_values: torch.Tensor, | |
| padding_mask: torch.Tensor | None = None, | |
| num_quantizers: int | None = None, | |
| return_dict: bool | None = None, | |
| chunk_duration: float | None = None, | |
| ): | |
| """ | |
| Encodes the input audio waveform into discrete codes. | |
| Args: | |
| input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | |
| Float values of the input audio waveform. | |
| padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to indicate valid audio samples. | |
| num_quantizers (`int`, *optional*): | |
| Number of quantizers to use. By default, all quantizers are used. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| chunk_duration (`float`, *optional*): | |
| If provided, encode the input waveform in successive chunks of `chunk_duration` seconds while keeping a | |
| streaming KV cache for the causal transformers. | |
| `chunk_duration` must be <= `config.causal_transformer_context_duration`, and | |
| `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. | |
| Returns: | |
| `MossAudioTokenizerEncoderOutput` or tuple containing audio codes and lengths. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.return_dict | |
| wav_list: list[torch.Tensor] | |
| if input_values.dim() == 1: | |
| wav_list = [input_values] | |
| elif input_values.dim() == 2: | |
| if self.number_channels == 1: | |
| lengths = ( | |
| padding_mask.sum(dim=-1).long() | |
| if padding_mask is not None and padding_mask.dim() == 2 | |
| else torch.full((input_values.shape[0],), input_values.shape[-1], device=input_values.device, dtype=torch.long) | |
| ) | |
| wav_list = [input_values[i, : int(lengths[i].item())] for i in range(input_values.shape[0])] | |
| else: | |
| length = ( | |
| int(padding_mask.sum().item()) | |
| if padding_mask is not None and padding_mask.dim() == 1 | |
| else int(input_values.shape[-1]) | |
| ) | |
| wav_list = [input_values[:, :length]] | |
| elif input_values.dim() == 3: | |
| if input_values.shape[1] != self.number_channels: | |
| raise ValueError( | |
| f"Expected `input_values.shape[1] == {self.number_channels}`, got {input_values.shape[1]}." | |
| ) | |
| lengths = ( | |
| padding_mask.sum(dim=-1).long() | |
| if padding_mask is not None | |
| else torch.full((input_values.shape[0],), input_values.shape[-1], device=input_values.device, dtype=torch.long) | |
| ) | |
| wav_list = [input_values[i, :, : int(lengths[i].item())] for i in range(input_values.shape[0])] | |
| else: | |
| raise ValueError(f"Unsupported `input_values` shape: {tuple(input_values.shape)}") | |
| encoder_output = self.batch_encode(wav_list, num_quantizers=num_quantizers, chunk_duration=chunk_duration) | |
| if not return_dict: | |
| assert encoder_output.audio_codes is not None | |
| assert encoder_output.audio_codes_lengths is not None | |
| return ( | |
| cast(torch.Tensor, encoder_output.audio_codes), | |
| cast(torch.Tensor, encoder_output.audio_codes_lengths), | |
| ) | |
| return encoder_output | |
| def decode( # type: ignore[override] | |
| self, | |
| audio_codes: torch.Tensor, | |
| padding_mask: torch.Tensor | None = None, | |
| return_dict: bool | None = None, | |
| chunk_duration: float | None = None, | |
| num_quantizers: int | None = None, | |
| ): | |
| """ | |
| Decodes the given codes into an output audio waveform. | |
| Args: | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`): | |
| Discrete code embeddings computed using `model.encode`. | |
| padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to indicate valid code positions. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| chunk_duration (`float`, *optional*): | |
| If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a | |
| streaming KV cache for the causal transformers. | |
| num_quantizers (`int`, *optional*): | |
| Number of quantizers to use. By default, all quantizers in `audio_codes` are used. | |
| `chunk_duration` must be <= `config.causal_transformer_context_duration`, and | |
| `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. | |
| Returns: | |
| `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.return_dict | |
| self._raise_if_plain_decode_conflicts_with_active_session() | |
| if audio_codes.dim() == 2: | |
| codes_list = [audio_codes[:num_quantizers] if num_quantizers is not None else audio_codes] | |
| elif audio_codes.dim() == 3: | |
| if num_quantizers is not None and num_quantizers > audio_codes.shape[0]: | |
| raise ValueError( | |
| f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})." | |
| ) | |
| codes_lengths = ( | |
| padding_mask.sum(dim=-1).long() | |
| if padding_mask is not None | |
| else torch.full((audio_codes.shape[1],), audio_codes.shape[-1], device=audio_codes.device, dtype=torch.long) | |
| ) | |
| codes_list = [ | |
| (audio_codes[:num_quantizers, i, : int(codes_lengths[i].item())] if num_quantizers is not None else audio_codes[:, i, : int(codes_lengths[i].item())]) | |
| for i in range(audio_codes.shape[1]) | |
| ] | |
| else: | |
| raise ValueError(f"Unsupported `audio_codes` shape: {tuple(audio_codes.shape)}") | |
| decoder_output = self.batch_decode(codes_list, num_quantizers=num_quantizers, chunk_duration=chunk_duration) | |
| if not return_dict: | |
| assert decoder_output.audio is not None | |
| return (cast(torch.Tensor, decoder_output.audio),) | |
| return decoder_output | |
| def forward( | |
| self, | |
| input_values: torch.FloatTensor | None = None, | |
| padding_mask: torch.BoolTensor | None = None, | |
| audio_codes: torch.Tensor | None = None, | |
| num_quantizers: int | None = None, | |
| return_dict: bool | None = None, | |
| ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | MossAudioTokenizerOutput: # type: ignore[override] | |
| r""" | |
| input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): | |
| Raw audio input converted to Float. | |
| padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to avoid computing on padding token indices. Mask values selected in `[0, 1]`: | |
| - 1 for tokens that are **not masked**, | |
| - 0 for tokens that are **masked**. | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): | |
| Discrete code embeddings computed using `model.encode`. | |
| num_quantizers (`int`, *optional*): | |
| Number of quantizers (codebooks) to use. By default, all quantizers are used. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| Examples: | |
| ```python | |
| >>> import torch | |
| >>> from transformers import MossAudioTokenizerModel | |
| >>> model = MossAudioTokenizerModel.from_pretrained("moss_audio_tokenizer-model") | |
| >>> # Create dummy audio input | |
| >>> audio = torch.randn(1, 1, 24000) # 1 second of audio at 24kHz | |
| >>> outputs = model(input_values=audio) | |
| >>> audio_codes = outputs.audio_codes | |
| >>> audio_values = outputs.audio | |
| ``` | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.return_dict | |
| output_audio_codes: torch.Tensor | None = None | |
| output_audio_codes_lengths: torch.Tensor | None = None | |
| output_audio: torch.Tensor | None = None | |
| output_audio_lengths: torch.Tensor | None = None | |
| decoded_from_encoded_codes = False | |
| # Encode if input_values provided | |
| if input_values is not None: | |
| encoder_output = self.encode(input_values, padding_mask, num_quantizers, return_dict=True) | |
| encoder_output = cast(MossAudioTokenizerEncoderOutput, encoder_output) | |
| output_audio_codes = encoder_output.audio_codes | |
| output_audio_codes_lengths = encoder_output.audio_codes_lengths | |
| # If codes not provided separately, use encoded codes for decoding | |
| if audio_codes is None: | |
| audio_codes = output_audio_codes | |
| decoded_from_encoded_codes = True | |
| # Decode if codes available | |
| if audio_codes is not None: | |
| # If we're decoding the codes we just produced, use the computed lengths so we don't decode padded garbage. | |
| if decoded_from_encoded_codes and output_audio_codes_lengths is not None: | |
| decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths) | |
| else: | |
| decoder_output = self.decode( | |
| audio_codes, | |
| padding_mask=padding_mask, | |
| return_dict=True, | |
| num_quantizers=num_quantizers, | |
| ) | |
| decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output) | |
| output_audio = decoder_output.audio | |
| output_audio_lengths = decoder_output.audio_lengths | |
| if not return_dict: | |
| return (output_audio_codes, output_audio, output_audio_lengths) | |
| return MossAudioTokenizerOutput( | |
| audio=output_audio, | |
| audio_lengths=output_audio_lengths, | |
| audio_codes=output_audio_codes, | |
| audio_codes_lengths=output_audio_codes_lengths, | |
| ) | |
| __all__ = ["MossAudioTokenizerModel", "MossAudioTokenizerPreTrainedModel"] | |