MOSS-TTS-Realtime-ONNX / test_basic_streaming-onnx.py
pltobing's picture
Add main scripts
41938cf
# Copyright 2026 Patrick Lumbantobing, Vertox-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end streaming TTS test script using ONNX Runtime.
This script demonstrates the full MOSS-TTS-Realtime ONNX pipeline by:
1. Loading four ONNX models (backbone LLM, local transformer, codec encoder,
codec decoder) into ONNX Runtime ``InferenceSession`` instances.
2. Encoding a reference audio prompt for voice cloning.
3. Simulating a streaming LLM text source (character-by-character deltas).
4. Running the streaming TTS pipeline to produce audio chunks.
5. Writing the concatenated audio to a WAV file.
Usage[with INT8 codec decoder]::
python test_basic_streaming-onnx.py \
--tokenizer_vocab_path tokenizers/tokenizer.json \
--tokenizer_config_path tokenizers/tokenizer_config.json \
--backbone_llm_path onnx_models/backbone_f32/backbone_f32.onnx \
--backbone_local_path onnx_models/local_transformer_f32/local_transformer_f32.onnx \
--codec_decoder_path onnx_models_quantized/codec_decoder_int8/codec_decoder_int8.onnx \
--codec_encoder_path onnx_models/codec_encoder/codec_encoder.onnx \
--backbone_config_path configs/config_backbone.json \
--codec_config_path configs/config_codec.json \
--prompt_wav audio_ref/speaker.[wav|flac|mp3] \
--out_wav output.wav
"""
import argparse
import json
import time
import wave
from pathlib import Path
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import onnxruntime as ort
from inferencer_onnx import MossTTSRealtimeInferenceONNX
from moss_text_tokenizer import MOSSTextTokenizer
NDArrayInt = npt.NDArray[np.int64]
NDArrayFloat = npt.NDArray[np.floating]
CODEC_SAMPLE_RATE = 24000
def fake_llm_text_stream(
text: str,
chunk_chars: int = 1,
delay_s: float = 0.0,
) -> Iterator[str]:
"""Simulate streaming text deltas from an LLM.
Each iteration yields ``chunk_chars`` characters with a delay of
``delay_s`` seconds. In real-world usage, this can be replaced with
streaming responses from models such as OpenAI or vLLM.
Parameters
----------
text : str
Full text to stream character-by-character.
chunk_chars : int, optional
Number of characters per delta (default ``1``).
delay_s : float, optional
Simulated delay in seconds between deltas (default ``0.0``).
Yields
------
str
A text delta of up to ``chunk_chars`` characters.
"""
if not text:
return
step = max(1, chunk_chars)
for idx in range(0, len(text), step):
if delay_s > 0 and idx > 0:
time.sleep(delay_s)
yield text[idx : idx + step]
def write_wav(out_path: Path, sample_rate: int, chunks: Iterator[np.ndarray]) -> None:
"""Collect audio chunks and write them to a 16-bit PCM WAV file.
Parameters
----------
out_path : Path
Output file path.
sample_rate : int
Sample rate in Hz.
chunks : Iterator[np.ndarray]
Iterator of float32 audio chunks in ``[-1, 1]`` range.
"""
all_chunks: list[np.ndarray] = []
for chunk in chunks:
all_chunks.append(chunk.astype(np.float32).reshape(-1))
if not all_chunks:
raise RuntimeError("No audio chunks produced.")
audio = np.concatenate(all_chunks)
# float32 → int16 PCM
audio = np.clip(audio, -1.0, 1.0)
pcm16 = (audio * 32767.0).astype(np.int16)
out_path.parent.mkdir(parents=True, exist_ok=True)
with wave.open(str(out_path), "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(int(sample_rate))
wf.writeframes(pcm16.tobytes())
def _sanitize_tokens(
tokens: NDArrayInt,
codebook_size: int,
eos_audio_id: int,
) -> Tuple[NDArrayInt, bool]:
"""Validate and truncate audio tokens at EOS or invalid code boundaries.
Parameters
----------
tokens : NDArrayInt
Audio token array of shape ``(T,)`` or ``(T, C)``.
codebook_size : int
Valid code range is ``[0, codebook_size)``.
eos_audio_id : int
End-of-sequence audio token ID.
Returns
-------
tuple[NDArrayInt, bool]
Sanitized tokens and a flag indicating whether truncation occurred.
"""
# Make sure tokens is 2D: (rows, codes)
if tokens.ndim == 1:
tokens = np.expand_dims(tokens, axis=0) # same as tokens[None, :]
if tokens.size == 0:
return tokens, False
# Rows whose first element is eos_audio_id
eos_rows = np.nonzero(tokens[:, 0] == eos_audio_id)[0] # 1D index array
# Rows that contain any invalid code
invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(axis=1) # axis instead of dim
invalid_rows_idx = np.nonzero(invalid_rows)[0]
stop_idx = None
if eos_rows.size > 0:
stop_idx = int(eos_rows[0])
if invalid_rows_idx.size > 0:
invalid_idx = int(invalid_rows_idx[0])
stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
if stop_idx is not None:
tokens = tokens[:stop_idx]
return tokens, True
return tokens, False
def decode_audio_frames(
audio_frames: list[NDArrayInt],
inferencer: MossTTSRealtimeInferenceONNX,
codebook_size: int,
eos_audio_id: int,
) -> Iterator[np.ndarray]:
"""Sanitize, buffer, and decode audio token frames into waveform chunks.
Parameters
----------
audio_frames : list[NDArrayInt]
List of audio token arrays from the backbone.
inferencer : MossTTSRealtimeInferenceONNX
The ONNX inference engine (used for ``push_tokens`` / ``audio_chunks``).
codebook_size : int
Valid code range for sanitization.
eos_audio_id : int
End-of-sequence audio token ID.
Yields
------
np.ndarray
Decoded waveform segments.
"""
if isinstance(audio_frames, np.ndarray):
audio_frames = [audio_frames]
for frame in audio_frames:
tokens = frame
if tokens.ndim == 3:
tokens = tokens[0]
if tokens.ndim != 2:
raise ValueError(f"Expected [T, C] audio tokens, got {tuple(tokens.shape)}")
print(f"tokens before sanitize {tokens} {tokens.shape}")
tokens, _ = _sanitize_tokens(tokens, codebook_size, eos_audio_id)
print(f"tokens after sanitize {tokens} {tokens.shape}")
if tokens.size == 0:
continue
inferencer.push_tokens(tokens)
for wav in inferencer.audio_chunks():
if wav.size == 0:
continue
print(f"decode_audio_frames wav {wav} {wav.shape}")
yield wav.reshape(-1)
def flush_decoder(inferencer: MossTTSRealtimeInferenceONNX) -> Iterator[np.ndarray]:
"""Flush the codec decoder buffer and yield any remaining audio.
Parameters
----------
inferencer : MossTTSRealtimeInferenceONNX
The ONNX inference engine.
Yields
------
np.ndarray
Final waveform segment, if any.
"""
final_chunk = inferencer.flush()
if final_chunk is not None and final_chunk.size > 0:
print(f"final_chunk flush {final_chunk} {final_chunk.shape}")
yield final_chunk.reshape(-1)
# Core: Streaming generation: text delta → push_text → audio
def run_streaming_tts(
inferencer: MossTTSRealtimeInferenceONNX,
text_deltas: Iterator[str],
) -> Iterator[np.ndarray]:
"""Receive streaming text deltas and produce playable WAV chunks in real time.
The pipeline matches the Gradio demo:
codec.streaming → push_text → decode_frames → end_text → drain → flush
Parameters
----------
inferencer : MossTTSRealtimeInferenceONNX
A fully initialized ONNX inferencer with ``reset_turn`` already called.
text_deltas : Iterator[str]
An iterator of text deltas (simulating LLM streaming output).
Yields
------
np.ndarray
Decoded waveform chunks suitable for playback or concatenation.
"""
codebook_size = inferencer.codebook_size
eos_audio_id = inferencer.eos_audio_id
for delta in text_deltas:
# print(delta, end="", flush=True)
print(f"delta {delta}")
audio_frames = inferencer.push_text(delta)
if len(audio_frames) > 0:
print(f"audio_frames {audio_frames} {len(audio_frames)} {audio_frames[0].shape}")
yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id)
audio_frames = inferencer.end_text()
if len(audio_frames) > 0:
print(f"audio_frames end_text {audio_frames} {len(audio_frames)} {audio_frames[0].shape}")
yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id)
while True:
audio_frames = inferencer.drain(max_steps=1)
if not audio_frames:
break
else:
print(f"audio_frames drain {audio_frames} {len(audio_frames)} {audio_frames[0].shape}")
yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id)
if inferencer.is_finished:
break
yield from flush_decoder(inferencer)
def main() -> None:
"""Entry point: parse arguments, load models, run streaming TTS, write WAV."""
p = argparse.ArgumentParser(description="Simulated LLM streaming text → TTS streaming audio。")
p.add_argument("--tokenizer_vocab_path", type=str, required=True)
p.add_argument("--tokenizer_config_path", type=str, required=True)
p.add_argument("--backbone_llm_path", type=str, required=True)
p.add_argument("--backbone_local_path", type=str, required=True)
p.add_argument("--codec_decoder_path", type=str, required=True)
p.add_argument("--codec_encoder_path", type=str, required=True)
p.add_argument("--backbone_config_path", type=str, required=True)
p.add_argument("--codec_config_path", type=str, required=True)
p.add_argument("--prompt_wav", type=str, required=True)
p.add_argument("--out_wav", type=str, default="out_streaming.wav")
p.add_argument("--sample_rate", type=int, default=CODEC_SAMPLE_RATE)
p.add_argument("--temperature", type=float, default=0.725)
p.add_argument("--top_p", type=float, default=0.6)
p.add_argument("--top_k", type=int, default=34)
p.add_argument("--repetition_penalty", type=float, default=1.9)
p.add_argument("--repetition_window", type=int, default=50)
p.add_argument("--max_length", type=int, default=5000)
# 模拟 LLM streaming 参数
p.add_argument(
"--delta_chunk_chars", type=int, default=1, help="Number of characters to output at each delta (1 = verbatim)"
)
p.add_argument(
"--delta_delay_s", type=float, default=0.0, help="Simulated delay in seconds between deltas, let 0 = no delay"
)
p.add_argument(
"--assistant_text",
type=str,
default=(
"в зависимости от времени не только точность, но и низкая задержка. Если это не мгновенно, то человеческое взаимодействие теряется. Мы наконец-то достигаем момента, когда технология достаточно быстра для того, чтобы люди просто общались, и это является огромным сдвигом для глобального бизнеса."
),
)
args = p.parse_args()
tokenizer = MOSSTextTokenizer(args.tokenizer_vocab_path, args.tokenizer_config_path)
print(f"tokenizer {tokenizer} {args.tokenizer_vocab_path} {args.tokenizer_config_path}")
backbone_llm = ort.InferenceSession(
args.backbone_llm_path,
providers=["CPUExecutionProvider"],
)
print(f"backbone_llm {backbone_llm} {args.backbone_llm_path}")
backbone_local = ort.InferenceSession(
args.backbone_local_path,
providers=["CPUExecutionProvider"],
)
print(f"backbone_local {backbone_local} {args.backbone_local_path}")
codec_decoder = ort.InferenceSession(
args.codec_decoder_path,
providers=["CPUExecutionProvider"],
)
print(f"codec_decoder {codec_decoder} {args.codec_decoder_path}")
codec_encoder = ort.InferenceSession(
args.codec_encoder_path,
providers=["CPUExecutionProvider"],
)
print(f"codec_encoder {codec_encoder} {args.codec_encoder_path}")
with open(args.backbone_config_path, "r") as f:
backbone_config = json.load(f)
print(f"backbone_config {backbone_config} {args.backbone_config_path}")
with open(args.codec_config_path, "r") as f:
codec_config = json.load(f)
print(f"codec_config {codec_config} {args.codec_config_path}")
inferencer = MossTTSRealtimeInferenceONNX(
tokenizer,
backbone_llm,
backbone_local,
codec_decoder,
codec_encoder,
backbone_config,
codec_config,
max_length=args.max_length,
codec_sample_rate=CODEC_SAMPLE_RATE,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
repetition_window=args.repetition_window,
)
print("Inferencer loaded.")
print("Extracting audio prompt...")
prompt_tokens = inferencer._encode_reference_audio(args.prompt_wav)
print(f"prompt_tokens {prompt_tokens} {prompt_tokens.shape}")
# ── Build input_ids without the user turn: system_prompt + assistant prefix ──
print("Loading input ids...")
input_ids = inferencer.processor.make_ensemble(prompt_tokens.squeeze(1))
print(f"input_ids {input_ids} {input_ids.shape}")
inferencer.reset_turn(
input_ids=input_ids,
include_system_prompt=False,
reset_cache=True,
)
print("Input ids loaded.")
text = args.assistant_text
text_deltas = fake_llm_text_stream(
text,
chunk_chars=args.delta_chunk_chars,
delay_s=args.delta_delay_s,
)
print("Running streaming tts simulation...")
wav_chunks = run_streaming_tts(
inferencer=inferencer,
text_deltas=text_deltas,
)
print("Done.")
out_path = Path(args.out_wav).expanduser()
write_wav(out_path, args.sample_rate, wav_chunks)
print(f"\n[OK] Write complete: {out_path}")
if __name__ == "__main__":
main()