Update mossttsrealtime/streaming_mossttsrealtime.py
Browse files
mossttsrealtime/streaming_mossttsrealtime.py
CHANGED
|
@@ -23,7 +23,7 @@ import numpy as np
|
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
|
| 26 |
-
from transformers.cache_utils import
|
| 27 |
from transformers.utils import is_torchaudio_available, requires_backends
|
| 28 |
from transformers.utils.import_utils import requires
|
| 29 |
|
|
@@ -34,6 +34,7 @@ if is_torchaudio_available():
|
|
| 34 |
@requires(backends=("torch",))
|
| 35 |
class MossTTSRealtimeInference:
|
| 36 |
"""Step-wise inference wrapper for MossTTSRealtime.
|
|
|
|
| 37 |
This class mirrors the non-streaming inference logic but exposes a
|
| 38 |
prefill/step/finish API for streaming usage.
|
| 39 |
"""
|
|
@@ -66,24 +67,6 @@ class MossTTSRealtimeInference:
|
|
| 66 |
self._is_stopping = None
|
| 67 |
self._last_audio_tokens = None
|
| 68 |
self._step_idx = 0
|
| 69 |
-
attn_impl = ""
|
| 70 |
-
for cfg in (
|
| 71 |
-
getattr(getattr(self.model, "local_transformer", None), "config", None),
|
| 72 |
-
getattr(getattr(self.model, "config", None), "local_config", None),
|
| 73 |
-
getattr(self.model, "config", None),
|
| 74 |
-
):
|
| 75 |
-
if cfg is None:
|
| 76 |
-
continue
|
| 77 |
-
for name in ("_attn_implementation", "attn_implementation"):
|
| 78 |
-
candidate = getattr(cfg, name, None)
|
| 79 |
-
if isinstance(candidate, str) and candidate.strip():
|
| 80 |
-
attn_impl = candidate.strip().lower()
|
| 81 |
-
break
|
| 82 |
-
if attn_impl:
|
| 83 |
-
break
|
| 84 |
-
self._use_dynamic_local_cache = attn_impl == "flash_attention_2"
|
| 85 |
-
self._should_compile_local_transformer = not self._use_dynamic_local_cache
|
| 86 |
-
self._compiled_local_transformer = None
|
| 87 |
|
| 88 |
@property
|
| 89 |
def device(self):
|
|
@@ -93,18 +76,6 @@ class MossTTSRealtimeInference:
|
|
| 93 |
def is_finished(self) -> bool:
|
| 94 |
return self._is_stopping is not None and bool(self._is_stopping.all())
|
| 95 |
|
| 96 |
-
def _build_local_past_key_values(self):
|
| 97 |
-
if self._use_dynamic_local_cache:
|
| 98 |
-
return DynamicCache()
|
| 99 |
-
return StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
|
| 100 |
-
|
| 101 |
-
def _get_local_transformer_runner(self):
|
| 102 |
-
if not self._should_compile_local_transformer:
|
| 103 |
-
return self._generate_local_transformer_impl
|
| 104 |
-
if self._compiled_local_transformer is None:
|
| 105 |
-
self._compiled_local_transformer = torch.compile(self._generate_local_transformer_impl, fullgraph=False)
|
| 106 |
-
return self._compiled_local_transformer
|
| 107 |
-
|
| 108 |
def reset_generation_state(self, keep_cache: bool = True):
|
| 109 |
if not keep_cache:
|
| 110 |
self.past_key_values = None
|
|
@@ -328,6 +299,7 @@ class MossTTSRealtimeInference:
|
|
| 328 |
steps_left -= 1
|
| 329 |
return outputs
|
| 330 |
|
|
|
|
| 331 |
def generate_local_transformer(
|
| 332 |
self,
|
| 333 |
hidden_states: torch.Tensor,
|
|
@@ -339,40 +311,16 @@ class MossTTSRealtimeInference:
|
|
| 339 |
repetition_window: Optional[int],
|
| 340 |
generated_tokens: Optional[torch.Tensor],
|
| 341 |
gen_step: int,
|
| 342 |
-
) -> torch.Tensor:
|
| 343 |
-
runner = self._get_local_transformer_runner()
|
| 344 |
-
return runner(
|
| 345 |
-
hidden_states=hidden_states,
|
| 346 |
-
temperature=temperature,
|
| 347 |
-
top_p=top_p,
|
| 348 |
-
top_k=top_k,
|
| 349 |
-
do_sample=do_sample,
|
| 350 |
-
repetition_penalty=repetition_penalty,
|
| 351 |
-
repetition_window=repetition_window,
|
| 352 |
-
generated_tokens=generated_tokens,
|
| 353 |
-
gen_step=gen_step,
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
def _generate_local_transformer_impl(
|
| 357 |
-
self,
|
| 358 |
-
hidden_states: torch.Tensor,
|
| 359 |
-
temperature: float,
|
| 360 |
-
top_p: float,
|
| 361 |
-
top_k: int,
|
| 362 |
-
do_sample: bool,
|
| 363 |
-
repetition_penalty: Optional[float],
|
| 364 |
-
repetition_window: Optional[int],
|
| 365 |
-
generated_tokens: Optional[torch.Tensor],
|
| 366 |
-
gen_step: int,
|
| 367 |
) -> torch.Tensor:
|
| 368 |
batch_size = hidden_states.shape[0]
|
|
|
|
| 369 |
local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
|
| 370 |
-
output_token = torch.empty(batch_size, self.channels, dtype=torch.long)
|
| 371 |
|
| 372 |
-
past_key_values = self.
|
| 373 |
local_token = None
|
| 374 |
|
| 375 |
-
cache_pos_t = torch.zeros(1, dtype=torch.long)
|
| 376 |
|
| 377 |
for i in range(self.channels):
|
| 378 |
cache_pos_t.fill_(i)
|
|
@@ -531,6 +479,7 @@ class MossTTSRealtimeStreamingSession:
|
|
| 531 |
|
| 532 |
def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
|
| 533 |
"""Set voice prompt from either audio tokens or waveform.
|
|
|
|
| 534 |
If `audio` is a 2D array whose shape matches the codebook channels, it is
|
| 535 |
treated as audio tokens. Otherwise a codec is required to encode waveform
|
| 536 |
prompts into tokens.
|
|
@@ -737,23 +686,18 @@ class AudioStreamDecoder:
|
|
| 737 |
codec,
|
| 738 |
chunk_frames: int = 40,
|
| 739 |
overlap_frames: int = 4,
|
| 740 |
-
initial_chunk_frames: Optional[int] = None,
|
| 741 |
-
decode_chunk_duration: Optional[float] = None,
|
| 742 |
decode_kwargs: Optional[dict] = None,
|
| 743 |
device: Optional[torch.device] = None,
|
| 744 |
):
|
| 745 |
self.codec = codec
|
| 746 |
self.chunk_frames = chunk_frames
|
| 747 |
self.overlap_frames = overlap_frames
|
| 748 |
-
self.initial_chunk_frames = initial_chunk_frames
|
| 749 |
-
self.decode_chunk_duration = decode_chunk_duration
|
| 750 |
self.decode_kwargs = decode_kwargs or {}
|
| 751 |
self.device = device
|
| 752 |
|
| 753 |
self._buffer: list[torch.Tensor] = []
|
| 754 |
self._buffer_len = 0
|
| 755 |
self._prev_tail: Optional[torch.Tensor] = None
|
| 756 |
-
self._chunks_emitted = 0
|
| 757 |
|
| 758 |
def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
|
| 759 |
if isinstance(audio_tokens, np.ndarray):
|
|
@@ -763,17 +707,10 @@ class AudioStreamDecoder:
|
|
| 763 |
self._buffer.append(audio_tokens)
|
| 764 |
self._buffer_len += audio_tokens.shape[0]
|
| 765 |
|
| 766 |
-
@property
|
| 767 |
-
def _active_chunk_frames(self) -> int:
|
| 768 |
-
if self.initial_chunk_frames is not None:
|
| 769 |
-
return min(self.initial_chunk_frames + self._chunks_emitted, self.chunk_frames)
|
| 770 |
-
return self.chunk_frames
|
| 771 |
-
|
| 772 |
def audio_chunks(self) -> Iterable[torch.Tensor]:
|
| 773 |
-
while self._buffer_len >= self.
|
| 774 |
-
chunk_tokens = self._consume_frames(self.
|
| 775 |
-
wav = self._decode(chunk_tokens)
|
| 776 |
-
self._chunks_emitted += 1
|
| 777 |
yield self._apply_crossfade(wav)
|
| 778 |
|
| 779 |
def flush(self) -> Optional[torch.Tensor]:
|
|
@@ -799,7 +736,7 @@ class AudioStreamDecoder:
|
|
| 799 |
self._buffer_len -= num_frames - remaining
|
| 800 |
return torch.cat(frames, dim=0)
|
| 801 |
|
| 802 |
-
def _decode(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 803 |
device = self.device
|
| 804 |
if device is None:
|
| 805 |
if hasattr(self.codec, "device"):
|
|
@@ -812,8 +749,22 @@ class AudioStreamDecoder:
|
|
| 812 |
if device is not None:
|
| 813 |
tokens = tokens.to(device)
|
| 814 |
tokens_t = tokens.permute(1, 0)
|
|
|
|
| 815 |
decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
|
| 816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
if isinstance(decoded, dict):
|
| 818 |
wav = decoded["audio"][0]
|
| 819 |
else:
|
|
@@ -858,6 +809,7 @@ class AudioStreamDecoder:
|
|
| 858 |
class TextDeltaTokenizer:
|
| 859 |
"""
|
| 860 |
Convert LLM streaming text (delta) into “incremental token IDs”.
|
|
|
|
| 861 |
Notes:
|
| 862 |
- The input is a delta that is progressively appended to the same string
|
| 863 |
(consistent with the common delta output behavior in vLLM).
|
|
@@ -939,6 +891,7 @@ def _maybe_codec_streaming(codec, *, batch_size: int):
|
|
| 939 |
class MossTTSRealtimeTextStreamBridge:
|
| 940 |
"""
|
| 941 |
Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
|
|
|
|
| 942 |
Usage overview:
|
| 943 |
- First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
|
| 944 |
- Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
|
|
@@ -972,6 +925,7 @@ class MossTTSRealtimeTextStreamBridge:
|
|
| 972 |
def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
|
| 973 |
"""
|
| 974 |
Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
|
|
|
|
| 975 |
Internally, this directly calls `session.push_text()`, which segments the text
|
| 976 |
based on punctuation/length and then tokenizes the *entire segment* at once,
|
| 977 |
avoiding the prefix instability issues of incremental BPE tokenization.
|
|
|
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
|
| 26 |
+
from transformers.cache_utils import StaticCache
|
| 27 |
from transformers.utils import is_torchaudio_available, requires_backends
|
| 28 |
from transformers.utils.import_utils import requires
|
| 29 |
|
|
|
|
| 34 |
@requires(backends=("torch",))
|
| 35 |
class MossTTSRealtimeInference:
|
| 36 |
"""Step-wise inference wrapper for MossTTSRealtime.
|
| 37 |
+
|
| 38 |
This class mirrors the non-streaming inference logic but exposes a
|
| 39 |
prefill/step/finish API for streaming usage.
|
| 40 |
"""
|
|
|
|
| 67 |
self._is_stopping = None
|
| 68 |
self._last_audio_tokens = None
|
| 69 |
self._step_idx = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
@property
|
| 72 |
def device(self):
|
|
|
|
| 76 |
def is_finished(self) -> bool:
|
| 77 |
return self._is_stopping is not None and bool(self._is_stopping.all())
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def reset_generation_state(self, keep_cache: bool = True):
|
| 80 |
if not keep_cache:
|
| 81 |
self.past_key_values = None
|
|
|
|
| 299 |
steps_left -= 1
|
| 300 |
return outputs
|
| 301 |
|
| 302 |
+
@torch.compile(fullgraph=True)
|
| 303 |
def generate_local_transformer(
|
| 304 |
self,
|
| 305 |
hidden_states: torch.Tensor,
|
|
|
|
| 311 |
repetition_window: Optional[int],
|
| 312 |
generated_tokens: Optional[torch.Tensor],
|
| 313 |
gen_step: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
) -> torch.Tensor:
|
| 315 |
batch_size = hidden_states.shape[0]
|
| 316 |
+
device = hidden_states.device
|
| 317 |
local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
|
| 318 |
+
output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device)
|
| 319 |
|
| 320 |
+
past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
|
| 321 |
local_token = None
|
| 322 |
|
| 323 |
+
cache_pos_t = torch.zeros(1, dtype=torch.long, device=device)
|
| 324 |
|
| 325 |
for i in range(self.channels):
|
| 326 |
cache_pos_t.fill_(i)
|
|
|
|
| 479 |
|
| 480 |
def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
|
| 481 |
"""Set voice prompt from either audio tokens or waveform.
|
| 482 |
+
|
| 483 |
If `audio` is a 2D array whose shape matches the codebook channels, it is
|
| 484 |
treated as audio tokens. Otherwise a codec is required to encode waveform
|
| 485 |
prompts into tokens.
|
|
|
|
| 686 |
codec,
|
| 687 |
chunk_frames: int = 40,
|
| 688 |
overlap_frames: int = 4,
|
|
|
|
|
|
|
| 689 |
decode_kwargs: Optional[dict] = None,
|
| 690 |
device: Optional[torch.device] = None,
|
| 691 |
):
|
| 692 |
self.codec = codec
|
| 693 |
self.chunk_frames = chunk_frames
|
| 694 |
self.overlap_frames = overlap_frames
|
|
|
|
|
|
|
| 695 |
self.decode_kwargs = decode_kwargs or {}
|
| 696 |
self.device = device
|
| 697 |
|
| 698 |
self._buffer: list[torch.Tensor] = []
|
| 699 |
self._buffer_len = 0
|
| 700 |
self._prev_tail: Optional[torch.Tensor] = None
|
|
|
|
| 701 |
|
| 702 |
def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
|
| 703 |
if isinstance(audio_tokens, np.ndarray):
|
|
|
|
| 707 |
self._buffer.append(audio_tokens)
|
| 708 |
self._buffer_len += audio_tokens.shape[0]
|
| 709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
def audio_chunks(self) -> Iterable[torch.Tensor]:
|
| 711 |
+
while self._buffer_len >= self.chunk_frames:
|
| 712 |
+
chunk_tokens = self._consume_frames(self.chunk_frames)
|
| 713 |
+
wav = self._decode(chunk_tokens, chunk_duration=0.32)
|
|
|
|
| 714 |
yield self._apply_crossfade(wav)
|
| 715 |
|
| 716 |
def flush(self) -> Optional[torch.Tensor]:
|
|
|
|
| 736 |
self._buffer_len -= num_frames - remaining
|
| 737 |
return torch.cat(frames, dim=0)
|
| 738 |
|
| 739 |
+
def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor:
|
| 740 |
device = self.device
|
| 741 |
if device is None:
|
| 742 |
if hasattr(self.codec, "device"):
|
|
|
|
| 749 |
if device is not None:
|
| 750 |
tokens = tokens.to(device)
|
| 751 |
tokens_t = tokens.permute(1, 0)
|
| 752 |
+
# allow callers to override decode settings (e.g. chunk_duration=-1 to disable internal streaming)
|
| 753 |
decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
|
| 754 |
+
if "chunk_duration" in decode_kwargs:
|
| 755 |
+
override = decode_kwargs.pop("chunk_duration")
|
| 756 |
+
if override is None:
|
| 757 |
+
chunk_duration_arg = None
|
| 758 |
+
else:
|
| 759 |
+
try:
|
| 760 |
+
override_f = float(override)
|
| 761 |
+
except Exception:
|
| 762 |
+
override_f = None
|
| 763 |
+
chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f
|
| 764 |
+
else:
|
| 765 |
+
chunk_duration_arg = chunk_duration
|
| 766 |
+
|
| 767 |
+
decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs)
|
| 768 |
if isinstance(decoded, dict):
|
| 769 |
wav = decoded["audio"][0]
|
| 770 |
else:
|
|
|
|
| 809 |
class TextDeltaTokenizer:
|
| 810 |
"""
|
| 811 |
Convert LLM streaming text (delta) into “incremental token IDs”.
|
| 812 |
+
|
| 813 |
Notes:
|
| 814 |
- The input is a delta that is progressively appended to the same string
|
| 815 |
(consistent with the common delta output behavior in vLLM).
|
|
|
|
| 891 |
class MossTTSRealtimeTextStreamBridge:
|
| 892 |
"""
|
| 893 |
Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
|
| 894 |
+
|
| 895 |
Usage overview:
|
| 896 |
- First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
|
| 897 |
- Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
|
|
|
|
| 925 |
def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
|
| 926 |
"""
|
| 927 |
Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
|
| 928 |
+
|
| 929 |
Internally, this directly calls `session.push_text()`, which segments the text
|
| 930 |
based on punctuation/length and then tokenizes the *entire segment* at once,
|
| 931 |
avoiding the prefix instability issues of incremental BPE tokenization.
|