pocket-tts-spanish-streaming / inference_mixed_int8.py
ipsilondev's picture
Upload folder using huggingface_hub
1c50ba4 verified
"""
Streaming ONNX inference for Pocket-TTS (Mixed Precision).
This script uses INT8 quantized KevinAHM models for the backbone and flow,
but keeps the Mimi decoder in FP32 to avoid quantization artifacts/errors.
Components:
- Text Conditioner: INT8
- Flow LM (Main): INT8
- Flow LM (Flow): INT8
- Mimi Decoder: FP32 (Required for audio quality/correctness)
Usage:
python final_inference_scripts/inference_mixed_int8.py --text "Hello world"
"""
import sys
import os
import json
import time
import argparse
import queue
import threading
from pathlib import Path
from typing import Optional, Generator, Union
import numpy as np
try:
import soundfile as sf
HAS_SOUNDFILE = True
except ImportError:
HAS_SOUNDFILE = False
try:
import scipy.signal
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
class PocketTTSStreamingONNX:
"""Streaming ONNX inference engine for Pocket-TTS.
Uses KV cache states for proper streaming inference that works
with both base and merged models.
"""
SAMPLE_RATE = 24000
SAMPLES_PER_FRAME = 1920
FRAME_DURATION = SAMPLES_PER_FRAME / SAMPLE_RATE
def __init__(
self,
models_dir: str = "onnx_kevinahmm_int8",
tokenizer_path: Optional[str] = None,
use_int8: bool = True,
temperature: float = 0.7,
lsd_steps: int = 10,
):
self.models_dir = Path(models_dir)
self.use_int8 = use_int8
self.temperature = temperature
self.lsd_steps = lsd_steps
import onnxruntime as ort
import sentencepiece as spm
self.providers = ["CPUExecutionProvider"]
available = ort.get_available_providers()
if "CUDAExecutionProvider" in available:
self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
sess_opts = ort.SessionOptions()
sess_opts.intra_op_num_threads = min(os.cpu_count() or 4, 4)
sess_opts.inter_op_num_threads = 1
tokenizer_path = tokenizer_path or self.models_dir / "tokenizer.model"
self.tokenizer = spm.SentencePieceProcessor()
self.tokenizer.Load(str(tokenizer_path))
with open(self.models_dir / "model_config.json", "r") as f:
self.config = json.load(f)
self._precompute_flow_buffers()
suffix = "_int8" if use_int8 else ""
def get_path(base):
p = self.models_dir / f"{base}{suffix}.onnx"
if p.exists():
return str(p)
return str(self.models_dir / f"{base}.onnx")
def get_path_multi(bases):
for base in bases:
p = self.models_dir / f"{base}{suffix}.onnx"
if p.exists():
return str(p)
p = self.models_dir / f"{base}.onnx"
if p.exists():
return str(p)
return str(self.models_dir / f"{bases[0]}.onnx")
print(f"Loading models from {self.models_dir}...")
# Text Conditioner (Follows int8 flag)
self.text_conditioner = ort.InferenceSession(
get_path("text_conditioner"), sess_opts, providers=self.providers
)
# Flow LM Main (Follows int8 flag)
self.flow_lm_main = ort.InferenceSession(
get_path_multi(["backbone", "flow_lm_main"]), sess_opts, providers=self.providers
)
# Flow LM Flow (Follows int8 flag)
self.flow_lm_flow = ort.InferenceSession(
get_path("flow_lm_flow"), sess_opts, providers=self.providers
)
# Mimi Decoder - FORCE FP32
# We explicitly look for mimi_decoder.onnx, ignoring the int8 flag and suffix
mimi_dec_path = self.models_dir / "mimi_decoder.onnx"
if not mimi_dec_path.exists():
# Fallback for weird cases where it might be named differently?
# But the user specifically asked for this setup.
print(f"Warning: FP32 mimi_decoder.onnx not found at {mimi_dec_path}")
# Try to fall back to get_path behavior just in case, but warn
mimi_dec_path = Path(get_path("mimi_decoder"))
print(f" Mimi Decoder: {mimi_dec_path}")
self.mimi_decoder = ort.InferenceSession(
str(mimi_dec_path), sess_opts, providers=self.providers
)
# Mimi Encoder (Optional)
encoder_path = get_path("mimi_encoder")
if os.path.exists(encoder_path):
self.mimi_encoder = ort.InferenceSession(
encoder_path, sess_opts, providers=self.providers
)
else:
self.mimi_encoder = None
print(" Note: mimi_encoder not found, voice cloning unavailable")
flow_inputs = {inp.name: inp.shape for inp in self.flow_lm_flow.get_inputs()}
if "c" in flow_inputs:
c_shape = flow_inputs["c"]
if len(c_shape) == 2:
self._flow_format = "kevinahmm" # 2D: [batch, 1024]
else:
self._flow_format = "standard" # 3D: [batch, time, 1024]
else:
self._flow_format = "standard"
print(f" Flow format: {self._flow_format}")
print(" Models loaded.")
def _precompute_flow_buffers(self):
dt = 1.0 / self.lsd_steps
self._st_buffers = []
for j in range(self.lsd_steps):
s = j / self.lsd_steps
t = s + dt
self._st_buffers.append((
np.array([[s]], dtype=np.float32),
np.array([[t]], dtype=np.float32)
))
def _init_backbone_state(self) -> dict:
state = {}
inputs = [inp.name for inp in self.flow_lm_main.get_inputs()]
if "step" in inputs:
self._backbone_format = "named"
state["step"] = np.zeros(1, dtype=np.int64)
for inp in self.flow_lm_main.get_inputs():
if inp.name.startswith("past_key_") or inp.name.startswith("past_value_"):
shape = list(inp.shape)
for i, d in enumerate(shape):
if isinstance(d, str) or d is None:
shape[i] = 1 if i == 0 else 500 # Pre-allocated max sequence
state[inp.name] = np.zeros(shape, dtype=np.float32)
else:
self._backbone_format = "state_indexed"
for inp in self.flow_lm_main.get_inputs():
name = inp.name
if name.startswith("state_"):
shape = list(inp.shape)
for i, d in enumerate(shape):
if isinstance(d, str) or d is None:
if i == 0:
shape[i] = 1
elif i == 2:
shape[i] = 500 # Pre-allocated max sequence
dtype = np.float32
if "tensor(int64)" in str(inp.type):
dtype = np.int64
state[name] = np.zeros(shape, dtype=dtype)
return state
def _init_mimi_state(self) -> dict:
state = {}
for inp in self.mimi_decoder.get_inputs():
name = inp.name
if name.startswith("state_"):
shape = list(inp.shape)
for i, d in enumerate(shape):
if isinstance(d, str) or d is None:
shape[i] = 1 if i == 0 else 500 # Pre-allocated max sequence
dtype_str = str(inp.type).lower()
if "int64" in dtype_str:
dtype = np.int64
elif "bool" in dtype_str:
dtype = np.bool_
else:
dtype = np.float32
if len(shape) == 1:
state[name] = np.ones(shape, dtype=dtype)
else:
state[name] = np.zeros(shape, dtype=dtype)
return state
def _update_state_from_outputs(self, state: dict, result: list, session):
if self._backbone_format == "named":
for i, out in enumerate(session.get_outputs()):
name = out.name
if name.startswith("present_key_"):
layer_idx = name.replace("present_key_", "")
state[f"past_key_{layer_idx}"] = result[i]
elif name.startswith("present_value_"):
layer_idx = name.replace("present_value_", "")
state[f"past_value_{layer_idx}"] = result[i]
seq_len = result[0].shape[0] if len(result[0].shape) > 0 else 1
state["step"] = np.array([int(state["step"][0]) + seq_len], dtype=np.int64)
else:
for i, out in enumerate(session.get_outputs()):
name = out.name
if name.startswith("out_state_"):
idx = int(name.replace("out_state_", ""))
state[f"state_{idx}"] = result[i]
def _tokenize(self, text: str) -> np.ndarray:
text = text.strip()
if not text:
raise ValueError("Text cannot be empty")
if text[-1].isalnum():
text = text + "."
if not text[0].isupper():
text = text[0].upper() + text[1:]
token_ids = self.tokenizer.Encode(text)
return np.array(token_ids, dtype=np.int64).reshape(1, -1)
def _load_audio(self, path: Union[str, Path]) -> np.ndarray:
if not HAS_SOUNDFILE:
raise ImportError("soundfile required. Install with: pip install soundfile")
audio, sr = sf.read(str(path))
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
if sr != self.SAMPLE_RATE:
if not HAS_SCIPY:
raise ImportError("scipy required for resampling. Install with: pip install scipy")
num_samples = int(len(audio) * self.SAMPLE_RATE / sr)
audio = scipy.signal.resample(audio, num_samples)
audio = audio.astype(np.float32)
if np.abs(audio).max() > 1.0:
audio = audio / np.abs(audio).max()
return audio.reshape(1, 1, -1)
def encode_voice(self, audio_path: Union[str, Path]) -> np.ndarray:
if self.mimi_encoder is None:
print(" Warning: mimi_encoder not available, using zeros")
return np.zeros((1, 1, 1024), dtype=np.float32)
audio = self._load_audio(audio_path)
embeddings = self.mimi_encoder.run(None, {"audio": audio})[0]
while embeddings.ndim > 3:
embeddings = embeddings.squeeze(0)
if embeddings.ndim < 3:
embeddings = embeddings[None]
return embeddings.astype(np.float32)
def load_predefined_voice(self, voice_name: str) -> np.ndarray:
import safetensors.torch
voices_dir = Path("voices")
voice_path = voices_dir / f"{voice_name}.safetensors"
if not voice_path.exists():
available = [f.stem for f in voices_dir.glob("*.safetensors")]
raise ValueError(
f"Voice '{voice_name}' not found. Available: {available}"
)
st = safetensors.torch.load_file(str(voice_path))
tensor = st["audio_prompt"]
return tensor.numpy().astype(np.float32)
PREDEFINED_VOICES = ["alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma"]
def _run_flow_lm(
self,
voice_embeddings: Optional[np.ndarray],
text_ids: np.ndarray,
max_frames: int = 500,
frames_after_eos: int = 3,
) -> Generator[np.ndarray, None, None]:
text_emb = self.text_conditioner.run(None, {"token_ids": text_ids})[0]
if text_emb.ndim == 2:
text_emb = text_emb[None]
state = self._init_backbone_state()
empty_seq = np.zeros((1, 0, 32), dtype=np.float32)
def run_backbone(sequence, text_emb_arg):
if self._backbone_format == "named":
inputs = {
"sequence": sequence,
"text_embeddings": text_emb_arg,
"step": state["step"],
}
for k, v in state.items():
if k.startswith("past_"):
inputs[k] = v
return self.flow_lm_main.run(None, inputs)
else:
return self.flow_lm_main.run(None, {
"sequence": sequence,
"text_embeddings": text_emb_arg,
**state
})
if voice_embeddings is not None:
res_voice = run_backbone(empty_seq, voice_embeddings)
self._update_state_from_outputs(state, res_voice, self.flow_lm_main)
res_text = run_backbone(empty_seq, text_emb)
self._update_state_from_outputs(state, res_text, self.flow_lm_main)
curr = np.full((1, 1, 32), np.nan, dtype=np.float32)
empty_text = np.zeros((1, 0, 1024), dtype=np.float32)
eos_step = None
for step in range(max_frames):
res_step = run_backbone(curr, empty_text)
conditioning = res_step[0]
conditioning_for_flow = conditioning
if self._flow_format == "kevinahmm":
if conditioning.ndim == 3:
conditioning_for_flow = conditioning[:, 0, :] # [batch, 1024]
else:
if conditioning.ndim == 2:
conditioning_for_flow = conditioning[:, None, :] # [batch, 1, 1024]
eos_logit = res_step[1]
self._update_state_from_outputs(state, res_step, self.flow_lm_main)
if eos_logit.ndim == 3:
eos_val = float(eos_logit[0, 0, 0])
elif eos_logit.ndim == 2:
eos_val = float(eos_logit[0, 0])
else:
eos_val = float(eos_logit[0])
if eos_val > -4.0 and eos_step is None:
eos_step = step
if eos_step is not None and step >= eos_step + frames_after_eos:
break
std = np.sqrt(self.temperature) if self.temperature > 0 else 0.0
if std > 0:
if self._flow_format == "kevinahmm":
x = np.random.normal(0, std, (1, 32)).astype(np.float32)
else:
x = np.random.normal(0, std, (1, 1, 32)).astype(np.float32)
else:
if self._flow_format == "kevinahmm":
x = np.zeros((1, 32), dtype=np.float32)
else:
x = np.zeros((1, 1, 32), dtype=np.float32)
for s_arr, t_arr in self._st_buffers:
flow_out = self.flow_lm_flow.run(None, {
"c": conditioning_for_flow,
"s": s_arr,
"t": t_arr,
"x": x
})
x = x + flow_out[0] * (t_arr[0, 0] - s_arr[0, 0])
latent = x.reshape(1, 1, 32)
yield latent
curr = latent
def _decode_latents(self, latents: list) -> np.ndarray:
mimi_inputs = [inp.name for inp in self.mimi_decoder.get_inputs()]
has_states = any(name.startswith("state_") for name in mimi_inputs)
if has_states:
state = self._init_mimi_state()
audio_chunks = []
for latent in latents:
inputs = {"latent": latent}
inputs.update(state)
result = self.mimi_decoder.run(None, inputs)
audio_chunks.append(result[0].flatten())
for i, out in enumerate(self.mimi_decoder.get_outputs()):
if out.name.startswith("out_state_"):
idx = int(out.name.replace("out_state_", ""))
state[f"state_{idx}"] = result[i]
return np.concatenate(audio_chunks)
else:
all_latents = np.concatenate(latents, axis=1)
result = self.mimi_decoder.run(None, {"normalized_latents": all_latents})
return result[0].flatten()
def _decode_worker(self, latent_queue: queue.Queue, audio_chunks: list):
mimi_inputs = [inp.name for inp in self.mimi_decoder.get_inputs()]
has_states = any(name.startswith("state_") for name in mimi_inputs)
if has_states:
mimi_state = self._init_mimi_state()
while True:
item = latent_queue.get()
if item is None:
break
inputs = {"latent": item}
inputs.update(mimi_state)
result = self.mimi_decoder.run(None, inputs)
audio_chunks.append(result[0].flatten())
for i, out in enumerate(self.mimi_decoder.get_outputs()):
if out.name.startswith("out_state_"):
idx = int(out.name.replace("out_state_", ""))
mimi_state[f"state_{idx}"] = result[i]
else:
all_latents = []
while True:
item = latent_queue.get()
if item is None:
break
all_latents.append(item)
if all_latents:
stacked = np.concatenate(all_latents, axis=1)
result = self.mimi_decoder.run(None, {"normalized_latents": stacked})
audio_chunks.append(result[0].flatten())
def generate(
self,
text: str,
voice: Optional[Union[str, Path, np.ndarray]] = None,
max_frames: int = 500,
) -> np.ndarray:
voice_emb = None
if voice is not None:
if isinstance(voice, np.ndarray):
voice_emb = voice
elif isinstance(voice, str):
if voice in self.PREDEFINED_VOICES:
print(f" Using predefined voice: {voice}")
voice_emb = self.load_predefined_voice(voice)
else:
voice_emb = self.encode_voice(voice)
else:
voice_emb = self.encode_voice(voice)
text_ids = self._tokenize(text)
latent_queue = queue.Queue()
audio_chunks = []
decoder = threading.Thread(
target=self._decode_worker,
args=(latent_queue, audio_chunks),
daemon=True,
)
decoder.start()
for latent in self._run_flow_lm(voice_emb, text_ids, max_frames):
latent_queue.put(latent)
latent_queue.put(None)
decoder.join()
return np.concatenate(audio_chunks)
def save_audio(self, audio: np.ndarray, path: Union[str, Path]):
if not HAS_SOUNDFILE:
raise ImportError("soundfile required. Install with: pip install soundfile")
sf.write(str(path), audio, self.SAMPLE_RATE)
def main():
parser = argparse.ArgumentParser(
description="Streaming ONNX inference for Pocket-TTS (Mixed Precision)"
)
parser.add_argument("--text", type=str, default="Hello, world!",
help="Text to synthesize")
parser.add_argument("--output", type=str, default="output_mixed.wav",
help="Output WAV file path")
parser.add_argument("--models_dir", type=str, default="onnx_kevinahmm_int8",
help="Directory containing ONNX models")
parser.add_argument("--voice", type=str, default="cosette",
help="Voice name (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file")
# Int8 is enabled by default in this script
parser.add_argument("--no-int8", dest="int8", action="store_false",
help="Disable INT8 quantized models")
parser.set_defaults(int8=True)
parser.add_argument("--temperature", type=float, default=0.7,
help="Sampling temperature")
parser.add_argument("--lsd_steps", type=int, default=10,
help="Flow matching steps")
parser.add_argument("--max_frames", type=int, default=500,
help="Maximum latent frames to generate")
parser.add_argument("--seed", type=int, default=None,
help="Random seed for reproducibility")
args = parser.parse_args()
if args.seed is not None:
np.random.seed(args.seed)
print(f"Random seed: {args.seed}")
print(f"\nLoading models (INT8={args.int8}, Mixed Precision)...")
t0 = time.time()
tts = PocketTTSStreamingONNX(
models_dir=args.models_dir,
use_int8=args.int8,
temperature=args.temperature,
lsd_steps=args.lsd_steps,
)
load_time = time.time() - t0
print(f" Loaded in {load_time:.2f}s")
print(f"\nGenerating speech...")
print(f" Text: {args.text}")
voice_arg = args.voice if args.voice else "cosette"
print(f" Voice: {voice_arg}")
t0 = time.time()
audio = tts.generate(args.text, voice=voice_arg, max_frames=args.max_frames)
gen_time = time.time() - t0
duration = len(audio) / tts.SAMPLE_RATE
rtf = gen_time / max(duration, 0.01)
print(f" Generated {duration:.2f}s audio in {gen_time:.2f}s (RTF: {rtf:.2f}x)")
tts.save_audio(audio, args.output)
print(f" Saved to: {args.output}")
if __name__ == "__main__":
main()