MOSS-TTS-Nano.AXERA / scripts /tts_runtime.py
HY-2012's picture
First commit
b3a7ca2 verified
Raw
History Blame Contribute Delete
61.2 kB
from __future__ import annotations
import json
import logging
import time
import unicodedata
import wave
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Optional, Sequence
import numpy as np
import sentencepiece as spm
from scripts.axe_session import AxeSession
try:
import onnxruntime as ort
_HAVE_ORT = True
except ImportError:
_HAVE_ORT = False
SAMPLE_MODE_GREEDY = "greedy"
SAMPLE_MODE_FIXED = "fixed"
SAMPLE_MODE_FULL = "full"
SENTENCE_END_PUNCTUATION = set(".!?。!?")
CLAUSE_SPLIT_PUNCTUATION = set(",,、;;::")
CLOSING_PUNCTUATION = set(['"', "'", "\u2019", "\u201d", ")", "]", "}", ")", "】", "》", "」", "』"])
_MANIFEST_CANDIDATES = ("browser_poc_manifest.json",)
_FULL8_IO_MAP: dict[str, tuple[list[str], list[str]]] = {
"prefill": (
["input_ids", "attention_mask"],
["global_hidden",
"present_key_0", "present_value_0", "present_key_1", "present_value_1",
"present_key_2", "present_value_2", "present_key_3", "present_value_3",
"present_key_4", "present_value_4", "present_key_5", "present_value_5",
"present_key_6", "present_value_6", "present_key_7", "present_value_7",
"present_key_8", "present_value_8", "present_key_9", "present_value_9",
"present_key_10", "present_value_10", "present_key_11", "present_value_11"],
),
"decode": (
["input_ids", "past_valid_lengths",
"past_key_0", "past_value_0", "past_key_1", "past_value_1",
"past_key_2", "past_value_2", "past_key_3", "past_value_3",
"past_key_4", "past_value_4", "past_key_5", "past_value_5",
"past_key_6", "past_value_6", "past_key_7", "past_value_7",
"past_key_8", "past_value_8", "past_key_9", "past_value_9",
"past_key_10", "past_value_10", "past_key_11", "past_value_11"],
["global_hidden",
"present_key_0", "present_value_0", "present_key_1", "present_value_1",
"present_key_2", "present_value_2", "present_key_3", "present_value_3",
"present_key_4", "present_value_4", "present_key_5", "present_value_5",
"present_key_6", "present_value_6", "present_key_7", "present_value_7",
"present_key_8", "present_value_8", "present_key_9", "present_value_9",
"present_key_10", "present_value_10", "present_key_11", "present_value_11"],
),
"local_decoder": (
["global_hidden", "text_token_id", "audio_prefix_token_ids"],
["text_logits", "audio_logits"],
),
"local_fixed_sampled_frame": (
["global_hidden", "repetition_seen_mask", "assistant_random_u", "audio_random_u"],
["should_continue", "frame_token_ids"],
),
"codec_encode": (
["waveform", "input_lengths"],
["audio_codes", "audio_code_lengths"],
),
"codec_decode": (
["audio_codes", "audio_code_lengths"],
["audio", "audio_lengths"],
),
}
_AXMODEL_FILE_MAP: dict[str, str] = {
"prefill": "tts_prefill.axmodel",
"decode": "tts_decode_step.axmodel",
"local_decoder": "tts_local_decoder.axmodel",
"local_fixed_sampled_frame": "tts_local_fixed_sampled_frame.axmodel",
"codec_encode": "codec_encode.axmodel",
"codec_decode": "codec_decode.axmodel",
}
_ONNX_FILE_MAP: dict[str, str] = {
"decode": "tts_decode_step.onnx",
}
def _argmax(values: np.ndarray) -> int:
return int(np.argmax(values))
def _softmax(values: np.ndarray) -> np.ndarray:
max_v = float(np.max(values))
shifted = np.asarray(values - max_v, dtype=np.float64)
exps = np.exp(shifted)
return exps / np.sum(exps)
def _apply_repetition_penalty(
values: np.ndarray, previous_token_ids: list[int], penalty: float
) -> np.ndarray:
if not previous_token_ids or penalty == 1.0:
return values
result = values.copy()
for tid in set(int(t) for t in previous_token_ids):
if 0 <= tid < result.shape[0]:
result[tid] = result[tid] * penalty if result[tid] < 0 else result[tid] / penalty
return result
def _argmax_with_repetition_penalty(
values: np.ndarray, previous_token_set: set[int], penalty: float
) -> int:
best_idx, best_val = 0, float("-inf")
apply_penalty = bool(previous_token_set) and penalty != 1.0
for idx, v in enumerate(values):
score = float(v)
if apply_penalty and idx in previous_token_set:
score = score * penalty if score < 0 else score / penalty
if score > best_val:
best_val = score
best_idx = idx
return int(best_idx)
def _sample_from_scores(
values: np.ndarray,
*,
do_sample: bool,
temperature: float,
top_k: int,
top_p: float,
rng: np.random.Generator,
) -> int:
if not do_sample:
return _argmax(values)
if temperature <= 0:
raise ValueError("temperature must be positive when do_sample=True")
scores = np.asarray(values, dtype=np.float32).copy() / float(temperature)
if top_k > 0 and top_k < scores.shape[0]:
threshold = float(np.sort(scores)[::-1][top_k - 1])
scores[scores < threshold] = float("-inf")
if 0 < top_p < 1:
indexed = sorted(enumerate(scores.tolist()), key=lambda x: x[1], reverse=True)
sorted_scores = np.asarray([x[1] for x in indexed], dtype=np.float32)
sorted_probs = _softmax(sorted_scores)
remove = [False] * len(indexed)
cumulative = 0.0
for i, p in enumerate(sorted_probs):
cumulative += float(p)
if cumulative > float(top_p):
remove[i] = True
for i in range(len(remove) - 1, 0, -1):
remove[i] = remove[i - 1]
if remove:
remove[0] = False
for i, should_remove in enumerate(remove):
if should_remove:
scores[indexed[i][0]] = float("-inf")
probs = _softmax(scores)
rv = float(rng.random())
for i, p in enumerate(probs):
rv -= float(p)
if rv <= 0:
return int(i)
return _argmax(scores)
def _sample_assistant_text_token(
text_logits: np.ndarray,
manifest: dict[str, Any],
generation_defaults: dict[str, Any],
rng: np.random.Generator,
) -> int:
candidate_ids = np.asarray(
[int(manifest["tts_config"]["audio_assistant_slot_token_id"]),
int(manifest["tts_config"]["audio_end_token_id"])],
dtype=np.int32,
)
candidate_scores = text_logits[candidate_ids]
sampled_idx = _sample_from_scores(
candidate_scores,
do_sample=False, # deterministic: always pick the higher-scoring token
temperature=float(generation_defaults["text_temperature"]),
top_k=min(int(generation_defaults["text_top_k"]), int(candidate_scores.shape[0])),
top_p=float(generation_defaults["text_top_p"]),
rng=rng,
)
return int(candidate_ids[sampled_idx])
def _sample_audio_token(
audio_logits: np.ndarray,
previous_token_ids: list[int],
previous_token_set: set[int],
generation_defaults: dict[str, Any],
rng: np.random.Generator,
) -> int:
penalty = float(generation_defaults["audio_repetition_penalty"])
if not bool(generation_defaults["do_sample"]):
return _argmax_with_repetition_penalty(audio_logits, previous_token_set, penalty)
penalized = _apply_repetition_penalty(audio_logits, previous_token_ids, penalty)
return _sample_from_scores(
penalized,
do_sample=True,
temperature=float(generation_defaults["audio_temperature"]),
top_k=int(generation_defaults["audio_top_k"]),
top_p=float(generation_defaults["audio_top_p"]),
rng=rng,
)
def _flatten3d_int32(nested: list[list[list[int]]]) -> tuple[np.ndarray, list[int]]:
d0, d1, d2 = len(nested), len(nested[0]), len(nested[0][0])
data = np.zeros((d0 * d1 * d2,), dtype=np.int32)
offset = 0
for i in range(d0):
for j in range(d1):
for k in range(d2):
data[offset] = int(nested[i][j][k])
offset += 1
return data, [d0, d1, d2]
def _flatten2d_int32(nested: list[list[int]]) -> tuple[np.ndarray, list[int]]:
d0, d1 = len(nested), len(nested[0])
data = np.zeros((d0 * d1,), dtype=np.int32)
offset = 0
for i in range(d0):
for j in range(d1):
data[offset] = int(nested[i][j])
offset += 1
return data, [d0, d1]
def _extract_last_hidden(hidden: np.ndarray) -> np.ndarray:
if hidden.ndim == 2:
return hidden.astype(np.float32, copy=False)
if hidden.ndim != 3 or hidden.shape[0] != 1:
raise ValueError(f"Unexpected global_hidden shape: {hidden.shape}")
return hidden[:, -1, :].astype(np.float32, copy=False)
def _slice_channel_major_audio(
audio: np.ndarray, start_sample: int = 0, end_sample: int | None = None
) -> list[np.ndarray]:
if audio.ndim != 3 or audio.shape[0] != 1:
raise ValueError(f"Unexpected audio tensor shape: {audio.shape}")
channels = int(audio.shape[1])
total = int(audio.shape[2])
start = max(0, int(start_sample))
end = total if end_sample is None else max(start, min(int(end_sample), total))
return [audio[0, c, start:end].astype(np.float32, copy=False) for c in range(channels)]
def _merge_audio_channels(channel_arrays: list[np.ndarray]) -> np.ndarray:
if not channel_arrays:
return np.zeros((0, 1), dtype=np.float32)
if len(channel_arrays) == 1:
return np.asarray(channel_arrays[0], dtype=np.float32).reshape(-1, 1)
min_len = min(int(c.shape[0]) for c in channel_arrays)
return np.stack([np.asarray(c[:min_len], dtype=np.float32) for c in channel_arrays], axis=1)
def _concat_waveforms(waveforms: list[np.ndarray]) -> np.ndarray:
if not waveforms:
return np.zeros((0, 1), dtype=np.float32)
non_empty = [w for w in waveforms if w.size > 0]
if not non_empty:
n_ch = int(waveforms[0].shape[1]) if waveforms[0].ndim == 2 else 1
return np.zeros((0, n_ch), dtype=np.float32)
return np.concatenate(non_empty, axis=0)
def _write_waveform_to_wav(path: str | Path, waveform: np.ndarray, sample_rate: int) -> Path:
output_path = Path(path).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
audio = np.asarray(waveform, dtype=np.float32)
if audio.ndim == 1:
audio = audio.reshape(-1, 1)
clipped = np.clip(audio, -1.0, 1.0)
pcm16 = np.round(clipped * 32767.0).astype(np.int16)
with wave.open(str(output_path), "wb") as wf:
wf.setnchannels(int(pcm16.shape[1]))
wf.setsampwidth(2)
wf.setframerate(int(sample_rate))
wf.writeframes(pcm16.tobytes())
return output_path
def _load_audio_numpy(path: str | Path, target_sample_rate: int, target_channels: int) -> np.ndarray:
path = Path(path).expanduser().resolve()
waveform: np.ndarray | None = None
src_sr: int = 0
try:
import soundfile as sf
data, src_sr = sf.read(str(path), dtype="float32", always_2d=True) # (samples, channels)
waveform = data.T # (channels, samples)
except Exception:
pass
if waveform is None:
with wave.open(str(path), "rb") as wf:
src_sr = wf.getframerate()
n_ch = wf.getnchannels()
sw = wf.getsampwidth()
raw = wf.readframes(wf.getnframes())
if sw == 2:
pcm = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
elif sw == 4:
pcm = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
elif sw == 1:
pcm = (np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0) / 128.0
else:
raise ValueError(f"Unsupported sample width: {sw}")
waveform = pcm.reshape(-1, n_ch).T # (channels, samples)
assert waveform is not None
src_channels = int(waveform.shape[0])
if src_channels != target_channels:
if src_channels == 1 and target_channels > 1:
waveform = np.repeat(waveform, target_channels, axis=0)
elif src_channels > 1 and target_channels == 1:
waveform = waveform.mean(axis=0, keepdims=True)
else:
raise ValueError(f"Unsupported channel conversion: {src_channels}{target_channels}")
if src_sr != target_sample_rate:
try:
from scipy.signal import resample_poly
from math import gcd
g = gcd(int(target_sample_rate), int(src_sr))
up = int(target_sample_rate) // g
down = int(src_sr) // g
resampled = []
for ch in range(waveform.shape[0]):
resampled.append(resample_poly(waveform[ch], up, down).astype(np.float32))
waveform = np.stack(resampled, axis=0)
except ImportError:
logging.warning("scipy not available, using linear interpolation for audio resampling")
old_len = int(waveform.shape[1])
new_len = int(round(old_len * target_sample_rate / src_sr))
old_idx = np.linspace(0, old_len - 1, new_len)
resampled = []
for ch in range(waveform.shape[0]):
resampled.append(np.interp(old_idx, np.arange(old_len), waveform[ch]).astype(np.float32))
waveform = np.stack(resampled, axis=0)
return waveform[np.newaxis, :, :] # (1, channels, samples)
def _contains_cjk(text: str) -> bool:
for ch in str(text or ""):
if (
"\u4e00" <= ch <= "\u9fff"
or "\u3400" <= ch <= "\u4dbf"
or "\u3040" <= ch <= "\u30ff"
or "\uac00" <= ch <= "\ud7af"
):
return True
return False
def _normalize_punctuation(text: str) -> str:
_WORD_INTERNAL = frozenset(["'", "\u2019", "-", "\u2013"]) # don't touch these
result: list[str] = []
prev_is_punct: bool = False
for ch in text:
is_sent_end = ch in SENTENCE_END_PUNCTUATION
is_punct = is_sent_end or (
ch not in _WORD_INTERNAL and unicodedata.category(ch).startswith("P")
)
if is_punct:
if prev_is_punct:
if is_sent_end:
result[-1] = ch
else:
result.append(ch if is_sent_end else ",")
prev_is_punct = True
else:
result.append(ch)
prev_is_punct = False
return "".join(result)
def _prepare_text_for_sentence_chunking(text: str) -> str:
t = str(text or "").strip()
if not t:
raise ValueError("Text prompt cannot be empty.")
t = t.replace("\r", " ").replace("\n", " ")
while " " in t:
t = t.replace(" ", " ")
t = _normalize_punctuation(t)
while ", ," in t:
t = t.replace(", ,", ",")
if _contains_cjk(t):
if t[-1] not in SENTENCE_END_PUNCTUATION:
t += "。"
return t
if t[:1].islower():
t = t[:1].upper() + t[1:]
if t[-1].isalnum():
t += "."
if len([x for x in t.split() if x]) < 5:
t = f" {t}"
return t
def _split_text_by_punctuation(text: str, punctuation: set[str]) -> list[str]:
sentences: list[str] = []
current: list[str] = []
idx = 0
t = str(text or "")
while idx < len(t):
ch = t[idx]
current.append(ch)
if ch in punctuation:
lookahead = idx + 1
while lookahead < len(t) and t[lookahead] in CLOSING_PUNCTUATION:
current.append(t[lookahead])
lookahead += 1
sentence = "".join(current).strip()
if sentence:
sentences.append(sentence)
current.clear()
while lookahead < len(t) and t[lookahead].isspace():
lookahead += 1
idx = lookahead
continue
idx += 1
tail = "".join(current).strip()
if tail:
sentences.append(tail)
return sentences
def _join_sentence_parts(left: str, right: str) -> str:
if not left:
return right
if not right:
return left
if _contains_cjk(left) or _contains_cjk(right):
return left + right
return f"{left} {right}"
class AxTtsRuntime:
def __init__(
self,
config_dir: str | Path,
axmodel_dir: str | Path,
*,
onnx_dir: str | Path | None = None,
use_onnx_decode: bool = False,
max_new_frames: int | None = None,
do_sample: bool = True,
sample_mode: str | None = None,
) -> None:
self.config_dir = Path(config_dir).expanduser().resolve()
self.axmodel_dir = Path(axmodel_dir).expanduser().resolve()
self.use_onnx_decode = bool(use_onnx_decode)
if onnx_dir is not None:
self.onnx_dir = Path(onnx_dir).expanduser().resolve()
else:
self.onnx_dir = self.axmodel_dir.parent / "onnxmodels"
self.manifest_path = self._find_manifest()
self.manifest_dir = self.manifest_path.parent
self.manifest: dict[str, Any] = json.loads(
self.manifest_path.read_text(encoding="utf-8")
)
gen = self.manifest["generation_defaults"]
if max_new_frames is not None:
gen["max_new_frames"] = int(max_new_frames)
if do_sample is not None:
gen["do_sample"] = bool(do_sample)
raw_mode = sample_mode if sample_mode is not None else gen.get("sample_mode", "fixed")
gen["sample_mode"] = self._normalize_sample_mode(raw_mode, bool(gen["do_sample"]))
gen["do_sample"] = gen["sample_mode"] != SAMPLE_MODE_GREEDY
self.tts_meta: dict[str, Any] = json.loads(
self._resolve_path(self.manifest["model_files"]["tts_meta"]).read_text("utf-8")
)
self.codec_meta: dict[str, Any] = json.loads(
self._resolve_path(self.manifest["model_files"]["codec_meta"]).read_text("utf-8")
)
self.tts_static_shapes = dict(self.tts_meta.get("static_shapes", {}))
self.codec_static_shapes = dict(self.codec_meta.get("static_shapes", {}))
tok_rel = str(self.manifest["model_files"].get("tokenizer_model", "tokenizer.model"))
self.sp_model = spm.SentencePieceProcessor(model_file=str(self._resolve_path(tok_rel)))
self.rng = np.random.default_rng(1234)
self._model_display_names: dict[str, str] = {}
self.sessions: dict = self._create_sessions()
self._model_time_stats: dict[str, float] = defaultdict(float)
self._model_call_stats: dict[str, int] = defaultdict(int)
logging.info("AxTtsRuntime ready: %d sessions loaded", len(self.sessions))
def _find_manifest(self) -> Path:
for name in _MANIFEST_CANDIDATES:
p = (self.config_dir / name).resolve()
if p.is_file():
return p
raise FileNotFoundError(
f"browser_poc_manifest.json not found in {self.config_dir}"
)
def _resolve_path(self, rel: str | Path) -> Path:
p = (self.manifest_dir / Path(rel)).resolve()
if p.exists():
return p
fallback = (self.config_dir / Path(rel).name).resolve()
if fallback.exists():
return fallback
return p
@staticmethod
def _normalize_sample_mode(raw: str | None, do_sample: bool = True) -> str:
s = str(raw or "").strip()
if s in {SAMPLE_MODE_GREEDY, SAMPLE_MODE_FIXED, SAMPLE_MODE_FULL}:
return s
if s == "mixed3":
return SAMPLE_MODE_FIXED if do_sample else SAMPLE_MODE_GREEDY
return SAMPLE_MODE_GREEDY if not do_sample else SAMPLE_MODE_FIXED
def _resolve_axmodel_path(self, key: str, axmodel_name: str) -> Path:
flat = self.axmodel_dir / axmodel_name
if key != "local_fixed_sampled_frame":
return flat
if flat.exists():
return flat
nested = self.axmodel_dir / "build-tts_local_fixed_sampled_frame" / axmodel_name
return nested
def _create_sessions(self) -> dict:
sessions: dict[str, Any] = {}
self._session_input_names: dict[str, list[str]] = {}
if self.use_onnx_decode:
onnx_name = _ONNX_FILE_MAP["decode"]
onnx_path = self.onnx_dir / onnx_name
if not onnx_path.exists():
raise FileNotFoundError(f"ONNX not found: {onnx_path}")
if not _HAVE_ORT:
raise RuntimeError("onnxruntime not installed; cannot load ONNX decode_step.")
sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
sessions["decode"] = sess
self._model_display_names["decode"] = f"{onnx_path.name} [onnxruntime CPU]"
self._session_input_names["decode"] = [inp.name for inp in sess.get_inputs()]
logging.info("loaded decode → %s (onnxruntime CPU)", onnx_path.name)
else:
axmodel_name = _AXMODEL_FILE_MAP["decode"]
axmodel_path = self.axmodel_dir / axmodel_name
if not axmodel_path.exists():
raise FileNotFoundError(f"axmodel not found: {axmodel_path}")
input_names, output_names = _FULL8_IO_MAP["decode"]
sess = AxeSession(axmodel_path, input_names, output_names)
sessions["decode"] = sess
self._model_display_names["decode"] = f"{axmodel_path.name} [axmodel]"
self._session_input_names["decode"] = [inp.name for inp in sess.get_inputs()]
logging.info("loaded decode → %s (axmodel)", axmodel_path.name)
for key in ("prefill", "local_decoder", "local_fixed_sampled_frame", "codec_decode", "codec_encode"):
if key not in _AXMODEL_FILE_MAP:
continue
axmodel_name = _AXMODEL_FILE_MAP[key]
axmodel_path = self._resolve_axmodel_path(key, axmodel_name)
if not axmodel_path.exists():
logging.warning("axmodel not found, skipping session %s: %s", key, axmodel_path)
continue
input_names, output_names = _FULL8_IO_MAP[key]
sess = AxeSession(axmodel_path, input_names, output_names)
sessions[key] = sess
self._model_display_names[key] = f"{axmodel_path.name} [axmodel]"
self._session_input_names[key] = [inp.name for inp in sess.get_inputs()]
logging.info("loaded %s → %s (axmodel)", key, axmodel_path.name)
if "codec_encode" not in sessions:
logging.info("codec_encode axmodel not found; voice clone (reference audio) is disabled.")
return sessions
def _reset_timing_stats(self) -> None:
self._model_time_stats = defaultdict(float)
self._model_call_stats = defaultdict(int)
self._used_model_keys: list[str] = []
self._used_model_key_set: set[str] = set()
def _mark_model_used(self, key: str) -> None:
if key in self._used_model_key_set:
return
self._used_model_key_set.add(key)
self._used_model_keys.append(key)
logging.info(
"[usage] first_use session=%s model=%s",
key,
self._model_display_names.get(key, key),
)
def _log_used_model_summary(self) -> None:
if not getattr(self, "_used_model_keys", None):
logging.info("[usage] used_models=none")
return
logging.info("[usage] used_models_count=%d", len(self._used_model_keys))
for idx, key in enumerate(self._used_model_keys, start=1):
logging.info(
"[usage] used_model[%d]=session=%s model=%s calls=%d",
idx,
key,
self._model_display_names.get(key, key),
self._model_call_stats.get(key, 0),
)
def _run_session(
self,
key: str,
input_feed: dict[str, np.ndarray],
output_names: list[str] | None = None,
) -> list[np.ndarray]:
sess = self.sessions[key]
actual_input_names = [inp.name for inp in sess.get_inputs()]
actual_input_name_set = set(actual_input_names)
filtered_input_feed = {
name: value for name, value in input_feed.items() if name in actual_input_name_set
}
missing_inputs = [name for name in actual_input_names if name not in filtered_input_feed]
if missing_inputs:
raise RuntimeError(
f"session={key} missing required inputs: {missing_inputs}; "
f"provided={sorted(input_feed.keys())}"
)
extra_inputs = sorted(name for name in input_feed if name not in actual_input_name_set)
self._mark_model_used(key)
start = time.perf_counter()
try:
outputs = sess.run(output_names, filtered_input_feed)
except Exception as exc:
input_summary = {
name: {"shape": list(np.asarray(value).shape), "dtype": str(np.asarray(value).dtype)}
for name, value in filtered_input_feed.items()
}
raise RuntimeError(
f"session={key} model={self._model_display_names.get(key, key)} run failed; "
f"used_models_so_far={self._used_model_keys}; "
f"expected_inputs={actual_input_names}; extra_inputs={extra_inputs}; "
f"input_summary={input_summary}"
) from exc
elapsed = time.perf_counter() - start
self._model_time_stats[key] += elapsed
self._model_call_stats[key] += 1
return outputs
def encode_text(self, text: str) -> list[int]:
return [int(t) for t in self.sp_model.encode(str(text or ""), out_type=int)]
def count_text_tokens(self, text: str) -> int:
return len(self.encode_text(text))
def split_text_by_token_budget(self, text: str, max_tokens: int) -> list[str]:
remaining = str(text or "").strip()
if not remaining:
return []
pieces: list[str] = []
preferred_boundary = set(CLAUSE_SPLIT_PUNCTUATION) | set(SENTENCE_END_PUNCTUATION) | {" "}
while remaining:
if self.count_text_tokens(remaining) <= max_tokens:
pieces.append(remaining)
break
lo, hi, best = 1, len(remaining), 1
while lo <= hi:
mid = (lo + hi) // 2
cand = remaining[:mid].strip()
if not cand:
lo = mid + 1
continue
if self.count_text_tokens(cand) <= max_tokens:
best = mid
lo = mid + 1
else:
hi = mid - 1
cut = best
prefix = remaining[:best]
preferred = -1
scan_min = max(-1, len(prefix) - 25)
for si in range(len(prefix) - 1, scan_min, -1):
if prefix[si] in preferred_boundary:
preferred = si + 1
break
if preferred > 0:
cut = preferred
piece = remaining[:cut].strip()
if not piece:
piece = remaining[:best].strip()
cut = best
pieces.append(piece)
remaining = remaining[cut:].strip()
return pieces
def split_voice_clone_text(self, text: str, max_tokens: int = 75) -> list[str]:
t = str(text or "").strip()
if not t:
return []
safe_max = max(1, int(max_tokens))
prepared = _prepare_text_for_sentence_chunking(t)
sents = _split_text_by_punctuation(prepared, SENTENCE_END_PUNCTUATION) or [prepared.strip()]
slices: list[tuple[int, str]] = []
for sent in sents:
s = sent.strip()
if not s:
continue
cnt = self.count_text_tokens(s)
if cnt <= safe_max:
slices.append((cnt, s))
continue
clauses = _split_text_by_punctuation(s, CLAUSE_SPLIT_PUNCTUATION)
if len(clauses) <= 1:
clauses = [s]
for clause in clauses:
c = clause.strip()
if not c:
continue
ccnt = self.count_text_tokens(c)
if ccnt <= safe_max:
slices.append((ccnt, c))
continue
for piece in self.split_text_by_token_budget(c, safe_max):
p = piece.strip()
if p:
slices.append((self.count_text_tokens(p), p))
chunks: list[str] = []
cur_chunk, cur_cnt = "", 0
for cnt, s in slices:
if not cur_chunk:
cur_chunk, cur_cnt = s, cnt
continue
if cur_cnt + cnt > safe_max:
chunks.append(cur_chunk.strip())
cur_chunk, cur_cnt = s, cnt
else:
cur_chunk = _join_sentence_parts(cur_chunk, s)
cur_cnt = self.count_text_tokens(cur_chunk)
if cur_chunk:
chunks.append(cur_chunk.strip())
return chunks or [t]
def _split_text_by_sentence_punctuation(self, text: str) -> list[str]:
t = str(text or "").strip()
if not t:
return []
codec_limit = self._static_decode_code_length()
if codec_limit is not None:
_safe_text_tokens = max(1, int(codec_limit / 3 * 0.9))
_min_tokens = max(1, int(_safe_text_tokens * 0.25))
else:
_safe_text_tokens = 80
_min_tokens = 20
prepared = _prepare_text_for_sentence_chunking(t)
sents = _split_text_by_punctuation(prepared, SENTENCE_END_PUNCTUATION)
if not sents:
sents = [prepared.strip()]
result: list[str] = []
for sent in sents:
s = sent.strip()
if not s:
continue
cnt = self.count_text_tokens(s)
if cnt <= _safe_text_tokens:
result.append(s)
continue
clauses = _split_text_by_punctuation(s, CLAUSE_SPLIT_PUNCTUATION)
if len(clauses) <= 1:
clauses = [s]
merged: list[str] = []
buf = ""
buf_tokens = 0
for clause in clauses:
c = clause.strip()
if not c:
continue
ccnt = self.count_text_tokens(c)
if ccnt >= _safe_text_tokens:
if buf:
merged.append(buf.strip())
buf = ""
buf_tokens = 0
merged.extend(
p.strip() for p in self.split_text_by_token_budget(c, _safe_text_tokens) if p.strip()
)
elif buf_tokens > 0 and buf_tokens + ccnt > _safe_text_tokens:
merged.append(buf.strip())
buf = c
buf_tokens = ccnt
else:
buf = (buf + " " + c) if buf else c
buf_tokens += ccnt
if buf:
merged.append(buf.strip())
for m in merged:
mcnt = self.count_text_tokens(m)
if mcnt <= _safe_text_tokens:
result.append(m)
else:
result.extend(
p.strip() for p in self.split_text_by_token_budget(m, _safe_text_tokens) if p.strip()
)
return result or [t]
def build_text_rows(self, token_ids: list[int]) -> list[list[int]]:
row_width = int(self.manifest["tts_config"]["n_vq"]) + 1
rows: list[list[int]] = []
for tid in token_ids:
row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width
row[0] = int(tid)
rows.append(row)
return rows
def build_audio_prefix_rows(
self, prompt_audio_codes: list[list[int]], slot_token_id: int | None = None
) -> list[list[int]]:
row_width = int(self.manifest["tts_config"]["n_vq"]) + 1
slot = int(
self.manifest["tts_config"]["audio_user_slot_token_id"]
if slot_token_id is None else slot_token_id
)
rows: list[list[int]] = []
for code_row in prompt_audio_codes:
row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width
row[0] = slot
for i in range(min(len(code_row), int(self.manifest["tts_config"]["n_vq"]))):
row[i + 1] = int(code_row[i])
rows.append(row)
return rows
def build_voice_clone_request_rows(
self, prompt_audio_codes: list[list[int]], text_token_ids: list[int]
) -> dict[str, list[list[int]]]:
prefix_ids = [
*self.manifest["prompt_templates"]["user_prompt_prefix_token_ids"],
int(self.manifest["tts_config"]["audio_start_token_id"]),
]
suffix_ids = [
int(self.manifest["tts_config"]["audio_end_token_id"]),
*self.manifest["prompt_templates"]["user_prompt_after_reference_token_ids"],
*text_token_ids,
*self.manifest["prompt_templates"]["assistant_prompt_prefix_token_ids"],
int(self.manifest["tts_config"]["audio_start_token_id"]),
]
rows = [
*self.build_text_rows(prefix_ids),
*self.build_audio_prefix_rows(prompt_audio_codes),
*self.build_text_rows(suffix_ids),
]
return {"inputIds": rows, "attentionMask": [[1 for _ in rows]]}
def _static_prefill_length(self) -> int | None:
v = self.tts_static_shapes.get("prefill_seq")
return None if v is None else int(v)
def _static_decode_code_length(self) -> int | None:
v = self.codec_static_shapes.get("decode_code_length")
return None if v is None else int(v)
def _static_encode_waveform_length(self) -> int | None:
v = self.codec_static_shapes.get("waveform_length")
return None if v is None else int(v)
def _pad_request_rows(
self, request_rows: dict[str, list[list[int]]]
) -> tuple[dict[str, list[list[int]]], int]:
static_len = self._static_prefill_length()
actual_len = len(request_rows["inputIds"])
if static_len is None:
return request_rows, actual_len
if actual_len > static_len:
raise ValueError(
f"static prefill_seq={static_len} < request length={actual_len}. "
"Reduce prompt/text length or re-export with larger --sample-seq-len."
)
if actual_len == static_len:
return request_rows, actual_len
row_width = int(self.manifest["tts_config"]["n_vq"]) + 1
pad_row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width
pad_row[0] = int(self.manifest["tts_config"]["pad_token_id"])
padded_rows = [list(r) for r in request_rows["inputIds"]]
padded_rows.extend([list(pad_row) for _ in range(static_len - actual_len)])
padded_mask = [1] * actual_len + [0] * (static_len - actual_len)
return {"inputIds": padded_rows, "attentionMask": [padded_mask]}, actual_len
def _pad_kv(
self, kv_dict: dict[str, np.ndarray], target_len: int
) -> dict[str, np.ndarray]:
result: dict[str, np.ndarray] = {}
for key, val in kv_dict.items():
cur = val.shape[1]
if cur > target_len:
padded = val[:, :target_len, :, :].astype(np.float32)
elif cur < target_len:
pad = np.zeros((1, target_len - cur, *val.shape[2:]), dtype=np.float32)
padded = np.concatenate([val.astype(np.float32), pad], axis=1)
else:
padded = val.astype(np.float32)
result[key] = padded
return result
def run_local_fixed_sampled_frame(
self,
global_hidden: np.ndarray,
*,
previous_token_sets_by_channel: list[set[int]],
) -> tuple[bool, list[int]]:
n_vq = int(self.manifest["tts_config"]["n_vq"])
codebook_size = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0])
rep_mask = np.zeros((1, n_vq, codebook_size), dtype=np.int32)
for ci, token_set in enumerate(previous_token_sets_by_channel):
for tid in token_set:
if 0 <= tid < codebook_size:
rep_mask[0, ci, tid] = 1
sess = self.sessions["local_fixed_sampled_frame"]
session_input_names = {inp.name for inp in sess.get_inputs()}
input_feed: dict = {
"global_hidden": global_hidden.astype(np.float32, copy=False),
"repetition_seen_mask": rep_mask,
}
asst_u: float | None = None
if "assistant_random_u" in session_input_names:
asst_u = float(self.rng.random())
input_feed["assistant_random_u"] = np.array([asst_u], dtype=np.float32)
if "audio_random_u" in session_input_names:
input_feed["audio_random_u"] = np.asarray(
[[float(self.rng.random()) for _ in range(n_vq)]],
dtype=np.float32,
)
outputs = self._run_session("local_fixed_sampled_frame", input_feed)
out_names = [o.name for o in sess.get_outputs()]
named = dict(zip(out_names, outputs, strict=True))
frame_ids = np.asarray(named["frame_token_ids"]).reshape(-1).astype(np.int32).tolist()
should_continue = bool(int(np.asarray(named["should_continue"]).reshape(-1)[0]))
logging.debug(
"[fixed_frame] asst_u=%.4f should_continue=%s tokens=%s",
asst_u if asst_u is not None else -1.0,
should_continue,
frame_ids[:4],
)
return should_continue, [int(x) for x in frame_ids]
def _can_use_local_fixed_sampled_frame(self) -> bool:
sess = self.sessions.get("local_fixed_sampled_frame")
if sess is None:
return False
session_input_names = {inp.name for inp in sess.get_inputs()}
return "audio_random_u" in session_input_names
def run_local_decoder(
self,
global_hidden: np.ndarray,
text_token_id: int,
frame_prefix: list[int],
) -> tuple[np.ndarray, np.ndarray]:
n_vq = int(self.manifest["tts_config"]["n_vq"])
audio_pad = int(self.manifest["tts_config"]["audio_pad_token_id"])
padded_prefix = np.full((1, n_vq - 1), audio_pad, dtype=np.int32)
for i in range(min(len(frame_prefix), n_vq - 1)):
padded_prefix[0, i] = int(frame_prefix[i])
outputs = self._run_session(
"local_decoder",
{
"global_hidden": global_hidden.astype(np.float32, copy=False),
"text_token_id": np.asarray([int(text_token_id)], dtype=np.int32),
"audio_prefix_token_ids": padded_prefix,
},
)
out_names = [o.name for o in self.sessions["local_decoder"].get_outputs()]
named = dict(zip(out_names, outputs, strict=True))
return named["text_logits"].reshape(-1), named["audio_logits"]
def slice_audio_channel_logits(self, audio_logits: np.ndarray, channel_index: int) -> np.ndarray:
per_ch = int(audio_logits.shape[-1])
flat = audio_logits.reshape(-1)
return flat[channel_index * per_ch : (channel_index + 1) * per_ch]
def _infer_global_kv_len(self) -> int | None:
if "decode" in self.sessions:
dec_sess = self.sessions["decode"]
if hasattr(dec_sess, "get_inputs"):
for inp in dec_sess.get_inputs():
if inp.name == "past_key_0":
dims = getattr(inp, "shape", []) or []
if len(dims) >= 2 and isinstance(dims[1], str):
return None
break
v = self.tts_static_shapes.get("kv_seq") or self.tts_static_shapes.get("global_kv_len")
if v is not None:
return int(v)
return 512
def generate_audio_frames(
self,
request_rows: dict[str, list[list[int]]],
on_frame: Callable | None = None,
) -> list[list[int]]:
padded_rows, actual_prefill_len = self._pad_request_rows(request_rows)
return self._generate_frames_with_prefill(
padded_rows,
actual_prefill_length=actual_prefill_len,
on_frame=on_frame,
)
def _generate_frames_with_prefill(
self,
request_rows: dict[str, list[list[int]]],
*,
actual_prefill_length: int,
on_frame: Callable | None = None,
) -> list[list[int]]:
global_kv_len = self._infer_global_kv_len()
generation_defaults = self.manifest["generation_defaults"]
row_width = int(self.manifest["tts_config"]["n_vq"]) + 1
prefill_ids, prefill_dims = _flatten3d_int32([request_rows["inputIds"]])
prefill_mask, prefill_mask_dims = _flatten2d_int32(request_rows["attentionMask"])
prefill_outputs = self._run_session(
"prefill",
{
"input_ids": prefill_ids.reshape(prefill_dims),
"attention_mask": prefill_mask.reshape(prefill_mask_dims),
},
)
prefill_out_names = [o.name for o in self.sessions["prefill"].get_outputs()]
named_prefill = dict(zip(prefill_out_names, prefill_outputs, strict=True))
global_hidden = named_prefill["global_hidden"][:, actual_prefill_length - 1, :].astype(np.float32)
past_valid_length = int(actual_prefill_length)
raw_past: dict[str, np.ndarray] = {
out.replace("present_", "past_"): named_prefill[out][:, :actual_prefill_length, :, :]
for out in self.tts_meta["onnx"]["prefill_output_names"][1:]
}
use_dynamic_kv = global_kv_len is None
if use_dynamic_kv:
past_by_name = raw_past # shape: (1, actual_len, 12, 64),步进增长
else:
past_by_name = self._pad_kv(raw_past, global_kv_len) # shape: (1, 512, 12, 64)
generated_frames: list[list[int]] = []
prev_tokens = [[] for _ in range(int(self.manifest["tts_config"]["n_vq"]))]
prev_tok_sets = [set() for _ in range(int(self.manifest["tts_config"]["n_vq"]))]
audio_pad_id = int(self.manifest["tts_config"]["audio_pad_token_id"])
consecutive_pad_frames = 0
_MAX_CONSECUTIVE_PAD_FRAMES = 10 # 0.8s silence → force stop (guard against quantized AXModel)
for step in range(int(generation_defaults["max_new_frames"])):
frame: list[int] = []
if (
generation_defaults["sample_mode"] == SAMPLE_MODE_FIXED
and self._can_use_local_fixed_sampled_frame()
):
should_continue, frame = self.run_local_fixed_sampled_frame(
global_hidden, previous_token_sets_by_channel=prev_tok_sets
)
if not should_continue:
logging.info("[gen] frame %d: should_continue=False → 停止生成", step)
break
for ci, tok in enumerate(frame):
prev_tokens[ci].append(tok)
prev_tok_sets[ci].add(tok)
elif "local_decoder" in self.sessions:
if (
generation_defaults["sample_mode"] == SAMPLE_MODE_FIXED
and "local_fixed_sampled_frame" in self.sessions
and step == 0
):
logging.warning(
"local_fixed_sampled_frame 缺少 audio_random_u 输入,回退到 local_decoder 随机采样路径"
)
local_text_logits, _ = self.run_local_decoder(global_hidden, 0, [])
next_text = _sample_assistant_text_token(
local_text_logits, self.manifest, generation_defaults, self.rng
)
if next_text != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]):
break
for ci in range(int(self.manifest["tts_config"]["n_vq"])):
_, audio_logits = self.run_local_decoder(global_hidden, next_text, frame)
ch_logits = self.slice_audio_channel_logits(audio_logits, ci).astype(np.float32)
tok = _sample_audio_token(
ch_logits, prev_tokens[ci], prev_tok_sets[ci], generation_defaults, self.rng
)
frame.append(tok)
prev_tokens[ci].append(tok)
prev_tok_sets[ci].add(tok)
else:
raise RuntimeError("Neither local_fixed_sampled_frame nor local_decoder session found")
if not frame:
break
if all(tok == audio_pad_id for tok in frame):
consecutive_pad_frames += 1
if consecutive_pad_frames >= _MAX_CONSECUTIVE_PAD_FRAMES:
logging.info(
"[gen] %d consecutive all-pad frames → force stop (trim %d silent frames)",
consecutive_pad_frames, consecutive_pad_frames,
)
generated_frames = generated_frames[:-(consecutive_pad_frames - 1)] if len(generated_frames) >= consecutive_pad_frames else []
break
else:
consecutive_pad_frames = 0
generated_frames.append(frame)
next_row = np.full((1, 1, row_width), int(self.manifest["tts_config"]["audio_pad_token_id"]), dtype=np.int32)
next_row[0, 0, 0] = int(self.manifest["tts_config"]["audio_assistant_slot_token_id"])
for idx, tok in enumerate(frame):
next_row[0, 0, idx + 1] = int(tok)
decode_feeds: dict[str, np.ndarray] = {
"input_ids": next_row,
"past_valid_lengths": np.asarray([past_valid_length], dtype=np.int32),
}
for inp_name in self.tts_meta["onnx"]["decode_input_names"][2:]:
decode_feeds[inp_name] = past_by_name[inp_name]
dec_outputs = self._run_session("decode", decode_feeds)
dec_out_names = [o.name for o in self.sessions["decode"].get_outputs()]
named_dec = dict(zip(dec_out_names, dec_outputs, strict=True))
global_hidden = _extract_last_hidden(named_dec["global_hidden"])
past_valid_length += 1
if use_dynamic_kv:
past_by_name = {
out_name.replace("present_", "past_"): named_dec[out_name]
for out_name in self.tts_meta["onnx"]["decode_output_names"][1:]
}
else:
new_pos = past_valid_length - 1 # 新 token 的逻辑位置(0-indexed)
new_padded: dict[str, np.ndarray] = {}
for out_name in self.tts_meta["onnx"]["decode_output_names"][1:]:
past_name = out_name.replace("present_", "past_")
v = named_dec[out_name]
buf = np.zeros((1, global_kv_len, 12, 64), dtype=np.float32)
if new_pos < global_kv_len:
buf[:, :new_pos, :, :] = v[:, :new_pos, :, :]
buf[:, new_pos, :, :] = v[:, global_kv_len, :, :]
else:
buf[:, :global_kv_len - 1, :, :] = v[:, 1:global_kv_len, :, :]
buf[:, global_kv_len - 1, :, :] = v[:, global_kv_len, :, :]
new_padded[past_name] = buf
past_by_name = new_padded
if past_valid_length > global_kv_len:
past_valid_length = global_kv_len
if on_frame is not None:
on_frame(generated_frames, step, frame)
logging.info("[gen] 生成完毕: 共 %d 帧 (max_new_frames=%s)", len(generated_frames), generation_defaults["max_new_frames"])
return generated_frames
def decode_full_audio(self, generated_frames: list[list[int]]) -> tuple[list[np.ndarray], int]:
if not generated_frames:
return [], 0
static_len = self._static_decode_code_length()
n_vq = int(self.manifest["tts_config"]["n_vq"])
frame_count = len(generated_frames)
if static_len is None or frame_count <= static_len:
return self._decode_one_segment(generated_frames, frame_count, n_vq, static_len)
all_audio: list[np.ndarray] = []
total_samples = 0
for seg_start in range(0, frame_count, static_len):
seg_end = min(seg_start + static_len, frame_count)
seg_frames = generated_frames[seg_start:seg_end]
seg_count = seg_end - seg_start
channel_arrays, seg_samples = self._decode_one_segment(seg_frames, seg_count, n_vq, static_len)
all_audio.extend(channel_arrays)
total_samples += seg_samples
return all_audio, total_samples
def _decode_one_segment(
self,
frames: list[list[int]],
frame_count: int,
n_vq: int,
static_len: int | None,
) -> tuple[list[np.ndarray], int]:
max_len = static_len if static_len is not None else frame_count
audio_codes = np.zeros((1, max_len, n_vq), dtype=np.int32)
for fi, frame in enumerate(frames):
for ci in range(n_vq):
audio_codes[0, fi, ci] = int(frame[ci] if ci < len(frame) else 0)
outputs = self._run_session(
"codec_decode",
{
"audio_codes": audio_codes,
"audio_code_lengths": np.asarray([frame_count], dtype=np.int32),
},
)
out_names = [o.name for o in self.sessions["codec_decode"].get_outputs()]
named = dict(zip(out_names, outputs, strict=True))
audio_length = int(named["audio_lengths"].reshape(-1)[0])
return _slice_channel_major_audio(named["audio"], 0, audio_length), audio_length
def encode_reference_audio(self, reference_audio_path: str | Path) -> list[list[int]]:
codec_cfg = self.codec_meta["codec_config"]
target_sr = int(codec_cfg["sample_rate"])
target_ch = int(codec_cfg["channels"])
waveform = _load_audio_numpy(reference_audio_path, target_sr, target_ch)
waveform_length = int(waveform.shape[-1])
if bool(codec_cfg.get("pad_reference_audio_to_downsample_rate", False)):
ds_rate = int(codec_cfg["downsample_rate"])
remainder = waveform_length % ds_rate
if remainder:
pad_len = ds_rate - remainder
waveform = np.concatenate(
[waveform, np.zeros((1, target_ch, pad_len), dtype=np.float32)], axis=2
)
waveform_length += pad_len
static_wl = self._static_encode_waveform_length()
num_quantizers = int(codec_cfg["num_quantizers"])
prompt_audio_codes: list[list[int]] = []
out_names = [o.name for o in self.sessions["codec_encode"].get_outputs()]
if static_wl and static_wl > 0:
for start in range(0, waveform_length, static_wl):
chunk = waveform[..., start : start + static_wl]
chunk_len = int(chunk.shape[-1])
if chunk_len <= 0:
continue
if chunk_len < static_wl:
padded = np.zeros((1, target_ch, static_wl), dtype=np.float32)
padded[..., :chunk_len] = chunk
chunk = padded
outputs = self._run_session(
"codec_encode",
{
"waveform": chunk,
"input_lengths": np.asarray([chunk_len], dtype=np.int32),
},
)
named = dict(zip(out_names, outputs, strict=True))
codes = np.asarray(named["audio_codes"], dtype=np.int32)
code_len = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0])
for fi in range(code_len):
prompt_audio_codes.append([int(codes[0, fi, qi]) for qi in range(num_quantizers)])
else:
outputs = self._run_session(
"codec_encode",
{
"waveform": waveform,
"input_lengths": np.asarray([waveform_length], dtype=np.int32),
},
)
named = dict(zip(out_names, outputs, strict=True))
codes = np.asarray(named["audio_codes"], dtype=np.int32)
code_len = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0])
for fi in range(code_len):
prompt_audio_codes.append([int(codes[0, fi, qi]) for qi in range(num_quantizers)])
return prompt_audio_codes
def resolve_prompt_audio_codes(
self, *, voice: str | None, prompt_audio_path: str | Path | None
) -> list[list[int]]:
if prompt_audio_path:
return self.encode_reference_audio(prompt_audio_path)
resolved_voice = str(voice or self.manifest["builtin_voices"][0]["voice"])
for v in self.manifest["builtin_voices"]:
if v["voice"] == resolved_voice:
return list(v["prompt_audio_codes"])
raise ValueError(f"Built-in voice not found: {resolved_voice}")
def list_builtin_voices(self) -> list[dict[str, Any]]:
return list(self.manifest["builtin_voices"])
def synthesize(
self,
*,
text: str,
voice: str | None = None,
prompt_audio_path: str | Path | None = None,
output_audio_path: str | Path | None = None,
sample_mode: str | None = None,
do_sample: bool = True,
streaming: bool = True,
max_new_frames: int | None = None,
voice_clone_max_text_tokens: int = 75,
seed: int | None = None,
) -> dict[str, Any]:
self._reset_timing_stats()
gen = self.manifest["generation_defaults"]
if max_new_frames is not None:
gen["max_new_frames"] = int(max_new_frames)
static_codec_len = self._static_decode_code_length()
if static_codec_len is not None and gen.get("max_new_frames", float("inf")) > static_codec_len:
gen["max_new_frames"] = static_codec_len
logging.info("capped max_new_frames=%d (static codec limit)", static_codec_len)
normalized_mode = self._normalize_sample_mode(sample_mode, do_sample)
gen["sample_mode"] = normalized_mode
gen["do_sample"] = normalized_mode != SAMPLE_MODE_GREEDY
if seed is not None:
self.rng = np.random.default_rng(int(seed))
infer_start_time = time.perf_counter()
try:
text_chunks = self._split_text_by_sentence_punctuation(text)
prompt_audio_codes = self.resolve_prompt_audio_codes(
voice=voice, prompt_audio_path=prompt_audio_path
)
sample_rate = int(self.codec_meta["codec_config"]["sample_rate"])
channels = int(self.codec_meta["codec_config"]["channels"])
static_decode_code_length = self._static_decode_code_length()
out_path = (
Path(output_audio_path).expanduser().resolve()
if output_audio_path
else (Path("output.wav")).resolve()
)
all_waveforms: list[np.ndarray] = []
all_generated_frames: list[list[int]] = []
final_text_chunks: list[str] = []
pending_text_chunks = list(text_chunks)
wav_file = None
if streaming:
out_path.parent.mkdir(parents=True, exist_ok=True)
wav_file = wave.open(str(out_path), "wb")
wav_file.setnchannels(channels)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
try:
while pending_text_chunks:
chunk_idx = len(final_text_chunks) + 1
total_chunks = len(final_text_chunks) + len(pending_text_chunks)
chunk_text = pending_text_chunks.pop(0)
logging.info("chunk %d/%d: %r", chunk_idx, total_chunks, chunk_text[:40])
text_token_ids = self.encode_text(chunk_text)
request_rows = self.build_voice_clone_request_rows(prompt_audio_codes, text_token_ids)
generated_frames = self.generate_audio_frames(request_rows)
channel_arrays, _ = self.decode_full_audio(generated_frames)
waveform = np.asarray(_merge_audio_channels(channel_arrays), dtype=np.float32)
all_waveforms.append(waveform)
all_generated_frames.extend(generated_frames)
final_text_chunks.append(chunk_text)
if wav_file is not None and waveform.size > 0:
clipped = np.clip(waveform, -1.0, 1.0)
pcm16 = np.round(clipped * 32767.0).astype(np.int16)
wav_file.writeframes(pcm16.tobytes())
logging.info(
"stream emitted chunk %d: samples=%d duration=%.3fs",
len(final_text_chunks),
int(waveform.shape[0]),
float(waveform.shape[0]) / float(sample_rate),
)
if pending_text_chunks:
pause_words = len([w for w in chunk_text.strip().split() if w])
pause_sec = 0.40 if pause_words <= 4 else 0.24
pause_samples = max(0, int(round(sample_rate * pause_sec)))
if pause_samples > 0:
pause_waveform = np.zeros((pause_samples, channels), dtype=np.float32)
all_waveforms.append(pause_waveform)
if wav_file is not None:
wav_file.writeframes(np.zeros((pause_samples, channels), dtype=np.int16).tobytes())
finally:
if wav_file is not None:
wav_file.close()
final_waveform = _concat_waveforms(all_waveforms)
total_infer_time = time.perf_counter() - infer_start_time
audio_duration_sec = (
float(final_waveform.shape[0]) / float(sample_rate)
if sample_rate > 0 and final_waveform.size > 0
else 0.0
)
rtf = (total_infer_time / audio_duration_sec) if audio_duration_sec > 0 else float("inf")
model_time_total = sum(self._model_time_stats.values())
audio_path = out_path if streaming else _write_waveform_to_wav(out_path, final_waveform, sample_rate)
logging.info("已保存 %s sample_rate=%s frames=%s", audio_path, sample_rate, len(all_generated_frames))
logging.info(
"[timing] total_infer_time=%.3fs audio_duration=%.3fs rtf=%.4f model_time_total=%.3fs",
total_infer_time,
audio_duration_sec,
rtf,
model_time_total,
)
for key in sorted(self._model_display_names):
model_name = self._model_display_names[key]
logging.info(
"[timing] session=%s model=%s calls=%d total=%.3fs",
key,
model_name,
self._model_call_stats.get(key, 0),
self._model_time_stats.get(key, 0.0),
)
return {
"audio_path": str(audio_path),
"waveform": final_waveform,
"sample_rate": sample_rate,
"audio_token_ids": np.asarray(all_generated_frames, dtype=np.int32),
"text_chunks": final_text_chunks,
"sample_mode": normalized_mode,
"do_sample": normalized_mode != SAMPLE_MODE_GREEDY,
"streaming": bool(streaming),
"timing": {
"total_infer_time_sec": total_infer_time,
"audio_duration_sec": audio_duration_sec,
"rtf": rtf,
"model_time_total_sec": model_time_total,
"per_model_time_sec": dict(self._model_time_stats),
"per_model_calls": dict(self._model_call_stats),
"per_model_display_name": dict(self._model_display_names),
"used_model_keys": list(self._used_model_keys),
},
}
finally:
self._log_used_model_summary()