| """ |
| Streaming ONNX inference for Pocket-TTS. |
| |
| This script uses the streaming ONNX export with KV cache states, |
| which correctly handles the base model inference. |
| |
| Pipeline: |
| 1. Tokenize text |
| 2. (Optional) Encode voice reference audio |
| 3. Voice conditioning pass (update KV cache) |
| 4. Text conditioning pass (update KV cache) |
| 5. Autoregressive generation (single latent + state update per step) |
| 6. Mimi streaming decode (chunk-by-chunk with state) |
| |
| Usage: |
| # Basic generation |
| python final_inference_scripts/inference_onnx_streaming.py --text "Hello world" |
| |
| # With voice cloning |
| python final_inference_scripts/inference_onnx_streaming.py --text "Hello world" --voice reference.wav |
| |
| # With INT8 models |
| python final_inference_scripts/inference_onnx_streaming.py --text "Hello world" --int8 |
| """ |
|
|
| import sys |
| import os |
| import json |
| import time |
| import argparse |
| import queue |
| import threading |
| from pathlib import Path |
| from typing import Optional, Generator, Union |
|
|
| import numpy as np |
|
|
| try: |
| import soundfile as sf |
| HAS_SOUNDFILE = True |
| except ImportError: |
| HAS_SOUNDFILE = False |
|
|
| try: |
| import scipy.signal |
| HAS_SCIPY = True |
| except ImportError: |
| HAS_SCIPY = False |
|
|
|
|
| class PocketTTSStreamingONNX: |
| """Streaming ONNX inference engine for Pocket-TTS. |
| |
| Uses KV cache states for proper streaming inference that works |
| with both base and merged models. |
| """ |
| |
| SAMPLE_RATE = 24000 |
| SAMPLES_PER_FRAME = 1920 |
| FRAME_DURATION = SAMPLES_PER_FRAME / SAMPLE_RATE |
|
|
| def __init__( |
| self, |
| models_dir: str = "onnx_export_glm", |
| tokenizer_path: Optional[str] = None, |
| use_int8: bool = False, |
| temperature: float = 0.7, |
| lsd_steps: int = 10, |
| ): |
| self.models_dir = Path(models_dir) |
| self.use_int8 = use_int8 |
| self.temperature = temperature |
| self.lsd_steps = lsd_steps |
| |
| import onnxruntime as ort |
| import sentencepiece as spm |
| |
| self.providers = ["CPUExecutionProvider"] |
| available = ort.get_available_providers() |
| if "CUDAExecutionProvider" in available: |
| self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| |
| sess_opts = ort.SessionOptions() |
| sess_opts.intra_op_num_threads = min(os.cpu_count() or 4, 4) |
| sess_opts.inter_op_num_threads = 1 |
| |
| tokenizer_path = tokenizer_path or self.models_dir / "tokenizer.model" |
| self.tokenizer = spm.SentencePieceProcessor() |
| self.tokenizer.Load(str(tokenizer_path)) |
| |
| with open(self.models_dir / "model_config.json", "r") as f: |
| self.config = json.load(f) |
| |
| self._precompute_flow_buffers() |
| |
| suffix = "_int8" if use_int8 else "" |
| |
| def get_path(base): |
| p = self.models_dir / f"{base}{suffix}.onnx" |
| if p.exists(): |
| return str(p) |
| return str(self.models_dir / f"{base}.onnx") |
| |
| def get_path_multi(bases): |
| for base in bases: |
| p = self.models_dir / f"{base}{suffix}.onnx" |
| if p.exists(): |
| return str(p) |
| p = self.models_dir / f"{base}.onnx" |
| if p.exists(): |
| return str(p) |
| return str(self.models_dir / f"{bases[0]}.onnx") |
| |
| print(f"Loading models from {self.models_dir}...") |
| |
| self.text_conditioner = ort.InferenceSession( |
| get_path("text_conditioner"), sess_opts, providers=self.providers |
| ) |
| self.flow_lm_main = ort.InferenceSession( |
| get_path_multi(["backbone", "flow_lm_main"]), sess_opts, providers=self.providers |
| ) |
| self.flow_lm_flow = ort.InferenceSession( |
| get_path("flow_lm_flow"), sess_opts, providers=self.providers |
| ) |
| self.mimi_decoder = ort.InferenceSession( |
| get_path("mimi_decoder"), sess_opts, providers=self.providers |
| ) |
| |
| encoder_path = get_path("mimi_encoder") |
| if os.path.exists(encoder_path): |
| self.mimi_encoder = ort.InferenceSession( |
| encoder_path, sess_opts, providers=self.providers |
| ) |
| else: |
| self.mimi_encoder = None |
| print(" Note: mimi_encoder not found, voice cloning unavailable") |
| |
| flow_inputs = {inp.name: inp.shape for inp in self.flow_lm_flow.get_inputs()} |
| if "c" in flow_inputs: |
| c_shape = flow_inputs["c"] |
| if len(c_shape) == 2: |
| self._flow_format = "kevinahmm" |
| else: |
| self._flow_format = "standard" |
| else: |
| self._flow_format = "standard" |
| |
| print(f" Flow format: {self._flow_format}") |
| print(" Models loaded.") |
|
|
| def _precompute_flow_buffers(self): |
| dt = 1.0 / self.lsd_steps |
| self._st_buffers = [] |
| for j in range(self.lsd_steps): |
| s = j / self.lsd_steps |
| t = s + dt |
| self._st_buffers.append(( |
| np.array([[s]], dtype=np.float32), |
| np.array([[t]], dtype=np.float32) |
| )) |
|
|
| def _init_backbone_state(self) -> dict: |
| state = {} |
| inputs = [inp.name for inp in self.flow_lm_main.get_inputs()] |
| |
| if "step" in inputs: |
| self._backbone_format = "named" |
| state["step"] = np.zeros(1, dtype=np.int64) |
| for inp in self.flow_lm_main.get_inputs(): |
| if inp.name.startswith("past_key_") or inp.name.startswith("past_value_"): |
| shape = list(inp.shape) |
| for i, d in enumerate(shape): |
| if isinstance(d, str) or d is None: |
| shape[i] = 1 if i == 0 else 1000 |
| state[inp.name] = np.zeros(shape, dtype=np.float32) |
| else: |
| self._backbone_format = "state_indexed" |
| for inp in self.flow_lm_main.get_inputs(): |
| name = inp.name |
| if name.startswith("state_"): |
| shape = list(inp.shape) |
| for i, d in enumerate(shape): |
| if isinstance(d, str) or d is None: |
| if i == 0: |
| shape[i] = 1 |
| elif i == 1: |
| shape[i] = 1 |
| elif i == 2: |
| shape[i] = 1000 |
| |
| |
| |
| pass |
| |
| dtype = np.float32 |
| if "tensor(int64)" in str(inp.type): |
| dtype = np.int64 |
| |
| state[name] = np.zeros(shape, dtype=dtype) |
| return state |
|
|
| def _init_mimi_state(self) -> dict: |
| state = {} |
| for inp in self.mimi_decoder.get_inputs(): |
| name = inp.name |
| if name.startswith("state_"): |
| shape = list(inp.shape) |
| for i, d in enumerate(shape): |
| if isinstance(d, str) or d is None: |
| |
| |
| |
| |
| if i == 1 and shape[0] == 2: |
| shape[i] = 1 |
| elif i == 0: |
| shape[i] = 1 |
| else: |
| shape[i] = 1000 |
| |
| dtype_str = str(inp.type).lower() |
| if "int64" in dtype_str: |
| dtype = np.int64 |
| elif "bool" in dtype_str: |
| dtype = np.bool_ |
| else: |
| dtype = np.float32 |
| |
| |
| if len(shape) == 1: |
| |
| state[name] = np.zeros(shape, dtype=dtype) |
| else: |
| state[name] = np.zeros(shape, dtype=dtype) |
| return state |
|
|
| def _update_state_from_outputs(self, state: dict, result: list, session): |
| if self._backbone_format == "named": |
| for i, out in enumerate(session.get_outputs()): |
| name = out.name |
| if name.startswith("present_key_"): |
| layer_idx = name.replace("present_key_", "") |
| state[f"past_key_{layer_idx}"] = result[i] |
| elif name.startswith("present_value_"): |
| layer_idx = name.replace("present_value_", "") |
| state[f"past_value_{layer_idx}"] = result[i] |
| seq_len = result[0].shape[0] if len(result[0].shape) > 0 else 1 |
| state["step"] = np.array([int(state["step"][0]) + seq_len], dtype=np.int64) |
| else: |
| for i, out in enumerate(session.get_outputs()): |
| name = out.name |
| if name.startswith("out_state_"): |
| idx = int(name.replace("out_state_", "")) |
| state[f"state_{idx}"] = result[i] |
|
|
| def _tokenize(self, text: str) -> np.ndarray: |
| text = text.strip() |
| if not text: |
| raise ValueError("Text cannot be empty") |
| if text[-1].isalnum(): |
| text = text + "." |
| if not text[0].isupper(): |
| text = text[0].upper() + text[1:] |
| token_ids = self.tokenizer.Encode(text) |
| return np.array(token_ids, dtype=np.int64).reshape(1, -1) |
|
|
| def _load_audio(self, path: Union[str, Path]) -> np.ndarray: |
| if not HAS_SOUNDFILE: |
| raise ImportError("soundfile required. Install with: pip install soundfile") |
| |
| audio, sr = sf.read(str(path)) |
| |
| if len(audio.shape) > 1: |
| audio = audio.mean(axis=1) |
| |
| if sr != self.SAMPLE_RATE: |
| if not HAS_SCIPY: |
| raise ImportError("scipy required for resampling. Install with: pip install scipy") |
| num_samples = int(len(audio) * self.SAMPLE_RATE / sr) |
| audio = scipy.signal.resample(audio, num_samples) |
| |
| audio = audio.astype(np.float32) |
| if np.abs(audio).max() > 1.0: |
| audio = audio / np.abs(audio).max() |
| |
| return audio.reshape(1, 1, -1) |
|
|
| def encode_voice(self, audio_path: Union[str, Path]) -> np.ndarray: |
| if self.mimi_encoder is None: |
| print(" Warning: mimi_encoder not available, using zeros") |
| return np.zeros((1, 1, 1024), dtype=np.float32) |
| |
| audio = self._load_audio(audio_path) |
| embeddings = self.mimi_encoder.run(None, {"audio": audio})[0] |
| |
| while embeddings.ndim > 3: |
| embeddings = embeddings.squeeze(0) |
| if embeddings.ndim < 3: |
| embeddings = embeddings[None] |
| |
| return embeddings.astype(np.float32) |
|
|
| def load_predefined_voice(self, voice_name: str) -> np.ndarray: |
| import safetensors.torch |
| |
| voices_dir = Path("voices") |
| voice_path = voices_dir / f"{voice_name}.safetensors" |
| |
| if not voice_path.exists(): |
| available = [f.stem for f in voices_dir.glob("*.safetensors")] |
| raise ValueError( |
| f"Voice '{voice_name}' not found. Available: {available}" |
| ) |
| |
| st = safetensors.torch.load_file(str(voice_path)) |
| tensor = st["audio_prompt"] |
| |
| return tensor.numpy().astype(np.float32) |
|
|
| PREDEFINED_VOICES = ["alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma"] |
|
|
| def _run_flow_lm( |
| self, |
| voice_embeddings: Optional[np.ndarray], |
| text_ids: np.ndarray, |
| max_frames: int = 500, |
| frames_after_eos: int = 3, |
| ) -> Generator[np.ndarray, None, None]: |
| text_emb = self.text_conditioner.run(None, {"token_ids": text_ids})[0] |
| if text_emb.ndim == 2: |
| text_emb = text_emb[None] |
| |
| state = self._init_backbone_state() |
| empty_seq = np.zeros((1, 0, 32), dtype=np.float32) |
| empty_text = np.zeros((1, 0, 1024), dtype=np.float32) |
| |
| def run_backbone(sequence, text_emb_arg): |
| if self._backbone_format == "named": |
| inputs = { |
| "sequence": sequence, |
| "text_embeddings": text_emb_arg, |
| "step": state["step"], |
| } |
| for k, v in state.items(): |
| if k.startswith("past_"): |
| inputs[k] = v |
| return self.flow_lm_main.run(None, inputs) |
| else: |
| return self.flow_lm_main.run(None, { |
| "sequence": sequence, |
| "text_embeddings": text_emb_arg, |
| **state |
| }) |
| |
| if voice_embeddings is not None: |
| res_voice = run_backbone(empty_seq, voice_embeddings) |
| self._update_state_from_outputs(state, res_voice, self.flow_lm_main) |
| |
| res_text = run_backbone(empty_seq, text_emb) |
| self._update_state_from_outputs(state, res_text, self.flow_lm_main) |
| |
| curr = np.full((1, 1, 32), np.nan, dtype=np.float32) |
| eos_step = None |
| |
| for step in range(max_frames): |
| res_step = run_backbone(curr, empty_text) |
| |
| conditioning = res_step[0] |
| conditioning_for_flow = conditioning |
| if self._flow_format == "kevinahmm": |
| if conditioning.ndim == 3: |
| conditioning_for_flow = conditioning[:, 0, :] |
| else: |
| if conditioning.ndim == 2: |
| conditioning_for_flow = conditioning[:, None, :] |
| eos_logit = res_step[1] |
| self._update_state_from_outputs(state, res_step, self.flow_lm_main) |
| |
| if eos_logit.ndim == 3: |
| eos_val = float(eos_logit[0, 0, 0]) |
| elif eos_logit.ndim == 2: |
| eos_val = float(eos_logit[0, 0]) |
| else: |
| eos_val = float(eos_logit[0]) |
| if eos_val > -4.0 and eos_step is None: |
| eos_step = step |
| |
| if eos_step is not None and step >= eos_step + frames_after_eos: |
| break |
| |
| std = np.sqrt(self.temperature) if self.temperature > 0 else 0.0 |
| if std > 0: |
| if self._flow_format == "kevinahmm": |
| x = np.random.normal(0, std, (1, 32)).astype(np.float32) |
| else: |
| x = np.random.normal(0, std, (1, 1, 32)).astype(np.float32) |
| else: |
| if self._flow_format == "kevinahmm": |
| x = np.zeros((1, 32), dtype=np.float32) |
| else: |
| x = np.zeros((1, 1, 32), dtype=np.float32) |
| |
| for s_arr, t_arr in self._st_buffers: |
| flow_out = self.flow_lm_flow.run(None, { |
| "c": conditioning_for_flow, |
| "s": s_arr, |
| "t": t_arr, |
| "x": x |
| }) |
| |
| |
| res = flow_out[0] |
| if res.ndim == 3 and x.ndim == 2: |
| res = res.squeeze(1) |
| elif res.ndim == 3 and x.ndim == 3: |
| pass |
| elif res.ndim == 2 and x.ndim == 3: |
| res = res[:, None, :] |
| |
| x = x + res * (t_arr[0, 0] - s_arr[0, 0]) |
| |
| latent = x.reshape(1, 1, 32) |
| yield latent |
| curr = latent |
|
|
| def _decode_latents(self, latents: list) -> np.ndarray: |
| mimi_inputs = [inp.name for inp in self.mimi_decoder.get_inputs()] |
| has_states = any(name.startswith("state_") for name in mimi_inputs) |
| |
| if has_states: |
| state = self._init_mimi_state() |
| audio_chunks = [] |
| |
| for latent in latents: |
| inputs = {"latent": latent} |
| inputs.update(state) |
| |
| result = self.mimi_decoder.run(None, inputs) |
| audio_chunks.append(result[0].flatten()) |
| |
| for i, out in enumerate(self.mimi_decoder.get_outputs()): |
| if out.name.startswith("out_state_"): |
| idx = int(out.name.replace("out_state_", "")) |
| state[f"state_{idx}"] = result[i] |
| |
| return np.concatenate(audio_chunks) |
| else: |
| all_latents = np.concatenate(latents, axis=1) |
| result = self.mimi_decoder.run(None, {"normalized_latents": all_latents}) |
| return result[0].flatten() |
|
|
| def _decode_worker(self, latent_queue: queue.Queue, audio_chunks: list): |
| mimi_inputs = [inp.name for inp in self.mimi_decoder.get_inputs()] |
| has_states = any(name.startswith("state_") for name in mimi_inputs) |
| |
| if has_states: |
| mimi_state = self._init_mimi_state() |
| |
| while True: |
| item = latent_queue.get() |
| if item is None: |
| break |
| |
| inputs = {"latent": item} |
| inputs.update(mimi_state) |
| |
| result = self.mimi_decoder.run(None, inputs) |
| audio_chunks.append(result[0].flatten()) |
| |
| for i, out in enumerate(self.mimi_decoder.get_outputs()): |
| if out.name.startswith("out_state_"): |
| idx = int(out.name.replace("out_state_", "")) |
| mimi_state[f"state_{idx}"] = result[i] |
| else: |
| all_latents = [] |
| while True: |
| item = latent_queue.get() |
| if item is None: |
| break |
| all_latents.append(item) |
| |
| if all_latents: |
| stacked = np.concatenate(all_latents, axis=1) |
| result = self.mimi_decoder.run(None, {"normalized_latents": stacked}) |
| audio_chunks.append(result[0].flatten()) |
|
|
| def generate( |
| self, |
| text: str, |
| voice: Optional[Union[str, Path, np.ndarray]] = None, |
| max_frames: int = 500, |
| ) -> np.ndarray: |
| voice_emb = None |
| if voice is not None and str(voice).lower() != "none": |
| if isinstance(voice, np.ndarray): |
| voice_emb = voice |
| elif isinstance(voice, str): |
| if voice in self.PREDEFINED_VOICES: |
| print(f" Using predefined voice: {voice}") |
| voice_emb = self.load_predefined_voice(voice) |
| else: |
| voice_emb = self.encode_voice(voice) |
| else: |
| voice_emb = self.encode_voice(voice) |
| |
| text_ids = self._tokenize(text) |
| |
| latent_queue = queue.Queue() |
| audio_chunks = [] |
| decoder = threading.Thread( |
| target=self._decode_worker, |
| args=(latent_queue, audio_chunks), |
| daemon=True, |
| ) |
| decoder.start() |
| |
| for latent in self._run_flow_lm(voice_emb, text_ids, max_frames): |
| latent_queue.put(latent) |
| latent_queue.put(None) |
| |
| decoder.join() |
| return np.concatenate(audio_chunks) |
|
|
| def save_audio(self, audio: np.ndarray, path: Union[str, Path]): |
| if not HAS_SOUNDFILE: |
| raise ImportError("soundfile required. Install with: pip install soundfile") |
| sf.write(str(path), audio, self.SAMPLE_RATE) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Streaming ONNX inference for Pocket-TTS" |
| ) |
| |
| parser.add_argument("--text", type=str, default="Hello, world!", |
| help="Text to synthesize") |
| parser.add_argument("--output", type=str, default="output_streaming.wav", |
| help="Output WAV file path") |
| parser.add_argument("--models_dir", type=str, default="onnx_streaming", |
| help="Directory containing ONNX models") |
| parser.add_argument("--voice", type=str, default="cosette", |
| help="Voice name (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file") |
| parser.add_argument("--int8", action="store_true", |
| help="Use INT8 quantized models") |
| parser.add_argument("--temperature", type=float, default=0.7, |
| help="Sampling temperature") |
| parser.add_argument("--lsd_steps", type=int, default=10, |
| help="Flow matching steps") |
| parser.add_argument("--max_frames", type=int, default=500, |
| help="Maximum latent frames to generate") |
| parser.add_argument("--seed", type=int, default=None, |
| help="Random seed for reproducibility") |
| |
| args = parser.parse_args() |
| |
| if args.seed is not None: |
| np.random.seed(args.seed) |
| print(f"Random seed: {args.seed}") |
| |
| print(f"\nLoading models (INT8={args.int8})...") |
| t0 = time.time() |
| |
| tts = PocketTTSStreamingONNX( |
| models_dir=args.models_dir, |
| use_int8=args.int8, |
| temperature=args.temperature, |
| lsd_steps=args.lsd_steps, |
| ) |
| |
| load_time = time.time() - t0 |
| print(f" Loaded in {load_time:.2f}s") |
| |
| print(f"\nGenerating speech...") |
| print(f" Text: {args.text}") |
| print(f" Voice: {args.voice}") |
| |
| t0 = time.time() |
| audio = tts.generate(args.text, voice=args.voice, max_frames=args.max_frames) |
| gen_time = time.time() - t0 |
| |
| duration = len(audio) / tts.SAMPLE_RATE |
| rtf = gen_time / max(duration, 0.01) |
| |
| print(f" Generated {duration:.2f}s audio in {gen_time:.2f}s (RTF: {rtf:.2f}x)") |
| |
| tts.save_audio(audio, args.output) |
| print(f" Saved to: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|