# Copyright 2026 Patrick Lumbantobing, Vertox-AI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """End-to-end streaming TTS test script using ONNX Runtime. This script demonstrates the full MOSS-TTS-Realtime ONNX pipeline by: 1. Loading four ONNX models (backbone LLM, local transformer, codec encoder, codec decoder) into ONNX Runtime ``InferenceSession`` instances. 2. Encoding a reference audio prompt for voice cloning. 3. Simulating a streaming LLM text source (character-by-character deltas). 4. Running the streaming TTS pipeline to produce audio chunks. 5. Writing the concatenated audio to a WAV file. Usage[with INT8 codec decoder]:: python test_basic_streaming-onnx.py \ --tokenizer_vocab_path tokenizers/tokenizer.json \ --tokenizer_config_path tokenizers/tokenizer_config.json \ --backbone_llm_path onnx_models/backbone_f32/backbone_f32.onnx \ --backbone_local_path onnx_models/local_transformer_f32/local_transformer_f32.onnx \ --codec_decoder_path onnx_models_quantized/codec_decoder_int8/codec_decoder_int8.onnx \ --codec_encoder_path onnx_models/codec_encoder/codec_encoder.onnx \ --backbone_config_path configs/config_backbone.json \ --codec_config_path configs/config_codec.json \ --prompt_wav audio_ref/speaker.[wav|flac|mp3] \ --out_wav output.wav """ import argparse import json import time import wave from pathlib import Path from typing import Iterator, Tuple import numpy as np import numpy.typing as npt import onnxruntime as ort from inferencer_onnx import MossTTSRealtimeInferenceONNX from moss_text_tokenizer import MOSSTextTokenizer NDArrayInt = npt.NDArray[np.int64] NDArrayFloat = npt.NDArray[np.floating] CODEC_SAMPLE_RATE = 24000 def fake_llm_text_stream( text: str, chunk_chars: int = 1, delay_s: float = 0.0, ) -> Iterator[str]: """Simulate streaming text deltas from an LLM. Each iteration yields ``chunk_chars`` characters with a delay of ``delay_s`` seconds. In real-world usage, this can be replaced with streaming responses from models such as OpenAI or vLLM. Parameters ---------- text : str Full text to stream character-by-character. chunk_chars : int, optional Number of characters per delta (default ``1``). delay_s : float, optional Simulated delay in seconds between deltas (default ``0.0``). Yields ------ str A text delta of up to ``chunk_chars`` characters. """ if not text: return step = max(1, chunk_chars) for idx in range(0, len(text), step): if delay_s > 0 and idx > 0: time.sleep(delay_s) yield text[idx : idx + step] def write_wav(out_path: Path, sample_rate: int, chunks: Iterator[np.ndarray]) -> None: """Collect audio chunks and write them to a 16-bit PCM WAV file. Parameters ---------- out_path : Path Output file path. sample_rate : int Sample rate in Hz. chunks : Iterator[np.ndarray] Iterator of float32 audio chunks in ``[-1, 1]`` range. """ all_chunks: list[np.ndarray] = [] for chunk in chunks: all_chunks.append(chunk.astype(np.float32).reshape(-1)) if not all_chunks: raise RuntimeError("No audio chunks produced.") audio = np.concatenate(all_chunks) # float32 → int16 PCM audio = np.clip(audio, -1.0, 1.0) pcm16 = (audio * 32767.0).astype(np.int16) out_path.parent.mkdir(parents=True, exist_ok=True) with wave.open(str(out_path), "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(int(sample_rate)) wf.writeframes(pcm16.tobytes()) def _sanitize_tokens( tokens: NDArrayInt, codebook_size: int, eos_audio_id: int, ) -> Tuple[NDArrayInt, bool]: """Validate and truncate audio tokens at EOS or invalid code boundaries. Parameters ---------- tokens : NDArrayInt Audio token array of shape ``(T,)`` or ``(T, C)``. codebook_size : int Valid code range is ``[0, codebook_size)``. eos_audio_id : int End-of-sequence audio token ID. Returns ------- tuple[NDArrayInt, bool] Sanitized tokens and a flag indicating whether truncation occurred. """ # Make sure tokens is 2D: (rows, codes) if tokens.ndim == 1: tokens = np.expand_dims(tokens, axis=0) # same as tokens[None, :] if tokens.size == 0: return tokens, False # Rows whose first element is eos_audio_id eos_rows = np.nonzero(tokens[:, 0] == eos_audio_id)[0] # 1D index array # Rows that contain any invalid code invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(axis=1) # axis instead of dim invalid_rows_idx = np.nonzero(invalid_rows)[0] stop_idx = None if eos_rows.size > 0: stop_idx = int(eos_rows[0]) if invalid_rows_idx.size > 0: invalid_idx = int(invalid_rows_idx[0]) stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx) if stop_idx is not None: tokens = tokens[:stop_idx] return tokens, True return tokens, False def decode_audio_frames( audio_frames: list[NDArrayInt], inferencer: MossTTSRealtimeInferenceONNX, codebook_size: int, eos_audio_id: int, ) -> Iterator[np.ndarray]: """Sanitize, buffer, and decode audio token frames into waveform chunks. Parameters ---------- audio_frames : list[NDArrayInt] List of audio token arrays from the backbone. inferencer : MossTTSRealtimeInferenceONNX The ONNX inference engine (used for ``push_tokens`` / ``audio_chunks``). codebook_size : int Valid code range for sanitization. eos_audio_id : int End-of-sequence audio token ID. Yields ------ np.ndarray Decoded waveform segments. """ if isinstance(audio_frames, np.ndarray): audio_frames = [audio_frames] for frame in audio_frames: tokens = frame if tokens.ndim == 3: tokens = tokens[0] if tokens.ndim != 2: raise ValueError(f"Expected [T, C] audio tokens, got {tuple(tokens.shape)}") print(f"tokens before sanitize {tokens} {tokens.shape}") tokens, _ = _sanitize_tokens(tokens, codebook_size, eos_audio_id) print(f"tokens after sanitize {tokens} {tokens.shape}") if tokens.size == 0: continue inferencer.push_tokens(tokens) for wav in inferencer.audio_chunks(): if wav.size == 0: continue print(f"decode_audio_frames wav {wav} {wav.shape}") yield wav.reshape(-1) def flush_decoder(inferencer: MossTTSRealtimeInferenceONNX) -> Iterator[np.ndarray]: """Flush the codec decoder buffer and yield any remaining audio. Parameters ---------- inferencer : MossTTSRealtimeInferenceONNX The ONNX inference engine. Yields ------ np.ndarray Final waveform segment, if any. """ final_chunk = inferencer.flush() if final_chunk is not None and final_chunk.size > 0: print(f"final_chunk flush {final_chunk} {final_chunk.shape}") yield final_chunk.reshape(-1) # Core: Streaming generation: text delta → push_text → audio def run_streaming_tts( inferencer: MossTTSRealtimeInferenceONNX, text_deltas: Iterator[str], ) -> Iterator[np.ndarray]: """Receive streaming text deltas and produce playable WAV chunks in real time. The pipeline matches the Gradio demo: codec.streaming → push_text → decode_frames → end_text → drain → flush Parameters ---------- inferencer : MossTTSRealtimeInferenceONNX A fully initialized ONNX inferencer with ``reset_turn`` already called. text_deltas : Iterator[str] An iterator of text deltas (simulating LLM streaming output). Yields ------ np.ndarray Decoded waveform chunks suitable for playback or concatenation. """ codebook_size = inferencer.codebook_size eos_audio_id = inferencer.eos_audio_id for delta in text_deltas: # print(delta, end="", flush=True) print(f"delta {delta}") audio_frames = inferencer.push_text(delta) if len(audio_frames) > 0: print(f"audio_frames {audio_frames} {len(audio_frames)} {audio_frames[0].shape}") yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id) audio_frames = inferencer.end_text() if len(audio_frames) > 0: print(f"audio_frames end_text {audio_frames} {len(audio_frames)} {audio_frames[0].shape}") yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id) while True: audio_frames = inferencer.drain(max_steps=1) if not audio_frames: break else: print(f"audio_frames drain {audio_frames} {len(audio_frames)} {audio_frames[0].shape}") yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id) if inferencer.is_finished: break yield from flush_decoder(inferencer) def main() -> None: """Entry point: parse arguments, load models, run streaming TTS, write WAV.""" p = argparse.ArgumentParser(description="Simulated LLM streaming text → TTS streaming audio。") p.add_argument("--tokenizer_vocab_path", type=str, required=True) p.add_argument("--tokenizer_config_path", type=str, required=True) p.add_argument("--backbone_llm_path", type=str, required=True) p.add_argument("--backbone_local_path", type=str, required=True) p.add_argument("--codec_decoder_path", type=str, required=True) p.add_argument("--codec_encoder_path", type=str, required=True) p.add_argument("--backbone_config_path", type=str, required=True) p.add_argument("--codec_config_path", type=str, required=True) p.add_argument("--prompt_wav", type=str, required=True) p.add_argument("--out_wav", type=str, default="out_streaming.wav") p.add_argument("--sample_rate", type=int, default=CODEC_SAMPLE_RATE) p.add_argument("--temperature", type=float, default=0.725) p.add_argument("--top_p", type=float, default=0.6) p.add_argument("--top_k", type=int, default=34) p.add_argument("--repetition_penalty", type=float, default=1.9) p.add_argument("--repetition_window", type=int, default=50) p.add_argument("--max_length", type=int, default=5000) # 模拟 LLM streaming 参数 p.add_argument( "--delta_chunk_chars", type=int, default=1, help="Number of characters to output at each delta (1 = verbatim)" ) p.add_argument( "--delta_delay_s", type=float, default=0.0, help="Simulated delay in seconds between deltas, let 0 = no delay" ) p.add_argument( "--assistant_text", type=str, default=( "в зависимости от времени не только точность, но и низкая задержка. Если это не мгновенно, то человеческое взаимодействие теряется. Мы наконец-то достигаем момента, когда технология достаточно быстра для того, чтобы люди просто общались, и это является огромным сдвигом для глобального бизнеса." ), ) args = p.parse_args() tokenizer = MOSSTextTokenizer(args.tokenizer_vocab_path, args.tokenizer_config_path) print(f"tokenizer {tokenizer} {args.tokenizer_vocab_path} {args.tokenizer_config_path}") backbone_llm = ort.InferenceSession( args.backbone_llm_path, providers=["CPUExecutionProvider"], ) print(f"backbone_llm {backbone_llm} {args.backbone_llm_path}") backbone_local = ort.InferenceSession( args.backbone_local_path, providers=["CPUExecutionProvider"], ) print(f"backbone_local {backbone_local} {args.backbone_local_path}") codec_decoder = ort.InferenceSession( args.codec_decoder_path, providers=["CPUExecutionProvider"], ) print(f"codec_decoder {codec_decoder} {args.codec_decoder_path}") codec_encoder = ort.InferenceSession( args.codec_encoder_path, providers=["CPUExecutionProvider"], ) print(f"codec_encoder {codec_encoder} {args.codec_encoder_path}") with open(args.backbone_config_path, "r") as f: backbone_config = json.load(f) print(f"backbone_config {backbone_config} {args.backbone_config_path}") with open(args.codec_config_path, "r") as f: codec_config = json.load(f) print(f"codec_config {codec_config} {args.codec_config_path}") inferencer = MossTTSRealtimeInferenceONNX( tokenizer, backbone_llm, backbone_local, codec_decoder, codec_encoder, backbone_config, codec_config, max_length=args.max_length, codec_sample_rate=CODEC_SAMPLE_RATE, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_window=args.repetition_window, ) print("Inferencer loaded.") print("Extracting audio prompt...") prompt_tokens = inferencer._encode_reference_audio(args.prompt_wav) print(f"prompt_tokens {prompt_tokens} {prompt_tokens.shape}") # ── Build input_ids without the user turn: system_prompt + assistant prefix ── print("Loading input ids...") input_ids = inferencer.processor.make_ensemble(prompt_tokens.squeeze(1)) print(f"input_ids {input_ids} {input_ids.shape}") inferencer.reset_turn( input_ids=input_ids, include_system_prompt=False, reset_cache=True, ) print("Input ids loaded.") text = args.assistant_text text_deltas = fake_llm_text_stream( text, chunk_chars=args.delta_chunk_chars, delay_s=args.delta_delay_s, ) print("Running streaming tts simulation...") wav_chunks = run_streaming_tts( inferencer=inferencer, text_deltas=text_deltas, ) print("Done.") out_path = Path(args.out_wav).expanduser() write_wav(out_path, args.sample_rate, wav_chunks) print(f"\n[OK] Write complete: {out_path}") if __name__ == "__main__": main()