| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """End-to-end streaming TTS test script using ONNX Runtime. |
| | |
| | This script demonstrates the full MOSS-TTS-Realtime ONNX pipeline by: |
| | |
| | 1. Loading four ONNX models (backbone LLM, local transformer, codec encoder, |
| | codec decoder) into ONNX Runtime ``InferenceSession`` instances. |
| | 2. Encoding a reference audio prompt for voice cloning. |
| | 3. Simulating a streaming LLM text source (character-by-character deltas). |
| | 4. Running the streaming TTS pipeline to produce audio chunks. |
| | 5. Writing the concatenated audio to a WAV file. |
| | |
| | Usage[with INT8 codec decoder]:: |
| | |
| | python test_basic_streaming-onnx.py \ |
| | --tokenizer_vocab_path tokenizers/tokenizer.json \ |
| | --tokenizer_config_path tokenizers/tokenizer_config.json \ |
| | --backbone_llm_path onnx_models/backbone_f32/backbone_f32.onnx \ |
| | --backbone_local_path onnx_models/local_transformer_f32/local_transformer_f32.onnx \ |
| | --codec_decoder_path onnx_models_quantized/codec_decoder_int8/codec_decoder_int8.onnx \ |
| | --codec_encoder_path onnx_models/codec_encoder/codec_encoder.onnx \ |
| | --backbone_config_path configs/config_backbone.json \ |
| | --codec_config_path configs/config_codec.json \ |
| | --prompt_wav audio_ref/speaker.[wav|flac|mp3] \ |
| | --out_wav output.wav |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import time |
| | import wave |
| | from pathlib import Path |
| | from typing import Iterator, Tuple |
| |
|
| | import numpy as np |
| | import numpy.typing as npt |
| | import onnxruntime as ort |
| |
|
| | from inferencer_onnx import MossTTSRealtimeInferenceONNX |
| | from moss_text_tokenizer import MOSSTextTokenizer |
| |
|
| | NDArrayInt = npt.NDArray[np.int64] |
| | NDArrayFloat = npt.NDArray[np.floating] |
| | CODEC_SAMPLE_RATE = 24000 |
| |
|
| |
|
| | def fake_llm_text_stream( |
| | text: str, |
| | chunk_chars: int = 1, |
| | delay_s: float = 0.0, |
| | ) -> Iterator[str]: |
| | """Simulate streaming text deltas from an LLM. |
| | |
| | Each iteration yields ``chunk_chars`` characters with a delay of |
| | ``delay_s`` seconds. In real-world usage, this can be replaced with |
| | streaming responses from models such as OpenAI or vLLM. |
| | |
| | Parameters |
| | ---------- |
| | text : str |
| | Full text to stream character-by-character. |
| | chunk_chars : int, optional |
| | Number of characters per delta (default ``1``). |
| | delay_s : float, optional |
| | Simulated delay in seconds between deltas (default ``0.0``). |
| | |
| | Yields |
| | ------ |
| | str |
| | A text delta of up to ``chunk_chars`` characters. |
| | """ |
| | if not text: |
| | return |
| | step = max(1, chunk_chars) |
| | for idx in range(0, len(text), step): |
| | if delay_s > 0 and idx > 0: |
| | time.sleep(delay_s) |
| | yield text[idx : idx + step] |
| |
|
| |
|
| | def write_wav(out_path: Path, sample_rate: int, chunks: Iterator[np.ndarray]) -> None: |
| | """Collect audio chunks and write them to a 16-bit PCM WAV file. |
| | |
| | Parameters |
| | ---------- |
| | out_path : Path |
| | Output file path. |
| | sample_rate : int |
| | Sample rate in Hz. |
| | chunks : Iterator[np.ndarray] |
| | Iterator of float32 audio chunks in ``[-1, 1]`` range. |
| | """ |
| | all_chunks: list[np.ndarray] = [] |
| | for chunk in chunks: |
| | all_chunks.append(chunk.astype(np.float32).reshape(-1)) |
| |
|
| | if not all_chunks: |
| | raise RuntimeError("No audio chunks produced.") |
| |
|
| | audio = np.concatenate(all_chunks) |
| | |
| | audio = np.clip(audio, -1.0, 1.0) |
| | pcm16 = (audio * 32767.0).astype(np.int16) |
| |
|
| | out_path.parent.mkdir(parents=True, exist_ok=True) |
| | with wave.open(str(out_path), "wb") as wf: |
| | wf.setnchannels(1) |
| | wf.setsampwidth(2) |
| | wf.setframerate(int(sample_rate)) |
| | wf.writeframes(pcm16.tobytes()) |
| |
|
| |
|
| | def _sanitize_tokens( |
| | tokens: NDArrayInt, |
| | codebook_size: int, |
| | eos_audio_id: int, |
| | ) -> Tuple[NDArrayInt, bool]: |
| | """Validate and truncate audio tokens at EOS or invalid code boundaries. |
| | |
| | Parameters |
| | ---------- |
| | tokens : NDArrayInt |
| | Audio token array of shape ``(T,)`` or ``(T, C)``. |
| | codebook_size : int |
| | Valid code range is ``[0, codebook_size)``. |
| | eos_audio_id : int |
| | End-of-sequence audio token ID. |
| | |
| | Returns |
| | ------- |
| | tuple[NDArrayInt, bool] |
| | Sanitized tokens and a flag indicating whether truncation occurred. |
| | """ |
| | |
| | if tokens.ndim == 1: |
| | tokens = np.expand_dims(tokens, axis=0) |
| |
|
| | if tokens.size == 0: |
| | return tokens, False |
| |
|
| | |
| | eos_rows = np.nonzero(tokens[:, 0] == eos_audio_id)[0] |
| |
|
| | |
| | invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(axis=1) |
| | invalid_rows_idx = np.nonzero(invalid_rows)[0] |
| |
|
| | stop_idx = None |
| | if eos_rows.size > 0: |
| | stop_idx = int(eos_rows[0]) |
| |
|
| | if invalid_rows_idx.size > 0: |
| | invalid_idx = int(invalid_rows_idx[0]) |
| | stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx) |
| |
|
| | if stop_idx is not None: |
| | tokens = tokens[:stop_idx] |
| | return tokens, True |
| |
|
| | return tokens, False |
| |
|
| |
|
| | def decode_audio_frames( |
| | audio_frames: list[NDArrayInt], |
| | inferencer: MossTTSRealtimeInferenceONNX, |
| | codebook_size: int, |
| | eos_audio_id: int, |
| | ) -> Iterator[np.ndarray]: |
| | """Sanitize, buffer, and decode audio token frames into waveform chunks. |
| | |
| | Parameters |
| | ---------- |
| | audio_frames : list[NDArrayInt] |
| | List of audio token arrays from the backbone. |
| | inferencer : MossTTSRealtimeInferenceONNX |
| | The ONNX inference engine (used for ``push_tokens`` / ``audio_chunks``). |
| | codebook_size : int |
| | Valid code range for sanitization. |
| | eos_audio_id : int |
| | End-of-sequence audio token ID. |
| | |
| | Yields |
| | ------ |
| | np.ndarray |
| | Decoded waveform segments. |
| | """ |
| | if isinstance(audio_frames, np.ndarray): |
| | audio_frames = [audio_frames] |
| | for frame in audio_frames: |
| | tokens = frame |
| | if tokens.ndim == 3: |
| | tokens = tokens[0] |
| | if tokens.ndim != 2: |
| | raise ValueError(f"Expected [T, C] audio tokens, got {tuple(tokens.shape)}") |
| | print(f"tokens before sanitize {tokens} {tokens.shape}") |
| | tokens, _ = _sanitize_tokens(tokens, codebook_size, eos_audio_id) |
| | print(f"tokens after sanitize {tokens} {tokens.shape}") |
| | if tokens.size == 0: |
| | continue |
| | inferencer.push_tokens(tokens) |
| | for wav in inferencer.audio_chunks(): |
| | if wav.size == 0: |
| | continue |
| | print(f"decode_audio_frames wav {wav} {wav.shape}") |
| | yield wav.reshape(-1) |
| |
|
| |
|
| | def flush_decoder(inferencer: MossTTSRealtimeInferenceONNX) -> Iterator[np.ndarray]: |
| | """Flush the codec decoder buffer and yield any remaining audio. |
| | |
| | Parameters |
| | ---------- |
| | inferencer : MossTTSRealtimeInferenceONNX |
| | The ONNX inference engine. |
| | |
| | Yields |
| | ------ |
| | np.ndarray |
| | Final waveform segment, if any. |
| | """ |
| | final_chunk = inferencer.flush() |
| | if final_chunk is not None and final_chunk.size > 0: |
| | print(f"final_chunk flush {final_chunk} {final_chunk.shape}") |
| | yield final_chunk.reshape(-1) |
| |
|
| |
|
| | |
| | def run_streaming_tts( |
| | inferencer: MossTTSRealtimeInferenceONNX, |
| | text_deltas: Iterator[str], |
| | ) -> Iterator[np.ndarray]: |
| | """Receive streaming text deltas and produce playable WAV chunks in real time. |
| | |
| | The pipeline matches the Gradio demo: |
| | codec.streaming → push_text → decode_frames → end_text → drain → flush |
| | |
| | Parameters |
| | ---------- |
| | inferencer : MossTTSRealtimeInferenceONNX |
| | A fully initialized ONNX inferencer with ``reset_turn`` already called. |
| | text_deltas : Iterator[str] |
| | An iterator of text deltas (simulating LLM streaming output). |
| | |
| | Yields |
| | ------ |
| | np.ndarray |
| | Decoded waveform chunks suitable for playback or concatenation. |
| | """ |
| | codebook_size = inferencer.codebook_size |
| | eos_audio_id = inferencer.eos_audio_id |
| |
|
| | for delta in text_deltas: |
| | |
| | print(f"delta {delta}") |
| | audio_frames = inferencer.push_text(delta) |
| | if len(audio_frames) > 0: |
| | print(f"audio_frames {audio_frames} {len(audio_frames)} {audio_frames[0].shape}") |
| | yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id) |
| |
|
| | audio_frames = inferencer.end_text() |
| | if len(audio_frames) > 0: |
| | print(f"audio_frames end_text {audio_frames} {len(audio_frames)} {audio_frames[0].shape}") |
| | yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id) |
| |
|
| | while True: |
| | audio_frames = inferencer.drain(max_steps=1) |
| | if not audio_frames: |
| | break |
| | else: |
| | print(f"audio_frames drain {audio_frames} {len(audio_frames)} {audio_frames[0].shape}") |
| | yield from decode_audio_frames(audio_frames, inferencer, codebook_size, eos_audio_id) |
| | if inferencer.is_finished: |
| | break |
| |
|
| | yield from flush_decoder(inferencer) |
| |
|
| |
|
| | def main() -> None: |
| | """Entry point: parse arguments, load models, run streaming TTS, write WAV.""" |
| | p = argparse.ArgumentParser(description="Simulated LLM streaming text → TTS streaming audio。") |
| | p.add_argument("--tokenizer_vocab_path", type=str, required=True) |
| | p.add_argument("--tokenizer_config_path", type=str, required=True) |
| | p.add_argument("--backbone_llm_path", type=str, required=True) |
| | p.add_argument("--backbone_local_path", type=str, required=True) |
| | p.add_argument("--codec_decoder_path", type=str, required=True) |
| | p.add_argument("--codec_encoder_path", type=str, required=True) |
| | p.add_argument("--backbone_config_path", type=str, required=True) |
| | p.add_argument("--codec_config_path", type=str, required=True) |
| | p.add_argument("--prompt_wav", type=str, required=True) |
| | p.add_argument("--out_wav", type=str, default="out_streaming.wav") |
| | p.add_argument("--sample_rate", type=int, default=CODEC_SAMPLE_RATE) |
| | p.add_argument("--temperature", type=float, default=0.725) |
| | p.add_argument("--top_p", type=float, default=0.6) |
| | p.add_argument("--top_k", type=int, default=34) |
| | p.add_argument("--repetition_penalty", type=float, default=1.9) |
| | p.add_argument("--repetition_window", type=int, default=50) |
| | p.add_argument("--max_length", type=int, default=5000) |
| | |
| | p.add_argument( |
| | "--delta_chunk_chars", type=int, default=1, help="Number of characters to output at each delta (1 = verbatim)" |
| | ) |
| | p.add_argument( |
| | "--delta_delay_s", type=float, default=0.0, help="Simulated delay in seconds between deltas, let 0 = no delay" |
| | ) |
| | p.add_argument( |
| | "--assistant_text", |
| | type=str, |
| | default=( |
| | "в зависимости от времени не только точность, но и низкая задержка. Если это не мгновенно, то человеческое взаимодействие теряется. Мы наконец-то достигаем момента, когда технология достаточно быстра для того, чтобы люди просто общались, и это является огромным сдвигом для глобального бизнеса." |
| | ), |
| | ) |
| |
|
| | args = p.parse_args() |
| |
|
| | tokenizer = MOSSTextTokenizer(args.tokenizer_vocab_path, args.tokenizer_config_path) |
| | print(f"tokenizer {tokenizer} {args.tokenizer_vocab_path} {args.tokenizer_config_path}") |
| | backbone_llm = ort.InferenceSession( |
| | args.backbone_llm_path, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| | print(f"backbone_llm {backbone_llm} {args.backbone_llm_path}") |
| | backbone_local = ort.InferenceSession( |
| | args.backbone_local_path, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| | print(f"backbone_local {backbone_local} {args.backbone_local_path}") |
| | codec_decoder = ort.InferenceSession( |
| | args.codec_decoder_path, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| | print(f"codec_decoder {codec_decoder} {args.codec_decoder_path}") |
| | codec_encoder = ort.InferenceSession( |
| | args.codec_encoder_path, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| | print(f"codec_encoder {codec_encoder} {args.codec_encoder_path}") |
| | with open(args.backbone_config_path, "r") as f: |
| | backbone_config = json.load(f) |
| | print(f"backbone_config {backbone_config} {args.backbone_config_path}") |
| | with open(args.codec_config_path, "r") as f: |
| | codec_config = json.load(f) |
| | print(f"codec_config {codec_config} {args.codec_config_path}") |
| |
|
| | inferencer = MossTTSRealtimeInferenceONNX( |
| | tokenizer, |
| | backbone_llm, |
| | backbone_local, |
| | codec_decoder, |
| | codec_encoder, |
| | backbone_config, |
| | codec_config, |
| | max_length=args.max_length, |
| | codec_sample_rate=CODEC_SAMPLE_RATE, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | top_k=args.top_k, |
| | repetition_penalty=args.repetition_penalty, |
| | repetition_window=args.repetition_window, |
| | ) |
| | print("Inferencer loaded.") |
| |
|
| | print("Extracting audio prompt...") |
| | prompt_tokens = inferencer._encode_reference_audio(args.prompt_wav) |
| | print(f"prompt_tokens {prompt_tokens} {prompt_tokens.shape}") |
| | |
| | print("Loading input ids...") |
| | input_ids = inferencer.processor.make_ensemble(prompt_tokens.squeeze(1)) |
| | print(f"input_ids {input_ids} {input_ids.shape}") |
| | inferencer.reset_turn( |
| | input_ids=input_ids, |
| | include_system_prompt=False, |
| | reset_cache=True, |
| | ) |
| | print("Input ids loaded.") |
| |
|
| | text = args.assistant_text |
| | text_deltas = fake_llm_text_stream( |
| | text, |
| | chunk_chars=args.delta_chunk_chars, |
| | delay_s=args.delta_delay_s, |
| | ) |
| |
|
| | print("Running streaming tts simulation...") |
| | wav_chunks = run_streaming_tts( |
| | inferencer=inferencer, |
| | text_deltas=text_deltas, |
| | ) |
| | print("Done.") |
| |
|
| | out_path = Path(args.out_wav).expanduser() |
| | write_wav(out_path, args.sample_rate, wav_chunks) |
| | print(f"\n[OK] Write complete: {out_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|