| """
|
| Streaming ONNX inference for Pocket-TTS (Mixed Precision).
|
|
|
| This script uses INT8 quantized KevinAHM models for the backbone and flow,
|
| but keeps the Mimi decoder in FP32 to avoid quantization artifacts/errors.
|
|
|
| Components:
|
| - Text Conditioner: INT8
|
| - Flow LM (Main): INT8
|
| - Flow LM (Flow): INT8
|
| - Mimi Decoder: FP32 (Required for audio quality/correctness)
|
|
|
| Usage:
|
| python final_inference_scripts/inference_mixed_int8.py --text "Hello world"
|
| """
|
|
|
| import sys
|
| import os
|
| import json
|
| import time
|
| import argparse
|
| import queue
|
| import threading
|
| from pathlib import Path
|
| from typing import Optional, Generator, Union
|
|
|
| import numpy as np
|
|
|
| try:
|
| import soundfile as sf
|
| HAS_SOUNDFILE = True
|
| except ImportError:
|
| HAS_SOUNDFILE = False
|
|
|
| try:
|
| import scipy.signal
|
| HAS_SCIPY = True
|
| except ImportError:
|
| HAS_SCIPY = False
|
|
|
|
|
| class PocketTTSStreamingONNX:
|
| """Streaming ONNX inference engine for Pocket-TTS.
|
|
|
| Uses KV cache states for proper streaming inference that works
|
| with both base and merged models.
|
| """
|
|
|
| SAMPLE_RATE = 24000
|
| SAMPLES_PER_FRAME = 1920
|
| FRAME_DURATION = SAMPLES_PER_FRAME / SAMPLE_RATE
|
|
|
| def __init__(
|
| self,
|
| models_dir: str = "onnx_kevinahmm_int8",
|
| tokenizer_path: Optional[str] = None,
|
| use_int8: bool = True,
|
| temperature: float = 0.7,
|
| lsd_steps: int = 10,
|
| ):
|
| self.models_dir = Path(models_dir)
|
| self.use_int8 = use_int8
|
| self.temperature = temperature
|
| self.lsd_steps = lsd_steps
|
|
|
| import onnxruntime as ort
|
| import sentencepiece as spm
|
|
|
| self.providers = ["CPUExecutionProvider"]
|
| available = ort.get_available_providers()
|
| if "CUDAExecutionProvider" in available:
|
| self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
|
| sess_opts = ort.SessionOptions()
|
| sess_opts.intra_op_num_threads = min(os.cpu_count() or 4, 4)
|
| sess_opts.inter_op_num_threads = 1
|
|
|
| tokenizer_path = tokenizer_path or self.models_dir / "tokenizer.model"
|
| self.tokenizer = spm.SentencePieceProcessor()
|
| self.tokenizer.Load(str(tokenizer_path))
|
|
|
| with open(self.models_dir / "model_config.json", "r") as f:
|
| self.config = json.load(f)
|
|
|
| self._precompute_flow_buffers()
|
|
|
| suffix = "_int8" if use_int8 else ""
|
|
|
| def get_path(base):
|
| p = self.models_dir / f"{base}{suffix}.onnx"
|
| if p.exists():
|
| return str(p)
|
| return str(self.models_dir / f"{base}.onnx")
|
|
|
| def get_path_multi(bases):
|
| for base in bases:
|
| p = self.models_dir / f"{base}{suffix}.onnx"
|
| if p.exists():
|
| return str(p)
|
| p = self.models_dir / f"{base}.onnx"
|
| if p.exists():
|
| return str(p)
|
| return str(self.models_dir / f"{bases[0]}.onnx")
|
|
|
| print(f"Loading models from {self.models_dir}...")
|
|
|
|
|
| self.text_conditioner = ort.InferenceSession(
|
| get_path("text_conditioner"), sess_opts, providers=self.providers
|
| )
|
|
|
| self.flow_lm_main = ort.InferenceSession(
|
| get_path_multi(["backbone", "flow_lm_main"]), sess_opts, providers=self.providers
|
| )
|
|
|
| self.flow_lm_flow = ort.InferenceSession(
|
| get_path("flow_lm_flow"), sess_opts, providers=self.providers
|
| )
|
|
|
|
|
|
|
| mimi_dec_path = self.models_dir / "mimi_decoder.onnx"
|
| if not mimi_dec_path.exists():
|
|
|
|
|
| print(f"Warning: FP32 mimi_decoder.onnx not found at {mimi_dec_path}")
|
|
|
| mimi_dec_path = Path(get_path("mimi_decoder"))
|
|
|
| print(f" Mimi Decoder: {mimi_dec_path}")
|
| self.mimi_decoder = ort.InferenceSession(
|
| str(mimi_dec_path), sess_opts, providers=self.providers
|
| )
|
|
|
|
|
| encoder_path = get_path("mimi_encoder")
|
| if os.path.exists(encoder_path):
|
| self.mimi_encoder = ort.InferenceSession(
|
| encoder_path, sess_opts, providers=self.providers
|
| )
|
| else:
|
| self.mimi_encoder = None
|
| print(" Note: mimi_encoder not found, voice cloning unavailable")
|
|
|
| flow_inputs = {inp.name: inp.shape for inp in self.flow_lm_flow.get_inputs()}
|
| if "c" in flow_inputs:
|
| c_shape = flow_inputs["c"]
|
| if len(c_shape) == 2:
|
| self._flow_format = "kevinahmm"
|
| else:
|
| self._flow_format = "standard"
|
| else:
|
| self._flow_format = "standard"
|
|
|
| print(f" Flow format: {self._flow_format}")
|
| print(" Models loaded.")
|
|
|
| def _precompute_flow_buffers(self):
|
| dt = 1.0 / self.lsd_steps
|
| self._st_buffers = []
|
| for j in range(self.lsd_steps):
|
| s = j / self.lsd_steps
|
| t = s + dt
|
| self._st_buffers.append((
|
| np.array([[s]], dtype=np.float32),
|
| np.array([[t]], dtype=np.float32)
|
| ))
|
|
|
| def _init_backbone_state(self) -> dict:
|
| state = {}
|
| inputs = [inp.name for inp in self.flow_lm_main.get_inputs()]
|
|
|
| if "step" in inputs:
|
| self._backbone_format = "named"
|
| state["step"] = np.zeros(1, dtype=np.int64)
|
| for inp in self.flow_lm_main.get_inputs():
|
| if inp.name.startswith("past_key_") or inp.name.startswith("past_value_"):
|
| shape = list(inp.shape)
|
| for i, d in enumerate(shape):
|
| if isinstance(d, str) or d is None:
|
| shape[i] = 1 if i == 0 else 500
|
| state[inp.name] = np.zeros(shape, dtype=np.float32)
|
| else:
|
| self._backbone_format = "state_indexed"
|
| for inp in self.flow_lm_main.get_inputs():
|
| name = inp.name
|
| if name.startswith("state_"):
|
| shape = list(inp.shape)
|
| for i, d in enumerate(shape):
|
| if isinstance(d, str) or d is None:
|
| if i == 0:
|
| shape[i] = 1
|
| elif i == 2:
|
| shape[i] = 500
|
|
|
| dtype = np.float32
|
| if "tensor(int64)" in str(inp.type):
|
| dtype = np.int64
|
|
|
| state[name] = np.zeros(shape, dtype=dtype)
|
| return state
|
|
|
| def _init_mimi_state(self) -> dict:
|
| state = {}
|
| for inp in self.mimi_decoder.get_inputs():
|
| name = inp.name
|
| if name.startswith("state_"):
|
| shape = list(inp.shape)
|
| for i, d in enumerate(shape):
|
| if isinstance(d, str) or d is None:
|
| shape[i] = 1 if i == 0 else 500
|
|
|
| dtype_str = str(inp.type).lower()
|
| if "int64" in dtype_str:
|
| dtype = np.int64
|
| elif "bool" in dtype_str:
|
| dtype = np.bool_
|
| else:
|
| dtype = np.float32
|
|
|
| if len(shape) == 1:
|
| state[name] = np.ones(shape, dtype=dtype)
|
| else:
|
| state[name] = np.zeros(shape, dtype=dtype)
|
| return state
|
|
|
| def _update_state_from_outputs(self, state: dict, result: list, session):
|
| if self._backbone_format == "named":
|
| for i, out in enumerate(session.get_outputs()):
|
| name = out.name
|
| if name.startswith("present_key_"):
|
| layer_idx = name.replace("present_key_", "")
|
| state[f"past_key_{layer_idx}"] = result[i]
|
| elif name.startswith("present_value_"):
|
| layer_idx = name.replace("present_value_", "")
|
| state[f"past_value_{layer_idx}"] = result[i]
|
| seq_len = result[0].shape[0] if len(result[0].shape) > 0 else 1
|
| state["step"] = np.array([int(state["step"][0]) + seq_len], dtype=np.int64)
|
| else:
|
| for i, out in enumerate(session.get_outputs()):
|
| name = out.name
|
| if name.startswith("out_state_"):
|
| idx = int(name.replace("out_state_", ""))
|
| state[f"state_{idx}"] = result[i]
|
|
|
| def _tokenize(self, text: str) -> np.ndarray:
|
| text = text.strip()
|
| if not text:
|
| raise ValueError("Text cannot be empty")
|
| if text[-1].isalnum():
|
| text = text + "."
|
| if not text[0].isupper():
|
| text = text[0].upper() + text[1:]
|
| token_ids = self.tokenizer.Encode(text)
|
| return np.array(token_ids, dtype=np.int64).reshape(1, -1)
|
|
|
| def _load_audio(self, path: Union[str, Path]) -> np.ndarray:
|
| if not HAS_SOUNDFILE:
|
| raise ImportError("soundfile required. Install with: pip install soundfile")
|
|
|
| audio, sr = sf.read(str(path))
|
|
|
| if len(audio.shape) > 1:
|
| audio = audio.mean(axis=1)
|
|
|
| if sr != self.SAMPLE_RATE:
|
| if not HAS_SCIPY:
|
| raise ImportError("scipy required for resampling. Install with: pip install scipy")
|
| num_samples = int(len(audio) * self.SAMPLE_RATE / sr)
|
| audio = scipy.signal.resample(audio, num_samples)
|
|
|
| audio = audio.astype(np.float32)
|
| if np.abs(audio).max() > 1.0:
|
| audio = audio / np.abs(audio).max()
|
|
|
| return audio.reshape(1, 1, -1)
|
|
|
| def encode_voice(self, audio_path: Union[str, Path]) -> np.ndarray:
|
| if self.mimi_encoder is None:
|
| print(" Warning: mimi_encoder not available, using zeros")
|
| return np.zeros((1, 1, 1024), dtype=np.float32)
|
|
|
| audio = self._load_audio(audio_path)
|
| embeddings = self.mimi_encoder.run(None, {"audio": audio})[0]
|
|
|
| while embeddings.ndim > 3:
|
| embeddings = embeddings.squeeze(0)
|
| if embeddings.ndim < 3:
|
| embeddings = embeddings[None]
|
|
|
| return embeddings.astype(np.float32)
|
|
|
| def load_predefined_voice(self, voice_name: str) -> np.ndarray:
|
| import safetensors.torch
|
|
|
| voices_dir = Path("voices")
|
| voice_path = voices_dir / f"{voice_name}.safetensors"
|
|
|
| if not voice_path.exists():
|
| available = [f.stem for f in voices_dir.glob("*.safetensors")]
|
| raise ValueError(
|
| f"Voice '{voice_name}' not found. Available: {available}"
|
| )
|
|
|
| st = safetensors.torch.load_file(str(voice_path))
|
| tensor = st["audio_prompt"]
|
|
|
| return tensor.numpy().astype(np.float32)
|
|
|
| PREDEFINED_VOICES = ["alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma"]
|
|
|
| def _run_flow_lm(
|
| self,
|
| voice_embeddings: Optional[np.ndarray],
|
| text_ids: np.ndarray,
|
| max_frames: int = 500,
|
| frames_after_eos: int = 3,
|
| ) -> Generator[np.ndarray, None, None]:
|
| text_emb = self.text_conditioner.run(None, {"token_ids": text_ids})[0]
|
| if text_emb.ndim == 2:
|
| text_emb = text_emb[None]
|
|
|
| state = self._init_backbone_state()
|
| empty_seq = np.zeros((1, 0, 32), dtype=np.float32)
|
|
|
| def run_backbone(sequence, text_emb_arg):
|
| if self._backbone_format == "named":
|
| inputs = {
|
| "sequence": sequence,
|
| "text_embeddings": text_emb_arg,
|
| "step": state["step"],
|
| }
|
| for k, v in state.items():
|
| if k.startswith("past_"):
|
| inputs[k] = v
|
| return self.flow_lm_main.run(None, inputs)
|
| else:
|
| return self.flow_lm_main.run(None, {
|
| "sequence": sequence,
|
| "text_embeddings": text_emb_arg,
|
| **state
|
| })
|
|
|
| if voice_embeddings is not None:
|
| res_voice = run_backbone(empty_seq, voice_embeddings)
|
| self._update_state_from_outputs(state, res_voice, self.flow_lm_main)
|
|
|
| res_text = run_backbone(empty_seq, text_emb)
|
| self._update_state_from_outputs(state, res_text, self.flow_lm_main)
|
|
|
| curr = np.full((1, 1, 32), np.nan, dtype=np.float32)
|
| empty_text = np.zeros((1, 0, 1024), dtype=np.float32)
|
| eos_step = None
|
|
|
| for step in range(max_frames):
|
| res_step = run_backbone(curr, empty_text)
|
|
|
| conditioning = res_step[0]
|
| conditioning_for_flow = conditioning
|
| if self._flow_format == "kevinahmm":
|
| if conditioning.ndim == 3:
|
| conditioning_for_flow = conditioning[:, 0, :]
|
| else:
|
| if conditioning.ndim == 2:
|
| conditioning_for_flow = conditioning[:, None, :]
|
| eos_logit = res_step[1]
|
| self._update_state_from_outputs(state, res_step, self.flow_lm_main)
|
|
|
| if eos_logit.ndim == 3:
|
| eos_val = float(eos_logit[0, 0, 0])
|
| elif eos_logit.ndim == 2:
|
| eos_val = float(eos_logit[0, 0])
|
| else:
|
| eos_val = float(eos_logit[0])
|
| if eos_val > -4.0 and eos_step is None:
|
| eos_step = step
|
|
|
| if eos_step is not None and step >= eos_step + frames_after_eos:
|
| break
|
|
|
| std = np.sqrt(self.temperature) if self.temperature > 0 else 0.0
|
| if std > 0:
|
| if self._flow_format == "kevinahmm":
|
| x = np.random.normal(0, std, (1, 32)).astype(np.float32)
|
| else:
|
| x = np.random.normal(0, std, (1, 1, 32)).astype(np.float32)
|
| else:
|
| if self._flow_format == "kevinahmm":
|
| x = np.zeros((1, 32), dtype=np.float32)
|
| else:
|
| x = np.zeros((1, 1, 32), dtype=np.float32)
|
|
|
| for s_arr, t_arr in self._st_buffers:
|
| flow_out = self.flow_lm_flow.run(None, {
|
| "c": conditioning_for_flow,
|
| "s": s_arr,
|
| "t": t_arr,
|
| "x": x
|
| })
|
| x = x + flow_out[0] * (t_arr[0, 0] - s_arr[0, 0])
|
|
|
| latent = x.reshape(1, 1, 32)
|
| yield latent
|
| curr = latent
|
|
|
| def _decode_latents(self, latents: list) -> np.ndarray:
|
| mimi_inputs = [inp.name for inp in self.mimi_decoder.get_inputs()]
|
| has_states = any(name.startswith("state_") for name in mimi_inputs)
|
|
|
| if has_states:
|
| state = self._init_mimi_state()
|
| audio_chunks = []
|
|
|
| for latent in latents:
|
| inputs = {"latent": latent}
|
| inputs.update(state)
|
|
|
| result = self.mimi_decoder.run(None, inputs)
|
| audio_chunks.append(result[0].flatten())
|
|
|
| for i, out in enumerate(self.mimi_decoder.get_outputs()):
|
| if out.name.startswith("out_state_"):
|
| idx = int(out.name.replace("out_state_", ""))
|
| state[f"state_{idx}"] = result[i]
|
|
|
| return np.concatenate(audio_chunks)
|
| else:
|
| all_latents = np.concatenate(latents, axis=1)
|
| result = self.mimi_decoder.run(None, {"normalized_latents": all_latents})
|
| return result[0].flatten()
|
|
|
| def _decode_worker(self, latent_queue: queue.Queue, audio_chunks: list):
|
| mimi_inputs = [inp.name for inp in self.mimi_decoder.get_inputs()]
|
| has_states = any(name.startswith("state_") for name in mimi_inputs)
|
|
|
| if has_states:
|
| mimi_state = self._init_mimi_state()
|
|
|
| while True:
|
| item = latent_queue.get()
|
| if item is None:
|
| break
|
|
|
| inputs = {"latent": item}
|
| inputs.update(mimi_state)
|
|
|
| result = self.mimi_decoder.run(None, inputs)
|
| audio_chunks.append(result[0].flatten())
|
|
|
| for i, out in enumerate(self.mimi_decoder.get_outputs()):
|
| if out.name.startswith("out_state_"):
|
| idx = int(out.name.replace("out_state_", ""))
|
| mimi_state[f"state_{idx}"] = result[i]
|
| else:
|
| all_latents = []
|
| while True:
|
| item = latent_queue.get()
|
| if item is None:
|
| break
|
| all_latents.append(item)
|
|
|
| if all_latents:
|
| stacked = np.concatenate(all_latents, axis=1)
|
| result = self.mimi_decoder.run(None, {"normalized_latents": stacked})
|
| audio_chunks.append(result[0].flatten())
|
|
|
| def generate(
|
| self,
|
| text: str,
|
| voice: Optional[Union[str, Path, np.ndarray]] = None,
|
| max_frames: int = 500,
|
| ) -> np.ndarray:
|
| voice_emb = None
|
| if voice is not None:
|
| if isinstance(voice, np.ndarray):
|
| voice_emb = voice
|
| elif isinstance(voice, str):
|
| if voice in self.PREDEFINED_VOICES:
|
| print(f" Using predefined voice: {voice}")
|
| voice_emb = self.load_predefined_voice(voice)
|
| else:
|
| voice_emb = self.encode_voice(voice)
|
| else:
|
| voice_emb = self.encode_voice(voice)
|
|
|
| text_ids = self._tokenize(text)
|
|
|
| latent_queue = queue.Queue()
|
| audio_chunks = []
|
| decoder = threading.Thread(
|
| target=self._decode_worker,
|
| args=(latent_queue, audio_chunks),
|
| daemon=True,
|
| )
|
| decoder.start()
|
|
|
| for latent in self._run_flow_lm(voice_emb, text_ids, max_frames):
|
| latent_queue.put(latent)
|
| latent_queue.put(None)
|
|
|
| decoder.join()
|
| return np.concatenate(audio_chunks)
|
|
|
| def save_audio(self, audio: np.ndarray, path: Union[str, Path]):
|
| if not HAS_SOUNDFILE:
|
| raise ImportError("soundfile required. Install with: pip install soundfile")
|
| sf.write(str(path), audio, self.SAMPLE_RATE)
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(
|
| description="Streaming ONNX inference for Pocket-TTS (Mixed Precision)"
|
| )
|
|
|
| parser.add_argument("--text", type=str, default="Hello, world!",
|
| help="Text to synthesize")
|
| parser.add_argument("--output", type=str, default="output_mixed.wav",
|
| help="Output WAV file path")
|
| parser.add_argument("--models_dir", type=str, default="onnx_kevinahmm_int8",
|
| help="Directory containing ONNX models")
|
| parser.add_argument("--voice", type=str, default="cosette",
|
| help="Voice name (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file")
|
|
|
|
|
| parser.add_argument("--no-int8", dest="int8", action="store_false",
|
| help="Disable INT8 quantized models")
|
| parser.set_defaults(int8=True)
|
|
|
| parser.add_argument("--temperature", type=float, default=0.7,
|
| help="Sampling temperature")
|
| parser.add_argument("--lsd_steps", type=int, default=10,
|
| help="Flow matching steps")
|
| parser.add_argument("--max_frames", type=int, default=500,
|
| help="Maximum latent frames to generate")
|
| parser.add_argument("--seed", type=int, default=None,
|
| help="Random seed for reproducibility")
|
|
|
| args = parser.parse_args()
|
|
|
| if args.seed is not None:
|
| np.random.seed(args.seed)
|
| print(f"Random seed: {args.seed}")
|
|
|
| print(f"\nLoading models (INT8={args.int8}, Mixed Precision)...")
|
| t0 = time.time()
|
|
|
| tts = PocketTTSStreamingONNX(
|
| models_dir=args.models_dir,
|
| use_int8=args.int8,
|
| temperature=args.temperature,
|
| lsd_steps=args.lsd_steps,
|
| )
|
|
|
| load_time = time.time() - t0
|
| print(f" Loaded in {load_time:.2f}s")
|
|
|
| print(f"\nGenerating speech...")
|
| print(f" Text: {args.text}")
|
| voice_arg = args.voice if args.voice else "cosette"
|
| print(f" Voice: {voice_arg}")
|
|
|
| t0 = time.time()
|
| audio = tts.generate(args.text, voice=voice_arg, max_frames=args.max_frames)
|
| gen_time = time.time() - t0
|
|
|
| duration = len(audio) / tts.SAMPLE_RATE
|
| rtf = gen_time / max(duration, 0.01)
|
|
|
| print(f" Generated {duration:.2f}s audio in {gen_time:.2f}s (RTF: {rtf:.2f}x)")
|
|
|
| tts.save_audio(audio, args.output)
|
| print(f" Saved to: {args.output}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|