Kuangwei Chen commited on
Commit
1c2bf4d
·
1 Parent(s): d4a3b2c

Update for realtime

Browse files
Files changed (1) hide show
  1. modeling_moss_audio_tokenizer.py +808 -21
modeling_moss_audio_tokenizer.py CHANGED
@@ -17,14 +17,25 @@ from __future__ import annotations
17
 
18
  import copy
19
  import math
 
 
20
  from contextlib import ExitStack, contextmanager
21
  from dataclasses import dataclass
 
22
  from typing import cast
23
 
24
  import torch
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
 
 
 
 
 
 
 
 
 
28
  try:
29
  from transformers.modeling_utils import PreTrainedAudioTokenizerBase
30
  except ImportError:
@@ -32,9 +43,12 @@ except ImportError:
32
  from transformers.utils import ModelOutput, logging
33
 
34
  try:
35
- from transformers.utils import auto_docstring
36
  except ImportError:
37
- def auto_docstring(*args, **kwargs):
 
 
 
38
  if len(args) == 1 and callable(args[0]) and not kwargs:
39
  return args[0]
40
 
@@ -43,9 +57,35 @@ except ImportError:
43
 
44
  return decorator
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
  from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
48
  except ImportError:
 
 
 
49
  from configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
50
 
51
 
@@ -64,6 +104,25 @@ SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"}
64
  SUPPORTED_COMPUTE_DTYPES = {"fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16}
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def resolve_compute_dtype(compute_dtype: str) -> torch.dtype | None:
68
  if compute_dtype not in SUPPORTED_COMPUTE_DTYPES:
69
  raise ValueError(
@@ -83,6 +142,7 @@ def disable_cuda_autocast():
83
  # =============================================================================
84
 
85
 
 
86
  @dataclass
87
  @auto_docstring
88
  class MossAudioTokenizerEncoderOutput(ModelOutput):
@@ -100,6 +160,7 @@ class MossAudioTokenizerEncoderOutput(ModelOutput):
100
  encoder_hidden_states: torch.Tensor | None = None
101
 
102
 
 
103
  @dataclass
104
  @auto_docstring
105
  class MossAudioTokenizerDecoderOutput(ModelOutput):
@@ -114,6 +175,7 @@ class MossAudioTokenizerDecoderOutput(ModelOutput):
114
  audio_lengths: torch.Tensor | None = None
115
 
116
 
 
117
  @dataclass
118
  @auto_docstring
119
  class MossAudioTokenizerOutput(ModelOutput):
@@ -139,6 +201,7 @@ class MossAudioTokenizerOutput(ModelOutput):
139
  # =============================================================================
140
 
141
 
 
142
  @dataclass
143
  class StreamingState:
144
  """Base state for streaming modules."""
@@ -228,6 +291,463 @@ class StreamingContainer(StreamingModule):
228
  pass
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # =============================================================================
232
  # Normalization Layers
233
  # =============================================================================
@@ -598,6 +1118,7 @@ class RingKVCache:
598
  # =============================================================================
599
 
600
 
 
601
  @dataclass
602
  class MHAState(StreamingState):
603
  cached_keys: torch.Tensor | None
@@ -677,6 +1198,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
677
  f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}."
678
  )
679
  self.attention_implementation = attention_implementation
 
680
  self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False, **factory_kwargs)
681
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
682
 
@@ -811,6 +1333,34 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
811
  state.cached_positions = state.cached_positions.to(device=device)
812
  return state.cached_keys, state.cached_values, state.cached_positions
813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
  def _build_streaming_kv(
815
  self,
816
  cached_k: torch.Tensor,
@@ -845,12 +1395,15 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
845
  state.cached_positions = pos_k.contiguous()
846
  return
847
 
 
 
 
848
  new_cached_k = k_all[:, :, -self.context :, :].contiguous()
849
  new_cached_v = v_all[:, :, -self.context :, :].contiguous()
850
  new_cached_pos = pos_k[:, -self.context :].contiguous()
851
- state.cached_keys = torch.where(exec_mask, new_cached_k, cached_k)
852
- state.cached_values = torch.where(exec_mask, new_cached_v, cached_v)
853
- state.cached_positions = torch.where(exec_mask_pos, new_cached_pos, cached_pos)
854
 
855
  def _build_streaming_sdpa_bias(self, pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor:
856
  delta = pos_q[:, :, None] - pos_k[:, None, :]
@@ -890,16 +1443,19 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
890
  if flash_attn_varlen_func is None:
891
  raise RuntimeError("flash-attn is not installed.")
892
  window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
893
- return flash_attn_varlen_func(
894
- q.contiguous(),
895
- k.contiguous(),
896
- v.contiguous(),
897
- cu_seqlens_q,
898
- cu_seqlens_k,
899
- max_seqlen_q,
900
- max_seqlen_k,
901
- causal=self.causal,
902
- window_size=window_size,
 
 
 
903
  )
904
 
905
  def _forward_streaming_sdpa(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
@@ -968,6 +1524,46 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
968
  state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset)
969
  return out
970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
  def _forward_non_streaming_sdpa(self, x: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor:
972
  batch_size, max_seqlen, _ = x.shape
973
  q, k, v = self._project_qkv(x)
@@ -1009,11 +1605,12 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
1009
  if state is not None:
1010
  if query.dim() != 3:
1011
  raise ValueError(f"Streaming attention expects a 3D tensor, got shape {tuple(query.shape)}")
1012
- out = (
1013
- self._forward_streaming_flash(query, state)
1014
- if backend == "flash_attention_2"
1015
- else self._forward_streaming_sdpa(query, state)
1016
- )
 
1017
  return self.out_proj(out)
1018
 
1019
  if backend == "flash_attention_2":
@@ -1037,6 +1634,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
1037
  # =============================================================================
1038
 
1039
 
 
1040
  @dataclass
1041
  class LayerState(StreamingState):
1042
  pass
@@ -1128,6 +1726,7 @@ class MossAudioTokenizerTransformerLayer(StreamingModule):
1128
  # =============================================================================
1129
 
1130
 
 
1131
  @dataclass
1132
  class TransformerState(StreamingState):
1133
  offsets: torch.Tensor
@@ -1800,9 +2399,129 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1800
  )
1801
 
1802
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1803
 
1804
  def _start_streaming(self, batch_size: int):
1805
  """Start streaming mode for all modules."""
 
 
 
1806
 
1807
  def _start(module):
1808
  if isinstance(module, StreamingModule):
@@ -1812,6 +2531,9 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1812
 
1813
  def _stop_streaming(self):
1814
  """Stop streaming mode for all modules."""
 
 
 
1815
 
1816
  def _stop(module):
1817
  if isinstance(module, StreamingModule):
@@ -2183,7 +2905,27 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
2183
  codes_list: list[torch.Tensor],
2184
  num_quantizers: int | None = None,
2185
  chunk_duration: float | None = None,
 
 
 
 
 
2186
  ) -> MossAudioTokenizerDecoderOutput:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2187
  audio_codes, audio_codes_lengths, num_quantizers_used = self._prepare_codes_batch(
2188
  codes_list,
2189
  num_quantizers=num_quantizers,
@@ -2191,9 +2933,53 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
2191
  batch_size = len(codes_list)
2192
  device = audio_codes.device
2193
 
2194
- if chunk_duration is None:
2195
  return self._decode_frame(audio_codes, audio_codes_lengths)
2196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2197
  if chunk_duration <= 0:
2198
  raise ValueError("`chunk_duration` must be > 0 when provided.")
2199
 
@@ -2366,6 +3152,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
2366
  `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio.
2367
  """
2368
  return_dict = return_dict if return_dict is not None else self.config.return_dict
 
2369
 
2370
  if audio_codes.dim() == 2:
2371
  codes_list = [audio_codes[:num_quantizers] if num_quantizers is not None else audio_codes]
 
17
 
18
  import copy
19
  import math
20
+ import sys
21
+ import types
22
  from contextlib import ExitStack, contextmanager
23
  from dataclasses import dataclass
24
+ from pathlib import Path
25
  from typing import cast
26
 
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
30
 
31
+ if __name__ not in sys.modules:
32
+ _module_proxy = types.ModuleType(__name__)
33
+ sys.modules[__name__] = _module_proxy
34
+
35
+
36
+ def _sync_module_proxy() -> None:
37
+ sys.modules[__name__].__dict__.update(globals())
38
+
39
  try:
40
  from transformers.modeling_utils import PreTrainedAudioTokenizerBase
41
  except ImportError:
 
43
  from transformers.utils import ModelOutput, logging
44
 
45
  try:
46
+ from transformers.utils import auto_docstring as _hf_auto_docstring
47
  except ImportError:
48
+ _hf_auto_docstring = None
49
+
50
+ def auto_docstring(*args, **kwargs):
51
+ if _hf_auto_docstring is None:
52
  if len(args) == 1 and callable(args[0]) and not kwargs:
53
  return args[0]
54
 
 
57
 
58
  return decorator
59
 
60
+ if len(args) == 1 and callable(args[0]) and not kwargs:
61
+ obj = args[0]
62
+ try:
63
+ return _hf_auto_docstring(obj)
64
+ except Exception:
65
+ return obj
66
+
67
+ try:
68
+ decorator = _hf_auto_docstring(*args, **kwargs)
69
+ except Exception:
70
+ def decorator(obj):
71
+ return obj
72
+
73
+ return decorator
74
+
75
+ def safe_decorator(obj):
76
+ try:
77
+ return decorator(obj)
78
+ except Exception:
79
+ return obj
80
+
81
+ return safe_decorator
82
+
83
  try:
84
  from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
85
  except ImportError:
86
+ _module_dir = str(Path(__file__).resolve().parent)
87
+ if _module_dir not in sys.path:
88
+ sys.path.insert(0, _module_dir)
89
  from configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
90
 
91
 
 
104
  SUPPORTED_COMPUTE_DTYPES = {"fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16}
105
 
106
 
107
+ _ACTIVE_DECODE_SESSION_ERROR_MESSAGE = "MossAudioTokenizerModel only supports one active decode session at a time."
108
+ _CLOSED_DECODE_SESSION_ERROR_MESSAGE = "This decode session is closed."
109
+ _MODEL_STREAMING_CONFLICT_ERROR_MESSAGE = "Model-level streaming helpers cannot be used while a decode session is active."
110
+ _PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE = "Plain decode helpers cannot be used while a decode session is active."
111
+ _DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session already contains request_id={request_id!r}."
112
+ _UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE = "Decode session does not contain an active request_id={request_id!r}."
113
+ _DECODE_SESSION_FULL_ERROR_TEMPLATE = "Decode session has no free slots remaining (max_batch_size={max_batch_size})."
114
+ _INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE = (
115
+ "`request_ids` must exactly match the current active decode request order."
116
+ )
117
+ _BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE = "`finalize_indices` must not contain duplicates."
118
+ _BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE = (
119
+ "`finalize_indices` index {index} is out of range for the pre-call logical batch of size {batch_size}."
120
+ )
121
+ _BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE = (
122
+ "`batch_decode(streaming=True)` must include all pre-call active rows in the current call before applying `finalize_indices`."
123
+ )
124
+
125
+
126
  def resolve_compute_dtype(compute_dtype: str) -> torch.dtype | None:
127
  if compute_dtype not in SUPPORTED_COMPUTE_DTYPES:
128
  raise ValueError(
 
142
  # =============================================================================
143
 
144
 
145
+ _sync_module_proxy()
146
  @dataclass
147
  @auto_docstring
148
  class MossAudioTokenizerEncoderOutput(ModelOutput):
 
160
  encoder_hidden_states: torch.Tensor | None = None
161
 
162
 
163
+ _sync_module_proxy()
164
  @dataclass
165
  @auto_docstring
166
  class MossAudioTokenizerDecoderOutput(ModelOutput):
 
175
  audio_lengths: torch.Tensor | None = None
176
 
177
 
178
+ _sync_module_proxy()
179
  @dataclass
180
  @auto_docstring
181
  class MossAudioTokenizerOutput(ModelOutput):
 
201
  # =============================================================================
202
 
203
 
204
+ _sync_module_proxy()
205
  @dataclass
206
  class StreamingState:
207
  """Base state for streaming modules."""
 
291
  pass
292
 
293
 
294
+ class MossAudioTokenizerDecodeSession:
295
+ model: MossAudioTokenizerModel
296
+ max_batch_size: int
297
+ _use_cuda_graph: bool
298
+ active_request_ids: list[str | int]
299
+ request_id_to_slot_index: dict[str | int, int]
300
+ slot_index_to_request_id: list[str | int | None]
301
+ slot_is_free: list[bool]
302
+ request_id_to_code_offset: dict[str | int, int]
303
+ request_id_to_audio_offset: dict[str | int, int]
304
+ _flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention]
305
+ _graph_num_quantizers_capacity: int | None
306
+ _graph_input_codes: torch.Tensor | None
307
+ _graph_input_code_lengths: torch.Tensor | None
308
+ _graph_output_audio: torch.Tensor | None
309
+ _graph_output_audio_lengths: torch.Tensor | None
310
+ _cuda_graph: torch.cuda.CUDAGraph | None
311
+ _cuda_graph_key: tuple[str, int, int, str] | None
312
+ _decode_streaming_exit_stack: ExitStack | None
313
+ _closed: bool
314
+
315
+ def __init__(self, model: MossAudioTokenizerModel, max_batch_size: int, use_cuda_graph: bool = False):
316
+ if max_batch_size <= 0:
317
+ raise ValueError("`max_batch_size` must be > 0.")
318
+
319
+ decoder_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
320
+ for decoder_module in model.decoder:
321
+ for module in decoder_module.modules():
322
+ if isinstance(module, MossAudioTokenizerMultiheadAttention):
323
+ if module.context is None:
324
+ raise ValueError(
325
+ "MossAudioTokenizerDecodeSession requires all decoder MHA modules to have a finite "
326
+ "`context` (context=None is unsupported for continuous-batch streaming)."
327
+ )
328
+ decoder_attention_modules.append(module)
329
+
330
+ flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
331
+ if use_cuda_graph and HAS_FLASH_ATTN:
332
+ for module in decoder_attention_modules:
333
+ module._use_flash_kvcache = True
334
+ flash_kvcache_attention_modules.append(module)
335
+
336
+ decode_streaming_exit_stack = ExitStack()
337
+ try:
338
+ for decoder_module in model.decoder:
339
+ if isinstance(decoder_module, StreamingModule):
340
+ inner_stack = decoder_module.streaming(batch_size=max_batch_size)
341
+ _ = decode_streaming_exit_stack.enter_context(inner_stack)
342
+ except Exception:
343
+ decode_streaming_exit_stack.close()
344
+ for module in flash_kvcache_attention_modules:
345
+ module._use_flash_kvcache = False
346
+ raise
347
+
348
+ self.model = model
349
+ self.max_batch_size = max_batch_size
350
+ self._use_cuda_graph = use_cuda_graph
351
+ self.active_request_ids: list[str | int] = []
352
+ self.request_id_to_slot_index: dict[str | int, int] = {}
353
+ self.slot_index_to_request_id: list[str | int | None] = [None] * max_batch_size
354
+ self.slot_is_free: list[bool] = [True] * max_batch_size
355
+ self.request_id_to_code_offset: dict[str | int, int] = {}
356
+ self.request_id_to_audio_offset: dict[str | int, int] = {}
357
+ self._flash_kvcache_attention_modules = flash_kvcache_attention_modules
358
+ self._graph_num_quantizers_capacity = int(getattr(model.quantizer, "num_quantizers", 0)) if use_cuda_graph else None
359
+ self._graph_input_codes = None
360
+ self._graph_input_code_lengths = None
361
+ self._graph_output_audio = None
362
+ self._graph_output_audio_lengths = None
363
+ self._cuda_graph = None
364
+ self._cuda_graph_key = None
365
+ self._decode_streaming_exit_stack: ExitStack | None = decode_streaming_exit_stack
366
+ self._closed = False
367
+ if use_cuda_graph:
368
+ device = next(iter(model.parameters())).device
369
+ if device.type == "cuda":
370
+ self._ensure_cuda_graph_buffers(device)
371
+ model._active_decode_session = self
372
+
373
+ def _ensure_open(self) -> None:
374
+ if self._closed:
375
+ raise RuntimeError(_CLOSED_DECODE_SESSION_ERROR_MESSAGE)
376
+
377
+ def append(self, request_id: str | int) -> None:
378
+ self._ensure_open()
379
+
380
+ if request_id in self.request_id_to_slot_index:
381
+ raise RuntimeError(_DUPLICATE_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id))
382
+
383
+ slot_index = next((index for index, is_free in enumerate(self.slot_is_free) if is_free), None)
384
+ if slot_index is None:
385
+ raise RuntimeError(_DECODE_SESSION_FULL_ERROR_TEMPLATE.format(max_batch_size=self.max_batch_size))
386
+
387
+ self.active_request_ids.append(request_id)
388
+ self.request_id_to_slot_index[request_id] = slot_index
389
+ self.slot_index_to_request_id[slot_index] = request_id
390
+ self.slot_is_free[slot_index] = False
391
+ self.request_id_to_code_offset[request_id] = 0
392
+ self.request_id_to_audio_offset[request_id] = 0
393
+
394
+ def _decoder_streaming_states(self) -> list[StreamingState]:
395
+ decoder_streaming_states: list[StreamingState] = []
396
+ for decoder_module in self.model.decoder:
397
+ for module in decoder_module.modules():
398
+ if isinstance(module, StreamingModule) and module._streaming_state is not None:
399
+ decoder_streaming_states.append(module._streaming_state)
400
+ return decoder_streaming_states
401
+
402
+ def _ensure_cuda_graph_buffers(self, device: torch.device) -> None:
403
+ if not self._use_cuda_graph or device.type != "cuda":
404
+ return
405
+ graph_num_quantizers_capacity = self._graph_num_quantizers_capacity
406
+ if graph_num_quantizers_capacity is None:
407
+ graph_num_quantizers_capacity = int(getattr(self.model.quantizer, "num_quantizers", 0))
408
+ self._graph_num_quantizers_capacity = graph_num_quantizers_capacity
409
+ if graph_num_quantizers_capacity <= 0:
410
+ raise RuntimeError("`use_cuda_graph=True` requires a quantizer with `num_quantizers > 0`.")
411
+ if self._graph_input_codes is None or self._graph_input_codes.device != device:
412
+ self._graph_input_codes = torch.zeros(
413
+ (graph_num_quantizers_capacity, self.max_batch_size, 1),
414
+ device=device,
415
+ dtype=torch.long,
416
+ )
417
+ self._graph_input_code_lengths = torch.zeros(self.max_batch_size, device=device, dtype=torch.long)
418
+ self._graph_output_audio = None
419
+ self._graph_output_audio_lengths = None
420
+ self._cuda_graph = None
421
+ self._cuda_graph_key = None
422
+
423
+ def _snapshot_decoder_streaming_states(self) -> list[tuple[StreamingState, dict[str, torch.Tensor | None]]]:
424
+ snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]] = []
425
+ for streaming_state in self._decoder_streaming_states():
426
+ state_snapshot: dict[str, torch.Tensor | None] = {"exec_mask": streaming_state.exec_mask.clone()}
427
+ if isinstance(streaming_state, TransformerState):
428
+ state_snapshot["offsets"] = streaming_state.offsets.clone()
429
+ if isinstance(streaming_state, MHAState):
430
+ state_snapshot["offset"] = streaming_state.offset.clone()
431
+ state_snapshot["cached_keys"] = None if streaming_state.cached_keys is None else streaming_state.cached_keys.clone()
432
+ state_snapshot["cached_values"] = None if streaming_state.cached_values is None else streaming_state.cached_values.clone()
433
+ state_snapshot["cached_positions"] = (
434
+ None if streaming_state.cached_positions is None else streaming_state.cached_positions.clone()
435
+ )
436
+ state_snapshot["flash_cached_keys"] = (
437
+ None
438
+ if getattr(streaming_state, "_flash_cached_keys", None) is None
439
+ else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_keys")).clone()
440
+ )
441
+ state_snapshot["flash_cached_values"] = (
442
+ None
443
+ if getattr(streaming_state, "_flash_cached_values", None) is None
444
+ else cast(torch.Tensor, getattr(streaming_state, "_flash_cached_values")).clone()
445
+ )
446
+ snapshots.append((streaming_state, state_snapshot))
447
+ return snapshots
448
+
449
+ def _restore_decoder_streaming_states(
450
+ self,
451
+ snapshots: list[tuple[StreamingState, dict[str, torch.Tensor | None]]],
452
+ ) -> None:
453
+ for streaming_state, state_snapshot in snapshots:
454
+ exec_mask = state_snapshot["exec_mask"]
455
+ assert exec_mask is not None
456
+ streaming_state.exec_mask.copy_(exec_mask)
457
+ if isinstance(streaming_state, TransformerState):
458
+ offsets = state_snapshot.get("offsets")
459
+ assert offsets is not None
460
+ streaming_state.offsets.copy_(offsets)
461
+ if isinstance(streaming_state, MHAState):
462
+ offset = state_snapshot.get("offset")
463
+ assert offset is not None
464
+ streaming_state.offset.copy_(offset)
465
+ cached_keys = state_snapshot.get("cached_keys")
466
+ cached_values = state_snapshot.get("cached_values")
467
+ cached_positions = state_snapshot.get("cached_positions")
468
+ if cached_keys is None or cached_values is None or cached_positions is None:
469
+ if streaming_state.cached_keys is not None:
470
+ streaming_state.cached_keys.zero_()
471
+ if streaming_state.cached_values is not None:
472
+ streaming_state.cached_values.zero_()
473
+ if streaming_state.cached_positions is not None:
474
+ streaming_state.cached_positions.fill_(-1)
475
+ else:
476
+ if streaming_state.cached_keys is None or streaming_state.cached_keys.shape != cached_keys.shape:
477
+ streaming_state.cached_keys = cached_keys.clone()
478
+ else:
479
+ streaming_state.cached_keys.copy_(cached_keys)
480
+ if streaming_state.cached_values is None or streaming_state.cached_values.shape != cached_values.shape:
481
+ streaming_state.cached_values = cached_values.clone()
482
+ else:
483
+ streaming_state.cached_values.copy_(cached_values)
484
+ if streaming_state.cached_positions is None or streaming_state.cached_positions.shape != cached_positions.shape:
485
+ streaming_state.cached_positions = cached_positions.clone()
486
+ else:
487
+ streaming_state.cached_positions.copy_(cached_positions)
488
+
489
+ flash_cached_keys = state_snapshot.get("flash_cached_keys")
490
+ flash_cached_values = state_snapshot.get("flash_cached_values")
491
+ current_flash_cached_keys = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_keys", None))
492
+ current_flash_cached_values = cast(torch.Tensor | None, getattr(streaming_state, "_flash_cached_values", None))
493
+ if flash_cached_keys is None or flash_cached_values is None:
494
+ if current_flash_cached_keys is not None:
495
+ current_flash_cached_keys.zero_()
496
+ if current_flash_cached_values is not None:
497
+ current_flash_cached_values.zero_()
498
+ else:
499
+ if current_flash_cached_keys is None or current_flash_cached_keys.shape != flash_cached_keys.shape:
500
+ setattr(streaming_state, "_flash_cached_keys", flash_cached_keys.clone())
501
+ else:
502
+ current_flash_cached_keys.copy_(flash_cached_keys)
503
+ if current_flash_cached_values is None or current_flash_cached_values.shape != flash_cached_values.shape:
504
+ setattr(streaming_state, "_flash_cached_values", flash_cached_values.clone())
505
+ else:
506
+ current_flash_cached_values.copy_(flash_cached_values)
507
+
508
+ def _graphed_decode_frame(
509
+ self,
510
+ codes: torch.Tensor,
511
+ code_lengths: torch.Tensor,
512
+ ) -> MossAudioTokenizerDecoderOutput:
513
+ self._ensure_cuda_graph_buffers(codes.device)
514
+ graph_input_codes = self._graph_input_codes
515
+ graph_input_code_lengths = self._graph_input_code_lengths
516
+ if graph_input_codes is None or graph_input_code_lengths is None:
517
+ raise RuntimeError("CUDA graph buffers are unavailable.")
518
+
519
+ num_quantizers = codes.shape[0]
520
+ graph_input_codes_view = graph_input_codes[:num_quantizers]
521
+ graph_input_codes_view.copy_(codes)
522
+ graph_input_code_lengths.copy_(code_lengths)
523
+ cuda_graph_key = (str(codes.device), self.max_batch_size, num_quantizers, self.model.compute_dtype_name)
524
+
525
+ if self._cuda_graph is None or self._cuda_graph_key != cuda_graph_key:
526
+ state_snapshots = self._snapshot_decoder_streaming_states()
527
+ current_stream = torch.cuda.current_stream(device=codes.device)
528
+ warmup_stream = torch.cuda.Stream(device=codes.device)
529
+ warmup_stream.wait_stream(current_stream)
530
+ with torch.cuda.stream(warmup_stream):
531
+ _ = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths)
532
+ current_stream.wait_stream(warmup_stream)
533
+ self._restore_decoder_streaming_states(state_snapshots)
534
+
535
+ cuda_graph = torch.cuda.CUDAGraph()
536
+ with torch.cuda.graph(cuda_graph):
537
+ decoder_output = self.model._decode_frame(graph_input_codes_view, graph_input_code_lengths)
538
+
539
+ self._cuda_graph = cuda_graph
540
+ self._cuda_graph_key = cuda_graph_key
541
+ self._graph_output_audio = decoder_output.audio
542
+ self._graph_output_audio_lengths = decoder_output.audio_lengths
543
+ else:
544
+ self._cuda_graph.replay()
545
+
546
+ return MossAudioTokenizerDecoderOutput(
547
+ audio=self._graph_output_audio,
548
+ audio_lengths=self._graph_output_audio_lengths,
549
+ )
550
+
551
+ def _reset_slot(self, slot_index: int) -> None:
552
+ for streaming_state in self._decoder_streaming_states():
553
+ reset_mask = torch.zeros(streaming_state.batch_size, dtype=torch.bool, device=streaming_state.exec_mask.device)
554
+ reset_mask[slot_index] = True
555
+ streaming_state.reset(reset_mask)
556
+
557
+ def _pack_logical_codes_to_physical_slots(
558
+ self,
559
+ request_ids: list[str | int],
560
+ codes: torch.Tensor,
561
+ code_lengths: torch.Tensor,
562
+ ) -> tuple[torch.Tensor, torch.Tensor, list[int], torch.Tensor]:
563
+ if request_ids != self.active_request_ids:
564
+ raise ValueError(_INVALID_DECODE_STEP_REQUEST_IDS_ERROR_MESSAGE)
565
+
566
+ if not request_ids:
567
+ raise ValueError("`step()` requires at least one active request.")
568
+
569
+ if codes.dim() == 2:
570
+ codes = codes.unsqueeze(1)
571
+ if codes.dim() != 3:
572
+ raise ValueError(f"`codes` must be 3D with shape `(num_quantizers, batch_size, sequence_length)`, got {codes.shape}.")
573
+
574
+ code_lengths = code_lengths.to(device=codes.device, dtype=torch.long)
575
+ if code_lengths.dim() != 1:
576
+ raise ValueError(f"`code_lengths` must be 1D with shape `(batch_size,)`, got {code_lengths.shape}.")
577
+
578
+ num_quantizers, logical_batch_size, max_code_length = codes.shape
579
+ if logical_batch_size != len(request_ids):
580
+ raise ValueError(
581
+ f"`codes.shape[1]` ({logical_batch_size}) must match len(`request_ids`) ({len(request_ids)})."
582
+ )
583
+ if code_lengths.shape[0] != logical_batch_size:
584
+ raise ValueError(
585
+ f"`code_lengths.shape[0]` ({code_lengths.shape[0]}) must match len(`request_ids`) ({len(request_ids)})."
586
+ )
587
+ if torch.any(code_lengths < 0):
588
+ raise ValueError("`code_lengths` must be >= 0.")
589
+ if torch.any(code_lengths > max_code_length):
590
+ raise ValueError(f"`code_lengths` must be <= codes.shape[-1] ({max_code_length}).")
591
+
592
+ packed_codes = codes.new_zeros((num_quantizers, self.max_batch_size, max_code_length))
593
+ packed_code_lengths = code_lengths.new_zeros((self.max_batch_size,))
594
+ logical_row_to_slot_index: list[int] = []
595
+
596
+ for logical_row_index, request_id in enumerate(request_ids):
597
+ slot_index = self.request_id_to_slot_index[request_id]
598
+ logical_row_to_slot_index.append(slot_index)
599
+ row_length = int(code_lengths[logical_row_index].item())
600
+ if row_length > 0:
601
+ packed_codes[:, slot_index, :row_length] = codes[:, logical_row_index, :row_length]
602
+ packed_code_lengths[slot_index] = row_length
603
+
604
+ return packed_codes, packed_code_lengths, logical_row_to_slot_index, code_lengths
605
+
606
+ def _advance_request_progress(
607
+ self,
608
+ request_ids: list[str | int],
609
+ code_lengths: torch.Tensor,
610
+ audio_lengths: torch.Tensor,
611
+ ) -> None:
612
+ for logical_row_index, request_id in enumerate(request_ids):
613
+ self.request_id_to_code_offset[request_id] += int(code_lengths[logical_row_index].item())
614
+ self.request_id_to_audio_offset[request_id] += int(audio_lengths[logical_row_index].item())
615
+
616
+ def step(
617
+ self,
618
+ request_ids: list[str | int],
619
+ codes: torch.Tensor,
620
+ code_lengths: torch.Tensor,
621
+ ) -> tuple[list[str | int], torch.Tensor, torch.Tensor]:
622
+ self._ensure_open()
623
+
624
+ packed_codes, packed_code_lengths, logical_row_to_slot_index, logical_code_lengths = (
625
+ self._pack_logical_codes_to_physical_slots(
626
+ request_ids=request_ids,
627
+ codes=codes,
628
+ code_lengths=code_lengths,
629
+ )
630
+ )
631
+ max_step_length = int(packed_code_lengths.max().item())
632
+
633
+ if max_step_length <= 0:
634
+ raise ValueError("`step()` requires at least one row with `code_length > 0`.")
635
+
636
+ decoder_streaming_states = self._decoder_streaming_states()
637
+ logical_audio_chunks: list[list[torch.Tensor]] = [[] for _ in request_ids]
638
+ audio_device: torch.device | None = None
639
+ audio_dtype: torch.dtype | None = None
640
+ audio_num_channels: int | None = None
641
+
642
+ try:
643
+ for frame_index in range(max_step_length):
644
+ frame_exec_mask = packed_code_lengths > frame_index
645
+ for streaming_state in decoder_streaming_states:
646
+ streaming_state.set_exec_mask(frame_exec_mask)
647
+
648
+ frame_codes = packed_codes[:, :, frame_index : frame_index + 1]
649
+ frame_code_lengths = frame_exec_mask.to(dtype=packed_code_lengths.dtype)
650
+ if self._use_cuda_graph and frame_codes.is_cuda:
651
+ decoder_output = self._graphed_decode_frame(frame_codes, frame_code_lengths)
652
+ else:
653
+ decoder_output = self.model._decode_frame(frame_codes, frame_code_lengths)
654
+
655
+ if decoder_output.audio is None or decoder_output.audio_lengths is None:
656
+ raise RuntimeError("Internal error: `_decode_frame` returned empty audio.")
657
+
658
+ audio = decoder_output.audio
659
+ audio_lengths = decoder_output.audio_lengths
660
+ audio_device = audio.device
661
+ audio_dtype = audio.dtype
662
+ audio_num_channels = audio.shape[1]
663
+
664
+ for logical_row_index, slot_index in enumerate(logical_row_to_slot_index):
665
+ audio_length = int(audio_lengths[slot_index].item())
666
+ if audio_length <= 0:
667
+ continue
668
+ logical_audio_chunks[logical_row_index].append(audio[slot_index : slot_index + 1, :, :audio_length])
669
+ except Exception:
670
+ self.close()
671
+ raise
672
+ finally:
673
+ for streaming_state in decoder_streaming_states:
674
+ streaming_state.set_exec_mask(torch.ones_like(streaming_state.exec_mask))
675
+
676
+ if audio_device is None or audio_dtype is None or audio_num_channels is None:
677
+ raise RuntimeError("Internal error: `step()` produced no decoder outputs.")
678
+
679
+ logical_audio_rows: list[torch.Tensor] = []
680
+ logical_audio_lengths: list[int] = []
681
+ for row_chunks in logical_audio_chunks:
682
+ if row_chunks:
683
+ row_audio = torch.cat(row_chunks, dim=-1)
684
+ else:
685
+ row_audio = torch.zeros((1, audio_num_channels, 0), device=audio_device, dtype=audio_dtype)
686
+ logical_audio_rows.append(row_audio)
687
+ logical_audio_lengths.append(row_audio.shape[-1])
688
+
689
+ audio_lengths = torch.tensor(logical_audio_lengths, device=audio_device, dtype=torch.long)
690
+ max_audio_length = max(logical_audio_lengths)
691
+ audio = torch.zeros(
692
+ (len(request_ids), audio_num_channels, max_audio_length),
693
+ device=audio_device,
694
+ dtype=audio_dtype,
695
+ )
696
+ for logical_row_index, row_audio in enumerate(logical_audio_rows):
697
+ row_audio_length = row_audio.shape[-1]
698
+ if row_audio_length > 0:
699
+ audio[logical_row_index, :, :row_audio_length] = row_audio[0]
700
+
701
+ logical_request_ids = list(request_ids)
702
+ self._advance_request_progress(
703
+ request_ids=logical_request_ids,
704
+ code_lengths=logical_code_lengths,
705
+ audio_lengths=audio_lengths,
706
+ )
707
+
708
+ return logical_request_ids, audio, audio_lengths
709
+
710
+ def remove(self, request_id: str | int) -> None:
711
+ self._ensure_open()
712
+
713
+ slot_index = self.request_id_to_slot_index.get(request_id)
714
+ if slot_index is None or request_id not in self.active_request_ids:
715
+ raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id))
716
+ if self.slot_is_free[slot_index] or self.slot_index_to_request_id[slot_index] != request_id:
717
+ raise RuntimeError(_UNKNOWN_DECODE_REQUEST_ERROR_TEMPLATE.format(request_id=request_id))
718
+
719
+ self.active_request_ids.remove(request_id)
720
+ self._reset_slot(slot_index)
721
+ _ = self.request_id_to_slot_index.pop(request_id)
722
+ self.slot_index_to_request_id[slot_index] = None
723
+ self.slot_is_free[slot_index] = True
724
+ _ = self.request_id_to_code_offset.pop(request_id, None)
725
+ _ = self.request_id_to_audio_offset.pop(request_id, None)
726
+
727
+ def close(self) -> None:
728
+ if self._closed:
729
+ return
730
+
731
+ self._closed = True
732
+ decode_streaming_exit_stack = self._decode_streaming_exit_stack
733
+ self._decode_streaming_exit_stack = None
734
+ try:
735
+ if decode_streaming_exit_stack is not None:
736
+ decode_streaming_exit_stack.close()
737
+ finally:
738
+ for module in self._flash_kvcache_attention_modules:
739
+ module._use_flash_kvcache = False
740
+ self._flash_kvcache_attention_modules = []
741
+ self._cuda_graph = None
742
+ self._cuda_graph_key = None
743
+ self._graph_input_codes = None
744
+ self._graph_input_code_lengths = None
745
+ self._graph_output_audio = None
746
+ self._graph_output_audio_lengths = None
747
+ if self.model._active_decode_session is self:
748
+ self.model._active_decode_session = None
749
+
750
+
751
  # =============================================================================
752
  # Normalization Layers
753
  # =============================================================================
 
1118
  # =============================================================================
1119
 
1120
 
1121
+ _sync_module_proxy()
1122
  @dataclass
1123
  class MHAState(StreamingState):
1124
  cached_keys: torch.Tensor | None
 
1198
  f"Expected one of {sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}."
1199
  )
1200
  self.attention_implementation = attention_implementation
1201
+ self._use_flash_kvcache = False
1202
  self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False, **factory_kwargs)
1203
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
1204
 
 
1333
  state.cached_positions = state.cached_positions.to(device=device)
1334
  return state.cached_keys, state.cached_values, state.cached_positions
1335
 
1336
+ def _ensure_flash_kvcache(
1337
+ self,
1338
+ state: MHAState,
1339
+ batch_size: int,
1340
+ device: torch.device,
1341
+ dtype: torch.dtype,
1342
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1343
+ if self.context is None:
1344
+ raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
1345
+ head_dim = self.embed_dim // self.num_heads
1346
+ flash_cached_keys = cast(torch.Tensor | None, getattr(state, "_flash_cached_keys", None))
1347
+ flash_cached_values = cast(torch.Tensor | None, getattr(state, "_flash_cached_values", None))
1348
+ if flash_cached_keys is None or flash_cached_values is None:
1349
+ flash_cached_keys = torch.zeros(
1350
+ (batch_size, self.context, self.num_heads, head_dim),
1351
+ device=device,
1352
+ dtype=dtype,
1353
+ )
1354
+ flash_cached_values = torch.zeros_like(flash_cached_keys)
1355
+ else:
1356
+ if flash_cached_keys.device != device or flash_cached_keys.dtype != dtype:
1357
+ flash_cached_keys = flash_cached_keys.to(device=device, dtype=dtype)
1358
+ if flash_cached_values.device != device or flash_cached_values.dtype != dtype:
1359
+ flash_cached_values = flash_cached_values.to(device=device, dtype=dtype)
1360
+ setattr(state, "_flash_cached_keys", flash_cached_keys)
1361
+ setattr(state, "_flash_cached_values", flash_cached_values)
1362
+ return flash_cached_keys, flash_cached_values
1363
+
1364
  def _build_streaming_kv(
1365
  self,
1366
  cached_k: torch.Tensor,
 
1395
  state.cached_positions = pos_k.contiguous()
1396
  return
1397
 
1398
+ assert state.cached_keys is not None
1399
+ assert state.cached_values is not None
1400
+ assert state.cached_positions is not None
1401
  new_cached_k = k_all[:, :, -self.context :, :].contiguous()
1402
  new_cached_v = v_all[:, :, -self.context :, :].contiguous()
1403
  new_cached_pos = pos_k[:, -self.context :].contiguous()
1404
+ state.cached_keys.copy_(torch.where(exec_mask, new_cached_k, cached_k))
1405
+ state.cached_values.copy_(torch.where(exec_mask, new_cached_v, cached_v))
1406
+ state.cached_positions.copy_(torch.where(exec_mask_pos, new_cached_pos, cached_pos))
1407
 
1408
  def _build_streaming_sdpa_bias(self, pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor:
1409
  delta = pos_q[:, :, None] - pos_k[:, None, :]
 
1443
  if flash_attn_varlen_func is None:
1444
  raise RuntimeError("flash-attn is not installed.")
1445
  window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
1446
+ return cast(
1447
+ torch.Tensor,
1448
+ flash_attn_varlen_func(
1449
+ q.contiguous(),
1450
+ k.contiguous(),
1451
+ v.contiguous(),
1452
+ cu_seqlens_q,
1453
+ cu_seqlens_k,
1454
+ max_seqlen_q,
1455
+ max_seqlen_k,
1456
+ causal=self.causal,
1457
+ window_size=window_size,
1458
+ ),
1459
  )
1460
 
1461
  def _forward_streaming_sdpa(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
 
1524
  state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset)
1525
  return out
1526
 
1527
+ def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
1528
+ from flash_attn import flash_attn_with_kvcache
1529
+
1530
+ if self.context is None:
1531
+ raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
1532
+
1533
+ batch_size, chunk_length, _ = x.shape
1534
+ q, k_cur, v_cur = self._project_qkv(x)
1535
+ if self.rope is not None:
1536
+ q, k_cur = self.rope(q, k_cur, state.offset, time_before_heads=False)
1537
+
1538
+ q = q.transpose(1, 2).contiguous()
1539
+ k_cur = k_cur.transpose(1, 2).contiguous()
1540
+ v_cur = v_cur.transpose(1, 2).contiguous()
1541
+
1542
+ exec_mask = state.exec_mask.view(batch_size, 1, 1, 1).to(dtype=k_cur.dtype)
1543
+ k_cur = k_cur * exec_mask
1544
+ v_cur = v_cur * exec_mask
1545
+
1546
+ k_cache, v_cache = self._ensure_flash_kvcache(state, batch_size, k_cur.device, k_cur.dtype)
1547
+ cache_seqlens = state.offset.clamp(max=self.context).to(torch.int32)
1548
+ window_size = (self.context - 1, 0)
1549
+
1550
+ out = cast(
1551
+ torch.Tensor,
1552
+ flash_attn_with_kvcache(
1553
+ q,
1554
+ k_cache,
1555
+ v_cache,
1556
+ k=k_cur,
1557
+ v=v_cur,
1558
+ cache_seqlens=cache_seqlens,
1559
+ causal=True,
1560
+ window_size=window_size,
1561
+ ),
1562
+ )
1563
+ out = out.reshape(batch_size, chunk_length, self.embed_dim)
1564
+ state.offset[:] = torch.where(state.exec_mask, state.offset + chunk_length, state.offset)
1565
+ return out
1566
+
1567
  def _forward_non_streaming_sdpa(self, x: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor:
1568
  batch_size, max_seqlen, _ = x.shape
1569
  q, k, v = self._project_qkv(x)
 
1605
  if state is not None:
1606
  if query.dim() != 3:
1607
  raise ValueError(f"Streaming attention expects a 3D tensor, got shape {tuple(query.shape)}")
1608
+ if backend == "flash_attention_2" and self._use_flash_kvcache:
1609
+ out = self._forward_streaming_flash_kvcache(query, state)
1610
+ elif backend == "flash_attention_2":
1611
+ out = self._forward_streaming_flash(query, state)
1612
+ else:
1613
+ out = self._forward_streaming_sdpa(query, state)
1614
  return self.out_proj(out)
1615
 
1616
  if backend == "flash_attention_2":
 
1634
  # =============================================================================
1635
 
1636
 
1637
+ _sync_module_proxy()
1638
  @dataclass
1639
  class LayerState(StreamingState):
1640
  pass
 
1726
  # =============================================================================
1727
 
1728
 
1729
+ _sync_module_proxy()
1730
  @dataclass
1731
  class TransformerState(StreamingState):
1732
  offsets: torch.Tensor
 
2399
  )
2400
 
2401
  self.post_init()
2402
+ self._active_decode_session: "MossAudioTokenizerDecodeSession | None" = None
2403
+ self._batch_decode_streaming_max_batch_size: int | None = None
2404
+ self._batch_decode_streaming_batch_size: int | None = None
2405
+ self._batch_decode_streaming_session: "MossAudioTokenizerDecodeSession | None" = None
2406
+ self._batch_decode_streaming_next_request_id: int = 0
2407
+
2408
+ def create_decode_session(
2409
+ self,
2410
+ max_batch_size: int,
2411
+ use_cuda_graph: bool = False,
2412
+ ) -> MossAudioTokenizerDecodeSession:
2413
+ active_session = self._active_decode_session
2414
+ if active_session is not None and not active_session._closed:
2415
+ raise RuntimeError(_ACTIVE_DECODE_SESSION_ERROR_MESSAGE)
2416
+
2417
+ for module in self.modules():
2418
+ if isinstance(module, StreamingModule) and module._streaming_state is not None:
2419
+ raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE)
2420
+
2421
+ session = MossAudioTokenizerDecodeSession(self, max_batch_size, use_cuda_graph=use_cuda_graph)
2422
+ return session
2423
+
2424
+ def _reset_batch_decode_streaming_state(self) -> None:
2425
+ streaming_session = self._batch_decode_streaming_session
2426
+ self._batch_decode_streaming_session = None
2427
+ self._batch_decode_streaming_max_batch_size = None
2428
+ self._batch_decode_streaming_batch_size = None
2429
+ self._batch_decode_streaming_next_request_id = 0
2430
+ if streaming_session is not None and not streaming_session._closed:
2431
+ streaming_session.close()
2432
+
2433
+ def _prepare_batch_decode_streaming_state(
2434
+ self,
2435
+ batch_size: int,
2436
+ max_batch_size: int | None,
2437
+ reset_stream: bool,
2438
+ ) -> int:
2439
+ if reset_stream:
2440
+ self._reset_batch_decode_streaming_state()
2441
+
2442
+ if max_batch_size is not None and max_batch_size <= 0:
2443
+ raise ValueError("`max_batch_size` must be > 0 when provided.")
2444
+
2445
+ streaming_max_batch_size = self._batch_decode_streaming_max_batch_size
2446
+ if streaming_max_batch_size is None:
2447
+ streaming_max_batch_size = batch_size if max_batch_size is None else max_batch_size
2448
+ elif max_batch_size is not None and max_batch_size != streaming_max_batch_size:
2449
+ raise ValueError(
2450
+ "`max_batch_size` can only be set on the first streaming `batch_decode()` call for now. "
2451
+ f"Expected {streaming_max_batch_size}, got {max_batch_size}."
2452
+ )
2453
+
2454
+ if batch_size > streaming_max_batch_size:
2455
+ raise ValueError(
2456
+ "Streaming `batch_decode()` received a batch larger than the reserved `max_batch_size`. "
2457
+ f"Got batch_size={batch_size}, max_batch_size={streaming_max_batch_size}."
2458
+ )
2459
+
2460
+ return streaming_max_batch_size
2461
+
2462
+ def _ensure_batch_decode_streaming_session(
2463
+ self,
2464
+ max_batch_size: int,
2465
+ use_cuda_graph: bool = False,
2466
+ ) -> MossAudioTokenizerDecodeSession:
2467
+ session = self._batch_decode_streaming_session
2468
+ if session is not None and not session._closed:
2469
+ if session._use_cuda_graph != use_cuda_graph:
2470
+ raise ValueError(
2471
+ "`use_cuda_graph` must match the existing streaming `batch_decode()` session configuration. "
2472
+ f"Expected {session._use_cuda_graph}, got {use_cuda_graph}."
2473
+ )
2474
+ return session
2475
+
2476
+ session = self.create_decode_session(max_batch_size=max_batch_size, use_cuda_graph=use_cuda_graph)
2477
+ self._batch_decode_streaming_session = session
2478
+ self._batch_decode_streaming_max_batch_size = max_batch_size
2479
+ self._batch_decode_streaming_next_request_id = 0
2480
+ return session
2481
+
2482
+ def _append_batch_decode_streaming_requests(
2483
+ self,
2484
+ session: MossAudioTokenizerDecodeSession,
2485
+ target_batch_size: int,
2486
+ ) -> None:
2487
+ requests_to_append = target_batch_size - len(session.active_request_ids)
2488
+ for _ in range(requests_to_append):
2489
+ request_id = self._batch_decode_streaming_next_request_id
2490
+ session.append(request_id)
2491
+ self._batch_decode_streaming_next_request_id += 1
2492
+
2493
+ def _resolve_batch_decode_streaming_finalize_request_ids(
2494
+ self,
2495
+ request_ids: list[str | int],
2496
+ finalize_indices: list[int] | tuple[int, ...] | None,
2497
+ ) -> list[str | int]:
2498
+ normalized_finalize_indices = tuple(finalize_indices) if finalize_indices is not None else ()
2499
+ if len(set(normalized_finalize_indices)) != len(normalized_finalize_indices):
2500
+ raise ValueError(_BATCH_DECODE_STREAMING_DUPLICATE_FINALIZE_INDICES_ERROR_MESSAGE)
2501
+
2502
+ batch_size = len(request_ids)
2503
+ finalize_request_ids: list[str | int] = []
2504
+ for index in normalized_finalize_indices:
2505
+ if index < 0 or index >= batch_size:
2506
+ raise ValueError(
2507
+ _BATCH_DECODE_STREAMING_FINALIZE_INDEX_OUT_OF_RANGE_ERROR_TEMPLATE.format(
2508
+ index=index, batch_size=batch_size
2509
+ )
2510
+ )
2511
+ finalize_request_ids.append(request_ids[index])
2512
+
2513
+ return finalize_request_ids
2514
+
2515
+ def _raise_if_plain_decode_conflicts_with_active_session(self) -> None:
2516
+ active_session = self._active_decode_session
2517
+ if active_session is not None and not getattr(active_session, "_closed", False):
2518
+ raise RuntimeError(_PLAIN_DECODE_SESSION_CONFLICT_ERROR_MESSAGE)
2519
 
2520
  def _start_streaming(self, batch_size: int):
2521
  """Start streaming mode for all modules."""
2522
+ active_session = self._active_decode_session
2523
+ if active_session is not None and not getattr(active_session, "_closed", False):
2524
+ raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE)
2525
 
2526
  def _start(module):
2527
  if isinstance(module, StreamingModule):
 
2531
 
2532
  def _stop_streaming(self):
2533
  """Stop streaming mode for all modules."""
2534
+ active_session = self._active_decode_session
2535
+ if active_session is not None and not getattr(active_session, "_closed", False):
2536
+ raise RuntimeError(_MODEL_STREAMING_CONFLICT_ERROR_MESSAGE)
2537
 
2538
  def _stop(module):
2539
  if isinstance(module, StreamingModule):
 
2905
  codes_list: list[torch.Tensor],
2906
  num_quantizers: int | None = None,
2907
  chunk_duration: float | None = None,
2908
+ streaming: bool = False,
2909
+ max_batch_size: int | None = None,
2910
+ finalize_indices: list[int] | tuple[int, ...] | None = None,
2911
+ reset_stream: bool = False,
2912
+ use_cuda_graph: bool = False,
2913
  ) -> MossAudioTokenizerDecoderOutput:
2914
+ if len(codes_list) == 0:
2915
+ raise ValueError("`codes_list` must contain at least one code tensor.")
2916
+
2917
+ streaming_max_batch_size: int | None = None
2918
+ if streaming:
2919
+ streaming_max_batch_size = self._prepare_batch_decode_streaming_state(
2920
+ batch_size=len(codes_list),
2921
+ max_batch_size=max_batch_size,
2922
+ reset_stream=reset_stream,
2923
+ )
2924
+ else:
2925
+ if reset_stream:
2926
+ self._reset_batch_decode_streaming_state()
2927
+ self._raise_if_plain_decode_conflicts_with_active_session()
2928
+
2929
  audio_codes, audio_codes_lengths, num_quantizers_used = self._prepare_codes_batch(
2930
  codes_list,
2931
  num_quantizers=num_quantizers,
 
2933
  batch_size = len(codes_list)
2934
  device = audio_codes.device
2935
 
2936
+ if not streaming and chunk_duration is None:
2937
  return self._decode_frame(audio_codes, audio_codes_lengths)
2938
 
2939
+ if streaming:
2940
+ assert streaming_max_batch_size is not None
2941
+ existing_session = self._batch_decode_streaming_session
2942
+ reusing_streaming_session = existing_session is not None and not existing_session._closed
2943
+ session = self._ensure_batch_decode_streaming_session(
2944
+ max_batch_size=streaming_max_batch_size,
2945
+ use_cuda_graph=use_cuda_graph,
2946
+ )
2947
+ pre_call_request_ids = list(session.active_request_ids)
2948
+ pre_call_batch_size = len(pre_call_request_ids)
2949
+ if batch_size < pre_call_batch_size:
2950
+ raise ValueError(_BATCH_DECODE_STREAMING_SHRINK_ERROR_MESSAGE)
2951
+
2952
+ try:
2953
+ finalize_request_ids = self._resolve_batch_decode_streaming_finalize_request_ids(
2954
+ request_ids=pre_call_request_ids,
2955
+ finalize_indices=finalize_indices,
2956
+ )
2957
+ except Exception:
2958
+ if not reusing_streaming_session and pre_call_batch_size == 0:
2959
+ self._reset_batch_decode_streaming_state()
2960
+ raise
2961
+
2962
+ try:
2963
+ if batch_size > pre_call_batch_size:
2964
+ self._append_batch_decode_streaming_requests(session=session, target_batch_size=batch_size)
2965
+
2966
+ request_ids = list(session.active_request_ids)
2967
+ _, audio, audio_lengths = session.step(
2968
+ request_ids=request_ids,
2969
+ codes=audio_codes,
2970
+ code_lengths=audio_codes_lengths,
2971
+ )
2972
+ for request_id in finalize_request_ids:
2973
+ session.remove(request_id)
2974
+ except Exception:
2975
+ self._reset_batch_decode_streaming_state()
2976
+ raise
2977
+
2978
+ self._batch_decode_streaming_max_batch_size = session.max_batch_size
2979
+ self._batch_decode_streaming_batch_size = len(session.active_request_ids)
2980
+ return MossAudioTokenizerDecoderOutput(audio=audio, audio_lengths=audio_lengths)
2981
+
2982
+ assert chunk_duration is not None
2983
  if chunk_duration <= 0:
2984
  raise ValueError("`chunk_duration` must be > 0 when provided.")
2985
 
 
3152
  `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio.
3153
  """
3154
  return_dict = return_dict if return_dict is not None else self.config.return_dict
3155
+ self._raise_if_plain_decode_conflicts_with_active_session()
3156
 
3157
  if audio_codes.dim() == 2:
3158
  codes_list = [audio_codes[:num_quantizers] if num_quantizers is not None else audio_codes]