from __future__ import annotations import json import logging import time import unicodedata import wave from collections import defaultdict from pathlib import Path from typing import Any, Callable, Optional, Sequence import numpy as np import sentencepiece as spm from scripts.axe_session import AxeSession try: import onnxruntime as ort _HAVE_ORT = True except ImportError: _HAVE_ORT = False SAMPLE_MODE_GREEDY = "greedy" SAMPLE_MODE_FIXED = "fixed" SAMPLE_MODE_FULL = "full" SENTENCE_END_PUNCTUATION = set(".!?。!?") CLAUSE_SPLIT_PUNCTUATION = set(",,、;;::") CLOSING_PUNCTUATION = set(['"', "'", "\u2019", "\u201d", ")", "]", "}", ")", "】", "》", "」", "』"]) _MANIFEST_CANDIDATES = ("browser_poc_manifest.json",) _FULL8_IO_MAP: dict[str, tuple[list[str], list[str]]] = { "prefill": ( ["input_ids", "attention_mask"], ["global_hidden", "present_key_0", "present_value_0", "present_key_1", "present_value_1", "present_key_2", "present_value_2", "present_key_3", "present_value_3", "present_key_4", "present_value_4", "present_key_5", "present_value_5", "present_key_6", "present_value_6", "present_key_7", "present_value_7", "present_key_8", "present_value_8", "present_key_9", "present_value_9", "present_key_10", "present_value_10", "present_key_11", "present_value_11"], ), "decode": ( ["input_ids", "past_valid_lengths", "past_key_0", "past_value_0", "past_key_1", "past_value_1", "past_key_2", "past_value_2", "past_key_3", "past_value_3", "past_key_4", "past_value_4", "past_key_5", "past_value_5", "past_key_6", "past_value_6", "past_key_7", "past_value_7", "past_key_8", "past_value_8", "past_key_9", "past_value_9", "past_key_10", "past_value_10", "past_key_11", "past_value_11"], ["global_hidden", "present_key_0", "present_value_0", "present_key_1", "present_value_1", "present_key_2", "present_value_2", "present_key_3", "present_value_3", "present_key_4", "present_value_4", "present_key_5", "present_value_5", "present_key_6", "present_value_6", "present_key_7", "present_value_7", "present_key_8", "present_value_8", "present_key_9", "present_value_9", "present_key_10", "present_value_10", "present_key_11", "present_value_11"], ), "local_decoder": ( ["global_hidden", "text_token_id", "audio_prefix_token_ids"], ["text_logits", "audio_logits"], ), "local_fixed_sampled_frame": ( ["global_hidden", "repetition_seen_mask", "assistant_random_u", "audio_random_u"], ["should_continue", "frame_token_ids"], ), "codec_encode": ( ["waveform", "input_lengths"], ["audio_codes", "audio_code_lengths"], ), "codec_decode": ( ["audio_codes", "audio_code_lengths"], ["audio", "audio_lengths"], ), } _AXMODEL_FILE_MAP: dict[str, str] = { "prefill": "tts_prefill.axmodel", "decode": "tts_decode_step.axmodel", "local_decoder": "tts_local_decoder.axmodel", "local_fixed_sampled_frame": "tts_local_fixed_sampled_frame.axmodel", "codec_encode": "codec_encode.axmodel", "codec_decode": "codec_decode.axmodel", } _ONNX_FILE_MAP: dict[str, str] = { "decode": "tts_decode_step.onnx", } def _argmax(values: np.ndarray) -> int: return int(np.argmax(values)) def _softmax(values: np.ndarray) -> np.ndarray: max_v = float(np.max(values)) shifted = np.asarray(values - max_v, dtype=np.float64) exps = np.exp(shifted) return exps / np.sum(exps) def _apply_repetition_penalty( values: np.ndarray, previous_token_ids: list[int], penalty: float ) -> np.ndarray: if not previous_token_ids or penalty == 1.0: return values result = values.copy() for tid in set(int(t) for t in previous_token_ids): if 0 <= tid < result.shape[0]: result[tid] = result[tid] * penalty if result[tid] < 0 else result[tid] / penalty return result def _argmax_with_repetition_penalty( values: np.ndarray, previous_token_set: set[int], penalty: float ) -> int: best_idx, best_val = 0, float("-inf") apply_penalty = bool(previous_token_set) and penalty != 1.0 for idx, v in enumerate(values): score = float(v) if apply_penalty and idx in previous_token_set: score = score * penalty if score < 0 else score / penalty if score > best_val: best_val = score best_idx = idx return int(best_idx) def _sample_from_scores( values: np.ndarray, *, do_sample: bool, temperature: float, top_k: int, top_p: float, rng: np.random.Generator, ) -> int: if not do_sample: return _argmax(values) if temperature <= 0: raise ValueError("temperature must be positive when do_sample=True") scores = np.asarray(values, dtype=np.float32).copy() / float(temperature) if top_k > 0 and top_k < scores.shape[0]: threshold = float(np.sort(scores)[::-1][top_k - 1]) scores[scores < threshold] = float("-inf") if 0 < top_p < 1: indexed = sorted(enumerate(scores.tolist()), key=lambda x: x[1], reverse=True) sorted_scores = np.asarray([x[1] for x in indexed], dtype=np.float32) sorted_probs = _softmax(sorted_scores) remove = [False] * len(indexed) cumulative = 0.0 for i, p in enumerate(sorted_probs): cumulative += float(p) if cumulative > float(top_p): remove[i] = True for i in range(len(remove) - 1, 0, -1): remove[i] = remove[i - 1] if remove: remove[0] = False for i, should_remove in enumerate(remove): if should_remove: scores[indexed[i][0]] = float("-inf") probs = _softmax(scores) rv = float(rng.random()) for i, p in enumerate(probs): rv -= float(p) if rv <= 0: return int(i) return _argmax(scores) def _sample_assistant_text_token( text_logits: np.ndarray, manifest: dict[str, Any], generation_defaults: dict[str, Any], rng: np.random.Generator, ) -> int: candidate_ids = np.asarray( [int(manifest["tts_config"]["audio_assistant_slot_token_id"]), int(manifest["tts_config"]["audio_end_token_id"])], dtype=np.int32, ) candidate_scores = text_logits[candidate_ids] sampled_idx = _sample_from_scores( candidate_scores, do_sample=False, # deterministic: always pick the higher-scoring token temperature=float(generation_defaults["text_temperature"]), top_k=min(int(generation_defaults["text_top_k"]), int(candidate_scores.shape[0])), top_p=float(generation_defaults["text_top_p"]), rng=rng, ) return int(candidate_ids[sampled_idx]) def _sample_audio_token( audio_logits: np.ndarray, previous_token_ids: list[int], previous_token_set: set[int], generation_defaults: dict[str, Any], rng: np.random.Generator, ) -> int: penalty = float(generation_defaults["audio_repetition_penalty"]) if not bool(generation_defaults["do_sample"]): return _argmax_with_repetition_penalty(audio_logits, previous_token_set, penalty) penalized = _apply_repetition_penalty(audio_logits, previous_token_ids, penalty) return _sample_from_scores( penalized, do_sample=True, temperature=float(generation_defaults["audio_temperature"]), top_k=int(generation_defaults["audio_top_k"]), top_p=float(generation_defaults["audio_top_p"]), rng=rng, ) def _flatten3d_int32(nested: list[list[list[int]]]) -> tuple[np.ndarray, list[int]]: d0, d1, d2 = len(nested), len(nested[0]), len(nested[0][0]) data = np.zeros((d0 * d1 * d2,), dtype=np.int32) offset = 0 for i in range(d0): for j in range(d1): for k in range(d2): data[offset] = int(nested[i][j][k]) offset += 1 return data, [d0, d1, d2] def _flatten2d_int32(nested: list[list[int]]) -> tuple[np.ndarray, list[int]]: d0, d1 = len(nested), len(nested[0]) data = np.zeros((d0 * d1,), dtype=np.int32) offset = 0 for i in range(d0): for j in range(d1): data[offset] = int(nested[i][j]) offset += 1 return data, [d0, d1] def _extract_last_hidden(hidden: np.ndarray) -> np.ndarray: if hidden.ndim == 2: return hidden.astype(np.float32, copy=False) if hidden.ndim != 3 or hidden.shape[0] != 1: raise ValueError(f"Unexpected global_hidden shape: {hidden.shape}") return hidden[:, -1, :].astype(np.float32, copy=False) def _slice_channel_major_audio( audio: np.ndarray, start_sample: int = 0, end_sample: int | None = None ) -> list[np.ndarray]: if audio.ndim != 3 or audio.shape[0] != 1: raise ValueError(f"Unexpected audio tensor shape: {audio.shape}") channels = int(audio.shape[1]) total = int(audio.shape[2]) start = max(0, int(start_sample)) end = total if end_sample is None else max(start, min(int(end_sample), total)) return [audio[0, c, start:end].astype(np.float32, copy=False) for c in range(channels)] def _merge_audio_channels(channel_arrays: list[np.ndarray]) -> np.ndarray: if not channel_arrays: return np.zeros((0, 1), dtype=np.float32) if len(channel_arrays) == 1: return np.asarray(channel_arrays[0], dtype=np.float32).reshape(-1, 1) min_len = min(int(c.shape[0]) for c in channel_arrays) return np.stack([np.asarray(c[:min_len], dtype=np.float32) for c in channel_arrays], axis=1) def _concat_waveforms(waveforms: list[np.ndarray]) -> np.ndarray: if not waveforms: return np.zeros((0, 1), dtype=np.float32) non_empty = [w for w in waveforms if w.size > 0] if not non_empty: n_ch = int(waveforms[0].shape[1]) if waveforms[0].ndim == 2 else 1 return np.zeros((0, n_ch), dtype=np.float32) return np.concatenate(non_empty, axis=0) def _write_waveform_to_wav(path: str | Path, waveform: np.ndarray, sample_rate: int) -> Path: output_path = Path(path).expanduser().resolve() output_path.parent.mkdir(parents=True, exist_ok=True) audio = np.asarray(waveform, dtype=np.float32) if audio.ndim == 1: audio = audio.reshape(-1, 1) clipped = np.clip(audio, -1.0, 1.0) pcm16 = np.round(clipped * 32767.0).astype(np.int16) with wave.open(str(output_path), "wb") as wf: wf.setnchannels(int(pcm16.shape[1])) wf.setsampwidth(2) wf.setframerate(int(sample_rate)) wf.writeframes(pcm16.tobytes()) return output_path def _load_audio_numpy(path: str | Path, target_sample_rate: int, target_channels: int) -> np.ndarray: path = Path(path).expanduser().resolve() waveform: np.ndarray | None = None src_sr: int = 0 try: import soundfile as sf data, src_sr = sf.read(str(path), dtype="float32", always_2d=True) # (samples, channels) waveform = data.T # (channels, samples) except Exception: pass if waveform is None: with wave.open(str(path), "rb") as wf: src_sr = wf.getframerate() n_ch = wf.getnchannels() sw = wf.getsampwidth() raw = wf.readframes(wf.getnframes()) if sw == 2: pcm = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 elif sw == 4: pcm = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0 elif sw == 1: pcm = (np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0) / 128.0 else: raise ValueError(f"Unsupported sample width: {sw}") waveform = pcm.reshape(-1, n_ch).T # (channels, samples) assert waveform is not None src_channels = int(waveform.shape[0]) if src_channels != target_channels: if src_channels == 1 and target_channels > 1: waveform = np.repeat(waveform, target_channels, axis=0) elif src_channels > 1 and target_channels == 1: waveform = waveform.mean(axis=0, keepdims=True) else: raise ValueError(f"Unsupported channel conversion: {src_channels} → {target_channels}") if src_sr != target_sample_rate: try: from scipy.signal import resample_poly from math import gcd g = gcd(int(target_sample_rate), int(src_sr)) up = int(target_sample_rate) // g down = int(src_sr) // g resampled = [] for ch in range(waveform.shape[0]): resampled.append(resample_poly(waveform[ch], up, down).astype(np.float32)) waveform = np.stack(resampled, axis=0) except ImportError: logging.warning("scipy not available, using linear interpolation for audio resampling") old_len = int(waveform.shape[1]) new_len = int(round(old_len * target_sample_rate / src_sr)) old_idx = np.linspace(0, old_len - 1, new_len) resampled = [] for ch in range(waveform.shape[0]): resampled.append(np.interp(old_idx, np.arange(old_len), waveform[ch]).astype(np.float32)) waveform = np.stack(resampled, axis=0) return waveform[np.newaxis, :, :] # (1, channels, samples) def _contains_cjk(text: str) -> bool: for ch in str(text or ""): if ( "\u4e00" <= ch <= "\u9fff" or "\u3400" <= ch <= "\u4dbf" or "\u3040" <= ch <= "\u30ff" or "\uac00" <= ch <= "\ud7af" ): return True return False def _normalize_punctuation(text: str) -> str: _WORD_INTERNAL = frozenset(["'", "\u2019", "-", "\u2013"]) # don't touch these result: list[str] = [] prev_is_punct: bool = False for ch in text: is_sent_end = ch in SENTENCE_END_PUNCTUATION is_punct = is_sent_end or ( ch not in _WORD_INTERNAL and unicodedata.category(ch).startswith("P") ) if is_punct: if prev_is_punct: if is_sent_end: result[-1] = ch else: result.append(ch if is_sent_end else ",") prev_is_punct = True else: result.append(ch) prev_is_punct = False return "".join(result) def _prepare_text_for_sentence_chunking(text: str) -> str: t = str(text or "").strip() if not t: raise ValueError("Text prompt cannot be empty.") t = t.replace("\r", " ").replace("\n", " ") while " " in t: t = t.replace(" ", " ") t = _normalize_punctuation(t) while ", ," in t: t = t.replace(", ,", ",") if _contains_cjk(t): if t[-1] not in SENTENCE_END_PUNCTUATION: t += "。" return t if t[:1].islower(): t = t[:1].upper() + t[1:] if t[-1].isalnum(): t += "." if len([x for x in t.split() if x]) < 5: t = f" {t}" return t def _split_text_by_punctuation(text: str, punctuation: set[str]) -> list[str]: sentences: list[str] = [] current: list[str] = [] idx = 0 t = str(text or "") while idx < len(t): ch = t[idx] current.append(ch) if ch in punctuation: lookahead = idx + 1 while lookahead < len(t) and t[lookahead] in CLOSING_PUNCTUATION: current.append(t[lookahead]) lookahead += 1 sentence = "".join(current).strip() if sentence: sentences.append(sentence) current.clear() while lookahead < len(t) and t[lookahead].isspace(): lookahead += 1 idx = lookahead continue idx += 1 tail = "".join(current).strip() if tail: sentences.append(tail) return sentences def _join_sentence_parts(left: str, right: str) -> str: if not left: return right if not right: return left if _contains_cjk(left) or _contains_cjk(right): return left + right return f"{left} {right}" class AxTtsRuntime: def __init__( self, config_dir: str | Path, axmodel_dir: str | Path, *, onnx_dir: str | Path | None = None, use_onnx_decode: bool = False, max_new_frames: int | None = None, do_sample: bool = True, sample_mode: str | None = None, ) -> None: self.config_dir = Path(config_dir).expanduser().resolve() self.axmodel_dir = Path(axmodel_dir).expanduser().resolve() self.use_onnx_decode = bool(use_onnx_decode) if onnx_dir is not None: self.onnx_dir = Path(onnx_dir).expanduser().resolve() else: self.onnx_dir = self.axmodel_dir.parent / "onnxmodels" self.manifest_path = self._find_manifest() self.manifest_dir = self.manifest_path.parent self.manifest: dict[str, Any] = json.loads( self.manifest_path.read_text(encoding="utf-8") ) gen = self.manifest["generation_defaults"] if max_new_frames is not None: gen["max_new_frames"] = int(max_new_frames) if do_sample is not None: gen["do_sample"] = bool(do_sample) raw_mode = sample_mode if sample_mode is not None else gen.get("sample_mode", "fixed") gen["sample_mode"] = self._normalize_sample_mode(raw_mode, bool(gen["do_sample"])) gen["do_sample"] = gen["sample_mode"] != SAMPLE_MODE_GREEDY self.tts_meta: dict[str, Any] = json.loads( self._resolve_path(self.manifest["model_files"]["tts_meta"]).read_text("utf-8") ) self.codec_meta: dict[str, Any] = json.loads( self._resolve_path(self.manifest["model_files"]["codec_meta"]).read_text("utf-8") ) self.tts_static_shapes = dict(self.tts_meta.get("static_shapes", {})) self.codec_static_shapes = dict(self.codec_meta.get("static_shapes", {})) tok_rel = str(self.manifest["model_files"].get("tokenizer_model", "tokenizer.model")) self.sp_model = spm.SentencePieceProcessor(model_file=str(self._resolve_path(tok_rel))) self.rng = np.random.default_rng(1234) self._model_display_names: dict[str, str] = {} self.sessions: dict = self._create_sessions() self._model_time_stats: dict[str, float] = defaultdict(float) self._model_call_stats: dict[str, int] = defaultdict(int) logging.info("AxTtsRuntime ready: %d sessions loaded", len(self.sessions)) def _find_manifest(self) -> Path: for name in _MANIFEST_CANDIDATES: p = (self.config_dir / name).resolve() if p.is_file(): return p raise FileNotFoundError( f"browser_poc_manifest.json not found in {self.config_dir}" ) def _resolve_path(self, rel: str | Path) -> Path: p = (self.manifest_dir / Path(rel)).resolve() if p.exists(): return p fallback = (self.config_dir / Path(rel).name).resolve() if fallback.exists(): return fallback return p @staticmethod def _normalize_sample_mode(raw: str | None, do_sample: bool = True) -> str: s = str(raw or "").strip() if s in {SAMPLE_MODE_GREEDY, SAMPLE_MODE_FIXED, SAMPLE_MODE_FULL}: return s if s == "mixed3": return SAMPLE_MODE_FIXED if do_sample else SAMPLE_MODE_GREEDY return SAMPLE_MODE_GREEDY if not do_sample else SAMPLE_MODE_FIXED def _resolve_axmodel_path(self, key: str, axmodel_name: str) -> Path: flat = self.axmodel_dir / axmodel_name if key != "local_fixed_sampled_frame": return flat if flat.exists(): return flat nested = self.axmodel_dir / "build-tts_local_fixed_sampled_frame" / axmodel_name return nested def _create_sessions(self) -> dict: sessions: dict[str, Any] = {} self._session_input_names: dict[str, list[str]] = {} if self.use_onnx_decode: onnx_name = _ONNX_FILE_MAP["decode"] onnx_path = self.onnx_dir / onnx_name if not onnx_path.exists(): raise FileNotFoundError(f"ONNX not found: {onnx_path}") if not _HAVE_ORT: raise RuntimeError("onnxruntime not installed; cannot load ONNX decode_step.") sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) sessions["decode"] = sess self._model_display_names["decode"] = f"{onnx_path.name} [onnxruntime CPU]" self._session_input_names["decode"] = [inp.name for inp in sess.get_inputs()] logging.info("loaded decode → %s (onnxruntime CPU)", onnx_path.name) else: axmodel_name = _AXMODEL_FILE_MAP["decode"] axmodel_path = self.axmodel_dir / axmodel_name if not axmodel_path.exists(): raise FileNotFoundError(f"axmodel not found: {axmodel_path}") input_names, output_names = _FULL8_IO_MAP["decode"] sess = AxeSession(axmodel_path, input_names, output_names) sessions["decode"] = sess self._model_display_names["decode"] = f"{axmodel_path.name} [axmodel]" self._session_input_names["decode"] = [inp.name for inp in sess.get_inputs()] logging.info("loaded decode → %s (axmodel)", axmodel_path.name) for key in ("prefill", "local_decoder", "local_fixed_sampled_frame", "codec_decode", "codec_encode"): if key not in _AXMODEL_FILE_MAP: continue axmodel_name = _AXMODEL_FILE_MAP[key] axmodel_path = self._resolve_axmodel_path(key, axmodel_name) if not axmodel_path.exists(): logging.warning("axmodel not found, skipping session %s: %s", key, axmodel_path) continue input_names, output_names = _FULL8_IO_MAP[key] sess = AxeSession(axmodel_path, input_names, output_names) sessions[key] = sess self._model_display_names[key] = f"{axmodel_path.name} [axmodel]" self._session_input_names[key] = [inp.name for inp in sess.get_inputs()] logging.info("loaded %s → %s (axmodel)", key, axmodel_path.name) if "codec_encode" not in sessions: logging.info("codec_encode axmodel not found; voice clone (reference audio) is disabled.") return sessions def _reset_timing_stats(self) -> None: self._model_time_stats = defaultdict(float) self._model_call_stats = defaultdict(int) self._used_model_keys: list[str] = [] self._used_model_key_set: set[str] = set() def _mark_model_used(self, key: str) -> None: if key in self._used_model_key_set: return self._used_model_key_set.add(key) self._used_model_keys.append(key) logging.info( "[usage] first_use session=%s model=%s", key, self._model_display_names.get(key, key), ) def _log_used_model_summary(self) -> None: if not getattr(self, "_used_model_keys", None): logging.info("[usage] used_models=none") return logging.info("[usage] used_models_count=%d", len(self._used_model_keys)) for idx, key in enumerate(self._used_model_keys, start=1): logging.info( "[usage] used_model[%d]=session=%s model=%s calls=%d", idx, key, self._model_display_names.get(key, key), self._model_call_stats.get(key, 0), ) def _run_session( self, key: str, input_feed: dict[str, np.ndarray], output_names: list[str] | None = None, ) -> list[np.ndarray]: sess = self.sessions[key] actual_input_names = [inp.name for inp in sess.get_inputs()] actual_input_name_set = set(actual_input_names) filtered_input_feed = { name: value for name, value in input_feed.items() if name in actual_input_name_set } missing_inputs = [name for name in actual_input_names if name not in filtered_input_feed] if missing_inputs: raise RuntimeError( f"session={key} missing required inputs: {missing_inputs}; " f"provided={sorted(input_feed.keys())}" ) extra_inputs = sorted(name for name in input_feed if name not in actual_input_name_set) self._mark_model_used(key) start = time.perf_counter() try: outputs = sess.run(output_names, filtered_input_feed) except Exception as exc: input_summary = { name: {"shape": list(np.asarray(value).shape), "dtype": str(np.asarray(value).dtype)} for name, value in filtered_input_feed.items() } raise RuntimeError( f"session={key} model={self._model_display_names.get(key, key)} run failed; " f"used_models_so_far={self._used_model_keys}; " f"expected_inputs={actual_input_names}; extra_inputs={extra_inputs}; " f"input_summary={input_summary}" ) from exc elapsed = time.perf_counter() - start self._model_time_stats[key] += elapsed self._model_call_stats[key] += 1 return outputs def encode_text(self, text: str) -> list[int]: return [int(t) for t in self.sp_model.encode(str(text or ""), out_type=int)] def count_text_tokens(self, text: str) -> int: return len(self.encode_text(text)) def split_text_by_token_budget(self, text: str, max_tokens: int) -> list[str]: remaining = str(text or "").strip() if not remaining: return [] pieces: list[str] = [] preferred_boundary = set(CLAUSE_SPLIT_PUNCTUATION) | set(SENTENCE_END_PUNCTUATION) | {" "} while remaining: if self.count_text_tokens(remaining) <= max_tokens: pieces.append(remaining) break lo, hi, best = 1, len(remaining), 1 while lo <= hi: mid = (lo + hi) // 2 cand = remaining[:mid].strip() if not cand: lo = mid + 1 continue if self.count_text_tokens(cand) <= max_tokens: best = mid lo = mid + 1 else: hi = mid - 1 cut = best prefix = remaining[:best] preferred = -1 scan_min = max(-1, len(prefix) - 25) for si in range(len(prefix) - 1, scan_min, -1): if prefix[si] in preferred_boundary: preferred = si + 1 break if preferred > 0: cut = preferred piece = remaining[:cut].strip() if not piece: piece = remaining[:best].strip() cut = best pieces.append(piece) remaining = remaining[cut:].strip() return pieces def split_voice_clone_text(self, text: str, max_tokens: int = 75) -> list[str]: t = str(text or "").strip() if not t: return [] safe_max = max(1, int(max_tokens)) prepared = _prepare_text_for_sentence_chunking(t) sents = _split_text_by_punctuation(prepared, SENTENCE_END_PUNCTUATION) or [prepared.strip()] slices: list[tuple[int, str]] = [] for sent in sents: s = sent.strip() if not s: continue cnt = self.count_text_tokens(s) if cnt <= safe_max: slices.append((cnt, s)) continue clauses = _split_text_by_punctuation(s, CLAUSE_SPLIT_PUNCTUATION) if len(clauses) <= 1: clauses = [s] for clause in clauses: c = clause.strip() if not c: continue ccnt = self.count_text_tokens(c) if ccnt <= safe_max: slices.append((ccnt, c)) continue for piece in self.split_text_by_token_budget(c, safe_max): p = piece.strip() if p: slices.append((self.count_text_tokens(p), p)) chunks: list[str] = [] cur_chunk, cur_cnt = "", 0 for cnt, s in slices: if not cur_chunk: cur_chunk, cur_cnt = s, cnt continue if cur_cnt + cnt > safe_max: chunks.append(cur_chunk.strip()) cur_chunk, cur_cnt = s, cnt else: cur_chunk = _join_sentence_parts(cur_chunk, s) cur_cnt = self.count_text_tokens(cur_chunk) if cur_chunk: chunks.append(cur_chunk.strip()) return chunks or [t] def _split_text_by_sentence_punctuation(self, text: str) -> list[str]: t = str(text or "").strip() if not t: return [] codec_limit = self._static_decode_code_length() if codec_limit is not None: _safe_text_tokens = max(1, int(codec_limit / 3 * 0.9)) _min_tokens = max(1, int(_safe_text_tokens * 0.25)) else: _safe_text_tokens = 80 _min_tokens = 20 prepared = _prepare_text_for_sentence_chunking(t) sents = _split_text_by_punctuation(prepared, SENTENCE_END_PUNCTUATION) if not sents: sents = [prepared.strip()] result: list[str] = [] for sent in sents: s = sent.strip() if not s: continue cnt = self.count_text_tokens(s) if cnt <= _safe_text_tokens: result.append(s) continue clauses = _split_text_by_punctuation(s, CLAUSE_SPLIT_PUNCTUATION) if len(clauses) <= 1: clauses = [s] merged: list[str] = [] buf = "" buf_tokens = 0 for clause in clauses: c = clause.strip() if not c: continue ccnt = self.count_text_tokens(c) if ccnt >= _safe_text_tokens: if buf: merged.append(buf.strip()) buf = "" buf_tokens = 0 merged.extend( p.strip() for p in self.split_text_by_token_budget(c, _safe_text_tokens) if p.strip() ) elif buf_tokens > 0 and buf_tokens + ccnt > _safe_text_tokens: merged.append(buf.strip()) buf = c buf_tokens = ccnt else: buf = (buf + " " + c) if buf else c buf_tokens += ccnt if buf: merged.append(buf.strip()) for m in merged: mcnt = self.count_text_tokens(m) if mcnt <= _safe_text_tokens: result.append(m) else: result.extend( p.strip() for p in self.split_text_by_token_budget(m, _safe_text_tokens) if p.strip() ) return result or [t] def build_text_rows(self, token_ids: list[int]) -> list[list[int]]: row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 rows: list[list[int]] = [] for tid in token_ids: row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width row[0] = int(tid) rows.append(row) return rows def build_audio_prefix_rows( self, prompt_audio_codes: list[list[int]], slot_token_id: int | None = None ) -> list[list[int]]: row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 slot = int( self.manifest["tts_config"]["audio_user_slot_token_id"] if slot_token_id is None else slot_token_id ) rows: list[list[int]] = [] for code_row in prompt_audio_codes: row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width row[0] = slot for i in range(min(len(code_row), int(self.manifest["tts_config"]["n_vq"]))): row[i + 1] = int(code_row[i]) rows.append(row) return rows def build_voice_clone_request_rows( self, prompt_audio_codes: list[list[int]], text_token_ids: list[int] ) -> dict[str, list[list[int]]]: prefix_ids = [ *self.manifest["prompt_templates"]["user_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"]), ] suffix_ids = [ int(self.manifest["tts_config"]["audio_end_token_id"]), *self.manifest["prompt_templates"]["user_prompt_after_reference_token_ids"], *text_token_ids, *self.manifest["prompt_templates"]["assistant_prompt_prefix_token_ids"], int(self.manifest["tts_config"]["audio_start_token_id"]), ] rows = [ *self.build_text_rows(prefix_ids), *self.build_audio_prefix_rows(prompt_audio_codes), *self.build_text_rows(suffix_ids), ] return {"inputIds": rows, "attentionMask": [[1 for _ in rows]]} def _static_prefill_length(self) -> int | None: v = self.tts_static_shapes.get("prefill_seq") return None if v is None else int(v) def _static_decode_code_length(self) -> int | None: v = self.codec_static_shapes.get("decode_code_length") return None if v is None else int(v) def _static_encode_waveform_length(self) -> int | None: v = self.codec_static_shapes.get("waveform_length") return None if v is None else int(v) def _pad_request_rows( self, request_rows: dict[str, list[list[int]]] ) -> tuple[dict[str, list[list[int]]], int]: static_len = self._static_prefill_length() actual_len = len(request_rows["inputIds"]) if static_len is None: return request_rows, actual_len if actual_len > static_len: raise ValueError( f"static prefill_seq={static_len} < request length={actual_len}. " "Reduce prompt/text length or re-export with larger --sample-seq-len." ) if actual_len == static_len: return request_rows, actual_len row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 pad_row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width pad_row[0] = int(self.manifest["tts_config"]["pad_token_id"]) padded_rows = [list(r) for r in request_rows["inputIds"]] padded_rows.extend([list(pad_row) for _ in range(static_len - actual_len)]) padded_mask = [1] * actual_len + [0] * (static_len - actual_len) return {"inputIds": padded_rows, "attentionMask": [padded_mask]}, actual_len def _pad_kv( self, kv_dict: dict[str, np.ndarray], target_len: int ) -> dict[str, np.ndarray]: result: dict[str, np.ndarray] = {} for key, val in kv_dict.items(): cur = val.shape[1] if cur > target_len: padded = val[:, :target_len, :, :].astype(np.float32) elif cur < target_len: pad = np.zeros((1, target_len - cur, *val.shape[2:]), dtype=np.float32) padded = np.concatenate([val.astype(np.float32), pad], axis=1) else: padded = val.astype(np.float32) result[key] = padded return result def run_local_fixed_sampled_frame( self, global_hidden: np.ndarray, *, previous_token_sets_by_channel: list[set[int]], ) -> tuple[bool, list[int]]: n_vq = int(self.manifest["tts_config"]["n_vq"]) codebook_size = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0]) rep_mask = np.zeros((1, n_vq, codebook_size), dtype=np.int32) for ci, token_set in enumerate(previous_token_sets_by_channel): for tid in token_set: if 0 <= tid < codebook_size: rep_mask[0, ci, tid] = 1 sess = self.sessions["local_fixed_sampled_frame"] session_input_names = {inp.name for inp in sess.get_inputs()} input_feed: dict = { "global_hidden": global_hidden.astype(np.float32, copy=False), "repetition_seen_mask": rep_mask, } asst_u: float | None = None if "assistant_random_u" in session_input_names: asst_u = float(self.rng.random()) input_feed["assistant_random_u"] = np.array([asst_u], dtype=np.float32) if "audio_random_u" in session_input_names: input_feed["audio_random_u"] = np.asarray( [[float(self.rng.random()) for _ in range(n_vq)]], dtype=np.float32, ) outputs = self._run_session("local_fixed_sampled_frame", input_feed) out_names = [o.name for o in sess.get_outputs()] named = dict(zip(out_names, outputs, strict=True)) frame_ids = np.asarray(named["frame_token_ids"]).reshape(-1).astype(np.int32).tolist() should_continue = bool(int(np.asarray(named["should_continue"]).reshape(-1)[0])) logging.debug( "[fixed_frame] asst_u=%.4f should_continue=%s tokens=%s", asst_u if asst_u is not None else -1.0, should_continue, frame_ids[:4], ) return should_continue, [int(x) for x in frame_ids] def _can_use_local_fixed_sampled_frame(self) -> bool: sess = self.sessions.get("local_fixed_sampled_frame") if sess is None: return False session_input_names = {inp.name for inp in sess.get_inputs()} return "audio_random_u" in session_input_names def run_local_decoder( self, global_hidden: np.ndarray, text_token_id: int, frame_prefix: list[int], ) -> tuple[np.ndarray, np.ndarray]: n_vq = int(self.manifest["tts_config"]["n_vq"]) audio_pad = int(self.manifest["tts_config"]["audio_pad_token_id"]) padded_prefix = np.full((1, n_vq - 1), audio_pad, dtype=np.int32) for i in range(min(len(frame_prefix), n_vq - 1)): padded_prefix[0, i] = int(frame_prefix[i]) outputs = self._run_session( "local_decoder", { "global_hidden": global_hidden.astype(np.float32, copy=False), "text_token_id": np.asarray([int(text_token_id)], dtype=np.int32), "audio_prefix_token_ids": padded_prefix, }, ) out_names = [o.name for o in self.sessions["local_decoder"].get_outputs()] named = dict(zip(out_names, outputs, strict=True)) return named["text_logits"].reshape(-1), named["audio_logits"] def slice_audio_channel_logits(self, audio_logits: np.ndarray, channel_index: int) -> np.ndarray: per_ch = int(audio_logits.shape[-1]) flat = audio_logits.reshape(-1) return flat[channel_index * per_ch : (channel_index + 1) * per_ch] def _infer_global_kv_len(self) -> int | None: if "decode" in self.sessions: dec_sess = self.sessions["decode"] if hasattr(dec_sess, "get_inputs"): for inp in dec_sess.get_inputs(): if inp.name == "past_key_0": dims = getattr(inp, "shape", []) or [] if len(dims) >= 2 and isinstance(dims[1], str): return None break v = self.tts_static_shapes.get("kv_seq") or self.tts_static_shapes.get("global_kv_len") if v is not None: return int(v) return 512 def generate_audio_frames( self, request_rows: dict[str, list[list[int]]], on_frame: Callable | None = None, ) -> list[list[int]]: padded_rows, actual_prefill_len = self._pad_request_rows(request_rows) return self._generate_frames_with_prefill( padded_rows, actual_prefill_length=actual_prefill_len, on_frame=on_frame, ) def _generate_frames_with_prefill( self, request_rows: dict[str, list[list[int]]], *, actual_prefill_length: int, on_frame: Callable | None = None, ) -> list[list[int]]: global_kv_len = self._infer_global_kv_len() generation_defaults = self.manifest["generation_defaults"] row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 prefill_ids, prefill_dims = _flatten3d_int32([request_rows["inputIds"]]) prefill_mask, prefill_mask_dims = _flatten2d_int32(request_rows["attentionMask"]) prefill_outputs = self._run_session( "prefill", { "input_ids": prefill_ids.reshape(prefill_dims), "attention_mask": prefill_mask.reshape(prefill_mask_dims), }, ) prefill_out_names = [o.name for o in self.sessions["prefill"].get_outputs()] named_prefill = dict(zip(prefill_out_names, prefill_outputs, strict=True)) global_hidden = named_prefill["global_hidden"][:, actual_prefill_length - 1, :].astype(np.float32) past_valid_length = int(actual_prefill_length) raw_past: dict[str, np.ndarray] = { out.replace("present_", "past_"): named_prefill[out][:, :actual_prefill_length, :, :] for out in self.tts_meta["onnx"]["prefill_output_names"][1:] } use_dynamic_kv = global_kv_len is None if use_dynamic_kv: past_by_name = raw_past # shape: (1, actual_len, 12, 64),步进增长 else: past_by_name = self._pad_kv(raw_past, global_kv_len) # shape: (1, 512, 12, 64) generated_frames: list[list[int]] = [] prev_tokens = [[] for _ in range(int(self.manifest["tts_config"]["n_vq"]))] prev_tok_sets = [set() for _ in range(int(self.manifest["tts_config"]["n_vq"]))] audio_pad_id = int(self.manifest["tts_config"]["audio_pad_token_id"]) consecutive_pad_frames = 0 _MAX_CONSECUTIVE_PAD_FRAMES = 10 # 0.8s silence → force stop (guard against quantized AXModel) for step in range(int(generation_defaults["max_new_frames"])): frame: list[int] = [] if ( generation_defaults["sample_mode"] == SAMPLE_MODE_FIXED and self._can_use_local_fixed_sampled_frame() ): should_continue, frame = self.run_local_fixed_sampled_frame( global_hidden, previous_token_sets_by_channel=prev_tok_sets ) if not should_continue: logging.info("[gen] frame %d: should_continue=False → 停止生成", step) break for ci, tok in enumerate(frame): prev_tokens[ci].append(tok) prev_tok_sets[ci].add(tok) elif "local_decoder" in self.sessions: if ( generation_defaults["sample_mode"] == SAMPLE_MODE_FIXED and "local_fixed_sampled_frame" in self.sessions and step == 0 ): logging.warning( "local_fixed_sampled_frame 缺少 audio_random_u 输入,回退到 local_decoder 随机采样路径" ) local_text_logits, _ = self.run_local_decoder(global_hidden, 0, []) next_text = _sample_assistant_text_token( local_text_logits, self.manifest, generation_defaults, self.rng ) if next_text != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]): break for ci in range(int(self.manifest["tts_config"]["n_vq"])): _, audio_logits = self.run_local_decoder(global_hidden, next_text, frame) ch_logits = self.slice_audio_channel_logits(audio_logits, ci).astype(np.float32) tok = _sample_audio_token( ch_logits, prev_tokens[ci], prev_tok_sets[ci], generation_defaults, self.rng ) frame.append(tok) prev_tokens[ci].append(tok) prev_tok_sets[ci].add(tok) else: raise RuntimeError("Neither local_fixed_sampled_frame nor local_decoder session found") if not frame: break if all(tok == audio_pad_id for tok in frame): consecutive_pad_frames += 1 if consecutive_pad_frames >= _MAX_CONSECUTIVE_PAD_FRAMES: logging.info( "[gen] %d consecutive all-pad frames → force stop (trim %d silent frames)", consecutive_pad_frames, consecutive_pad_frames, ) generated_frames = generated_frames[:-(consecutive_pad_frames - 1)] if len(generated_frames) >= consecutive_pad_frames else [] break else: consecutive_pad_frames = 0 generated_frames.append(frame) next_row = np.full((1, 1, row_width), int(self.manifest["tts_config"]["audio_pad_token_id"]), dtype=np.int32) next_row[0, 0, 0] = int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]) for idx, tok in enumerate(frame): next_row[0, 0, idx + 1] = int(tok) decode_feeds: dict[str, np.ndarray] = { "input_ids": next_row, "past_valid_lengths": np.asarray([past_valid_length], dtype=np.int32), } for inp_name in self.tts_meta["onnx"]["decode_input_names"][2:]: decode_feeds[inp_name] = past_by_name[inp_name] dec_outputs = self._run_session("decode", decode_feeds) dec_out_names = [o.name for o in self.sessions["decode"].get_outputs()] named_dec = dict(zip(dec_out_names, dec_outputs, strict=True)) global_hidden = _extract_last_hidden(named_dec["global_hidden"]) past_valid_length += 1 if use_dynamic_kv: past_by_name = { out_name.replace("present_", "past_"): named_dec[out_name] for out_name in self.tts_meta["onnx"]["decode_output_names"][1:] } else: new_pos = past_valid_length - 1 # 新 token 的逻辑位置(0-indexed) new_padded: dict[str, np.ndarray] = {} for out_name in self.tts_meta["onnx"]["decode_output_names"][1:]: past_name = out_name.replace("present_", "past_") v = named_dec[out_name] buf = np.zeros((1, global_kv_len, 12, 64), dtype=np.float32) if new_pos < global_kv_len: buf[:, :new_pos, :, :] = v[:, :new_pos, :, :] buf[:, new_pos, :, :] = v[:, global_kv_len, :, :] else: buf[:, :global_kv_len - 1, :, :] = v[:, 1:global_kv_len, :, :] buf[:, global_kv_len - 1, :, :] = v[:, global_kv_len, :, :] new_padded[past_name] = buf past_by_name = new_padded if past_valid_length > global_kv_len: past_valid_length = global_kv_len if on_frame is not None: on_frame(generated_frames, step, frame) logging.info("[gen] 生成完毕: 共 %d 帧 (max_new_frames=%s)", len(generated_frames), generation_defaults["max_new_frames"]) return generated_frames def decode_full_audio(self, generated_frames: list[list[int]]) -> tuple[list[np.ndarray], int]: if not generated_frames: return [], 0 static_len = self._static_decode_code_length() n_vq = int(self.manifest["tts_config"]["n_vq"]) frame_count = len(generated_frames) if static_len is None or frame_count <= static_len: return self._decode_one_segment(generated_frames, frame_count, n_vq, static_len) all_audio: list[np.ndarray] = [] total_samples = 0 for seg_start in range(0, frame_count, static_len): seg_end = min(seg_start + static_len, frame_count) seg_frames = generated_frames[seg_start:seg_end] seg_count = seg_end - seg_start channel_arrays, seg_samples = self._decode_one_segment(seg_frames, seg_count, n_vq, static_len) all_audio.extend(channel_arrays) total_samples += seg_samples return all_audio, total_samples def _decode_one_segment( self, frames: list[list[int]], frame_count: int, n_vq: int, static_len: int | None, ) -> tuple[list[np.ndarray], int]: max_len = static_len if static_len is not None else frame_count audio_codes = np.zeros((1, max_len, n_vq), dtype=np.int32) for fi, frame in enumerate(frames): for ci in range(n_vq): audio_codes[0, fi, ci] = int(frame[ci] if ci < len(frame) else 0) outputs = self._run_session( "codec_decode", { "audio_codes": audio_codes, "audio_code_lengths": np.asarray([frame_count], dtype=np.int32), }, ) out_names = [o.name for o in self.sessions["codec_decode"].get_outputs()] named = dict(zip(out_names, outputs, strict=True)) audio_length = int(named["audio_lengths"].reshape(-1)[0]) return _slice_channel_major_audio(named["audio"], 0, audio_length), audio_length def encode_reference_audio(self, reference_audio_path: str | Path) -> list[list[int]]: codec_cfg = self.codec_meta["codec_config"] target_sr = int(codec_cfg["sample_rate"]) target_ch = int(codec_cfg["channels"]) waveform = _load_audio_numpy(reference_audio_path, target_sr, target_ch) waveform_length = int(waveform.shape[-1]) if bool(codec_cfg.get("pad_reference_audio_to_downsample_rate", False)): ds_rate = int(codec_cfg["downsample_rate"]) remainder = waveform_length % ds_rate if remainder: pad_len = ds_rate - remainder waveform = np.concatenate( [waveform, np.zeros((1, target_ch, pad_len), dtype=np.float32)], axis=2 ) waveform_length += pad_len static_wl = self._static_encode_waveform_length() num_quantizers = int(codec_cfg["num_quantizers"]) prompt_audio_codes: list[list[int]] = [] out_names = [o.name for o in self.sessions["codec_encode"].get_outputs()] if static_wl and static_wl > 0: for start in range(0, waveform_length, static_wl): chunk = waveform[..., start : start + static_wl] chunk_len = int(chunk.shape[-1]) if chunk_len <= 0: continue if chunk_len < static_wl: padded = np.zeros((1, target_ch, static_wl), dtype=np.float32) padded[..., :chunk_len] = chunk chunk = padded outputs = self._run_session( "codec_encode", { "waveform": chunk, "input_lengths": np.asarray([chunk_len], dtype=np.int32), }, ) named = dict(zip(out_names, outputs, strict=True)) codes = np.asarray(named["audio_codes"], dtype=np.int32) code_len = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0]) for fi in range(code_len): prompt_audio_codes.append([int(codes[0, fi, qi]) for qi in range(num_quantizers)]) else: outputs = self._run_session( "codec_encode", { "waveform": waveform, "input_lengths": np.asarray([waveform_length], dtype=np.int32), }, ) named = dict(zip(out_names, outputs, strict=True)) codes = np.asarray(named["audio_codes"], dtype=np.int32) code_len = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0]) for fi in range(code_len): prompt_audio_codes.append([int(codes[0, fi, qi]) for qi in range(num_quantizers)]) return prompt_audio_codes def resolve_prompt_audio_codes( self, *, voice: str | None, prompt_audio_path: str | Path | None ) -> list[list[int]]: if prompt_audio_path: return self.encode_reference_audio(prompt_audio_path) resolved_voice = str(voice or self.manifest["builtin_voices"][0]["voice"]) for v in self.manifest["builtin_voices"]: if v["voice"] == resolved_voice: return list(v["prompt_audio_codes"]) raise ValueError(f"Built-in voice not found: {resolved_voice}") def list_builtin_voices(self) -> list[dict[str, Any]]: return list(self.manifest["builtin_voices"]) def synthesize( self, *, text: str, voice: str | None = None, prompt_audio_path: str | Path | None = None, output_audio_path: str | Path | None = None, sample_mode: str | None = None, do_sample: bool = True, streaming: bool = True, max_new_frames: int | None = None, voice_clone_max_text_tokens: int = 75, seed: int | None = None, ) -> dict[str, Any]: self._reset_timing_stats() gen = self.manifest["generation_defaults"] if max_new_frames is not None: gen["max_new_frames"] = int(max_new_frames) static_codec_len = self._static_decode_code_length() if static_codec_len is not None and gen.get("max_new_frames", float("inf")) > static_codec_len: gen["max_new_frames"] = static_codec_len logging.info("capped max_new_frames=%d (static codec limit)", static_codec_len) normalized_mode = self._normalize_sample_mode(sample_mode, do_sample) gen["sample_mode"] = normalized_mode gen["do_sample"] = normalized_mode != SAMPLE_MODE_GREEDY if seed is not None: self.rng = np.random.default_rng(int(seed)) infer_start_time = time.perf_counter() try: text_chunks = self._split_text_by_sentence_punctuation(text) prompt_audio_codes = self.resolve_prompt_audio_codes( voice=voice, prompt_audio_path=prompt_audio_path ) sample_rate = int(self.codec_meta["codec_config"]["sample_rate"]) channels = int(self.codec_meta["codec_config"]["channels"]) static_decode_code_length = self._static_decode_code_length() out_path = ( Path(output_audio_path).expanduser().resolve() if output_audio_path else (Path("output.wav")).resolve() ) all_waveforms: list[np.ndarray] = [] all_generated_frames: list[list[int]] = [] final_text_chunks: list[str] = [] pending_text_chunks = list(text_chunks) wav_file = None if streaming: out_path.parent.mkdir(parents=True, exist_ok=True) wav_file = wave.open(str(out_path), "wb") wav_file.setnchannels(channels) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) try: while pending_text_chunks: chunk_idx = len(final_text_chunks) + 1 total_chunks = len(final_text_chunks) + len(pending_text_chunks) chunk_text = pending_text_chunks.pop(0) logging.info("chunk %d/%d: %r", chunk_idx, total_chunks, chunk_text[:40]) text_token_ids = self.encode_text(chunk_text) request_rows = self.build_voice_clone_request_rows(prompt_audio_codes, text_token_ids) generated_frames = self.generate_audio_frames(request_rows) channel_arrays, _ = self.decode_full_audio(generated_frames) waveform = np.asarray(_merge_audio_channels(channel_arrays), dtype=np.float32) all_waveforms.append(waveform) all_generated_frames.extend(generated_frames) final_text_chunks.append(chunk_text) if wav_file is not None and waveform.size > 0: clipped = np.clip(waveform, -1.0, 1.0) pcm16 = np.round(clipped * 32767.0).astype(np.int16) wav_file.writeframes(pcm16.tobytes()) logging.info( "stream emitted chunk %d: samples=%d duration=%.3fs", len(final_text_chunks), int(waveform.shape[0]), float(waveform.shape[0]) / float(sample_rate), ) if pending_text_chunks: pause_words = len([w for w in chunk_text.strip().split() if w]) pause_sec = 0.40 if pause_words <= 4 else 0.24 pause_samples = max(0, int(round(sample_rate * pause_sec))) if pause_samples > 0: pause_waveform = np.zeros((pause_samples, channels), dtype=np.float32) all_waveforms.append(pause_waveform) if wav_file is not None: wav_file.writeframes(np.zeros((pause_samples, channels), dtype=np.int16).tobytes()) finally: if wav_file is not None: wav_file.close() final_waveform = _concat_waveforms(all_waveforms) total_infer_time = time.perf_counter() - infer_start_time audio_duration_sec = ( float(final_waveform.shape[0]) / float(sample_rate) if sample_rate > 0 and final_waveform.size > 0 else 0.0 ) rtf = (total_infer_time / audio_duration_sec) if audio_duration_sec > 0 else float("inf") model_time_total = sum(self._model_time_stats.values()) audio_path = out_path if streaming else _write_waveform_to_wav(out_path, final_waveform, sample_rate) logging.info("已保存 %s sample_rate=%s frames=%s", audio_path, sample_rate, len(all_generated_frames)) logging.info( "[timing] total_infer_time=%.3fs audio_duration=%.3fs rtf=%.4f model_time_total=%.3fs", total_infer_time, audio_duration_sec, rtf, model_time_total, ) for key in sorted(self._model_display_names): model_name = self._model_display_names[key] logging.info( "[timing] session=%s model=%s calls=%d total=%.3fs", key, model_name, self._model_call_stats.get(key, 0), self._model_time_stats.get(key, 0.0), ) return { "audio_path": str(audio_path), "waveform": final_waveform, "sample_rate": sample_rate, "audio_token_ids": np.asarray(all_generated_frames, dtype=np.int32), "text_chunks": final_text_chunks, "sample_mode": normalized_mode, "do_sample": normalized_mode != SAMPLE_MODE_GREEDY, "streaming": bool(streaming), "timing": { "total_infer_time_sec": total_infer_time, "audio_duration_sec": audio_duration_sec, "rtf": rtf, "model_time_total_sec": model_time_total, "per_model_time_sec": dict(self._model_time_stats), "per_model_calls": dict(self._model_call_stats), "per_model_display_name": dict(self._model_display_names), "used_model_keys": list(self._used_model_keys), }, } finally: self._log_used_model_summary()