Kuangwei Chen commited on
Commit ·
1c2bf4d
1
Parent(s): d4a3b2c
Update for realtime
Browse files- modeling_moss_audio_tokenizer.py +808 -21
modeling_moss_audio_tokenizer.py
CHANGED
|
@@ -17,14 +17,25 @@ from __future__ import annotations
|
|
| 17 |
|
| 18 |
import copy
|
| 19 |
import math
|
|
|
|
|
|
|
| 20 |
from contextlib import ExitStack, contextmanager
|
| 21 |
from dataclasses import dataclass
|
|
|
|
| 22 |
from typing import cast
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
| 26 |
import torch.nn.functional as F
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
try:
|
| 29 |
from transformers.modeling_utils import PreTrainedAudioTokenizerBase
|
| 30 |
except ImportError:
|
|
@@ -32,9 +43,12 @@ except ImportError:
|
|
| 32 |
from transformers.utils import ModelOutput, logging
|
| 33 |
|
| 34 |
try:
|
| 35 |
-
from transformers.utils import auto_docstring
|
| 36 |
except ImportError:
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
if len(args) == 1 and callable(args[0]) and not kwargs:
|
| 39 |
return args[0]
|
| 40 |
|
|
@@ -43,9 +57,35 @@ except ImportError:
|
|
| 43 |
|
| 44 |
return decorator
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
try:
|
| 47 |
from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
|
| 48 |
except ImportError:
|
|
|
|
|
|
|
|
|
|
| 49 |
from configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
|
| 50 |
|
| 51 |
|
|
@@ -64,6 +104,25 @@ SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"}
|
|
| 64 |
SUPPORTED_COMPUTE_DTYPES = {"fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16}
|
| 65 |
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def resolve_compute_dtype(compute_dtype: str) -> torch.dtype | None:
|
| 68 |
if compute_dtype not in SUPPORTED_COMPUTE_DTYPES:
|
| 69 |
raise ValueError(
|
|
@@ -83,6 +142,7 @@ def disable_cuda_autocast():
|
|
| 83 |
# =============================================================================
|
| 84 |
|
| 85 |
|
|
|
|
| 86 |
@dataclass
|
| 87 |
@auto_docstring
|
| 88 |
class MossAudioTokenizerEncoderOutput(ModelOutput):
|
|
@@ -100,6 +160,7 @@ class MossAudioTokenizerEncoderOutput(ModelOutput):
|
|
| 100 |
encoder_hidden_states: torch.Tensor | None = None
|
| 101 |
|
| 102 |
|
|
|
|
| 103 |
@dataclass
|
| 104 |
@auto_docstring
|
| 105 |
class MossAudioTokenizerDecoderOutput(ModelOutput):
|
|
@@ -114,6 +175,7 @@ class MossAudioTokenizerDecoderOutput(ModelOutput):
|
|
| 114 |
audio_lengths: torch.Tensor | None = None
|
| 115 |
|
| 116 |
|
|
|
|
| 117 |
@dataclass
|
| 118 |
@auto_docstring
|
| 119 |
class MossAudioTokenizerOutput(ModelOutput):
|
|
@@ -139,6 +201,7 @@ class MossAudioTokenizerOutput(ModelOutput):
|
|
| 139 |
# =============================================================================
|
| 140 |
|
| 141 |
|
|
|
|
| 142 |
@dataclass
|
| 143 |
class StreamingState:
|
| 144 |
"""Base state for streaming modules."""
|
|
@@ -228,6 +291,463 @@ class StreamingContainer(StreamingModule):
|
|
| 228 |
pass
|
| 229 |
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
# =============================================================================
|
| 232 |
# Normalization Layers
|
| 233 |
# =============================================================================
|
|
@@ -598,6 +1118,7 @@ class RingKVCache:
|
|
| 598 |
# =============================================================================
|
| 599 |
|
| 600 |
|
|
|
|
| 601 |
@dataclass
|
| 602 |
class MHAState(StreamingState):
|
| 603 |
cached_keys: torch.Tensor | None
|
|
@@ -677,6 +1198,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 677 |
f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}."
|
| 678 |
)
|
| 679 |
self.attention_implementation = attention_implementation
|
|
|
|
| 680 |
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False, **factory_kwargs)
|
| 681 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
|
| 682 |
|
|
@@ -811,6 +1333,34 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 811 |
state.cached_positions = state.cached_positions.to(device=device)
|
| 812 |
return state.cached_keys, state.cached_values, state.cached_positions
|
| 813 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
def _build_streaming_kv(
|
| 815 |
self,
|
| 816 |
cached_k: torch.Tensor,
|
|
@@ -845,12 +1395,15 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 845 |
state.cached_positions = pos_k.contiguous()
|
| 846 |
return
|
| 847 |
|
|
|
|
|
|
|
|
|
|
| 848 |
new_cached_k = k_all[:, :, -self.context :, :].contiguous()
|
| 849 |
new_cached_v = v_all[:, :, -self.context :, :].contiguous()
|
| 850 |
new_cached_pos = pos_k[:, -self.context :].contiguous()
|
| 851 |
-
state.cached_keys
|
| 852 |
-
state.cached_values
|
| 853 |
-
state.cached_positions
|
| 854 |
|
| 855 |
def _build_streaming_sdpa_bias(self, pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor:
|
| 856 |
delta = pos_q[:, :, None] - pos_k[:, None, :]
|
|
@@ -890,16 +1443,19 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 890 |
if flash_attn_varlen_func is None:
|
| 891 |
raise RuntimeError("flash-attn is not installed.")
|
| 892 |
window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
|
| 893 |
-
return
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
|
|
|
|
|
|
|
|
|
| 903 |
)
|
| 904 |
|
| 905 |
def _forward_streaming_sdpa(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
|
|
@@ -968,6 +1524,46 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 968 |
state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset)
|
| 969 |
return out
|
| 970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
def _forward_non_streaming_sdpa(self, x: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor:
|
| 972 |
batch_size, max_seqlen, _ = x.shape
|
| 973 |
q, k, v = self._project_qkv(x)
|
|
@@ -1009,11 +1605,12 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 1009 |
if state is not None:
|
| 1010 |
if query.dim() != 3:
|
| 1011 |
raise ValueError(f"Streaming attention expects a 3D tensor, got shape {tuple(query.shape)}")
|
| 1012 |
-
|
| 1013 |
-
self.
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
|
|
|
| 1017 |
return self.out_proj(out)
|
| 1018 |
|
| 1019 |
if backend == "flash_attention_2":
|
|
@@ -1037,6 +1634,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 1037 |
# =============================================================================
|
| 1038 |
|
| 1039 |
|
|
|
|
| 1040 |
@dataclass
|
| 1041 |
class LayerState(StreamingState):
|
| 1042 |
pass
|
|
@@ -1128,6 +1726,7 @@ class MossAudioTokenizerTransformerLayer(StreamingModule):
|
|
| 1128 |
# =============================================================================
|
| 1129 |
|
| 1130 |
|
|
|
|
| 1131 |
@dataclass
|
| 1132 |
class TransformerState(StreamingState):
|
| 1133 |
offsets: torch.Tensor
|
|
@@ -1800,9 +2399,129 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1800 |
)
|
| 1801 |
|
| 1802 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1803 |
|
| 1804 |
def _start_streaming(self, batch_size: int):
|
| 1805 |
"""Start streaming mode for all modules."""
|
|
|
|
|
|
|
|
|
|
| 1806 |
|
| 1807 |
def _start(module):
|
| 1808 |
if isinstance(module, StreamingModule):
|
|
@@ -1812,6 +2531,9 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1812 |
|
| 1813 |
def _stop_streaming(self):
|
| 1814 |
"""Stop streaming mode for all modules."""
|
|
|
|
|
|
|
|
|
|
| 1815 |
|
| 1816 |
def _stop(module):
|
| 1817 |
if isinstance(module, StreamingModule):
|
|
@@ -2183,7 +2905,27 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 2183 |
codes_list: list[torch.Tensor],
|
| 2184 |
num_quantizers: int | None = None,
|
| 2185 |
chunk_duration: float | None = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2186 |
) -> MossAudioTokenizerDecoderOutput:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2187 |
audio_codes, audio_codes_lengths, num_quantizers_used = self._prepare_codes_batch(
|
| 2188 |
codes_list,
|
| 2189 |
num_quantizers=num_quantizers,
|
|
@@ -2191,9 +2933,53 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 2191 |
batch_size = len(codes_list)
|
| 2192 |
device = audio_codes.device
|
| 2193 |
|
| 2194 |
-
if chunk_duration is None:
|
| 2195 |
return self._decode_frame(audio_codes, audio_codes_lengths)
|
| 2196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2197 |
if chunk_duration <= 0:
|
| 2198 |
raise ValueError("`chunk_duration` must be > 0 when provided.")
|
| 2199 |
|
|
@@ -2366,6 +3152,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 2366 |
`MossAudioTokenizerDecoderOutput` or tuple containing decoded audio.
|
| 2367 |
"""
|
| 2368 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
| 2369 |
|
| 2370 |
if audio_codes.dim() == 2:
|
| 2371 |
codes_list = [audio_codes[:num_quantizers] if num_quantizers is not None else audio_codes]
|
|
|
|
| 17 |
|
| 18 |
import copy
|
| 19 |
import math
|
| 20 |
+
import sys
|
| 21 |
+
import types
|
| 22 |
from contextlib import ExitStack, contextmanager
|
| 23 |
from dataclasses import dataclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
from typing import cast
|
| 26 |
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
import torch.nn.functional as F
|
| 30 |
|
| 31 |
+
if __name__ not in sys.modules:
|
| 32 |
+
_module_proxy = types.ModuleType(__name__)
|
| 33 |
+
sys.modules[__name__] = _module_proxy
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _sync_module_proxy() -> None:
|
| 37 |
+
sys.modules[__name__].__dict__.update(globals())
|
| 38 |
+
|
| 39 |
try:
|
| 40 |
from transformers.modeling_utils import PreTrainedAudioTokenizerBase
|
| 41 |
except ImportError:
|
|
|
|
| 43 |
from transformers.utils import ModelOutput, logging
|
| 44 |
|
| 45 |
try:
|
| 46 |
+
from transformers.utils import auto_docstring as _hf_auto_docstring
|
| 47 |
except ImportError:
|
| 48 |
+
_hf_auto_docstring = None
|
| 49 |
+
|
| 50 |
+
def auto_docstring(*args, **kwargs):
|
| 51 |
+
if _hf_auto_docstring is None:
|
| 52 |
if len(args) == 1 and callable(args[0]) and not kwargs:
|
| 53 |
return args[0]
|
| 54 |
|
|
|
|
| 57 |
|
| 58 |
return decorator
|
| 59 |
|
| 60 |
+
if len(args) == 1 and callable(args[0]) and not kwargs:
|
| 61 |
+
obj = args[0]
|
| 62 |
+
try:
|
| 63 |
+
return _hf_auto_docstring(obj)
|
| 64 |
+
except Exception:
|
| 65 |
+
return obj
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
decorator = _hf_auto_docstring(*args, **kwargs)
|
| 69 |
+
except Exception:
|
| 70 |
+
def decorator(obj):
|
| 71 |
+
return obj
|
| 72 |
+
|
| 73 |
+
return decorator
|
| 74 |
+
|
| 75 |
+
def safe_decorator(obj):
|
| 76 |
+
try:
|
| 77 |
+
return decorator(obj)
|
| 78 |
+
except Exception:
|
| 79 |
+
return obj
|
| 80 |
+
|
| 81 |
+
return safe_decorator
|
| 82 |
+
|
| 83 |
try:
|
| 84 |
from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
|
| 85 |
except ImportError:
|
| 86 |
+
_module_dir = str(Path(__file__).resolve().parent)
|
| 87 |
+
if _module_dir not in sys.path:
|
| 88 |
+
sys.path.insert(0, _module_dir)
|
| 89 |
from configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
|
| 90 |
|
| 91 |
|
|
|
|
| 104 |
SUPPORTED_COMPUTE_DTYPES = {"fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16}
|
| 105 |
|
| 106 |
|
| 107 |
+
_ACTIVE_DECODE_SESSION_ERROR_MESSAGE = "MossAudioTokenizerModel only supports one active decode session at a time."
|
| 108 |
+
_CLOSED_DECODE_SESSION_ERROR_MESSAGE = "This decode session is closed."
|
| 109 |
+
_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE = "Model-level streaming helpers cannot be used while a decode session is active."
|
| 110 |
+
_PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE = "Plain decode helpers cannot be used while a decode session is active."
|
| 111 |
+
_DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session already contains request_id={request_id!r}."
|
| 112 |
+
_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session does not contain an active request_id={request_id!r}."
|
| 113 |
+
_DECODE_SESSION_FULL_ERROR_TEMPLATE = "Decode session has no free slots remaining (max_batch_size={max_batch_size})."
|
| 114 |
+
_INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE = (
|
| 115 |
+
"`request_ids` must exactly match the current active decode request order."
|
| 116 |
+
)
|
| 117 |
+
_BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE = "`finalize_indices` must not contain duplicates."
|
| 118 |
+
_BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE = (
|
| 119 |
+
"`finalize_indices` index {index} is out of range for the pre-call logical batch of size {batch_size}."
|
| 120 |
+
)
|
| 121 |
+
_BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE = (
|
| 122 |
+
"`batch_decode(streaming=True)` must include all pre-call active rows in the current call before applying `finalize_indices`."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
def resolve_compute_dtype(compute_dtype: str) -> torch.dtype | None:
|
| 127 |
if compute_dtype not in SUPPORTED_COMPUTE_DTYPES:
|
| 128 |
raise ValueError(
|
|
|
|
| 142 |
# =============================================================================
|
| 143 |
|
| 144 |
|
| 145 |
+
_sync_module_proxy()
|
| 146 |
@dataclass
|
| 147 |
@auto_docstring
|
| 148 |
class MossAudioTokenizerEncoderOutput(ModelOutput):
|
|
|
|
| 160 |
encoder_hidden_states: torch.Tensor | None = None
|
| 161 |
|
| 162 |
|
| 163 |
+
_sync_module_proxy()
|
| 164 |
@dataclass
|
| 165 |
@auto_docstring
|
| 166 |
class MossAudioTokenizerDecoderOutput(ModelOutput):
|
|
|
|
| 175 |
audio_lengths: torch.Tensor | None = None
|
| 176 |
|
| 177 |
|
| 178 |
+
_sync_module_proxy()
|
| 179 |
@dataclass
|
| 180 |
@auto_docstring
|
| 181 |
class MossAudioTokenizerOutput(ModelOutput):
|
|
|
|
| 201 |
# =============================================================================
|
| 202 |
|
| 203 |
|
| 204 |
+
_sync_module_proxy()
|
| 205 |
@dataclass
|
| 206 |
class StreamingState:
|
| 207 |
"""Base state for streaming modules."""
|
|
|
|
| 291 |
pass
|
| 292 |
|
| 293 |
|
| 294 |
+
class MossAudioTokenizerDecodeSession:
|
| 295 |
+
model: MossAudioTokenizerModel
|
| 296 |
+
max_batch_size: int
|
| 297 |
+
_use_cuda_graph: bool
|
| 298 |
+
active_request_ids: list[str | int]
|
| 299 |
+
request_id_to_slot_index: dict[str | int, int]
|
| 300 |
+
slot_index_to_request_id: list[str | int | None]
|
| 301 |
+
slot_is_free: list[bool]
|
| 302 |
+
request_id_to_code_offset: dict[str | int, int]
|
| 303 |
+
request_id_to_audio_offset: dict[str | int, int]
|
| 304 |
+
_flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention]
|
| 305 |
+
_graph_num_quantizers_capacity: int | None
|
| 306 |
+
_graph_input_codes: torch.Tensor | None
|
| 307 |
+
_graph_input_code_lengths: torch.Tensor | None
|
| 308 |
+
_graph_output_audio: torch.Tensor | None
|
| 309 |
+
_graph_output_audio_lengths: torch.Tensor | None
|
| 310 |
+
_cuda_graph: torch.cuda.CUDAGraph | None
|
| 311 |
+
_cuda_graph_key: tuple[str, int, int, str] | None
|
| 312 |
+
_decode_streaming_exit_stack: ExitStack | None
|
| 313 |
+
_closed: bool
|
| 314 |
+
|
| 315 |
+
def __init__(self, model: MossAudioTokenizerModel, max_batch_size: int, use_cuda_graph: bool = False):
|
| 316 |
+
if max_batch_size <= 0:
|
| 317 |
+
raise ValueError("`max_batch_size` must be > 0.")
|
| 318 |
+
|
| 319 |
+
decoder_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
|
| 320 |
+
for decoder_module in model.decoder:
|
| 321 |
+
for module in decoder_module.modules():
|
| 322 |
+
if isinstance(module, MossAudioTokenizerMultiheadAttention):
|
| 323 |
+
if module.context is None:
|
| 324 |
+
raise ValueError(
|
| 325 |
+
"MossAudioTokenizerDecodeSession requires all decoder MHA modules to have a finite "
|
| 326 |
+
"`context` (context=None is unsupported for continuous-batch streaming)."
|
| 327 |
+
)
|
| 328 |
+
decoder_attention_modules.append(module)
|
| 329 |
+
|
| 330 |
+
flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
|
| 331 |
+
if use_cuda_graph and HAS_FLASH_ATTN:
|
| 332 |
+
for module in decoder_attention_modules:
|
| 333 |
+
module._use_flash_kvcache = True
|
| 334 |
+
flash_kvcache_attention_modules.append(module)
|
| 335 |
+
|
| 336 |
+
decode_streaming_exit_stack = ExitStack()
|
| 337 |
+
try:
|
| 338 |
+
for decoder_module in model.decoder:
|
| 339 |
+
if isinstance(decoder_module, StreamingModule):
|
| 340 |
+
inner_stack = decoder_module.streaming(batch_size=max_batch_size)
|
| 341 |
+
_ = decode_streaming_exit_stack.enter_context(inner_stack)
|
| 342 |
+
except Exception:
|
| 343 |
+
decode_streaming_exit_stack.close()
|
| 344 |
+
for module in flash_kvcache_attention_modules:
|
| 345 |
+
module._use_flash_kvcache = False
|
| 346 |
+
raise
|
| 347 |
+
|
| 348 |
+
self.model = model
|
| 349 |
+
self.max_batch_size = max_batch_size
|
| 350 |
+
self._use_cuda_graph = use_cuda_graph
|
| 351 |
+
self.active_request_ids: list[str | int] = []
|
| 352 |
+
self.request_id_to_slot_index: dict[str | int, int] = {}
|
| 353 |
+
self.slot_index_to_request_id: list[str | int | None] = [None] * max_batch_size
|
| 354 |
+
self.slot_is_free: list[bool] = [True] * max_batch_size
|
| 355 |
+
self.request_id_to_code_offset: dict[str | int, int] = {}
|
| 356 |
+
self.request_id_to_audio_offset: dict[str | int, int] = {}
|
| 357 |
+
self._flash_kvcache_attention_modules = flash_kvcache_attention_modules
|
| 358 |
+
self._graph_num_quantizers_capacity = int(getattr(model.quantizer, "num_quantizers", 0)) if use_cuda_graph else None
|
| 359 |
+
self._graph_input_codes = None
|
| 360 |
+
self._graph_input_code_lengths = None
|
| 361 |
+
self._graph_output_audio = None
|
| 362 |
+
self._graph_output_audio_lengths = None
|
| 363 |
+
self._cuda_graph = None
|
| 364 |
+
self._cuda_graph_key = None
|
| 365 |
+
self._decode_streaming_exit_stack: ExitStack | None = decode_streaming_exit_stack
|
| 366 |
+
self._closed = False
|
| 367 |
+
if use_cuda_graph:
|
| 368 |
+
device = next(iter(model.parameters())).device
|
| 369 |
+
if device.type == "cuda":
|
| 370 |
+
self._ensure_cuda_graph_buffers(device)
|
| 371 |
+
model._active_decode_session = self
|
| 372 |
+
|
| 373 |
+
def _ensure_open(self) -> None:
|
| 374 |
+
if self._closed:
|
| 375 |
+
raise RuntimeError(_CLOSED_DECODE_SESSION_ERROR_MESSAGE)
|
| 376 |
+
|
| 377 |
+
def append(self, request_id: str | int) -> None:
|
| 378 |
+
self._ensure_open()
|
| 379 |
+
|
| 380 |
+
if request_id in self.request_id_to_slot_index:
|
| 381 |
+
raise RuntimeError(_DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id))
|
| 382 |
+
|
| 383 |
+
slot_index = next((index for index, is_free in enumerate(self.slot_is_free) if is_free), None)
|
| 384 |
+
if slot_index is None:
|
| 385 |
+
raise RuntimeError(_DECODE_SESSION_FULL_ERROR_TEMPLATE.format(max_batch_size=self.max_batch_size))
|
| 386 |
+
|
| 387 |
+
self.active_request_ids.append(request_id)
|
| 388 |
+
self.request_id_to_slot_index[request_id] = slot_index
|
| 389 |
+
self.slot_index_to_request_id[slot_index] = request_id
|
| 390 |
+
self.slot_is_free[slot_index] = False
|
| 391 |
+
self.request_id_to_code_offset[request_id] = 0
|
| 392 |
+
self.request_id_to_audio_offset[request_id] = 0
|
| 393 |
+
|
| 394 |
+
def _decoder_streaming_states(self) -> list[StreamingState]:
|
| 395 |
+
decoder_streaming_states: list[StreamingState] = []
|
| 396 |
+
for decoder_module in self.model.decoder:
|
| 397 |
+
for module in decoder_module.modules():
|
| 398 |
+
if isinstance(module, StreamingModule) and module._streaming_state is not None:
|
| 399 |
+
decoder_streaming_states.append(module._streaming_state)
|
| 400 |
+
return decoder_streaming_states
|
| 401 |
+
|
| 402 |
+
def _ensure_cuda_graph_buffers(self, device: torch.device) -> None:
|
| 403 |
+
if not self._use_cuda_graph or device.type != "cuda":
|
| 404 |
+
return
|
| 405 |
+
graph_num_quantizers_capacity = self._graph_num_quantizers_capacity
|
| 406 |
+
if graph_num_quantizers_capacity is None:
|
| 407 |
+
graph_num_quantizers_capacity = int(getattr(self.model.quantizer, "num_quantizers", 0))
|
| 408 |
+
self._graph_num_quantizers_capacity = graph_num_quantizers_capacity
|
| 409 |
+
if graph_num_quantizers_capacity <= 0:
|
| 410 |
+
raise RuntimeError("`use_cuda_graph=True` requires a quantizer with `num_quantizers > 0`.")
|
| 411 |
+
if self._graph_input_codes is None or self._graph_input_codes.device != device:
|
| 412 |
+
self._graph_input_codes = torch.zeros(
|
| 413 |
+
(graph_num_quantizers_capacity, self.max_batch_size, 1),
|
| 414 |
+
device=device,
|
| 415 |
+
dtype=torch.long,
|
| 416 |
+
)
|
| 417 |
+
self._graph_input_code_lengths = torch.zeros(self.max_batch_size, device=device, dtype=torch.long)
|
| 418 |
+
self._graph_output_audio = None
|
| 419 |
+
self._graph_output_audio_lengths = None
|
| 420 |
+
self._cuda_graph = None
|
| 421 |
+
self._cuda_graph_key = None
|
| 422 |
+
|
| 423 |
+
def _snapshot_decoder_streaming_states(self) -> list[tuple[StreamingState, dict[str, torch.Tensor | None]]]:
|
| 424 |
+
snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]] = []
|
| 425 |
+
for streaming_state in self._decoder_streaming_states():
|
| 426 |
+
state_snapshot: dict[str, torch.Tensor | None] = {"exec_mask": streaming_state.exec_mask.clone()}
|
| 427 |
+
if isinstance(streaming_state, TransformerState):
|
| 428 |
+
state_snapshot["offsets"] = streaming_state.offsets.clone()
|
| 429 |
+
if isinstance(streaming_state, MHAState):
|
| 430 |
+
state_snapshot["offset"] = streaming_state.offset.clone()
|
| 431 |
+
state_snapshot["cached_keys"] = None if streaming_state.cached_keys is None else streaming_state.cached_keys.clone()
|
| 432 |
+
state_snapshot["cached_values"] = None if streaming_state.cached_values is None else streaming_state.cached_values.clone()
|
| 433 |
+
state_snapshot["cached_positions"] = (
|
| 434 |
+
None if streaming_state.cached_positions is None else streaming_state.cached_positions.clone()
|
| 435 |
+
)
|
| 436 |
+
state_snapshot["flash_cached_keys"] = (
|
| 437 |
+
None
|
| 438 |
+
if getattr(streaming_state, "_flash_cached_keys", None) is None
|
| 439 |
+
else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_keys")).clone()
|
| 440 |
+
)
|
| 441 |
+
state_snapshot["flash_cached_values"] = (
|
| 442 |
+
None
|
| 443 |
+
if getattr(streaming_state, "_flash_cached_values", None) is None
|
| 444 |
+
else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_values")).clone()
|
| 445 |
+
)
|
| 446 |
+
snapshots.append((streaming_state, state_snapshot))
|
| 447 |
+
return snapshots
|
| 448 |
+
|
| 449 |
+
def _restore_decoder_streaming_states(
|
| 450 |
+
self,
|
| 451 |
+
snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]],
|
| 452 |
+
) -> None:
|
| 453 |
+
for streaming_state, state_snapshot in snapshots:
|
| 454 |
+
exec_mask = state_snapshot["exec_mask"]
|
| 455 |
+
assert exec_mask is not None
|
| 456 |
+
streaming_state.exec_mask.copy_(exec_mask)
|
| 457 |
+
if isinstance(streaming_state, TransformerState):
|
| 458 |
+
offsets = state_snapshot.get("offsets")
|
| 459 |
+
assert offsets is not None
|
| 460 |
+
streaming_state.offsets.copy_(offsets)
|
| 461 |
+
if isinstance(streaming_state, MHAState):
|
| 462 |
+
offset = state_snapshot.get("offset")
|
| 463 |
+
assert offset is not None
|
| 464 |
+
streaming_state.offset.copy_(offset)
|
| 465 |
+
cached_keys = state_snapshot.get("cached_keys")
|
| 466 |
+
cached_values = state_snapshot.get("cached_values")
|
| 467 |
+
cached_positions = state_snapshot.get("cached_positions")
|
| 468 |
+
if cached_keys is None or cached_values is None or cached_positions is None:
|
| 469 |
+
if streaming_state.cached_keys is not None:
|
| 470 |
+
streaming_state.cached_keys.zero_()
|
| 471 |
+
if streaming_state.cached_values is not None:
|
| 472 |
+
streaming_state.cached_values.zero_()
|
| 473 |
+
if streaming_state.cached_positions is not None:
|
| 474 |
+
streaming_state.cached_positions.fill_(-1)
|
| 475 |
+
else:
|
| 476 |
+
if streaming_state.cached_keys is None or streaming_state.cached_keys.shape != cached_keys.shape:
|
| 477 |
+
streaming_state.cached_keys = cached_keys.clone()
|
| 478 |
+
else:
|
| 479 |
+
streaming_state.cached_keys.copy_(cached_keys)
|
| 480 |
+
if streaming_state.cached_values is None or streaming_state.cached_values.shape != cached_values.shape:
|
| 481 |
+
streaming_state.cached_values = cached_values.clone()
|
| 482 |
+
else:
|
| 483 |
+
streaming_state.cached_values.copy_(cached_values)
|
| 484 |
+
if streaming_state.cached_positions is None or streaming_state.cached_positions.shape != cached_positions.shape:
|
| 485 |
+
streaming_state.cached_positions = cached_positions.clone()
|
| 486 |
+
else:
|
| 487 |
+
streaming_state.cached_positions.copy_(cached_positions)
|
| 488 |
+
|
| 489 |
+
flash_cached_keys = state_snapshot.get("flash_cached_keys")
|
| 490 |
+
flash_cached_values = state_snapshot.get("flash_cached_values")
|
| 491 |
+
current_flash_cached_keys = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_keys", None))
|
| 492 |
+
current_flash_cached_values = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_values", None))
|
| 493 |
+
if flash_cached_keys is None or flash_cached_values is None:
|
| 494 |
+
if current_flash_cached_keys is not None:
|
| 495 |
+
current_flash_cached_keys.zero_()
|
| 496 |
+
if current_flash_cached_values is not None:
|
| 497 |
+
current_flash_cached_values.zero_()
|
| 498 |
+
else:
|
| 499 |
+
if current_flash_cached_keys is None or current_flash_cached_keys.shape != flash_cached_keys.shape:
|
| 500 |
+
setattr(streaming_state, "_flash_cached_keys", flash_cached_keys.clone())
|
| 501 |
+
else:
|
| 502 |
+
current_flash_cached_keys.copy_(flash_cached_keys)
|
| 503 |
+
if current_flash_cached_values is None or current_flash_cached_values.shape != flash_cached_values.shape:
|
| 504 |
+
setattr(streaming_state, "_flash_cached_values", flash_cached_values.clone())
|
| 505 |
+
else:
|
| 506 |
+
current_flash_cached_values.copy_(flash_cached_values)
|
| 507 |
+
|
| 508 |
+
def _graphed_decode_frame(
|
| 509 |
+
self,
|
| 510 |
+
codes: torch.Tensor,
|
| 511 |
+
code_lengths: torch.Tensor,
|
| 512 |
+
) -> MossAudioTokenizerDecoderOutput:
|
| 513 |
+
self._ensure_cuda_graph_buffers(codes.device)
|
| 514 |
+
graph_input_codes = self._graph_input_codes
|
| 515 |
+
graph_input_code_lengths = self._graph_input_code_lengths
|
| 516 |
+
if graph_input_codes is None or graph_input_code_lengths is None:
|
| 517 |
+
raise RuntimeError("CUDA graph buffers are unavailable.")
|
| 518 |
+
|
| 519 |
+
num_quantizers = codes.shape[0]
|
| 520 |
+
graph_input_codes_view = graph_input_codes[:num_quantizers]
|
| 521 |
+
graph_input_codes_view.copy_(codes)
|
| 522 |
+
graph_input_code_lengths.copy_(code_lengths)
|
| 523 |
+
cuda_graph_key = (str(codes.device), self.max_batch_size, num_quantizers, self.model.compute_dtype_name)
|
| 524 |
+
|
| 525 |
+
if self._cuda_graph is None or self._cuda_graph_key != cuda_graph_key:
|
| 526 |
+
state_snapshots = self._snapshot_decoder_streaming_states()
|
| 527 |
+
current_stream = torch.cuda.current_stream(device=codes.device)
|
| 528 |
+
warmup_stream = torch.cuda.Stream(device=codes.device)
|
| 529 |
+
warmup_stream.wait_stream(current_stream)
|
| 530 |
+
with torch.cuda.stream(warmup_stream):
|
| 531 |
+
_ = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths)
|
| 532 |
+
current_stream.wait_stream(warmup_stream)
|
| 533 |
+
self._restore_decoder_streaming_states(state_snapshots)
|
| 534 |
+
|
| 535 |
+
cuda_graph = torch.cuda.CUDAGraph()
|
| 536 |
+
with torch.cuda.graph(cuda_graph):
|
| 537 |
+
decoder_output = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths)
|
| 538 |
+
|
| 539 |
+
self._cuda_graph = cuda_graph
|
| 540 |
+
self._cuda_graph_key = cuda_graph_key
|
| 541 |
+
self._graph_output_audio = decoder_output.audio
|
| 542 |
+
self._graph_output_audio_lengths = decoder_output.audio_lengths
|
| 543 |
+
else:
|
| 544 |
+
self._cuda_graph.replay()
|
| 545 |
+
|
| 546 |
+
return MossAudioTokenizerDecoderOutput(
|
| 547 |
+
audio=self._graph_output_audio,
|
| 548 |
+
audio_lengths=self._graph_output_audio_lengths,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def _reset_slot(self, slot_index: int) -> None:
|
| 552 |
+
for streaming_state in self._decoder_streaming_states():
|
| 553 |
+
reset_mask = torch.zeros(streaming_state.batch_size, dtype=torch.bool, device=streaming_state.exec_mask.device)
|
| 554 |
+
reset_mask[slot_index] = True
|
| 555 |
+
streaming_state.reset(reset_mask)
|
| 556 |
+
|
| 557 |
+
def _pack_logical_codes_to_physical_slots(
|
| 558 |
+
self,
|
| 559 |
+
request_ids: list[str | int],
|
| 560 |
+
codes: torch.Tensor,
|
| 561 |
+
code_lengths: torch.Tensor,
|
| 562 |
+
) -> tuple[torch.Tensor, torch.Tensor, list[int], torch.Tensor]:
|
| 563 |
+
if request_ids != self.active_request_ids:
|
| 564 |
+
raise ValueError(_INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE)
|
| 565 |
+
|
| 566 |
+
if not request_ids:
|
| 567 |
+
raise ValueError("`step()` requires at least one active request.")
|
| 568 |
+
|
| 569 |
+
if codes.dim() == 2:
|
| 570 |
+
codes = codes.unsqueeze(1)
|
| 571 |
+
if codes.dim() != 3:
|
| 572 |
+
raise ValueError(f"`codes` must be 3D with shape `(num_quantizers, batch_size, sequence_length)`, got {codes.shape}.")
|
| 573 |
+
|
| 574 |
+
code_lengths = code_lengths.to(device=codes.device, dtype=torch.long)
|
| 575 |
+
if code_lengths.dim() != 1:
|
| 576 |
+
raise ValueError(f"`code_lengths` must be 1D with shape `(batch_size,)`, got {code_lengths.shape}.")
|
| 577 |
+
|
| 578 |
+
num_quantizers, logical_batch_size, max_code_length = codes.shape
|
| 579 |
+
if logical_batch_size != len(request_ids):
|
| 580 |
+
raise ValueError(
|
| 581 |
+
f"`codes.shape[1]` ({logical_batch_size}) must match len(`request_ids`) ({len(request_ids)})."
|
| 582 |
+
)
|
| 583 |
+
if code_lengths.shape[0] != logical_batch_size:
|
| 584 |
+
raise ValueError(
|
| 585 |
+
f"`code_lengths.shape[0]` ({code_lengths.shape[0]}) must match len(`request_ids`) ({len(request_ids)})."
|
| 586 |
+
)
|
| 587 |
+
if torch.any(code_lengths < 0):
|
| 588 |
+
raise ValueError("`code_lengths` must be >= 0.")
|
| 589 |
+
if torch.any(code_lengths > max_code_length):
|
| 590 |
+
raise ValueError(f"`code_lengths` must be <= codes.shape[-1] ({max_code_length}).")
|
| 591 |
+
|
| 592 |
+
packed_codes = codes.new_zeros((num_quantizers, self.max_batch_size, max_code_length))
|
| 593 |
+
packed_code_lengths = code_lengths.new_zeros((self.max_batch_size,))
|
| 594 |
+
logical_row_to_slot_index: list[int] = []
|
| 595 |
+
|
| 596 |
+
for logical_row_index, request_id in enumerate(request_ids):
|
| 597 |
+
slot_index = self.request_id_to_slot_index[request_id]
|
| 598 |
+
logical_row_to_slot_index.append(slot_index)
|
| 599 |
+
row_length = int(code_lengths[logical_row_index].item())
|
| 600 |
+
if row_length > 0:
|
| 601 |
+
packed_codes[:, slot_index, :row_length] = codes[:, logical_row_index, :row_length]
|
| 602 |
+
packed_code_lengths[slot_index] = row_length
|
| 603 |
+
|
| 604 |
+
return packed_codes, packed_code_lengths, logical_row_to_slot_index, code_lengths
|
| 605 |
+
|
| 606 |
+
def _advance_request_progress(
|
| 607 |
+
self,
|
| 608 |
+
request_ids: list[str | int],
|
| 609 |
+
code_lengths: torch.Tensor,
|
| 610 |
+
audio_lengths: torch.Tensor,
|
| 611 |
+
) -> None:
|
| 612 |
+
for logical_row_index, request_id in enumerate(request_ids):
|
| 613 |
+
self.request_id_to_code_offset[request_id] += int(code_lengths[logical_row_index].item())
|
| 614 |
+
self.request_id_to_audio_offset[request_id] += int(audio_lengths[logical_row_index].item())
|
| 615 |
+
|
| 616 |
+
def step(
|
| 617 |
+
self,
|
| 618 |
+
request_ids: list[str | int],
|
| 619 |
+
codes: torch.Tensor,
|
| 620 |
+
code_lengths: torch.Tensor,
|
| 621 |
+
) -> tuple[list[str | int], torch.Tensor, torch.Tensor]:
|
| 622 |
+
self._ensure_open()
|
| 623 |
+
|
| 624 |
+
packed_codes, packed_code_lengths, logical_row_to_slot_index, logical_code_lengths = (
|
| 625 |
+
self._pack_logical_codes_to_physical_slots(
|
| 626 |
+
request_ids=request_ids,
|
| 627 |
+
codes=codes,
|
| 628 |
+
code_lengths=code_lengths,
|
| 629 |
+
)
|
| 630 |
+
)
|
| 631 |
+
max_step_length = int(packed_code_lengths.max().item())
|
| 632 |
+
|
| 633 |
+
if max_step_length <= 0:
|
| 634 |
+
raise ValueError("`step()` requires at least one row with `code_length > 0`.")
|
| 635 |
+
|
| 636 |
+
decoder_streaming_states = self._decoder_streaming_states()
|
| 637 |
+
logical_audio_chunks: list[list[torch.Tensor]] = [[] for _ in request_ids]
|
| 638 |
+
audio_device: torch.device | None = None
|
| 639 |
+
audio_dtype: torch.dtype | None = None
|
| 640 |
+
audio_num_channels: int | None = None
|
| 641 |
+
|
| 642 |
+
try:
|
| 643 |
+
for frame_index in range(max_step_length):
|
| 644 |
+
frame_exec_mask = packed_code_lengths > frame_index
|
| 645 |
+
for streaming_state in decoder_streaming_states:
|
| 646 |
+
streaming_state.set_exec_mask(frame_exec_mask)
|
| 647 |
+
|
| 648 |
+
frame_codes = packed_codes[:, :, frame_index : frame_index + 1]
|
| 649 |
+
frame_code_lengths = frame_exec_mask.to(dtype=packed_code_lengths.dtype)
|
| 650 |
+
if self._use_cuda_graph and frame_codes.is_cuda:
|
| 651 |
+
decoder_output = self._graphed_decode_frame(frame_codes, frame_code_lengths)
|
| 652 |
+
else:
|
| 653 |
+
decoder_output = self.model._decode_frame(frame_codes, frame_code_lengths)
|
| 654 |
+
|
| 655 |
+
if decoder_output.audio is None or decoder_output.audio_lengths is None:
|
| 656 |
+
raise RuntimeError("Internal error: `_decode_frame` returned empty audio.")
|
| 657 |
+
|
| 658 |
+
audio = decoder_output.audio
|
| 659 |
+
audio_lengths = decoder_output.audio_lengths
|
| 660 |
+
audio_device = audio.device
|
| 661 |
+
audio_dtype = audio.dtype
|
| 662 |
+
audio_num_channels = audio.shape[1]
|
| 663 |
+
|
| 664 |
+
for logical_row_index, slot_index in enumerate(logical_row_to_slot_index):
|
| 665 |
+
audio_length = int(audio_lengths[slot_index].item())
|
| 666 |
+
if audio_length <= 0:
|
| 667 |
+
continue
|
| 668 |
+
logical_audio_chunks[logical_row_index].append(audio[slot_index : slot_index + 1, :, :audio_length])
|
| 669 |
+
except Exception:
|
| 670 |
+
self.close()
|
| 671 |
+
raise
|
| 672 |
+
finally:
|
| 673 |
+
for streaming_state in decoder_streaming_states:
|
| 674 |
+
streaming_state.set_exec_mask(torch.ones_like(streaming_state.exec_mask))
|
| 675 |
+
|
| 676 |
+
if audio_device is None or audio_dtype is None or audio_num_channels is None:
|
| 677 |
+
raise RuntimeError("Internal error: `step()` produced no decoder outputs.")
|
| 678 |
+
|
| 679 |
+
logical_audio_rows: list[torch.Tensor] = []
|
| 680 |
+
logical_audio_lengths: list[int] = []
|
| 681 |
+
for row_chunks in logical_audio_chunks:
|
| 682 |
+
if row_chunks:
|
| 683 |
+
row_audio = torch.cat(row_chunks, dim=-1)
|
| 684 |
+
else:
|
| 685 |
+
row_audio = torch.zeros((1, audio_num_channels, 0), device=audio_device, dtype=audio_dtype)
|
| 686 |
+
logical_audio_rows.append(row_audio)
|
| 687 |
+
logical_audio_lengths.append(row_audio.shape[-1])
|
| 688 |
+
|
| 689 |
+
audio_lengths = torch.tensor(logical_audio_lengths, device=audio_device, dtype=torch.long)
|
| 690 |
+
max_audio_length = max(logical_audio_lengths)
|
| 691 |
+
audio = torch.zeros(
|
| 692 |
+
(len(request_ids), audio_num_channels, max_audio_length),
|
| 693 |
+
device=audio_device,
|
| 694 |
+
dtype=audio_dtype,
|
| 695 |
+
)
|
| 696 |
+
for logical_row_index, row_audio in enumerate(logical_audio_rows):
|
| 697 |
+
row_audio_length = row_audio.shape[-1]
|
| 698 |
+
if row_audio_length > 0:
|
| 699 |
+
audio[logical_row_index, :, :row_audio_length] = row_audio[0]
|
| 700 |
+
|
| 701 |
+
logical_request_ids = list(request_ids)
|
| 702 |
+
self._advance_request_progress(
|
| 703 |
+
request_ids=logical_request_ids,
|
| 704 |
+
code_lengths=logical_code_lengths,
|
| 705 |
+
audio_lengths=audio_lengths,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
return logical_request_ids, audio, audio_lengths
|
| 709 |
+
|
| 710 |
+
def remove(self, request_id: str | int) -> None:
|
| 711 |
+
self._ensure_open()
|
| 712 |
+
|
| 713 |
+
slot_index = self.request_id_to_slot_index.get(request_id)
|
| 714 |
+
if slot_index is None or request_id not in self.active_request_ids:
|
| 715 |
+
raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id))
|
| 716 |
+
if self.slot_is_free[slot_index] or self.slot_index_to_request_id[slot_index] != request_id:
|
| 717 |
+
raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id))
|
| 718 |
+
|
| 719 |
+
self.active_request_ids.remove(request_id)
|
| 720 |
+
self._reset_slot(slot_index)
|
| 721 |
+
_ = self.request_id_to_slot_index.pop(request_id)
|
| 722 |
+
self.slot_index_to_request_id[slot_index] = None
|
| 723 |
+
self.slot_is_free[slot_index] = True
|
| 724 |
+
_ = self.request_id_to_code_offset.pop(request_id, None)
|
| 725 |
+
_ = self.request_id_to_audio_offset.pop(request_id, None)
|
| 726 |
+
|
| 727 |
+
def close(self) -> None:
|
| 728 |
+
if self._closed:
|
| 729 |
+
return
|
| 730 |
+
|
| 731 |
+
self._closed = True
|
| 732 |
+
decode_streaming_exit_stack = self._decode_streaming_exit_stack
|
| 733 |
+
self._decode_streaming_exit_stack = None
|
| 734 |
+
try:
|
| 735 |
+
if decode_streaming_exit_stack is not None:
|
| 736 |
+
decode_streaming_exit_stack.close()
|
| 737 |
+
finally:
|
| 738 |
+
for module in self._flash_kvcache_attention_modules:
|
| 739 |
+
module._use_flash_kvcache = False
|
| 740 |
+
self._flash_kvcache_attention_modules = []
|
| 741 |
+
self._cuda_graph = None
|
| 742 |
+
self._cuda_graph_key = None
|
| 743 |
+
self._graph_input_codes = None
|
| 744 |
+
self._graph_input_code_lengths = None
|
| 745 |
+
self._graph_output_audio = None
|
| 746 |
+
self._graph_output_audio_lengths = None
|
| 747 |
+
if self.model._active_decode_session is self:
|
| 748 |
+
self.model._active_decode_session = None
|
| 749 |
+
|
| 750 |
+
|
| 751 |
# =============================================================================
|
| 752 |
# Normalization Layers
|
| 753 |
# =============================================================================
|
|
|
|
| 1118 |
# =============================================================================
|
| 1119 |
|
| 1120 |
|
| 1121 |
+
_sync_module_proxy()
|
| 1122 |
@dataclass
|
| 1123 |
class MHAState(StreamingState):
|
| 1124 |
cached_keys: torch.Tensor | None
|
|
|
|
| 1198 |
f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}."
|
| 1199 |
)
|
| 1200 |
self.attention_implementation = attention_implementation
|
| 1201 |
+
self._use_flash_kvcache = False
|
| 1202 |
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False, **factory_kwargs)
|
| 1203 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
|
| 1204 |
|
|
|
|
| 1333 |
state.cached_positions = state.cached_positions.to(device=device)
|
| 1334 |
return state.cached_keys, state.cached_values, state.cached_positions
|
| 1335 |
|
| 1336 |
+
def _ensure_flash_kvcache(
|
| 1337 |
+
self,
|
| 1338 |
+
state: MHAState,
|
| 1339 |
+
batch_size: int,
|
| 1340 |
+
device: torch.device,
|
| 1341 |
+
dtype: torch.dtype,
|
| 1342 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1343 |
+
if self.context is None:
|
| 1344 |
+
raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
|
| 1345 |
+
head_dim = self.embed_dim // self.num_heads
|
| 1346 |
+
flash_cached_keys = cast(torch.Tensor | None, getattr(state, "_flash_cached_keys", None))
|
| 1347 |
+
flash_cached_values = cast(torch.Tensor | None, getattr(state, "_flash_cached_values", None))
|
| 1348 |
+
if flash_cached_keys is None or flash_cached_values is None:
|
| 1349 |
+
flash_cached_keys = torch.zeros(
|
| 1350 |
+
(batch_size, self.context, self.num_heads, head_dim),
|
| 1351 |
+
device=device,
|
| 1352 |
+
dtype=dtype,
|
| 1353 |
+
)
|
| 1354 |
+
flash_cached_values = torch.zeros_like(flash_cached_keys)
|
| 1355 |
+
else:
|
| 1356 |
+
if flash_cached_keys.device != device or flash_cached_keys.dtype != dtype:
|
| 1357 |
+
flash_cached_keys = flash_cached_keys.to(device=device, dtype=dtype)
|
| 1358 |
+
if flash_cached_values.device != device or flash_cached_values.dtype != dtype:
|
| 1359 |
+
flash_cached_values = flash_cached_values.to(device=device, dtype=dtype)
|
| 1360 |
+
setattr(state, "_flash_cached_keys", flash_cached_keys)
|
| 1361 |
+
setattr(state, "_flash_cached_values", flash_cached_values)
|
| 1362 |
+
return flash_cached_keys, flash_cached_values
|
| 1363 |
+
|
| 1364 |
def _build_streaming_kv(
|
| 1365 |
self,
|
| 1366 |
cached_k: torch.Tensor,
|
|
|
|
| 1395 |
state.cached_positions = pos_k.contiguous()
|
| 1396 |
return
|
| 1397 |
|
| 1398 |
+
assert state.cached_keys is not None
|
| 1399 |
+
assert state.cached_values is not None
|
| 1400 |
+
assert state.cached_positions is not None
|
| 1401 |
new_cached_k = k_all[:, :, -self.context :, :].contiguous()
|
| 1402 |
new_cached_v = v_all[:, :, -self.context :, :].contiguous()
|
| 1403 |
new_cached_pos = pos_k[:, -self.context :].contiguous()
|
| 1404 |
+
state.cached_keys.copy_(torch.where(exec_mask, new_cached_k, cached_k))
|
| 1405 |
+
state.cached_values.copy_(torch.where(exec_mask, new_cached_v, cached_v))
|
| 1406 |
+
state.cached_positions.copy_(torch.where(exec_mask_pos, new_cached_pos, cached_pos))
|
| 1407 |
|
| 1408 |
def _build_streaming_sdpa_bias(self, pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor:
|
| 1409 |
delta = pos_q[:, :, None] - pos_k[:, None, :]
|
|
|
|
| 1443 |
if flash_attn_varlen_func is None:
|
| 1444 |
raise RuntimeError("flash-attn is not installed.")
|
| 1445 |
window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
|
| 1446 |
+
return cast(
|
| 1447 |
+
torch.Tensor,
|
| 1448 |
+
flash_attn_varlen_func(
|
| 1449 |
+
q.contiguous(),
|
| 1450 |
+
k.contiguous(),
|
| 1451 |
+
v.contiguous(),
|
| 1452 |
+
cu_seqlens_q,
|
| 1453 |
+
cu_seqlens_k,
|
| 1454 |
+
max_seqlen_q,
|
| 1455 |
+
max_seqlen_k,
|
| 1456 |
+
causal=self.causal,
|
| 1457 |
+
window_size=window_size,
|
| 1458 |
+
),
|
| 1459 |
)
|
| 1460 |
|
| 1461 |
def _forward_streaming_sdpa(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
|
|
|
|
| 1524 |
state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset)
|
| 1525 |
return out
|
| 1526 |
|
| 1527 |
+
def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
|
| 1528 |
+
from flash_attn import flash_attn_with_kvcache
|
| 1529 |
+
|
| 1530 |
+
if self.context is None:
|
| 1531 |
+
raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
|
| 1532 |
+
|
| 1533 |
+
batch_size, chunk_length, _ = x.shape
|
| 1534 |
+
q, k_cur, v_cur = self._project_qkv(x)
|
| 1535 |
+
if self.rope is not None:
|
| 1536 |
+
q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False)
|
| 1537 |
+
|
| 1538 |
+
q = q.transpose(1, 2).contiguous()
|
| 1539 |
+
k_cur = k_cur.transpose(1, 2).contiguous()
|
| 1540 |
+
v_cur = v_cur.transpose(1, 2).contiguous()
|
| 1541 |
+
|
| 1542 |
+
exec_mask = state.exec_mask.view(batch_size, 1, 1, 1).to(dtype=k_cur.dtype)
|
| 1543 |
+
k_cur = k_cur * exec_mask
|
| 1544 |
+
v_cur = v_cur * exec_mask
|
| 1545 |
+
|
| 1546 |
+
k_cache, v_cache = self._ensure_flash_kvcache(state, batch_size, k_cur.device, k_cur.dtype)
|
| 1547 |
+
cache_seqlens = state.offset.clamp(max=self.context).to(torch.int32)
|
| 1548 |
+
window_size = (self.context - 1, 0)
|
| 1549 |
+
|
| 1550 |
+
out = cast(
|
| 1551 |
+
torch.Tensor,
|
| 1552 |
+
flash_attn_with_kvcache(
|
| 1553 |
+
q,
|
| 1554 |
+
k_cache,
|
| 1555 |
+
v_cache,
|
| 1556 |
+
k=k_cur,
|
| 1557 |
+
v=v_cur,
|
| 1558 |
+
cache_seqlens=cache_seqlens,
|
| 1559 |
+
causal=True,
|
| 1560 |
+
window_size=window_size,
|
| 1561 |
+
),
|
| 1562 |
+
)
|
| 1563 |
+
out = out.reshape(batch_size, chunk_length, self.embed_dim)
|
| 1564 |
+
state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset)
|
| 1565 |
+
return out
|
| 1566 |
+
|
| 1567 |
def _forward_non_streaming_sdpa(self, x: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor:
|
| 1568 |
batch_size, max_seqlen, _ = x.shape
|
| 1569 |
q, k, v = self._project_qkv(x)
|
|
|
|
| 1605 |
if state is not None:
|
| 1606 |
if query.dim() != 3:
|
| 1607 |
raise ValueError(f"Streaming attention expects a 3D tensor, got shape {tuple(query.shape)}")
|
| 1608 |
+
if backend == "flash_attention_2" and self._use_flash_kvcache:
|
| 1609 |
+
out = self._forward_streaming_flash_kvcache(query, state)
|
| 1610 |
+
elif backend == "flash_attention_2":
|
| 1611 |
+
out = self._forward_streaming_flash(query, state)
|
| 1612 |
+
else:
|
| 1613 |
+
out = self._forward_streaming_sdpa(query, state)
|
| 1614 |
return self.out_proj(out)
|
| 1615 |
|
| 1616 |
if backend == "flash_attention_2":
|
|
|
|
| 1634 |
# =============================================================================
|
| 1635 |
|
| 1636 |
|
| 1637 |
+
_sync_module_proxy()
|
| 1638 |
@dataclass
|
| 1639 |
class LayerState(StreamingState):
|
| 1640 |
pass
|
|
|
|
| 1726 |
# =============================================================================
|
| 1727 |
|
| 1728 |
|
| 1729 |
+
_sync_module_proxy()
|
| 1730 |
@dataclass
|
| 1731 |
class TransformerState(StreamingState):
|
| 1732 |
offsets: torch.Tensor
|
|
|
|
| 2399 |
)
|
| 2400 |
|
| 2401 |
self.post_init()
|
| 2402 |
+
self._active_decode_session: "MossAudioTokenizerDecodeSession | None" = None
|
| 2403 |
+
self._batch_decode_streaming_max_batch_size: int | None = None
|
| 2404 |
+
self._batch_decode_streaming_batch_size: int | None = None
|
| 2405 |
+
self._batch_decode_streaming_session: "MossAudioTokenizerDecodeSession | None" = None
|
| 2406 |
+
self._batch_decode_streaming_next_request_id: int = 0
|
| 2407 |
+
|
| 2408 |
+
def create_decode_session(
|
| 2409 |
+
self,
|
| 2410 |
+
max_batch_size: int,
|
| 2411 |
+
use_cuda_graph: bool = False,
|
| 2412 |
+
) -> MossAudioTokenizerDecodeSession:
|
| 2413 |
+
active_session = self._active_decode_session
|
| 2414 |
+
if active_session is not None and not active_session._closed:
|
| 2415 |
+
raise RuntimeError(_ACTIVE_DECODE_SESSION_ERROR_MESSAGE)
|
| 2416 |
+
|
| 2417 |
+
for module in self.modules():
|
| 2418 |
+
if isinstance(module, StreamingModule) and module._streaming_state is not None:
|
| 2419 |
+
raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE)
|
| 2420 |
+
|
| 2421 |
+
session = MossAudioTokenizerDecodeSession(self, max_batch_size, use_cuda_graph=use_cuda_graph)
|
| 2422 |
+
return session
|
| 2423 |
+
|
| 2424 |
+
def _reset_batch_decode_streaming_state(self) -> None:
|
| 2425 |
+
streaming_session = self._batch_decode_streaming_session
|
| 2426 |
+
self._batch_decode_streaming_session = None
|
| 2427 |
+
self._batch_decode_streaming_max_batch_size = None
|
| 2428 |
+
self._batch_decode_streaming_batch_size = None
|
| 2429 |
+
self._batch_decode_streaming_next_request_id = 0
|
| 2430 |
+
if streaming_session is not None and not streaming_session._closed:
|
| 2431 |
+
streaming_session.close()
|
| 2432 |
+
|
| 2433 |
+
def _prepare_batch_decode_streaming_state(
|
| 2434 |
+
self,
|
| 2435 |
+
batch_size: int,
|
| 2436 |
+
max_batch_size: int | None,
|
| 2437 |
+
reset_stream: bool,
|
| 2438 |
+
) -> int:
|
| 2439 |
+
if reset_stream:
|
| 2440 |
+
self._reset_batch_decode_streaming_state()
|
| 2441 |
+
|
| 2442 |
+
if max_batch_size is not None and max_batch_size <= 0:
|
| 2443 |
+
raise ValueError("`max_batch_size` must be > 0 when provided.")
|
| 2444 |
+
|
| 2445 |
+
streaming_max_batch_size = self._batch_decode_streaming_max_batch_size
|
| 2446 |
+
if streaming_max_batch_size is None:
|
| 2447 |
+
streaming_max_batch_size = batch_size if max_batch_size is None else max_batch_size
|
| 2448 |
+
elif max_batch_size is not None and max_batch_size != streaming_max_batch_size:
|
| 2449 |
+
raise ValueError(
|
| 2450 |
+
"`max_batch_size` can only be set on the first streaming `batch_decode()` call for now. "
|
| 2451 |
+
f"Expected {streaming_max_batch_size}, got {max_batch_size}."
|
| 2452 |
+
)
|
| 2453 |
+
|
| 2454 |
+
if batch_size > streaming_max_batch_size:
|
| 2455 |
+
raise ValueError(
|
| 2456 |
+
"Streaming `batch_decode()` received a batch larger than the reserved `max_batch_size`. "
|
| 2457 |
+
f"Got batch_size={batch_size}, max_batch_size={streaming_max_batch_size}."
|
| 2458 |
+
)
|
| 2459 |
+
|
| 2460 |
+
return streaming_max_batch_size
|
| 2461 |
+
|
| 2462 |
+
def _ensure_batch_decode_streaming_session(
|
| 2463 |
+
self,
|
| 2464 |
+
max_batch_size: int,
|
| 2465 |
+
use_cuda_graph: bool = False,
|
| 2466 |
+
) -> MossAudioTokenizerDecodeSession:
|
| 2467 |
+
session = self._batch_decode_streaming_session
|
| 2468 |
+
if session is not None and not session._closed:
|
| 2469 |
+
if session._use_cuda_graph != use_cuda_graph:
|
| 2470 |
+
raise ValueError(
|
| 2471 |
+
"`use_cuda_graph` must match the existing streaming `batch_decode()` session configuration. "
|
| 2472 |
+
f"Expected {session._use_cuda_graph}, got {use_cuda_graph}."
|
| 2473 |
+
)
|
| 2474 |
+
return session
|
| 2475 |
+
|
| 2476 |
+
session = self.create_decode_session(max_batch_size=max_batch_size, use_cuda_graph=use_cuda_graph)
|
| 2477 |
+
self._batch_decode_streaming_session = session
|
| 2478 |
+
self._batch_decode_streaming_max_batch_size = max_batch_size
|
| 2479 |
+
self._batch_decode_streaming_next_request_id = 0
|
| 2480 |
+
return session
|
| 2481 |
+
|
| 2482 |
+
def _append_batch_decode_streaming_requests(
|
| 2483 |
+
self,
|
| 2484 |
+
session: MossAudioTokenizerDecodeSession,
|
| 2485 |
+
target_batch_size: int,
|
| 2486 |
+
) -> None:
|
| 2487 |
+
requests_to_append = target_batch_size - len(session.active_request_ids)
|
| 2488 |
+
for _ in range(requests_to_append):
|
| 2489 |
+
request_id = self._batch_decode_streaming_next_request_id
|
| 2490 |
+
session.append(request_id)
|
| 2491 |
+
self._batch_decode_streaming_next_request_id += 1
|
| 2492 |
+
|
| 2493 |
+
def _resolve_batch_decode_streaming_finalize_request_ids(
|
| 2494 |
+
self,
|
| 2495 |
+
request_ids: list[str | int],
|
| 2496 |
+
finalize_indices: list[int] | tuple[int, ...] | None,
|
| 2497 |
+
) -> list[str | int]:
|
| 2498 |
+
normalized_finalize_indices = tuple(finalize_indices) if finalize_indices is not None else ()
|
| 2499 |
+
if len(set(normalized_finalize_indices)) != len(normalized_finalize_indices):
|
| 2500 |
+
raise ValueError(_BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE)
|
| 2501 |
+
|
| 2502 |
+
batch_size = len(request_ids)
|
| 2503 |
+
finalize_request_ids: list[str | int] = []
|
| 2504 |
+
for index in normalized_finalize_indices:
|
| 2505 |
+
if index < 0 or index >= batch_size:
|
| 2506 |
+
raise ValueError(
|
| 2507 |
+
_BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE.format(
|
| 2508 |
+
index=index, batch_size=batch_size
|
| 2509 |
+
)
|
| 2510 |
+
)
|
| 2511 |
+
finalize_request_ids.append(request_ids[index])
|
| 2512 |
+
|
| 2513 |
+
return finalize_request_ids
|
| 2514 |
+
|
| 2515 |
+
def _raise_if_plain_decode_conflicts_with_active_session(self) -> None:
|
| 2516 |
+
active_session = self._active_decode_session
|
| 2517 |
+
if active_session is not None and not getattr(active_session, "_closed", False):
|
| 2518 |
+
raise RuntimeError(_PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE)
|
| 2519 |
|
| 2520 |
def _start_streaming(self, batch_size: int):
|
| 2521 |
"""Start streaming mode for all modules."""
|
| 2522 |
+
active_session = self._active_decode_session
|
| 2523 |
+
if active_session is not None and not getattr(active_session, "_closed", False):
|
| 2524 |
+
raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE)
|
| 2525 |
|
| 2526 |
def _start(module):
|
| 2527 |
if isinstance(module, StreamingModule):
|
|
|
|
| 2531 |
|
| 2532 |
def _stop_streaming(self):
|
| 2533 |
"""Stop streaming mode for all modules."""
|
| 2534 |
+
active_session = self._active_decode_session
|
| 2535 |
+
if active_session is not None and not getattr(active_session, "_closed", False):
|
| 2536 |
+
raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE)
|
| 2537 |
|
| 2538 |
def _stop(module):
|
| 2539 |
if isinstance(module, StreamingModule):
|
|
|
|
| 2905 |
codes_list: list[torch.Tensor],
|
| 2906 |
num_quantizers: int | None = None,
|
| 2907 |
chunk_duration: float | None = None,
|
| 2908 |
+
streaming: bool = False,
|
| 2909 |
+
max_batch_size: int | None = None,
|
| 2910 |
+
finalize_indices: list[int] | tuple[int, ...] | None = None,
|
| 2911 |
+
reset_stream: bool = False,
|
| 2912 |
+
use_cuda_graph: bool = False,
|
| 2913 |
) -> MossAudioTokenizerDecoderOutput:
|
| 2914 |
+
if len(codes_list) == 0:
|
| 2915 |
+
raise ValueError("`codes_list` must contain at least one code tensor.")
|
| 2916 |
+
|
| 2917 |
+
streaming_max_batch_size: int | None = None
|
| 2918 |
+
if streaming:
|
| 2919 |
+
streaming_max_batch_size = self._prepare_batch_decode_streaming_state(
|
| 2920 |
+
batch_size=len(codes_list),
|
| 2921 |
+
max_batch_size=max_batch_size,
|
| 2922 |
+
reset_stream=reset_stream,
|
| 2923 |
+
)
|
| 2924 |
+
else:
|
| 2925 |
+
if reset_stream:
|
| 2926 |
+
self._reset_batch_decode_streaming_state()
|
| 2927 |
+
self._raise_if_plain_decode_conflicts_with_active_session()
|
| 2928 |
+
|
| 2929 |
audio_codes, audio_codes_lengths, num_quantizers_used = self._prepare_codes_batch(
|
| 2930 |
codes_list,
|
| 2931 |
num_quantizers=num_quantizers,
|
|
|
|
| 2933 |
batch_size = len(codes_list)
|
| 2934 |
device = audio_codes.device
|
| 2935 |
|
| 2936 |
+
if not streaming and chunk_duration is None:
|
| 2937 |
return self._decode_frame(audio_codes, audio_codes_lengths)
|
| 2938 |
|
| 2939 |
+
if streaming:
|
| 2940 |
+
assert streaming_max_batch_size is not None
|
| 2941 |
+
existing_session = self._batch_decode_streaming_session
|
| 2942 |
+
reusing_streaming_session = existing_session is not None and not existing_session._closed
|
| 2943 |
+
session = self._ensure_batch_decode_streaming_session(
|
| 2944 |
+
max_batch_size=streaming_max_batch_size,
|
| 2945 |
+
use_cuda_graph=use_cuda_graph,
|
| 2946 |
+
)
|
| 2947 |
+
pre_call_request_ids = list(session.active_request_ids)
|
| 2948 |
+
pre_call_batch_size = len(pre_call_request_ids)
|
| 2949 |
+
if batch_size < pre_call_batch_size:
|
| 2950 |
+
raise ValueError(_BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE)
|
| 2951 |
+
|
| 2952 |
+
try:
|
| 2953 |
+
finalize_request_ids = self._resolve_batch_decode_streaming_finalize_request_ids(
|
| 2954 |
+
request_ids=pre_call_request_ids,
|
| 2955 |
+
finalize_indices=finalize_indices,
|
| 2956 |
+
)
|
| 2957 |
+
except Exception:
|
| 2958 |
+
if not reusing_streaming_session and pre_call_batch_size == 0:
|
| 2959 |
+
self._reset_batch_decode_streaming_state()
|
| 2960 |
+
raise
|
| 2961 |
+
|
| 2962 |
+
try:
|
| 2963 |
+
if batch_size > pre_call_batch_size:
|
| 2964 |
+
self._append_batch_decode_streaming_requests(session=session, target_batch_size=batch_size)
|
| 2965 |
+
|
| 2966 |
+
request_ids = list(session.active_request_ids)
|
| 2967 |
+
_, audio, audio_lengths = session.step(
|
| 2968 |
+
request_ids=request_ids,
|
| 2969 |
+
codes=audio_codes,
|
| 2970 |
+
code_lengths=audio_codes_lengths,
|
| 2971 |
+
)
|
| 2972 |
+
for request_id in finalize_request_ids:
|
| 2973 |
+
session.remove(request_id)
|
| 2974 |
+
except Exception:
|
| 2975 |
+
self._reset_batch_decode_streaming_state()
|
| 2976 |
+
raise
|
| 2977 |
+
|
| 2978 |
+
self._batch_decode_streaming_max_batch_size = session.max_batch_size
|
| 2979 |
+
self._batch_decode_streaming_batch_size = len(session.active_request_ids)
|
| 2980 |
+
return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths)
|
| 2981 |
+
|
| 2982 |
+
assert chunk_duration is not None
|
| 2983 |
if chunk_duration <= 0:
|
| 2984 |
raise ValueError("`chunk_duration` must be > 0 when provided.")
|
| 2985 |
|
|
|
|
| 3152 |
`MossAudioTokenizerDecoderOutput` or tuple containing decoded audio.
|
| 3153 |
"""
|
| 3154 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 3155 |
+
self._raise_if_plain_decode_conflicts_with_active_session()
|
| 3156 |
|
| 3157 |
if audio_codes.dim() == 2:
|
| 3158 |
codes_list = [audio_codes[:num_quantizers] if num_quantizers is not None else audio_codes]
|