|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Streaming inference utilities for MossTTSRealtime.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import contextlib |
|
|
import re |
|
|
from typing import Iterable, Iterator, List, Optional, Sequence |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers.cache_utils import StaticCache |
|
|
from transformers.utils import is_torchaudio_available, requires_backends |
|
|
from transformers.utils.import_utils import requires |
|
|
|
|
|
if is_torchaudio_available(): |
|
|
import torchaudio |
|
|
|
|
|
|
|
|
@requires(backends=("torch",)) |
|
|
class MossTTSRealtimeInference: |
|
|
"""Step-wise inference wrapper for MossTTSRealtime. |
|
|
|
|
|
This class mirrors the non-streaming inference logic but exposes a |
|
|
prefill/step/finish API for streaming usage. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
tokenizer, |
|
|
max_length: int = 1000, |
|
|
channels: int = 16, |
|
|
audio_channel_pad: int = 1024, |
|
|
audio_bos_token: int = 1025, |
|
|
audio_eos_token: int = 1026, |
|
|
text_pad_id: int = 151655, |
|
|
aud_pad_id: int = 151654, |
|
|
): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.channels = channels |
|
|
self.audio_channel_pad = audio_channel_pad |
|
|
self.audio_bos_token = audio_bos_token |
|
|
self.audio_eos_token = audio_eos_token |
|
|
self.text_pad_id = text_pad_id |
|
|
self.aud_pad_id = aud_pad_id |
|
|
|
|
|
self.past_key_values = None |
|
|
self.attention_mask = None |
|
|
self._generated_tokens: List[torch.Tensor] = [] |
|
|
self._is_stopping = None |
|
|
self._last_audio_tokens = None |
|
|
self._step_idx = 0 |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return next(self.model.parameters()).device |
|
|
|
|
|
@property |
|
|
def is_finished(self) -> bool: |
|
|
return self._is_stopping is not None and bool(self._is_stopping.all()) |
|
|
|
|
|
def reset_generation_state(self, keep_cache: bool = True): |
|
|
if not keep_cache: |
|
|
self.past_key_values = None |
|
|
self.attention_mask = None |
|
|
|
|
|
|
|
|
self._generated_tokens = [] |
|
|
self._is_stopping = None |
|
|
self._last_audio_tokens = None |
|
|
self._step_idx = 0 |
|
|
|
|
|
def _normalize_input_ids(self, input_ids): |
|
|
if isinstance(input_ids, torch.Tensor): |
|
|
input_ids = input_ids.detach().cpu().numpy() |
|
|
if isinstance(input_ids, np.ndarray): |
|
|
if input_ids.ndim == 2: |
|
|
return [input_ids] |
|
|
if input_ids.ndim == 3: |
|
|
return [input_ids[i] for i in range(input_ids.shape[0])] |
|
|
if isinstance(input_ids, (list, tuple)): |
|
|
return [np.array(item) for item in input_ids] |
|
|
raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].") |
|
|
|
|
|
def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]: |
|
|
if text_prefix_ids is None: |
|
|
raise ValueError("text_prefix_ids must be provided for prefill.") |
|
|
if isinstance(text_prefix_ids, torch.Tensor): |
|
|
text_prefix_ids = text_prefix_ids.detach().cpu().tolist() |
|
|
if isinstance(text_prefix_ids, np.ndarray): |
|
|
text_prefix_ids = text_prefix_ids.tolist() |
|
|
if isinstance(text_prefix_ids, list): |
|
|
if len(text_prefix_ids) == 0: |
|
|
return [[] for _ in range(batch_size)] |
|
|
if isinstance(text_prefix_ids[0], (int, np.integer)): |
|
|
return [list(text_prefix_ids)] |
|
|
if len(text_prefix_ids) == 1 and batch_size > 1: |
|
|
return [list(text_prefix_ids[0]) for _ in range(batch_size)] |
|
|
if len(text_prefix_ids) != batch_size: |
|
|
raise ValueError( |
|
|
f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}." |
|
|
) |
|
|
return [list(item) for item in text_prefix_ids] |
|
|
raise ValueError("text_prefix_ids must be list-like or tensor-like.") |
|
|
|
|
|
@torch.inference_mode() |
|
|
def prefill( |
|
|
self, |
|
|
input_ids, |
|
|
text_prefix_ids, |
|
|
max_prefill_len: Optional[int] = None, |
|
|
past_key_values=None, |
|
|
device: Optional[torch.device] = None, |
|
|
temperature: float = 0.8, |
|
|
top_p: float = 0.6, |
|
|
top_k: int = 30, |
|
|
do_sample: bool = True, |
|
|
repetition_penalty: Optional[float] = 1.1, |
|
|
repetition_window: Optional[int] = 50, |
|
|
) -> torch.Tensor: |
|
|
if device is None: |
|
|
device = self.device |
|
|
|
|
|
if past_key_values is not None: |
|
|
self.past_key_values = past_key_values |
|
|
|
|
|
input_ids_list = self._normalize_input_ids(input_ids) |
|
|
batch_size = len(input_ids_list) |
|
|
text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size) |
|
|
|
|
|
concat_inputs_id_list = [] |
|
|
for i in range(batch_size): |
|
|
prefix = text_prefix_list[i] |
|
|
if max_prefill_len is not None: |
|
|
prefix = prefix[:max_prefill_len] |
|
|
if len(prefix) == 0: |
|
|
raise ValueError("Prefill requires at least one text token.") |
|
|
|
|
|
text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64) |
|
|
text_seg[:, 0] = np.array(prefix, dtype=np.int64) |
|
|
text_seg[len(prefix) - 1, 1] = self.audio_bos_token |
|
|
concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0) |
|
|
concat_inputs_id_list.append(concat_inputs_id) |
|
|
|
|
|
attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list] |
|
|
max_len = max(ids.shape[0] for ids in concat_inputs_id_list) |
|
|
padded_input_ids, padded_attns = [], [] |
|
|
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id |
|
|
|
|
|
for ids, attn in zip(concat_inputs_id_list, attention_masks): |
|
|
pad_len = max_len - ids.shape[0] |
|
|
input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64) |
|
|
input_pad[:, 0] = pad_token_id |
|
|
padded_input_ids.append(np.concatenate([input_pad, ids])) |
|
|
attn_pad = np.zeros(pad_len, dtype=np.bool_) |
|
|
padded_attns.append(np.concatenate([attn_pad, attn])) |
|
|
|
|
|
current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device) |
|
|
current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device) |
|
|
|
|
|
|
|
|
if self.attention_mask is not None and self.past_key_values is not None: |
|
|
current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=current_input_ids, |
|
|
attention_mask=current_attention_mask, |
|
|
past_key_values=self.past_key_values, |
|
|
use_cache=True, |
|
|
return_dict=True, |
|
|
) |
|
|
self.past_key_values = outputs.past_key_values |
|
|
self.attention_mask = current_attention_mask |
|
|
|
|
|
backbone_hidden_states = outputs.last_hidden_state[:, -1:, :] |
|
|
audio_tokens = self.generate_local_transformer( |
|
|
hidden_states=backbone_hidden_states, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
do_sample=do_sample, |
|
|
repetition_penalty=repetition_penalty, |
|
|
repetition_window=repetition_window, |
|
|
generated_tokens=None, |
|
|
gen_step=0, |
|
|
) |
|
|
|
|
|
self._generated_tokens = [audio_tokens] |
|
|
self._last_audio_tokens = audio_tokens |
|
|
self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token |
|
|
self._step_idx = 1 |
|
|
return audio_tokens |
|
|
|
|
|
@torch.inference_mode() |
|
|
def step( |
|
|
self, |
|
|
text_token: Optional[Iterable[int] | torch.Tensor | int], |
|
|
temperature: float = 0.8, |
|
|
top_p: float = 0.6, |
|
|
top_k: int = 30, |
|
|
do_sample: bool = True, |
|
|
repetition_penalty: Optional[float] = 1.1, |
|
|
repetition_window: Optional[int] = 50, |
|
|
) -> torch.Tensor: |
|
|
if self._last_audio_tokens is None or self.attention_mask is None: |
|
|
raise ValueError("You must call prefill() before step().") |
|
|
if self.is_finished: |
|
|
return self._last_audio_tokens |
|
|
|
|
|
batch_size = self._last_audio_tokens.shape[0] |
|
|
if text_token is None: |
|
|
text_tokens = [self.text_pad_id] * batch_size |
|
|
elif isinstance(text_token, torch.Tensor): |
|
|
text_tokens = text_token.detach().cpu().tolist() |
|
|
elif isinstance(text_token, (list, tuple, np.ndarray)): |
|
|
text_tokens = list(text_token) |
|
|
else: |
|
|
text_tokens = [int(text_token)] |
|
|
|
|
|
if len(text_tokens) != batch_size: |
|
|
raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.") |
|
|
|
|
|
device = self._last_audio_tokens.device |
|
|
text_t = torch.tensor(text_tokens, device=device, dtype=torch.long) |
|
|
step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2) |
|
|
self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=step_ids, |
|
|
attention_mask=self.attention_mask, |
|
|
past_key_values=self.past_key_values, |
|
|
use_cache=True, |
|
|
return_dict=True, |
|
|
) |
|
|
self.past_key_values = outputs.past_key_values |
|
|
backbone_hidden_states = outputs.last_hidden_state[:, -1:, :] |
|
|
|
|
|
history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None |
|
|
audio_tokens = self.generate_local_transformer( |
|
|
hidden_states=backbone_hidden_states, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
do_sample=do_sample, |
|
|
repetition_penalty=repetition_penalty, |
|
|
repetition_window=repetition_window, |
|
|
generated_tokens=history, |
|
|
gen_step=self._step_idx, |
|
|
) |
|
|
|
|
|
self._generated_tokens.append(audio_tokens) |
|
|
self._last_audio_tokens = audio_tokens |
|
|
self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token |
|
|
self._step_idx += 1 |
|
|
return audio_tokens |
|
|
|
|
|
@torch.inference_mode() |
|
|
def finish( |
|
|
self, |
|
|
max_steps: Optional[int] = None, |
|
|
temperature: float = 0.8, |
|
|
top_p: float = 0.6, |
|
|
top_k: int = 30, |
|
|
do_sample: bool = True, |
|
|
repetition_penalty: Optional[float] = 1.1, |
|
|
repetition_window: Optional[int] = 50, |
|
|
) -> list[torch.Tensor]: |
|
|
outputs = [] |
|
|
steps_left = max_steps if max_steps is not None else self.max_length |
|
|
while steps_left > 0 and not self.is_finished: |
|
|
outputs.append( |
|
|
self.step( |
|
|
text_token=None, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
do_sample=do_sample, |
|
|
repetition_penalty=repetition_penalty, |
|
|
repetition_window=repetition_window, |
|
|
) |
|
|
) |
|
|
steps_left -= 1 |
|
|
return outputs |
|
|
|
|
|
@torch.compile(fullgraph=True) |
|
|
def generate_local_transformer( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
top_k: int, |
|
|
do_sample: bool, |
|
|
repetition_penalty: Optional[float], |
|
|
repetition_window: Optional[int], |
|
|
generated_tokens: Optional[torch.Tensor], |
|
|
gen_step: int, |
|
|
) -> torch.Tensor: |
|
|
batch_size = hidden_states.shape[0] |
|
|
device = hidden_states.device |
|
|
local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size) |
|
|
output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device) |
|
|
|
|
|
past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels) |
|
|
local_token = None |
|
|
|
|
|
cache_pos_t = torch.zeros(1, dtype=torch.long, device=device) |
|
|
|
|
|
for i in range(self.channels): |
|
|
cache_pos_t.fill_(i) |
|
|
|
|
|
local_outputs = self.model.local_transformer( |
|
|
input_ids=local_token, |
|
|
inputs_embeds=local_inputs, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_pos_t, |
|
|
codebook_idx=i, |
|
|
use_cache=True, |
|
|
logits_to_keep=1, |
|
|
) |
|
|
logits = local_outputs.logits |
|
|
|
|
|
if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None: |
|
|
logits = self.apply_repetition_penalty( |
|
|
scores=logits, |
|
|
history_tokens=generated_tokens[:, :gen_step, i], |
|
|
penalty=float(repetition_penalty), |
|
|
repetition_window=repetition_window, |
|
|
) |
|
|
|
|
|
local_token = self.sample_token( |
|
|
logits=logits, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
do_sample=do_sample, |
|
|
) |
|
|
output_token[:, i] = local_token.squeeze(-1) |
|
|
|
|
|
if i == 0: |
|
|
local_inputs = None |
|
|
return output_token |
|
|
|
|
|
def apply_repetition_penalty( |
|
|
self, |
|
|
scores: torch.Tensor, |
|
|
history_tokens: torch.Tensor, |
|
|
penalty: float = 1.1, |
|
|
repetition_window: Optional[int] = None, |
|
|
): |
|
|
scores_ = scores[:, 0, :] |
|
|
ht = history_tokens |
|
|
|
|
|
if repetition_window is not None and repetition_window > 0: |
|
|
ht = ht[:, -repetition_window:] |
|
|
|
|
|
cur = scores_.gather(1, ht) |
|
|
new = torch.where(cur < 0, cur * penalty, cur / penalty) |
|
|
scores_.scatter_(1, ht, new) |
|
|
return scores_ |
|
|
|
|
|
def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True): |
|
|
if not do_sample or temperature == 0: |
|
|
return torch.argmax(logits, dim=-1) |
|
|
logits = logits / temperature |
|
|
original_shape = logits.shape |
|
|
vocab_size = original_shape[-1] |
|
|
reshaped_logits = logits.reshape(-1, vocab_size) |
|
|
|
|
|
if top_k is not None: |
|
|
reshaped_logits = self.apply_top_k(reshaped_logits, top_k) |
|
|
|
|
|
if top_p is not None: |
|
|
reshaped_logits = self.apply_top_p(reshaped_logits, top_p) |
|
|
|
|
|
probs = F.softmax(reshaped_logits, dim=-1) |
|
|
next_tokens_flat = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
output_shape = original_shape[:-1] |
|
|
return next_tokens_flat.view(output_shape) |
|
|
|
|
|
def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1): |
|
|
if not isinstance(top_k, int) or top_k <= 0: |
|
|
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") |
|
|
batch_size, vocab_size = logits.shape |
|
|
top_k = max(top_k, min_tokens_to_keep) |
|
|
top_k = min(top_k, vocab_size) |
|
|
indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None] |
|
|
return logits.masked_fill(logits < indices_to_remove, filter_value) |
|
|
|
|
|
def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1): |
|
|
top_p = float(top_p) |
|
|
if top_p < 0 or top_p > 1.0: |
|
|
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
|
|
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 |
|
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits_processed = logits.masked_fill(indices_to_remove, filter_value) |
|
|
return logits_processed |
|
|
|
|
|
|
|
|
@requires(backends=("torch",)) |
|
|
class MossTTSRealtimeStreamingSession: |
|
|
"""Manage text-to-audio streaming for a single conversation.""" |
|
|
|
|
|
_split_pattern = re.compile( |
|
|
r"[。!?!?\.\u2026]\s*" |
|
|
r"|[,,;;::\u2014\u2013\-]\s*" |
|
|
r"|\)\s*|\]\s*" |
|
|
r"|\n" |
|
|
) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
inferencer: MossTTSRealtimeInference, |
|
|
processor, |
|
|
codec=None, |
|
|
codec_sample_rate: int = 24000, |
|
|
codec_encode_kwargs: Optional[dict] = None, |
|
|
prefill_text_len: int = 12, |
|
|
text_buffer_size: int = 32, |
|
|
min_text_chunk_chars: int = 8, |
|
|
temperature: float = 0.8, |
|
|
top_p: float = 0.6, |
|
|
top_k: int = 30, |
|
|
do_sample: bool = True, |
|
|
repetition_penalty: Optional[float] = 1.1, |
|
|
repetition_window: Optional[int] = 50, |
|
|
): |
|
|
self.inferencer = inferencer |
|
|
self.processor = processor |
|
|
self.tokenizer = processor.tokenizer |
|
|
self.codec = codec |
|
|
self.codec_sample_rate = codec_sample_rate |
|
|
self.codec_encode_kwargs = codec_encode_kwargs or {} |
|
|
|
|
|
self.prefill_text_len = prefill_text_len |
|
|
self.text_buffer_size = text_buffer_size |
|
|
self.min_text_chunk_chars = min_text_chunk_chars |
|
|
|
|
|
self.temperature = temperature |
|
|
self.top_p = top_p |
|
|
self.top_k = top_k |
|
|
self.do_sample = do_sample |
|
|
self.repetition_penalty = repetition_penalty |
|
|
self.repetition_window = repetition_window |
|
|
|
|
|
self._voice_prompt_tokens = None |
|
|
self._turn_input_ids = None |
|
|
self._turn_idx = 0 |
|
|
|
|
|
self._text_cache = "" |
|
|
self._pending_tokens: list[int] = [] |
|
|
self._prefilled = False |
|
|
self._text_ended = False |
|
|
|
|
|
def set_voice_prompt_tokens(self, audio_tokens: np.ndarray): |
|
|
self._voice_prompt_tokens = audio_tokens |
|
|
|
|
|
def set_voice_prompt(self, audio, sample_rate: Optional[int] = None): |
|
|
"""Set voice prompt from either audio tokens or waveform. |
|
|
|
|
|
If `audio` is a 2D array whose shape matches the codebook channels, it is |
|
|
treated as audio tokens. Otherwise a codec is required to encode waveform |
|
|
prompts into tokens. |
|
|
""" |
|
|
if isinstance(audio, np.ndarray) and audio.ndim == 2: |
|
|
if self.processor.channels in audio.shape: |
|
|
self._voice_prompt_tokens = audio |
|
|
return |
|
|
if isinstance(audio, torch.Tensor) and audio.dim() == 2: |
|
|
if self.processor.channels in audio.shape: |
|
|
self._voice_prompt_tokens = audio.detach().cpu().numpy() |
|
|
return |
|
|
|
|
|
if self.codec is None: |
|
|
raise ValueError("codec is required to encode waveform prompts.") |
|
|
|
|
|
waveform = audio |
|
|
if isinstance(audio, (str, bytes)): |
|
|
requires_backends(self, ["torchaudio"]) |
|
|
wav, sr = torchaudio.load(audio) |
|
|
if wav.shape[0] > 1: |
|
|
wav = wav.mean(dim=0, keepdim=True) |
|
|
waveform = wav.squeeze(0) |
|
|
sample_rate = sr |
|
|
|
|
|
if isinstance(waveform, np.ndarray): |
|
|
waveform = torch.from_numpy(waveform) |
|
|
if not isinstance(waveform, torch.Tensor): |
|
|
raise ValueError("Unsupported audio type for voice prompt.") |
|
|
|
|
|
if sample_rate is not None and sample_rate != self.codec_sample_rate: |
|
|
requires_backends(self, ["torchaudio"]) |
|
|
waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate) |
|
|
|
|
|
waveform = waveform.to(self.inferencer.device) |
|
|
encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs) |
|
|
if isinstance(encode_out, dict): |
|
|
if "codes_list" in encode_out: |
|
|
tokens = encode_out["codes_list"][0] |
|
|
elif "audio_codes" in encode_out: |
|
|
tokens = encode_out["audio_codes"][0] |
|
|
else: |
|
|
raise ValueError("codec.encode output missing audio codes.") |
|
|
else: |
|
|
tokens = encode_out |
|
|
if isinstance(tokens, torch.Tensor): |
|
|
tokens = tokens.detach().cpu().numpy() |
|
|
self._voice_prompt_tokens = tokens |
|
|
|
|
|
def clear_voice_prompt(self): |
|
|
self._voice_prompt_tokens = None |
|
|
|
|
|
def reset_turn( |
|
|
self, |
|
|
user_text: Optional[str] = None, |
|
|
user_audio_tokens: Optional[np.ndarray] = None, |
|
|
input_ids: Optional[np.ndarray] = None, |
|
|
include_system_prompt: Optional[bool] = None, |
|
|
reset_cache: bool = False, |
|
|
): |
|
|
if include_system_prompt is None: |
|
|
include_system_prompt = self._turn_idx == 0 |
|
|
|
|
|
if input_ids is None: |
|
|
if user_text is None or user_audio_tokens is None: |
|
|
raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.") |
|
|
user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens) |
|
|
if include_system_prompt: |
|
|
system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens) |
|
|
input_ids = np.concatenate([system_prompt, user_prompt], axis=0) |
|
|
else: |
|
|
input_ids = user_prompt |
|
|
|
|
|
self._turn_input_ids = input_ids |
|
|
self._turn_idx += 1 |
|
|
|
|
|
self._text_cache = "" |
|
|
self._pending_tokens = [] |
|
|
self._prefilled = False |
|
|
self._text_ended = False |
|
|
|
|
|
self.inferencer.reset_generation_state(keep_cache=not reset_cache) |
|
|
|
|
|
def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]: |
|
|
self._pending_tokens.extend([int(t) for t in tokens]) |
|
|
return self._drain_pending_tokens() |
|
|
|
|
|
def push_text(self, text_fragment: str) -> list[torch.Tensor]: |
|
|
self._text_cache += text_fragment |
|
|
segments = self._extract_text_segments(force=False) |
|
|
for segment in segments: |
|
|
self._pending_tokens.extend(self._tokenize(segment)) |
|
|
return self._drain_pending_tokens() |
|
|
|
|
|
def end_text(self) -> list[torch.Tensor]: |
|
|
self._text_ended = True |
|
|
if self._text_cache: |
|
|
self._pending_tokens.extend(self._tokenize(self._text_cache)) |
|
|
self._text_cache = "" |
|
|
return self._drain_pending_tokens() |
|
|
|
|
|
def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]: |
|
|
if not self._prefilled: |
|
|
return [] |
|
|
return self.inferencer.finish( |
|
|
max_steps=max_steps, |
|
|
temperature=self.temperature, |
|
|
top_p=self.top_p, |
|
|
top_k=self.top_k, |
|
|
do_sample=self.do_sample, |
|
|
repetition_penalty=self.repetition_penalty, |
|
|
repetition_window=self.repetition_window, |
|
|
) |
|
|
|
|
|
def _tokenize(self, text: str) -> list[int]: |
|
|
return self.tokenizer.encode(text, add_special_tokens=False) |
|
|
|
|
|
def _extract_text_segments(self, force: bool) -> list[str]: |
|
|
segments = [] |
|
|
if force: |
|
|
if self._text_cache: |
|
|
segments.append(self._text_cache) |
|
|
self._text_cache = "" |
|
|
return segments |
|
|
|
|
|
while self._text_cache: |
|
|
cut_idx = None |
|
|
if len(self._text_cache) >= self.min_text_chunk_chars: |
|
|
matches = list(self._split_pattern.finditer(self._text_cache)) |
|
|
for match in matches: |
|
|
if match.end() >= self.min_text_chunk_chars: |
|
|
cut_idx = match.end() |
|
|
break |
|
|
if cut_idx is None and len(self._text_cache) >= self.text_buffer_size: |
|
|
whitespace_idx = self._text_cache.rfind(" ") |
|
|
if whitespace_idx != -1: |
|
|
cut_idx = whitespace_idx + 1 |
|
|
if cut_idx is None: |
|
|
break |
|
|
segments.append(self._text_cache[:cut_idx]) |
|
|
self._text_cache = self._text_cache[cut_idx:] |
|
|
return segments |
|
|
|
|
|
def _prefill_if_needed(self) -> list[torch.Tensor]: |
|
|
if self._prefilled: |
|
|
return [] |
|
|
if not self._pending_tokens and not self._text_ended: |
|
|
return [] |
|
|
if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended: |
|
|
return [] |
|
|
if self._turn_input_ids is None: |
|
|
raise ValueError("reset_turn must be called before streaming text.") |
|
|
|
|
|
if self._text_ended: |
|
|
prefill_len = len(self._pending_tokens) |
|
|
else: |
|
|
prefill_len = min(len(self._pending_tokens), self.prefill_text_len) |
|
|
|
|
|
if prefill_len == 0: |
|
|
return [] |
|
|
|
|
|
prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)] |
|
|
audio_tokens = self.inferencer.prefill( |
|
|
input_ids=[self._turn_input_ids], |
|
|
text_prefix_ids=[prefix_tokens], |
|
|
temperature=self.temperature, |
|
|
top_p=self.top_p, |
|
|
top_k=self.top_k, |
|
|
do_sample=self.do_sample, |
|
|
repetition_penalty=None, |
|
|
repetition_window=self.repetition_window, |
|
|
) |
|
|
self._prefilled = True |
|
|
return [audio_tokens] |
|
|
|
|
|
def _drain_pending_tokens(self) -> list[torch.Tensor]: |
|
|
outputs: list[torch.Tensor] = [] |
|
|
outputs.extend(self._prefill_if_needed()) |
|
|
if not self._prefilled: |
|
|
return outputs |
|
|
|
|
|
while self._pending_tokens and not self.inferencer.is_finished: |
|
|
token = self._pending_tokens.pop(0) |
|
|
outputs.append( |
|
|
self.inferencer.step( |
|
|
token, |
|
|
temperature=self.temperature, |
|
|
top_p=self.top_p, |
|
|
top_k=self.top_k, |
|
|
do_sample=self.do_sample, |
|
|
repetition_penalty=self.repetition_penalty, |
|
|
repetition_window=self.repetition_window, |
|
|
) |
|
|
) |
|
|
return outputs |
|
|
|
|
|
|
|
|
@requires(backends=("torch",)) |
|
|
class AudioStreamDecoder: |
|
|
"""Decode audio tokens into waveform chunks with optional crossfade.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
codec, |
|
|
chunk_frames: int = 40, |
|
|
overlap_frames: int = 4, |
|
|
decode_kwargs: Optional[dict] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
self.codec = codec |
|
|
self.chunk_frames = chunk_frames |
|
|
self.overlap_frames = overlap_frames |
|
|
self.decode_kwargs = decode_kwargs or {} |
|
|
self.device = device |
|
|
|
|
|
self._buffer: list[torch.Tensor] = [] |
|
|
self._buffer_len = 0 |
|
|
self._prev_tail: Optional[torch.Tensor] = None |
|
|
|
|
|
def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor): |
|
|
if isinstance(audio_tokens, np.ndarray): |
|
|
audio_tokens = torch.from_numpy(audio_tokens) |
|
|
if audio_tokens.dim() != 2: |
|
|
raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}") |
|
|
self._buffer.append(audio_tokens) |
|
|
self._buffer_len += audio_tokens.shape[0] |
|
|
|
|
|
def audio_chunks(self) -> Iterable[torch.Tensor]: |
|
|
while self._buffer_len >= self.chunk_frames: |
|
|
chunk_tokens = self._consume_frames(self.chunk_frames) |
|
|
wav = self._decode(chunk_tokens, chunk_duration=0.32) |
|
|
yield self._apply_crossfade(wav) |
|
|
|
|
|
def flush(self) -> Optional[torch.Tensor]: |
|
|
if self._buffer_len == 0: |
|
|
return None |
|
|
chunk_tokens = self._consume_frames(self._buffer_len) |
|
|
wav = self._decode(chunk_tokens) |
|
|
return self._apply_crossfade(wav, final_chunk=True) |
|
|
|
|
|
def _consume_frames(self, num_frames: int) -> torch.Tensor: |
|
|
frames = [] |
|
|
remaining = num_frames |
|
|
while remaining > 0 and self._buffer: |
|
|
head = self._buffer[0] |
|
|
if head.shape[0] <= remaining: |
|
|
frames.append(head) |
|
|
remaining -= head.shape[0] |
|
|
self._buffer.pop(0) |
|
|
else: |
|
|
frames.append(head[:remaining]) |
|
|
self._buffer[0] = head[remaining:] |
|
|
remaining = 0 |
|
|
self._buffer_len -= num_frames - remaining |
|
|
return torch.cat(frames, dim=0) |
|
|
|
|
|
def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor: |
|
|
device = self.device |
|
|
if device is None: |
|
|
if hasattr(self.codec, "device"): |
|
|
device = self.codec.device |
|
|
else: |
|
|
try: |
|
|
device = next(self.codec.parameters()).device |
|
|
except Exception: |
|
|
device = None |
|
|
if device is not None: |
|
|
tokens = tokens.to(device) |
|
|
tokens_t = tokens.permute(1, 0) |
|
|
|
|
|
decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {} |
|
|
if "chunk_duration" in decode_kwargs: |
|
|
override = decode_kwargs.pop("chunk_duration") |
|
|
if override is None: |
|
|
chunk_duration_arg = None |
|
|
else: |
|
|
try: |
|
|
override_f = float(override) |
|
|
except Exception: |
|
|
override_f = None |
|
|
chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f |
|
|
else: |
|
|
chunk_duration_arg = chunk_duration |
|
|
|
|
|
decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs) |
|
|
if isinstance(decoded, dict): |
|
|
wav = decoded["audio"][0] |
|
|
else: |
|
|
wav = decoded |
|
|
if isinstance(wav, np.ndarray): |
|
|
wav = torch.from_numpy(wav) |
|
|
if wav.dim() > 1: |
|
|
wav = wav.squeeze(0) |
|
|
return wav |
|
|
|
|
|
def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor: |
|
|
if self.overlap_frames <= 0: |
|
|
return wav |
|
|
if self._prev_tail is None: |
|
|
self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None |
|
|
return wav |
|
|
|
|
|
overlap = self._overlap_samples(wav) |
|
|
if overlap == 0: |
|
|
return wav |
|
|
|
|
|
prev_tail = self._prev_tail |
|
|
if prev_tail.numel() < overlap: |
|
|
overlap = prev_tail.numel() |
|
|
if overlap == 0: |
|
|
return wav |
|
|
|
|
|
fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device) |
|
|
fade_in = 1.0 - fade_out |
|
|
cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in |
|
|
merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1) |
|
|
|
|
|
self._prev_tail = None if final_chunk else wav[-overlap:].clone() |
|
|
return merged |
|
|
|
|
|
def _overlap_samples(self, wav: torch.Tensor) -> int: |
|
|
if self.chunk_frames <= 0: |
|
|
return 0 |
|
|
return int(wav.numel() * (self.overlap_frames / self.chunk_frames)) |
|
|
|
|
|
|
|
|
class TextDeltaTokenizer: |
|
|
""" |
|
|
Convert LLM streaming text (delta) into “incremental token IDs”. |
|
|
|
|
|
Notes: |
|
|
- The input is a delta that is progressively appended to the same string |
|
|
(consistent with the common delta output behavior in vLLM). |
|
|
- Each time, re-encode the *full text* with the tokenizer, then take only |
|
|
the newly added token IDs. |
|
|
- This guarantees that tokenization is consistent with the final complete |
|
|
text, avoiding boundary mismatches caused by tokenizing partial segments. |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer, *, hold_back: int = 3): |
|
|
self.tokenizer = tokenizer |
|
|
self.hold_back = max(0, int(hold_back)) |
|
|
self._text = "" |
|
|
self._all_ids: list[int] = [] |
|
|
self._emitted_count: int = 0 |
|
|
|
|
|
@property |
|
|
def text(self) -> str: |
|
|
return self._text |
|
|
|
|
|
@property |
|
|
def token_ids(self) -> list[int]: |
|
|
return list(self._all_ids) |
|
|
|
|
|
def push_delta(self, delta: str) -> list[int]: |
|
|
"""Append a text delta and return newly stable token ids (may be empty).""" |
|
|
if not delta: |
|
|
return [] |
|
|
self._text += str(delta) |
|
|
self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False) |
|
|
|
|
|
stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back) |
|
|
new_ids = self._all_ids[self._emitted_count : stable_count] |
|
|
self._emitted_count = stable_count |
|
|
return new_ids |
|
|
|
|
|
def flush(self) -> list[int]: |
|
|
"""Emit all remaining token ids at end of stream.""" |
|
|
self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False) |
|
|
remaining = self._all_ids[self._emitted_count :] |
|
|
self._emitted_count = len(self._all_ids) |
|
|
return remaining |
|
|
|
|
|
|
|
|
def _sanitize_audio_tokens( |
|
|
tokens: torch.Tensor, |
|
|
*, |
|
|
codebook_size: int, |
|
|
audio_eos_token: int, |
|
|
) -> tuple[torch.Tensor, bool]: |
|
|
"""Trim rows after EOS/invalid tokens and return whether decoding should stop.""" |
|
|
if tokens.dim() == 1: |
|
|
tokens = tokens.unsqueeze(0) |
|
|
if tokens.numel() == 0: |
|
|
return tokens, False |
|
|
|
|
|
eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False) |
|
|
invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1) |
|
|
|
|
|
stop_idx = None |
|
|
if eos_rows.numel() > 0: |
|
|
stop_idx = int(eos_rows[0].item()) |
|
|
if invalid_rows.any(): |
|
|
invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item()) |
|
|
stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx) |
|
|
|
|
|
if stop_idx is not None: |
|
|
return tokens[:stop_idx], True |
|
|
return tokens, False |
|
|
|
|
|
|
|
|
def _maybe_codec_streaming(codec, *, batch_size: int): |
|
|
if codec is None or not hasattr(codec, "streaming"): |
|
|
return contextlib.nullcontext() |
|
|
return codec.streaming(batch_size=batch_size) |
|
|
|
|
|
|
|
|
@requires(backends=("torch",)) |
|
|
class MossTTSRealtimeTextStreamBridge: |
|
|
""" |
|
|
Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks. |
|
|
|
|
|
Usage overview: |
|
|
- First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`). |
|
|
- Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via |
|
|
`push_text_delta()`. |
|
|
- Once the accumulated token count reaches `prefill_text_len`, the session will |
|
|
start generating audio tokens; the bridge will immediately decode them into WAV |
|
|
chunks and yield them. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
session: MossTTSRealtimeStreamingSession, |
|
|
decoder: AudioStreamDecoder, |
|
|
*, |
|
|
codebook_size: Optional[int] = None, |
|
|
audio_eos_token: Optional[int] = None, |
|
|
batch_size: int = 1, |
|
|
): |
|
|
self.session = session |
|
|
self.decoder = decoder |
|
|
self.batch_size = int(batch_size) |
|
|
|
|
|
if codebook_size is None: |
|
|
codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024)) |
|
|
if audio_eos_token is None: |
|
|
audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026)) |
|
|
|
|
|
self.codebook_size = int(codebook_size) |
|
|
self.audio_eos_token = int(audio_eos_token) |
|
|
|
|
|
def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]: |
|
|
""" |
|
|
Push a chunk of incremental text output from the LLM and return newly generated WAV chunks. |
|
|
|
|
|
Internally, this directly calls `session.push_text()`, which segments the text |
|
|
based on punctuation/length and then tokenizes the *entire segment* at once, |
|
|
avoiding the prefix instability issues of incremental BPE tokenization. |
|
|
""" |
|
|
audio_frames = self.session.push_text(delta) |
|
|
yield from self._decode_audio_frames(audio_frames) |
|
|
|
|
|
def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]: |
|
|
"""Push token ids directly (for sources that stream token ids).""" |
|
|
if not token_ids: |
|
|
return |
|
|
audio_frames = self.session.push_text_tokens(token_ids) |
|
|
yield from self._decode_audio_frames(audio_frames) |
|
|
|
|
|
def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]: |
|
|
"""Mark text stream end and emit all remaining audio chunks (including flush).""" |
|
|
audio_frames = self.session.end_text() |
|
|
yield from self._decode_audio_frames(audio_frames) |
|
|
|
|
|
while True: |
|
|
more_frames = self.session.drain(max_steps=drain_step) |
|
|
if not more_frames: |
|
|
break |
|
|
yield from self._decode_audio_frames(more_frames) |
|
|
if self.session.inferencer.is_finished: |
|
|
break |
|
|
|
|
|
final = self.decoder.flush() |
|
|
if final is not None and final.numel() > 0: |
|
|
yield final.detach().cpu() |
|
|
|
|
|
def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]: |
|
|
"""Consume a full delta iterator and continuously yield waveform chunks.""" |
|
|
with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size): |
|
|
for delta in deltas: |
|
|
yield from self.push_text_delta(delta) |
|
|
yield from self.finish(drain_step=drain_step) |
|
|
|
|
|
def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]: |
|
|
for frame in audio_frames: |
|
|
tokens = frame |
|
|
if tokens.dim() == 3: |
|
|
tokens = tokens[0] |
|
|
if tokens.dim() != 2: |
|
|
raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}") |
|
|
if tokens.shape[0] != 1: |
|
|
raise ValueError( |
|
|
f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}." |
|
|
) |
|
|
|
|
|
tokens, stop = _sanitize_audio_tokens( |
|
|
tokens, |
|
|
codebook_size=self.codebook_size, |
|
|
audio_eos_token=self.audio_eos_token, |
|
|
) |
|
|
if tokens.numel() == 0: |
|
|
if stop: |
|
|
break |
|
|
continue |
|
|
|
|
|
self.decoder.push_tokens(tokens.detach()) |
|
|
for wav in self.decoder.audio_chunks(): |
|
|
if wav.numel() == 0: |
|
|
continue |
|
|
yield wav.detach().cpu() |
|
|
if stop: |
|
|
break |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"AudioStreamDecoder", |
|
|
"MossTTSRealtimeInference", |
|
|
"MossTTSRealtimeStreamingSession", |
|
|
"MossTTSRealtimeTextStreamBridge", |
|
|
"TextDeltaTokenizer", |
|
|
] |
|
|
|