Zhyw commited on
Commit
fe61555
·
verified ·
1 Parent(s): a3a48d4

Update mossttsrealtime/streaming_mossttsrealtime.py

Browse files
mossttsrealtime/streaming_mossttsrealtime.py CHANGED
@@ -23,7 +23,7 @@ import numpy as np
23
  import torch
24
  import torch.nn.functional as F
25
 
26
- from transformers.cache_utils import DynamicCache, StaticCache
27
  from transformers.utils import is_torchaudio_available, requires_backends
28
  from transformers.utils.import_utils import requires
29
 
@@ -34,6 +34,7 @@ if is_torchaudio_available():
34
  @requires(backends=("torch",))
35
  class MossTTSRealtimeInference:
36
  """Step-wise inference wrapper for MossTTSRealtime.
 
37
  This class mirrors the non-streaming inference logic but exposes a
38
  prefill/step/finish API for streaming usage.
39
  """
@@ -66,24 +67,6 @@ class MossTTSRealtimeInference:
66
  self._is_stopping = None
67
  self._last_audio_tokens = None
68
  self._step_idx = 0
69
- attn_impl = ""
70
- for cfg in (
71
- getattr(getattr(self.model, "local_transformer", None), "config", None),
72
- getattr(getattr(self.model, "config", None), "local_config", None),
73
- getattr(self.model, "config", None),
74
- ):
75
- if cfg is None:
76
- continue
77
- for name in ("_attn_implementation", "attn_implementation"):
78
- candidate = getattr(cfg, name, None)
79
- if isinstance(candidate, str) and candidate.strip():
80
- attn_impl = candidate.strip().lower()
81
- break
82
- if attn_impl:
83
- break
84
- self._use_dynamic_local_cache = attn_impl == "flash_attention_2"
85
- self._should_compile_local_transformer = not self._use_dynamic_local_cache
86
- self._compiled_local_transformer = None
87
 
88
  @property
89
  def device(self):
@@ -93,18 +76,6 @@ class MossTTSRealtimeInference:
93
  def is_finished(self) -> bool:
94
  return self._is_stopping is not None and bool(self._is_stopping.all())
95
 
96
- def _build_local_past_key_values(self):
97
- if self._use_dynamic_local_cache:
98
- return DynamicCache()
99
- return StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
100
-
101
- def _get_local_transformer_runner(self):
102
- if not self._should_compile_local_transformer:
103
- return self._generate_local_transformer_impl
104
- if self._compiled_local_transformer is None:
105
- self._compiled_local_transformer = torch.compile(self._generate_local_transformer_impl, fullgraph=False)
106
- return self._compiled_local_transformer
107
-
108
  def reset_generation_state(self, keep_cache: bool = True):
109
  if not keep_cache:
110
  self.past_key_values = None
@@ -328,6 +299,7 @@ class MossTTSRealtimeInference:
328
  steps_left -= 1
329
  return outputs
330
 
 
331
  def generate_local_transformer(
332
  self,
333
  hidden_states: torch.Tensor,
@@ -339,40 +311,16 @@ class MossTTSRealtimeInference:
339
  repetition_window: Optional[int],
340
  generated_tokens: Optional[torch.Tensor],
341
  gen_step: int,
342
- ) -> torch.Tensor:
343
- runner = self._get_local_transformer_runner()
344
- return runner(
345
- hidden_states=hidden_states,
346
- temperature=temperature,
347
- top_p=top_p,
348
- top_k=top_k,
349
- do_sample=do_sample,
350
- repetition_penalty=repetition_penalty,
351
- repetition_window=repetition_window,
352
- generated_tokens=generated_tokens,
353
- gen_step=gen_step,
354
- )
355
-
356
- def _generate_local_transformer_impl(
357
- self,
358
- hidden_states: torch.Tensor,
359
- temperature: float,
360
- top_p: float,
361
- top_k: int,
362
- do_sample: bool,
363
- repetition_penalty: Optional[float],
364
- repetition_window: Optional[int],
365
- generated_tokens: Optional[torch.Tensor],
366
- gen_step: int,
367
  ) -> torch.Tensor:
368
  batch_size = hidden_states.shape[0]
 
369
  local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
370
- output_token = torch.empty(batch_size, self.channels, dtype=torch.long)
371
 
372
- past_key_values = self._build_local_past_key_values()
373
  local_token = None
374
 
375
- cache_pos_t = torch.zeros(1, dtype=torch.long)
376
 
377
  for i in range(self.channels):
378
  cache_pos_t.fill_(i)
@@ -531,6 +479,7 @@ class MossTTSRealtimeStreamingSession:
531
 
532
  def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
533
  """Set voice prompt from either audio tokens or waveform.
 
534
  If `audio` is a 2D array whose shape matches the codebook channels, it is
535
  treated as audio tokens. Otherwise a codec is required to encode waveform
536
  prompts into tokens.
@@ -737,23 +686,18 @@ class AudioStreamDecoder:
737
  codec,
738
  chunk_frames: int = 40,
739
  overlap_frames: int = 4,
740
- initial_chunk_frames: Optional[int] = None,
741
- decode_chunk_duration: Optional[float] = None,
742
  decode_kwargs: Optional[dict] = None,
743
  device: Optional[torch.device] = None,
744
  ):
745
  self.codec = codec
746
  self.chunk_frames = chunk_frames
747
  self.overlap_frames = overlap_frames
748
- self.initial_chunk_frames = initial_chunk_frames
749
- self.decode_chunk_duration = decode_chunk_duration
750
  self.decode_kwargs = decode_kwargs or {}
751
  self.device = device
752
 
753
  self._buffer: list[torch.Tensor] = []
754
  self._buffer_len = 0
755
  self._prev_tail: Optional[torch.Tensor] = None
756
- self._chunks_emitted = 0
757
 
758
  def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
759
  if isinstance(audio_tokens, np.ndarray):
@@ -763,17 +707,10 @@ class AudioStreamDecoder:
763
  self._buffer.append(audio_tokens)
764
  self._buffer_len += audio_tokens.shape[0]
765
 
766
- @property
767
- def _active_chunk_frames(self) -> int:
768
- if self.initial_chunk_frames is not None:
769
- return min(self.initial_chunk_frames + self._chunks_emitted, self.chunk_frames)
770
- return self.chunk_frames
771
-
772
  def audio_chunks(self) -> Iterable[torch.Tensor]:
773
- while self._buffer_len >= self._active_chunk_frames:
774
- chunk_tokens = self._consume_frames(self._active_chunk_frames)
775
- wav = self._decode(chunk_tokens)
776
- self._chunks_emitted += 1
777
  yield self._apply_crossfade(wav)
778
 
779
  def flush(self) -> Optional[torch.Tensor]:
@@ -799,7 +736,7 @@ class AudioStreamDecoder:
799
  self._buffer_len -= num_frames - remaining
800
  return torch.cat(frames, dim=0)
801
 
802
- def _decode(self, tokens: torch.Tensor) -> torch.Tensor:
803
  device = self.device
804
  if device is None:
805
  if hasattr(self.codec, "device"):
@@ -812,8 +749,22 @@ class AudioStreamDecoder:
812
  if device is not None:
813
  tokens = tokens.to(device)
814
  tokens_t = tokens.permute(1, 0)
 
815
  decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
816
- decoded = self.codec.decode(tokens_t, chunk_duration=self.decode_chunk_duration, **decode_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  if isinstance(decoded, dict):
818
  wav = decoded["audio"][0]
819
  else:
@@ -858,6 +809,7 @@ class AudioStreamDecoder:
858
  class TextDeltaTokenizer:
859
  """
860
  Convert LLM streaming text (delta) into “incremental token IDs”.
 
861
  Notes:
862
  - The input is a delta that is progressively appended to the same string
863
  (consistent with the common delta output behavior in vLLM).
@@ -939,6 +891,7 @@ def _maybe_codec_streaming(codec, *, batch_size: int):
939
  class MossTTSRealtimeTextStreamBridge:
940
  """
941
  Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
 
942
  Usage overview:
943
  - First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
944
  - Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
@@ -972,6 +925,7 @@ class MossTTSRealtimeTextStreamBridge:
972
  def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
973
  """
974
  Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
 
975
  Internally, this directly calls `session.push_text()`, which segments the text
976
  based on punctuation/length and then tokenizes the *entire segment* at once,
977
  avoiding the prefix instability issues of incremental BPE tokenization.
 
23
  import torch
24
  import torch.nn.functional as F
25
 
26
+ from transformers.cache_utils import StaticCache
27
  from transformers.utils import is_torchaudio_available, requires_backends
28
  from transformers.utils.import_utils import requires
29
 
 
34
  @requires(backends=("torch",))
35
  class MossTTSRealtimeInference:
36
  """Step-wise inference wrapper for MossTTSRealtime.
37
+
38
  This class mirrors the non-streaming inference logic but exposes a
39
  prefill/step/finish API for streaming usage.
40
  """
 
67
  self._is_stopping = None
68
  self._last_audio_tokens = None
69
  self._step_idx = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  @property
72
  def device(self):
 
76
  def is_finished(self) -> bool:
77
  return self._is_stopping is not None and bool(self._is_stopping.all())
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def reset_generation_state(self, keep_cache: bool = True):
80
  if not keep_cache:
81
  self.past_key_values = None
 
299
  steps_left -= 1
300
  return outputs
301
 
302
+ @torch.compile(fullgraph=True)
303
  def generate_local_transformer(
304
  self,
305
  hidden_states: torch.Tensor,
 
311
  repetition_window: Optional[int],
312
  generated_tokens: Optional[torch.Tensor],
313
  gen_step: int,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  ) -> torch.Tensor:
315
  batch_size = hidden_states.shape[0]
316
+ device = hidden_states.device
317
  local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
318
+ output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device)
319
 
320
+ past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
321
  local_token = None
322
 
323
+ cache_pos_t = torch.zeros(1, dtype=torch.long, device=device)
324
 
325
  for i in range(self.channels):
326
  cache_pos_t.fill_(i)
 
479
 
480
  def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
481
  """Set voice prompt from either audio tokens or waveform.
482
+
483
  If `audio` is a 2D array whose shape matches the codebook channels, it is
484
  treated as audio tokens. Otherwise a codec is required to encode waveform
485
  prompts into tokens.
 
686
  codec,
687
  chunk_frames: int = 40,
688
  overlap_frames: int = 4,
 
 
689
  decode_kwargs: Optional[dict] = None,
690
  device: Optional[torch.device] = None,
691
  ):
692
  self.codec = codec
693
  self.chunk_frames = chunk_frames
694
  self.overlap_frames = overlap_frames
 
 
695
  self.decode_kwargs = decode_kwargs or {}
696
  self.device = device
697
 
698
  self._buffer: list[torch.Tensor] = []
699
  self._buffer_len = 0
700
  self._prev_tail: Optional[torch.Tensor] = None
 
701
 
702
  def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
703
  if isinstance(audio_tokens, np.ndarray):
 
707
  self._buffer.append(audio_tokens)
708
  self._buffer_len += audio_tokens.shape[0]
709
 
 
 
 
 
 
 
710
  def audio_chunks(self) -> Iterable[torch.Tensor]:
711
+ while self._buffer_len >= self.chunk_frames:
712
+ chunk_tokens = self._consume_frames(self.chunk_frames)
713
+ wav = self._decode(chunk_tokens, chunk_duration=0.32)
 
714
  yield self._apply_crossfade(wav)
715
 
716
  def flush(self) -> Optional[torch.Tensor]:
 
736
  self._buffer_len -= num_frames - remaining
737
  return torch.cat(frames, dim=0)
738
 
739
+ def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor:
740
  device = self.device
741
  if device is None:
742
  if hasattr(self.codec, "device"):
 
749
  if device is not None:
750
  tokens = tokens.to(device)
751
  tokens_t = tokens.permute(1, 0)
752
+ # allow callers to override decode settings (e.g. chunk_duration=-1 to disable internal streaming)
753
  decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
754
+ if "chunk_duration" in decode_kwargs:
755
+ override = decode_kwargs.pop("chunk_duration")
756
+ if override is None:
757
+ chunk_duration_arg = None
758
+ else:
759
+ try:
760
+ override_f = float(override)
761
+ except Exception:
762
+ override_f = None
763
+ chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f
764
+ else:
765
+ chunk_duration_arg = chunk_duration
766
+
767
+ decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs)
768
  if isinstance(decoded, dict):
769
  wav = decoded["audio"][0]
770
  else:
 
809
  class TextDeltaTokenizer:
810
  """
811
  Convert LLM streaming text (delta) into “incremental token IDs”.
812
+
813
  Notes:
814
  - The input is a delta that is progressively appended to the same string
815
  (consistent with the common delta output behavior in vLLM).
 
891
  class MossTTSRealtimeTextStreamBridge:
892
  """
893
  Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
894
+
895
  Usage overview:
896
  - First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
897
  - Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
 
925
  def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
926
  """
927
  Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
928
+
929
  Internally, this directly calls `session.push_text()`, which segments the text
930
  based on punctuation/length and then tokenizes the *entire segment* at once,
931
  avoiding the prefix instability issues of incremental BPE tokenization.