| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Streaming inference utilities for MossTTSRealtime.""" |
|
|
| from __future__ import annotations |
|
|
| import contextlib |
| import re |
| from typing import Iterable, Iterator, List, Optional, Sequence |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from transformers.cache_utils import StaticCache |
| from transformers.utils import is_torchaudio_available, requires_backends |
| from transformers.utils.import_utils import requires |
|
|
| if is_torchaudio_available(): |
| import torchaudio |
|
|
|
|
| @requires(backends=("torch",)) |
| class MossTTSRealtimeInference: |
| """Step-wise inference wrapper for MossTTSRealtime. |
| |
| This class mirrors the non-streaming inference logic but exposes a |
| prefill/step/finish API for streaming usage. |
| """ |
|
|
| def __init__( |
| self, |
| model, |
| tokenizer, |
| max_length: int = 1000, |
| channels: int = 16, |
| audio_channel_pad: int = 1024, |
| audio_bos_token: int = 1025, |
| audio_eos_token: int = 1026, |
| text_pad_id: int = 151655, |
| aud_pad_id: int = 151654, |
| ): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.channels = channels |
| self.audio_channel_pad = audio_channel_pad |
| self.audio_bos_token = audio_bos_token |
| self.audio_eos_token = audio_eos_token |
| self.text_pad_id = text_pad_id |
| self.aud_pad_id = aud_pad_id |
|
|
| self.past_key_values = None |
| self.attention_mask = None |
| self._generated_tokens: List[torch.Tensor] = [] |
| self._is_stopping = None |
| self._last_audio_tokens = None |
| self._step_idx = 0 |
|
|
| @property |
| def device(self): |
| return next(self.model.parameters()).device |
|
|
| @property |
| def is_finished(self) -> bool: |
| return self._is_stopping is not None and bool(self._is_stopping.all()) |
|
|
| def reset_generation_state(self, keep_cache: bool = True): |
| if not keep_cache: |
| self.past_key_values = None |
| self.attention_mask = None |
| |
| |
| self._generated_tokens = [] |
| self._is_stopping = None |
| self._last_audio_tokens = None |
| self._step_idx = 0 |
|
|
| def _normalize_input_ids(self, input_ids): |
| if isinstance(input_ids, torch.Tensor): |
| input_ids = input_ids.detach().cpu().numpy() |
| if isinstance(input_ids, np.ndarray): |
| if input_ids.ndim == 2: |
| return [input_ids] |
| if input_ids.ndim == 3: |
| return [input_ids[i] for i in range(input_ids.shape[0])] |
| if isinstance(input_ids, (list, tuple)): |
| return [np.array(item) for item in input_ids] |
| raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].") |
|
|
| def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]: |
| if text_prefix_ids is None: |
| raise ValueError("text_prefix_ids must be provided for prefill.") |
| if isinstance(text_prefix_ids, torch.Tensor): |
| text_prefix_ids = text_prefix_ids.detach().cpu().tolist() |
| if isinstance(text_prefix_ids, np.ndarray): |
| text_prefix_ids = text_prefix_ids.tolist() |
| if isinstance(text_prefix_ids, list): |
| if len(text_prefix_ids) == 0: |
| return [[] for _ in range(batch_size)] |
| if isinstance(text_prefix_ids[0], (int, np.integer)): |
| return [list(text_prefix_ids)] |
| if len(text_prefix_ids) == 1 and batch_size > 1: |
| return [list(text_prefix_ids[0]) for _ in range(batch_size)] |
| if len(text_prefix_ids) != batch_size: |
| raise ValueError( |
| f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}." |
| ) |
| return [list(item) for item in text_prefix_ids] |
| raise ValueError("text_prefix_ids must be list-like or tensor-like.") |
|
|
| @torch.inference_mode() |
| def prefill( |
| self, |
| input_ids, |
| text_prefix_ids, |
| max_prefill_len: Optional[int] = None, |
| past_key_values=None, |
| device: Optional[torch.device] = None, |
| temperature: float = 0.8, |
| top_p: float = 0.6, |
| top_k: int = 30, |
| do_sample: bool = True, |
| repetition_penalty: Optional[float] = 1.1, |
| repetition_window: Optional[int] = 50, |
| ) -> torch.Tensor: |
| if device is None: |
| device = self.device |
|
|
| if past_key_values is not None: |
| self.past_key_values = past_key_values |
|
|
| input_ids_list = self._normalize_input_ids(input_ids) |
| batch_size = len(input_ids_list) |
| text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size) |
|
|
| concat_inputs_id_list = [] |
| for i in range(batch_size): |
| prefix = text_prefix_list[i] |
| if max_prefill_len is not None: |
| prefix = prefix[:max_prefill_len] |
| if len(prefix) == 0: |
| raise ValueError("Prefill requires at least one text token.") |
|
|
| text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64) |
| text_seg[:, 0] = np.array(prefix, dtype=np.int64) |
| text_seg[len(prefix) - 1, 1] = self.audio_bos_token |
| concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0) |
| concat_inputs_id_list.append(concat_inputs_id) |
|
|
| attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list] |
| max_len = max(ids.shape[0] for ids in concat_inputs_id_list) |
| padded_input_ids, padded_attns = [], [] |
| pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id |
|
|
| for ids, attn in zip(concat_inputs_id_list, attention_masks): |
| pad_len = max_len - ids.shape[0] |
| input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64) |
| input_pad[:, 0] = pad_token_id |
| padded_input_ids.append(np.concatenate([input_pad, ids])) |
| attn_pad = np.zeros(pad_len, dtype=np.bool_) |
| padded_attns.append(np.concatenate([attn_pad, attn])) |
|
|
| current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device) |
| current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device) |
|
|
| |
| if self.attention_mask is not None and self.past_key_values is not None: |
| current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1) |
|
|
| outputs = self.model( |
| input_ids=current_input_ids, |
| attention_mask=current_attention_mask, |
| past_key_values=self.past_key_values, |
| use_cache=True, |
| return_dict=True, |
| ) |
| self.past_key_values = outputs.past_key_values |
| self.attention_mask = current_attention_mask |
|
|
| backbone_hidden_states = outputs.last_hidden_state[:, -1:, :] |
| audio_tokens = self.generate_local_transformer( |
| hidden_states=backbone_hidden_states, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| do_sample=do_sample, |
| repetition_penalty=repetition_penalty, |
| repetition_window=repetition_window, |
| generated_tokens=None, |
| gen_step=0, |
| ) |
|
|
| self._generated_tokens = [audio_tokens] |
| self._last_audio_tokens = audio_tokens |
| self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token |
| self._step_idx = 1 |
| return audio_tokens |
|
|
| @torch.inference_mode() |
| def step( |
| self, |
| text_token: Optional[Iterable[int] | torch.Tensor | int], |
| temperature: float = 0.8, |
| top_p: float = 0.6, |
| top_k: int = 30, |
| do_sample: bool = True, |
| repetition_penalty: Optional[float] = 1.1, |
| repetition_window: Optional[int] = 50, |
| ) -> torch.Tensor: |
| if self._last_audio_tokens is None or self.attention_mask is None: |
| raise ValueError("You must call prefill() before step().") |
| if self.is_finished: |
| return self._last_audio_tokens |
|
|
| batch_size = self._last_audio_tokens.shape[0] |
| if text_token is None: |
| text_tokens = [self.text_pad_id] * batch_size |
| elif isinstance(text_token, torch.Tensor): |
| text_tokens = text_token.detach().cpu().tolist() |
| elif isinstance(text_token, (list, tuple, np.ndarray)): |
| text_tokens = list(text_token) |
| else: |
| text_tokens = [int(text_token)] |
|
|
| if len(text_tokens) != batch_size: |
| raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.") |
|
|
| device = self._last_audio_tokens.device |
| text_t = torch.tensor(text_tokens, device=device, dtype=torch.long) |
| step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2) |
| self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1) |
|
|
| outputs = self.model( |
| input_ids=step_ids, |
| attention_mask=self.attention_mask, |
| past_key_values=self.past_key_values, |
| use_cache=True, |
| return_dict=True, |
| ) |
| self.past_key_values = outputs.past_key_values |
| backbone_hidden_states = outputs.last_hidden_state[:, -1:, :] |
|
|
| history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None |
| audio_tokens = self.generate_local_transformer( |
| hidden_states=backbone_hidden_states, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| do_sample=do_sample, |
| repetition_penalty=repetition_penalty, |
| repetition_window=repetition_window, |
| generated_tokens=history, |
| gen_step=self._step_idx, |
| ) |
|
|
| self._generated_tokens.append(audio_tokens) |
| self._last_audio_tokens = audio_tokens |
| self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token |
| self._step_idx += 1 |
| return audio_tokens |
|
|
| @torch.inference_mode() |
| def finish( |
| self, |
| max_steps: Optional[int] = None, |
| temperature: float = 0.8, |
| top_p: float = 0.6, |
| top_k: int = 30, |
| do_sample: bool = True, |
| repetition_penalty: Optional[float] = 1.1, |
| repetition_window: Optional[int] = 50, |
| ) -> list[torch.Tensor]: |
| outputs = [] |
| steps_left = max_steps if max_steps is not None else self.max_length |
| while steps_left > 0 and not self.is_finished: |
| outputs.append( |
| self.step( |
| text_token=None, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| do_sample=do_sample, |
| repetition_penalty=repetition_penalty, |
| repetition_window=repetition_window, |
| ) |
| ) |
| steps_left -= 1 |
| return outputs |
|
|
| @torch.compile(fullgraph=True) |
| def generate_local_transformer( |
| self, |
| hidden_states: torch.Tensor, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| do_sample: bool, |
| repetition_penalty: Optional[float], |
| repetition_window: Optional[int], |
| generated_tokens: Optional[torch.Tensor], |
| gen_step: int, |
| ) -> torch.Tensor: |
| batch_size = hidden_states.shape[0] |
| device = hidden_states.device |
| local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size) |
| output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device) |
|
|
| past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels) |
| local_token = None |
|
|
| cache_pos_t = torch.zeros(1, dtype=torch.long, device=device) |
|
|
| for i in range(self.channels): |
| cache_pos_t.fill_(i) |
|
|
| local_outputs = self.model.local_transformer( |
| input_ids=local_token, |
| inputs_embeds=local_inputs, |
| past_key_values=past_key_values, |
| cache_position=cache_pos_t, |
| codebook_idx=i, |
| use_cache=True, |
| logits_to_keep=1, |
| ) |
| logits = local_outputs.logits |
|
|
| if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None: |
| logits = self.apply_repetition_penalty( |
| scores=logits, |
| history_tokens=generated_tokens[:, :gen_step, i], |
| penalty=float(repetition_penalty), |
| repetition_window=repetition_window, |
| ) |
|
|
| local_token = self.sample_token( |
| logits=logits, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| do_sample=do_sample, |
| ) |
| output_token[:, i] = local_token.squeeze(-1) |
|
|
| if i == 0: |
| local_inputs = None |
| return output_token |
|
|
| def apply_repetition_penalty( |
| self, |
| scores: torch.Tensor, |
| history_tokens: torch.Tensor, |
| penalty: float = 1.1, |
| repetition_window: Optional[int] = None, |
| ): |
| scores_ = scores[:, 0, :] |
| ht = history_tokens |
|
|
| if repetition_window is not None and repetition_window > 0: |
| ht = ht[:, -repetition_window:] |
|
|
| cur = scores_.gather(1, ht) |
| new = torch.where(cur < 0, cur * penalty, cur / penalty) |
| scores_.scatter_(1, ht, new) |
| return scores_ |
|
|
| def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True): |
| if not do_sample or temperature == 0: |
| return torch.argmax(logits, dim=-1) |
| logits = logits / temperature |
| original_shape = logits.shape |
| vocab_size = original_shape[-1] |
| reshaped_logits = logits.reshape(-1, vocab_size) |
|
|
| if top_k is not None: |
| reshaped_logits = self.apply_top_k(reshaped_logits, top_k) |
|
|
| if top_p is not None: |
| reshaped_logits = self.apply_top_p(reshaped_logits, top_p) |
|
|
| probs = F.softmax(reshaped_logits, dim=-1) |
| next_tokens_flat = torch.multinomial(probs, num_samples=1) |
|
|
| output_shape = original_shape[:-1] |
| return next_tokens_flat.view(output_shape) |
|
|
| def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1): |
| if not isinstance(top_k, int) or top_k <= 0: |
| raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") |
| batch_size, vocab_size = logits.shape |
| top_k = max(top_k, min_tokens_to_keep) |
| top_k = min(top_k, vocab_size) |
| indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None] |
| return logits.masked_fill(logits < indices_to_remove, filter_value) |
|
|
| def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1): |
| top_p = float(top_p) |
| if top_p < 0 or top_p > 1.0: |
| raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") |
|
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
| sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
| sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 |
| indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove) |
| logits_processed = logits.masked_fill(indices_to_remove, filter_value) |
| return logits_processed |
|
|
|
|
| @requires(backends=("torch",)) |
| class MossTTSRealtimeStreamingSession: |
| """Manage text-to-audio streaming for a single conversation.""" |
|
|
| _split_pattern = re.compile( |
| r"[。!?!?\.\u2026]\s*" |
| r"|[,,;;::\u2014\u2013\-]\s*" |
| r"|\)\s*|\]\s*" |
| r"|\n" |
| ) |
|
|
| def __init__( |
| self, |
| inferencer: MossTTSRealtimeInference, |
| processor, |
| codec=None, |
| codec_sample_rate: int = 24000, |
| codec_encode_kwargs: Optional[dict] = None, |
| prefill_text_len: int = 12, |
| text_buffer_size: int = 32, |
| min_text_chunk_chars: int = 8, |
| temperature: float = 0.8, |
| top_p: float = 0.6, |
| top_k: int = 30, |
| do_sample: bool = True, |
| repetition_penalty: Optional[float] = 1.1, |
| repetition_window: Optional[int] = 50, |
| ): |
| self.inferencer = inferencer |
| self.processor = processor |
| self.tokenizer = processor.tokenizer |
| self.codec = codec |
| self.codec_sample_rate = codec_sample_rate |
| self.codec_encode_kwargs = codec_encode_kwargs or {} |
|
|
| self.prefill_text_len = prefill_text_len |
| self.text_buffer_size = text_buffer_size |
| self.min_text_chunk_chars = min_text_chunk_chars |
|
|
| self.temperature = temperature |
| self.top_p = top_p |
| self.top_k = top_k |
| self.do_sample = do_sample |
| self.repetition_penalty = repetition_penalty |
| self.repetition_window = repetition_window |
|
|
| self._voice_prompt_tokens = None |
| self._turn_input_ids = None |
| self._turn_idx = 0 |
|
|
| self._text_cache = "" |
| self._pending_tokens: list[int] = [] |
| self._prefilled = False |
| self._text_ended = False |
|
|
| def set_voice_prompt_tokens(self, audio_tokens: np.ndarray): |
| self._voice_prompt_tokens = audio_tokens |
|
|
| def set_voice_prompt(self, audio, sample_rate: Optional[int] = None): |
| """Set voice prompt from either audio tokens or waveform. |
| |
| If `audio` is a 2D array whose shape matches the codebook channels, it is |
| treated as audio tokens. Otherwise a codec is required to encode waveform |
| prompts into tokens. |
| """ |
| if isinstance(audio, np.ndarray) and audio.ndim == 2: |
| if self.processor.channels in audio.shape: |
| self._voice_prompt_tokens = audio |
| return |
| if isinstance(audio, torch.Tensor) and audio.dim() == 2: |
| if self.processor.channels in audio.shape: |
| self._voice_prompt_tokens = audio.detach().cpu().numpy() |
| return |
|
|
| if self.codec is None: |
| raise ValueError("codec is required to encode waveform prompts.") |
|
|
| waveform = audio |
| if isinstance(audio, (str, bytes)): |
| requires_backends(self, ["torchaudio"]) |
| wav, sr = torchaudio.load(audio) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| waveform = wav.squeeze(0) |
| sample_rate = sr |
|
|
| if isinstance(waveform, np.ndarray): |
| waveform = torch.from_numpy(waveform) |
| if not isinstance(waveform, torch.Tensor): |
| raise ValueError("Unsupported audio type for voice prompt.") |
|
|
| if sample_rate is not None and sample_rate != self.codec_sample_rate: |
| requires_backends(self, ["torchaudio"]) |
| waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate) |
|
|
| waveform = waveform.to(self.inferencer.device) |
| encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs) |
| if isinstance(encode_out, dict): |
| if "codes_list" in encode_out: |
| tokens = encode_out["codes_list"][0] |
| elif "audio_codes" in encode_out: |
| tokens = encode_out["audio_codes"][0] |
| else: |
| raise ValueError("codec.encode output missing audio codes.") |
| else: |
| tokens = encode_out |
| if isinstance(tokens, torch.Tensor): |
| tokens = tokens.detach().cpu().numpy() |
| self._voice_prompt_tokens = tokens |
|
|
| def clear_voice_prompt(self): |
| self._voice_prompt_tokens = None |
|
|
| def reset_turn( |
| self, |
| user_text: Optional[str] = None, |
| user_audio_tokens: Optional[np.ndarray] = None, |
| input_ids: Optional[np.ndarray] = None, |
| include_system_prompt: Optional[bool] = None, |
| reset_cache: bool = False, |
| ): |
| if include_system_prompt is None: |
| include_system_prompt = self._turn_idx == 0 |
|
|
| if input_ids is None: |
| if user_text is None or user_audio_tokens is None: |
| raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.") |
| user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens) |
| if include_system_prompt: |
| system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens) |
| input_ids = np.concatenate([system_prompt, user_prompt], axis=0) |
| else: |
| input_ids = user_prompt |
|
|
| self._turn_input_ids = input_ids |
| self._turn_idx += 1 |
|
|
| self._text_cache = "" |
| self._pending_tokens = [] |
| self._prefilled = False |
| self._text_ended = False |
|
|
| self.inferencer.reset_generation_state(keep_cache=not reset_cache) |
|
|
| def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]: |
| self._pending_tokens.extend([int(t) for t in tokens]) |
| return self._drain_pending_tokens() |
|
|
| def push_text(self, text_fragment: str) -> list[torch.Tensor]: |
| self._text_cache += text_fragment |
| segments = self._extract_text_segments(force=False) |
| for segment in segments: |
| self._pending_tokens.extend(self._tokenize(segment)) |
| return self._drain_pending_tokens() |
|
|
| def end_text(self) -> list[torch.Tensor]: |
| self._text_ended = True |
| if self._text_cache: |
| self._pending_tokens.extend(self._tokenize(self._text_cache)) |
| self._text_cache = "" |
| return self._drain_pending_tokens() |
|
|
| def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]: |
| if not self._prefilled: |
| return [] |
| return self.inferencer.finish( |
| max_steps=max_steps, |
| temperature=self.temperature, |
| top_p=self.top_p, |
| top_k=self.top_k, |
| do_sample=self.do_sample, |
| repetition_penalty=self.repetition_penalty, |
| repetition_window=self.repetition_window, |
| ) |
|
|
| def _tokenize(self, text: str) -> list[int]: |
| return self.tokenizer.encode(text, add_special_tokens=False) |
|
|
| def _extract_text_segments(self, force: bool) -> list[str]: |
| segments = [] |
| if force: |
| if self._text_cache: |
| segments.append(self._text_cache) |
| self._text_cache = "" |
| return segments |
|
|
| while self._text_cache: |
| cut_idx = None |
| if len(self._text_cache) >= self.min_text_chunk_chars: |
| matches = list(self._split_pattern.finditer(self._text_cache)) |
| for match in matches: |
| if match.end() >= self.min_text_chunk_chars: |
| cut_idx = match.end() |
| break |
| if cut_idx is None and len(self._text_cache) >= self.text_buffer_size: |
| whitespace_idx = self._text_cache.rfind(" ") |
| if whitespace_idx != -1: |
| cut_idx = whitespace_idx + 1 |
| if cut_idx is None: |
| break |
| segments.append(self._text_cache[:cut_idx]) |
| self._text_cache = self._text_cache[cut_idx:] |
| return segments |
|
|
| def _prefill_if_needed(self) -> list[torch.Tensor]: |
| if self._prefilled: |
| return [] |
| if not self._pending_tokens and not self._text_ended: |
| return [] |
| if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended: |
| return [] |
| if self._turn_input_ids is None: |
| raise ValueError("reset_turn must be called before streaming text.") |
|
|
| if self._text_ended: |
| prefill_len = len(self._pending_tokens) |
| else: |
| prefill_len = min(len(self._pending_tokens), self.prefill_text_len) |
|
|
| if prefill_len == 0: |
| return [] |
|
|
| prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)] |
| audio_tokens = self.inferencer.prefill( |
| input_ids=[self._turn_input_ids], |
| text_prefix_ids=[prefix_tokens], |
| temperature=self.temperature, |
| top_p=self.top_p, |
| top_k=self.top_k, |
| do_sample=self.do_sample, |
| repetition_penalty=None, |
| repetition_window=self.repetition_window, |
| ) |
| self._prefilled = True |
| return [audio_tokens] |
|
|
| def _drain_pending_tokens(self) -> list[torch.Tensor]: |
| outputs: list[torch.Tensor] = [] |
| outputs.extend(self._prefill_if_needed()) |
| if not self._prefilled: |
| return outputs |
|
|
| while self._pending_tokens and not self.inferencer.is_finished: |
| token = self._pending_tokens.pop(0) |
| outputs.append( |
| self.inferencer.step( |
| token, |
| temperature=self.temperature, |
| top_p=self.top_p, |
| top_k=self.top_k, |
| do_sample=self.do_sample, |
| repetition_penalty=self.repetition_penalty, |
| repetition_window=self.repetition_window, |
| ) |
| ) |
| return outputs |
|
|
|
|
| @requires(backends=("torch",)) |
| class AudioStreamDecoder: |
| """Decode audio tokens into waveform chunks with optional crossfade.""" |
|
|
| def __init__( |
| self, |
| codec, |
| chunk_frames: int = 40, |
| overlap_frames: int = 4, |
| decode_kwargs: Optional[dict] = None, |
| device: Optional[torch.device] = None, |
| ): |
| self.codec = codec |
| self.chunk_frames = chunk_frames |
| self.overlap_frames = overlap_frames |
| self.decode_kwargs = decode_kwargs or {} |
| self.device = device |
|
|
| self._buffer: list[torch.Tensor] = [] |
| self._buffer_len = 0 |
| self._prev_tail: Optional[torch.Tensor] = None |
|
|
| def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor): |
| if isinstance(audio_tokens, np.ndarray): |
| audio_tokens = torch.from_numpy(audio_tokens) |
| if audio_tokens.dim() != 2: |
| raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}") |
| self._buffer.append(audio_tokens) |
| self._buffer_len += audio_tokens.shape[0] |
|
|
| def audio_chunks(self) -> Iterable[torch.Tensor]: |
| while self._buffer_len >= self.chunk_frames: |
| chunk_tokens = self._consume_frames(self.chunk_frames) |
| wav = self._decode(chunk_tokens, chunk_duration=0.32) |
| yield self._apply_crossfade(wav) |
|
|
| def flush(self) -> Optional[torch.Tensor]: |
| if self._buffer_len == 0: |
| return None |
| chunk_tokens = self._consume_frames(self._buffer_len) |
| wav = self._decode(chunk_tokens) |
| return self._apply_crossfade(wav, final_chunk=True) |
|
|
| def _consume_frames(self, num_frames: int) -> torch.Tensor: |
| frames = [] |
| remaining = num_frames |
| while remaining > 0 and self._buffer: |
| head = self._buffer[0] |
| if head.shape[0] <= remaining: |
| frames.append(head) |
| remaining -= head.shape[0] |
| self._buffer.pop(0) |
| else: |
| frames.append(head[:remaining]) |
| self._buffer[0] = head[remaining:] |
| remaining = 0 |
| self._buffer_len -= num_frames - remaining |
| return torch.cat(frames, dim=0) |
|
|
| def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor: |
| device = self.device |
| if device is None: |
| if hasattr(self.codec, "device"): |
| device = self.codec.device |
| else: |
| try: |
| device = next(self.codec.parameters()).device |
| except Exception: |
| device = None |
| if device is not None: |
| tokens = tokens.to(device) |
| tokens_t = tokens.permute(1, 0) |
| |
| decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {} |
| if "chunk_duration" in decode_kwargs: |
| override = decode_kwargs.pop("chunk_duration") |
| if override is None: |
| chunk_duration_arg = None |
| else: |
| try: |
| override_f = float(override) |
| except Exception: |
| override_f = None |
| chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f |
| else: |
| chunk_duration_arg = chunk_duration |
|
|
| decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs) |
| if isinstance(decoded, dict): |
| wav = decoded["audio"][0] |
| else: |
| wav = decoded |
| if isinstance(wav, np.ndarray): |
| wav = torch.from_numpy(wav) |
| if wav.dim() > 1: |
| wav = wav.squeeze(0) |
| return wav |
|
|
| def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor: |
| if self.overlap_frames <= 0: |
| return wav |
| if self._prev_tail is None: |
| self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None |
| return wav |
|
|
| overlap = self._overlap_samples(wav) |
| if overlap == 0: |
| return wav |
|
|
| prev_tail = self._prev_tail |
| if prev_tail.numel() < overlap: |
| overlap = prev_tail.numel() |
| if overlap == 0: |
| return wav |
|
|
| fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device) |
| fade_in = 1.0 - fade_out |
| cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in |
| merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1) |
|
|
| self._prev_tail = None if final_chunk else wav[-overlap:].clone() |
| return merged |
|
|
| def _overlap_samples(self, wav: torch.Tensor) -> int: |
| if self.chunk_frames <= 0: |
| return 0 |
| return int(wav.numel() * (self.overlap_frames / self.chunk_frames)) |
|
|
|
|
| class TextDeltaTokenizer: |
| """ |
| Convert LLM streaming text (delta) into “incremental token IDs”. |
| |
| Notes: |
| - The input is a delta that is progressively appended to the same string |
| (consistent with the common delta output behavior in vLLM). |
| - Each time, re-encode the *full text* with the tokenizer, then take only |
| the newly added token IDs. |
| - This guarantees that tokenization is consistent with the final complete |
| text, avoiding boundary mismatches caused by tokenizing partial segments. |
| """ |
|
|
| def __init__(self, tokenizer, *, hold_back: int = 3): |
| self.tokenizer = tokenizer |
| self.hold_back = max(0, int(hold_back)) |
| self._text = "" |
| self._all_ids: list[int] = [] |
| self._emitted_count: int = 0 |
|
|
| @property |
| def text(self) -> str: |
| return self._text |
|
|
| @property |
| def token_ids(self) -> list[int]: |
| return list(self._all_ids) |
|
|
| def push_delta(self, delta: str) -> list[int]: |
| """Append a text delta and return newly stable token ids (may be empty).""" |
| if not delta: |
| return [] |
| self._text += str(delta) |
| self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False) |
| |
| stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back) |
| new_ids = self._all_ids[self._emitted_count : stable_count] |
| self._emitted_count = stable_count |
| return new_ids |
|
|
| def flush(self) -> list[int]: |
| """Emit all remaining token ids at end of stream.""" |
| self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False) |
| remaining = self._all_ids[self._emitted_count :] |
| self._emitted_count = len(self._all_ids) |
| return remaining |
|
|
|
|
| def _sanitize_audio_tokens( |
| tokens: torch.Tensor, |
| *, |
| codebook_size: int, |
| audio_eos_token: int, |
| ) -> tuple[torch.Tensor, bool]: |
| """Trim rows after EOS/invalid tokens and return whether decoding should stop.""" |
| if tokens.dim() == 1: |
| tokens = tokens.unsqueeze(0) |
| if tokens.numel() == 0: |
| return tokens, False |
|
|
| eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False) |
| invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1) |
|
|
| stop_idx = None |
| if eos_rows.numel() > 0: |
| stop_idx = int(eos_rows[0].item()) |
| if invalid_rows.any(): |
| invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item()) |
| stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx) |
|
|
| if stop_idx is not None: |
| return tokens[:stop_idx], True |
| return tokens, False |
|
|
|
|
| def _maybe_codec_streaming(codec, *, batch_size: int): |
| if codec is None or not hasattr(codec, "streaming"): |
| return contextlib.nullcontext() |
| return codec.streaming(batch_size=batch_size) |
|
|
|
|
| @requires(backends=("torch",)) |
| class MossTTSRealtimeTextStreamBridge: |
| """ |
| Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks. |
| |
| Usage overview: |
| - First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`). |
| - Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via |
| `push_text_delta()`. |
| - Once the accumulated token count reaches `prefill_text_len`, the session will |
| start generating audio tokens; the bridge will immediately decode them into WAV |
| chunks and yield them. |
| """ |
|
|
| def __init__( |
| self, |
| session: MossTTSRealtimeStreamingSession, |
| decoder: AudioStreamDecoder, |
| *, |
| codebook_size: Optional[int] = None, |
| audio_eos_token: Optional[int] = None, |
| batch_size: int = 1, |
| ): |
| self.session = session |
| self.decoder = decoder |
| self.batch_size = int(batch_size) |
|
|
| if codebook_size is None: |
| codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024)) |
| if audio_eos_token is None: |
| audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026)) |
|
|
| self.codebook_size = int(codebook_size) |
| self.audio_eos_token = int(audio_eos_token) |
|
|
| def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]: |
| """ |
| Push a chunk of incremental text output from the LLM and return newly generated WAV chunks. |
| |
| Internally, this directly calls `session.push_text()`, which segments the text |
| based on punctuation/length and then tokenizes the *entire segment* at once, |
| avoiding the prefix instability issues of incremental BPE tokenization. |
| """ |
| audio_frames = self.session.push_text(delta) |
| yield from self._decode_audio_frames(audio_frames) |
|
|
| def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]: |
| """Push token ids directly (for sources that stream token ids).""" |
| if not token_ids: |
| return |
| audio_frames = self.session.push_text_tokens(token_ids) |
| yield from self._decode_audio_frames(audio_frames) |
|
|
| def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]: |
| """Mark text stream end and emit all remaining audio chunks (including flush).""" |
| audio_frames = self.session.end_text() |
| yield from self._decode_audio_frames(audio_frames) |
|
|
| while True: |
| more_frames = self.session.drain(max_steps=drain_step) |
| if not more_frames: |
| break |
| yield from self._decode_audio_frames(more_frames) |
| if self.session.inferencer.is_finished: |
| break |
|
|
| final = self.decoder.flush() |
| if final is not None and final.numel() > 0: |
| yield final.detach().cpu() |
|
|
| def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]: |
| """Consume a full delta iterator and continuously yield waveform chunks.""" |
| with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size): |
| for delta in deltas: |
| yield from self.push_text_delta(delta) |
| yield from self.finish(drain_step=drain_step) |
|
|
| def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]: |
| for frame in audio_frames: |
| tokens = frame |
| if tokens.dim() == 3: |
| tokens = tokens[0] |
| if tokens.dim() != 2: |
| raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}") |
| if tokens.shape[0] != 1: |
| raise ValueError( |
| f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}." |
| ) |
|
|
| tokens, stop = _sanitize_audio_tokens( |
| tokens, |
| codebook_size=self.codebook_size, |
| audio_eos_token=self.audio_eos_token, |
| ) |
| if tokens.numel() == 0: |
| if stop: |
| break |
| continue |
|
|
| self.decoder.push_tokens(tokens.detach()) |
| for wav in self.decoder.audio_chunks(): |
| if wav.numel() == 0: |
| continue |
| yield wav.detach().cpu() |
| if stop: |
| break |
|
|
|
|
| __all__ = [ |
| "AudioStreamDecoder", |
| "MossTTSRealtimeInference", |
| "MossTTSRealtimeStreamingSession", |
| "MossTTSRealtimeTextStreamBridge", |
| "TextDeltaTokenizer", |
| ] |
|
|