| from __future__ import annotations |
|
|
| import json |
| import logging |
| import time |
| import unicodedata |
| import wave |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Any, Callable, Optional, Sequence |
|
|
| import numpy as np |
| import sentencepiece as spm |
|
|
| from scripts.axe_session import AxeSession |
|
|
| try: |
| import onnxruntime as ort |
| _HAVE_ORT = True |
| except ImportError: |
| _HAVE_ORT = False |
|
|
| SAMPLE_MODE_GREEDY = "greedy" |
| SAMPLE_MODE_FIXED = "fixed" |
| SAMPLE_MODE_FULL = "full" |
|
|
| SENTENCE_END_PUNCTUATION = set(".!?。!?") |
| CLAUSE_SPLIT_PUNCTUATION = set(",,、;;::") |
| CLOSING_PUNCTUATION = set(['"', "'", "\u2019", "\u201d", ")", "]", "}", ")", "】", "》", "」", "』"]) |
|
|
| _MANIFEST_CANDIDATES = ("browser_poc_manifest.json",) |
|
|
| _FULL8_IO_MAP: dict[str, tuple[list[str], list[str]]] = { |
| "prefill": ( |
| ["input_ids", "attention_mask"], |
| ["global_hidden", |
| "present_key_0", "present_value_0", "present_key_1", "present_value_1", |
| "present_key_2", "present_value_2", "present_key_3", "present_value_3", |
| "present_key_4", "present_value_4", "present_key_5", "present_value_5", |
| "present_key_6", "present_value_6", "present_key_7", "present_value_7", |
| "present_key_8", "present_value_8", "present_key_9", "present_value_9", |
| "present_key_10", "present_value_10", "present_key_11", "present_value_11"], |
| ), |
| "decode": ( |
| ["input_ids", "past_valid_lengths", |
| "past_key_0", "past_value_0", "past_key_1", "past_value_1", |
| "past_key_2", "past_value_2", "past_key_3", "past_value_3", |
| "past_key_4", "past_value_4", "past_key_5", "past_value_5", |
| "past_key_6", "past_value_6", "past_key_7", "past_value_7", |
| "past_key_8", "past_value_8", "past_key_9", "past_value_9", |
| "past_key_10", "past_value_10", "past_key_11", "past_value_11"], |
| ["global_hidden", |
| "present_key_0", "present_value_0", "present_key_1", "present_value_1", |
| "present_key_2", "present_value_2", "present_key_3", "present_value_3", |
| "present_key_4", "present_value_4", "present_key_5", "present_value_5", |
| "present_key_6", "present_value_6", "present_key_7", "present_value_7", |
| "present_key_8", "present_value_8", "present_key_9", "present_value_9", |
| "present_key_10", "present_value_10", "present_key_11", "present_value_11"], |
| ), |
| "local_decoder": ( |
| ["global_hidden", "text_token_id", "audio_prefix_token_ids"], |
| ["text_logits", "audio_logits"], |
| ), |
| "local_fixed_sampled_frame": ( |
| ["global_hidden", "repetition_seen_mask", "assistant_random_u", "audio_random_u"], |
| ["should_continue", "frame_token_ids"], |
| ), |
| "codec_encode": ( |
| ["waveform", "input_lengths"], |
| ["audio_codes", "audio_code_lengths"], |
| ), |
| "codec_decode": ( |
| ["audio_codes", "audio_code_lengths"], |
| ["audio", "audio_lengths"], |
| ), |
| } |
|
|
| _AXMODEL_FILE_MAP: dict[str, str] = { |
| "prefill": "tts_prefill.axmodel", |
| "decode": "tts_decode_step.axmodel", |
| "local_decoder": "tts_local_decoder.axmodel", |
| "local_fixed_sampled_frame": "tts_local_fixed_sampled_frame.axmodel", |
| "codec_encode": "codec_encode.axmodel", |
| "codec_decode": "codec_decode.axmodel", |
| } |
|
|
| _ONNX_FILE_MAP: dict[str, str] = { |
| "decode": "tts_decode_step.onnx", |
| } |
|
|
|
|
| def _argmax(values: np.ndarray) -> int: |
| return int(np.argmax(values)) |
|
|
|
|
| def _softmax(values: np.ndarray) -> np.ndarray: |
| max_v = float(np.max(values)) |
| shifted = np.asarray(values - max_v, dtype=np.float64) |
| exps = np.exp(shifted) |
| return exps / np.sum(exps) |
|
|
|
|
| def _apply_repetition_penalty( |
| values: np.ndarray, previous_token_ids: list[int], penalty: float |
| ) -> np.ndarray: |
| if not previous_token_ids or penalty == 1.0: |
| return values |
| result = values.copy() |
| for tid in set(int(t) for t in previous_token_ids): |
| if 0 <= tid < result.shape[0]: |
| result[tid] = result[tid] * penalty if result[tid] < 0 else result[tid] / penalty |
| return result |
|
|
|
|
| def _argmax_with_repetition_penalty( |
| values: np.ndarray, previous_token_set: set[int], penalty: float |
| ) -> int: |
| best_idx, best_val = 0, float("-inf") |
| apply_penalty = bool(previous_token_set) and penalty != 1.0 |
| for idx, v in enumerate(values): |
| score = float(v) |
| if apply_penalty and idx in previous_token_set: |
| score = score * penalty if score < 0 else score / penalty |
| if score > best_val: |
| best_val = score |
| best_idx = idx |
| return int(best_idx) |
|
|
|
|
| def _sample_from_scores( |
| values: np.ndarray, |
| *, |
| do_sample: bool, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| rng: np.random.Generator, |
| ) -> int: |
| if not do_sample: |
| return _argmax(values) |
| if temperature <= 0: |
| raise ValueError("temperature must be positive when do_sample=True") |
| scores = np.asarray(values, dtype=np.float32).copy() / float(temperature) |
| if top_k > 0 and top_k < scores.shape[0]: |
| threshold = float(np.sort(scores)[::-1][top_k - 1]) |
| scores[scores < threshold] = float("-inf") |
| if 0 < top_p < 1: |
| indexed = sorted(enumerate(scores.tolist()), key=lambda x: x[1], reverse=True) |
| sorted_scores = np.asarray([x[1] for x in indexed], dtype=np.float32) |
| sorted_probs = _softmax(sorted_scores) |
| remove = [False] * len(indexed) |
| cumulative = 0.0 |
| for i, p in enumerate(sorted_probs): |
| cumulative += float(p) |
| if cumulative > float(top_p): |
| remove[i] = True |
| for i in range(len(remove) - 1, 0, -1): |
| remove[i] = remove[i - 1] |
| if remove: |
| remove[0] = False |
| for i, should_remove in enumerate(remove): |
| if should_remove: |
| scores[indexed[i][0]] = float("-inf") |
| probs = _softmax(scores) |
| rv = float(rng.random()) |
| for i, p in enumerate(probs): |
| rv -= float(p) |
| if rv <= 0: |
| return int(i) |
| return _argmax(scores) |
|
|
|
|
| def _sample_assistant_text_token( |
| text_logits: np.ndarray, |
| manifest: dict[str, Any], |
| generation_defaults: dict[str, Any], |
| rng: np.random.Generator, |
| ) -> int: |
| candidate_ids = np.asarray( |
| [int(manifest["tts_config"]["audio_assistant_slot_token_id"]), |
| int(manifest["tts_config"]["audio_end_token_id"])], |
| dtype=np.int32, |
| ) |
| candidate_scores = text_logits[candidate_ids] |
| sampled_idx = _sample_from_scores( |
| candidate_scores, |
| do_sample=False, |
| temperature=float(generation_defaults["text_temperature"]), |
| top_k=min(int(generation_defaults["text_top_k"]), int(candidate_scores.shape[0])), |
| top_p=float(generation_defaults["text_top_p"]), |
| rng=rng, |
| ) |
| return int(candidate_ids[sampled_idx]) |
|
|
|
|
| def _sample_audio_token( |
| audio_logits: np.ndarray, |
| previous_token_ids: list[int], |
| previous_token_set: set[int], |
| generation_defaults: dict[str, Any], |
| rng: np.random.Generator, |
| ) -> int: |
| penalty = float(generation_defaults["audio_repetition_penalty"]) |
| if not bool(generation_defaults["do_sample"]): |
| return _argmax_with_repetition_penalty(audio_logits, previous_token_set, penalty) |
| penalized = _apply_repetition_penalty(audio_logits, previous_token_ids, penalty) |
| return _sample_from_scores( |
| penalized, |
| do_sample=True, |
| temperature=float(generation_defaults["audio_temperature"]), |
| top_k=int(generation_defaults["audio_top_k"]), |
| top_p=float(generation_defaults["audio_top_p"]), |
| rng=rng, |
| ) |
|
|
|
|
| def _flatten3d_int32(nested: list[list[list[int]]]) -> tuple[np.ndarray, list[int]]: |
| d0, d1, d2 = len(nested), len(nested[0]), len(nested[0][0]) |
| data = np.zeros((d0 * d1 * d2,), dtype=np.int32) |
| offset = 0 |
| for i in range(d0): |
| for j in range(d1): |
| for k in range(d2): |
| data[offset] = int(nested[i][j][k]) |
| offset += 1 |
| return data, [d0, d1, d2] |
|
|
|
|
| def _flatten2d_int32(nested: list[list[int]]) -> tuple[np.ndarray, list[int]]: |
| d0, d1 = len(nested), len(nested[0]) |
| data = np.zeros((d0 * d1,), dtype=np.int32) |
| offset = 0 |
| for i in range(d0): |
| for j in range(d1): |
| data[offset] = int(nested[i][j]) |
| offset += 1 |
| return data, [d0, d1] |
|
|
|
|
| def _extract_last_hidden(hidden: np.ndarray) -> np.ndarray: |
| if hidden.ndim == 2: |
| return hidden.astype(np.float32, copy=False) |
| if hidden.ndim != 3 or hidden.shape[0] != 1: |
| raise ValueError(f"Unexpected global_hidden shape: {hidden.shape}") |
| return hidden[:, -1, :].astype(np.float32, copy=False) |
|
|
|
|
| def _slice_channel_major_audio( |
| audio: np.ndarray, start_sample: int = 0, end_sample: int | None = None |
| ) -> list[np.ndarray]: |
| if audio.ndim != 3 or audio.shape[0] != 1: |
| raise ValueError(f"Unexpected audio tensor shape: {audio.shape}") |
| channels = int(audio.shape[1]) |
| total = int(audio.shape[2]) |
| start = max(0, int(start_sample)) |
| end = total if end_sample is None else max(start, min(int(end_sample), total)) |
| return [audio[0, c, start:end].astype(np.float32, copy=False) for c in range(channels)] |
|
|
|
|
| def _merge_audio_channels(channel_arrays: list[np.ndarray]) -> np.ndarray: |
| if not channel_arrays: |
| return np.zeros((0, 1), dtype=np.float32) |
| if len(channel_arrays) == 1: |
| return np.asarray(channel_arrays[0], dtype=np.float32).reshape(-1, 1) |
| min_len = min(int(c.shape[0]) for c in channel_arrays) |
| return np.stack([np.asarray(c[:min_len], dtype=np.float32) for c in channel_arrays], axis=1) |
|
|
|
|
| def _concat_waveforms(waveforms: list[np.ndarray]) -> np.ndarray: |
| if not waveforms: |
| return np.zeros((0, 1), dtype=np.float32) |
| non_empty = [w for w in waveforms if w.size > 0] |
| if not non_empty: |
| n_ch = int(waveforms[0].shape[1]) if waveforms[0].ndim == 2 else 1 |
| return np.zeros((0, n_ch), dtype=np.float32) |
| return np.concatenate(non_empty, axis=0) |
|
|
|
|
| def _write_waveform_to_wav(path: str | Path, waveform: np.ndarray, sample_rate: int) -> Path: |
| output_path = Path(path).expanduser().resolve() |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| audio = np.asarray(waveform, dtype=np.float32) |
| if audio.ndim == 1: |
| audio = audio.reshape(-1, 1) |
| clipped = np.clip(audio, -1.0, 1.0) |
| pcm16 = np.round(clipped * 32767.0).astype(np.int16) |
| with wave.open(str(output_path), "wb") as wf: |
| wf.setnchannels(int(pcm16.shape[1])) |
| wf.setsampwidth(2) |
| wf.setframerate(int(sample_rate)) |
| wf.writeframes(pcm16.tobytes()) |
| return output_path |
|
|
|
|
| def _load_audio_numpy(path: str | Path, target_sample_rate: int, target_channels: int) -> np.ndarray: |
| path = Path(path).expanduser().resolve() |
|
|
| waveform: np.ndarray | None = None |
| src_sr: int = 0 |
|
|
| try: |
| import soundfile as sf |
| data, src_sr = sf.read(str(path), dtype="float32", always_2d=True) |
| waveform = data.T |
| except Exception: |
| pass |
|
|
| if waveform is None: |
| with wave.open(str(path), "rb") as wf: |
| src_sr = wf.getframerate() |
| n_ch = wf.getnchannels() |
| sw = wf.getsampwidth() |
| raw = wf.readframes(wf.getnframes()) |
| if sw == 2: |
| pcm = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 |
| elif sw == 4: |
| pcm = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0 |
| elif sw == 1: |
| pcm = (np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0) / 128.0 |
| else: |
| raise ValueError(f"Unsupported sample width: {sw}") |
| waveform = pcm.reshape(-1, n_ch).T |
|
|
| assert waveform is not None |
|
|
| src_channels = int(waveform.shape[0]) |
| if src_channels != target_channels: |
| if src_channels == 1 and target_channels > 1: |
| waveform = np.repeat(waveform, target_channels, axis=0) |
| elif src_channels > 1 and target_channels == 1: |
| waveform = waveform.mean(axis=0, keepdims=True) |
| else: |
| raise ValueError(f"Unsupported channel conversion: {src_channels} → {target_channels}") |
|
|
| if src_sr != target_sample_rate: |
| try: |
| from scipy.signal import resample_poly |
| from math import gcd |
| g = gcd(int(target_sample_rate), int(src_sr)) |
| up = int(target_sample_rate) // g |
| down = int(src_sr) // g |
| resampled = [] |
| for ch in range(waveform.shape[0]): |
| resampled.append(resample_poly(waveform[ch], up, down).astype(np.float32)) |
| waveform = np.stack(resampled, axis=0) |
| except ImportError: |
| logging.warning("scipy not available, using linear interpolation for audio resampling") |
| old_len = int(waveform.shape[1]) |
| new_len = int(round(old_len * target_sample_rate / src_sr)) |
| old_idx = np.linspace(0, old_len - 1, new_len) |
| resampled = [] |
| for ch in range(waveform.shape[0]): |
| resampled.append(np.interp(old_idx, np.arange(old_len), waveform[ch]).astype(np.float32)) |
| waveform = np.stack(resampled, axis=0) |
|
|
| return waveform[np.newaxis, :, :] |
|
|
|
|
| def _contains_cjk(text: str) -> bool: |
| for ch in str(text or ""): |
| if ( |
| "\u4e00" <= ch <= "\u9fff" |
| or "\u3400" <= ch <= "\u4dbf" |
| or "\u3040" <= ch <= "\u30ff" |
| or "\uac00" <= ch <= "\ud7af" |
| ): |
| return True |
| return False |
|
|
|
|
| def _normalize_punctuation(text: str) -> str: |
| _WORD_INTERNAL = frozenset(["'", "\u2019", "-", "\u2013"]) |
|
|
| result: list[str] = [] |
| prev_is_punct: bool = False |
|
|
| for ch in text: |
| is_sent_end = ch in SENTENCE_END_PUNCTUATION |
| is_punct = is_sent_end or ( |
| ch not in _WORD_INTERNAL and unicodedata.category(ch).startswith("P") |
| ) |
|
|
| if is_punct: |
| if prev_is_punct: |
| if is_sent_end: |
| result[-1] = ch |
| else: |
| result.append(ch if is_sent_end else ",") |
| prev_is_punct = True |
| else: |
| result.append(ch) |
| prev_is_punct = False |
|
|
| return "".join(result) |
|
|
|
|
| def _prepare_text_for_sentence_chunking(text: str) -> str: |
| t = str(text or "").strip() |
| if not t: |
| raise ValueError("Text prompt cannot be empty.") |
| t = t.replace("\r", " ").replace("\n", " ") |
| while " " in t: |
| t = t.replace(" ", " ") |
| t = _normalize_punctuation(t) |
| while ", ," in t: |
| t = t.replace(", ,", ",") |
| if _contains_cjk(t): |
| if t[-1] not in SENTENCE_END_PUNCTUATION: |
| t += "。" |
| return t |
| if t[:1].islower(): |
| t = t[:1].upper() + t[1:] |
| if t[-1].isalnum(): |
| t += "." |
| if len([x for x in t.split() if x]) < 5: |
| t = f" {t}" |
| return t |
|
|
|
|
| def _split_text_by_punctuation(text: str, punctuation: set[str]) -> list[str]: |
| sentences: list[str] = [] |
| current: list[str] = [] |
| idx = 0 |
| t = str(text or "") |
| while idx < len(t): |
| ch = t[idx] |
| current.append(ch) |
| if ch in punctuation: |
| lookahead = idx + 1 |
| while lookahead < len(t) and t[lookahead] in CLOSING_PUNCTUATION: |
| current.append(t[lookahead]) |
| lookahead += 1 |
| sentence = "".join(current).strip() |
| if sentence: |
| sentences.append(sentence) |
| current.clear() |
| while lookahead < len(t) and t[lookahead].isspace(): |
| lookahead += 1 |
| idx = lookahead |
| continue |
| idx += 1 |
| tail = "".join(current).strip() |
| if tail: |
| sentences.append(tail) |
| return sentences |
|
|
|
|
| def _join_sentence_parts(left: str, right: str) -> str: |
| if not left: |
| return right |
| if not right: |
| return left |
| if _contains_cjk(left) or _contains_cjk(right): |
| return left + right |
| return f"{left} {right}" |
|
|
|
|
| class AxTtsRuntime: |
|
|
| def __init__( |
| self, |
| config_dir: str | Path, |
| axmodel_dir: str | Path, |
| *, |
| onnx_dir: str | Path | None = None, |
| use_onnx_decode: bool = False, |
| max_new_frames: int | None = None, |
| do_sample: bool = True, |
| sample_mode: str | None = None, |
| ) -> None: |
| self.config_dir = Path(config_dir).expanduser().resolve() |
| self.axmodel_dir = Path(axmodel_dir).expanduser().resolve() |
| self.use_onnx_decode = bool(use_onnx_decode) |
| if onnx_dir is not None: |
| self.onnx_dir = Path(onnx_dir).expanduser().resolve() |
| else: |
| self.onnx_dir = self.axmodel_dir.parent / "onnxmodels" |
|
|
| self.manifest_path = self._find_manifest() |
| self.manifest_dir = self.manifest_path.parent |
| self.manifest: dict[str, Any] = json.loads( |
| self.manifest_path.read_text(encoding="utf-8") |
| ) |
|
|
| gen = self.manifest["generation_defaults"] |
| if max_new_frames is not None: |
| gen["max_new_frames"] = int(max_new_frames) |
| if do_sample is not None: |
| gen["do_sample"] = bool(do_sample) |
|
|
| raw_mode = sample_mode if sample_mode is not None else gen.get("sample_mode", "fixed") |
| gen["sample_mode"] = self._normalize_sample_mode(raw_mode, bool(gen["do_sample"])) |
| gen["do_sample"] = gen["sample_mode"] != SAMPLE_MODE_GREEDY |
|
|
| self.tts_meta: dict[str, Any] = json.loads( |
| self._resolve_path(self.manifest["model_files"]["tts_meta"]).read_text("utf-8") |
| ) |
| self.codec_meta: dict[str, Any] = json.loads( |
| self._resolve_path(self.manifest["model_files"]["codec_meta"]).read_text("utf-8") |
| ) |
| self.tts_static_shapes = dict(self.tts_meta.get("static_shapes", {})) |
| self.codec_static_shapes = dict(self.codec_meta.get("static_shapes", {})) |
|
|
| tok_rel = str(self.manifest["model_files"].get("tokenizer_model", "tokenizer.model")) |
| self.sp_model = spm.SentencePieceProcessor(model_file=str(self._resolve_path(tok_rel))) |
|
|
| self.rng = np.random.default_rng(1234) |
|
|
| self._model_display_names: dict[str, str] = {} |
| self.sessions: dict = self._create_sessions() |
| self._model_time_stats: dict[str, float] = defaultdict(float) |
| self._model_call_stats: dict[str, int] = defaultdict(int) |
| logging.info("AxTtsRuntime ready: %d sessions loaded", len(self.sessions)) |
|
|
|
|
| def _find_manifest(self) -> Path: |
| for name in _MANIFEST_CANDIDATES: |
| p = (self.config_dir / name).resolve() |
| if p.is_file(): |
| return p |
| raise FileNotFoundError( |
| f"browser_poc_manifest.json not found in {self.config_dir}" |
| ) |
|
|
| def _resolve_path(self, rel: str | Path) -> Path: |
| p = (self.manifest_dir / Path(rel)).resolve() |
| if p.exists(): |
| return p |
| fallback = (self.config_dir / Path(rel).name).resolve() |
| if fallback.exists(): |
| return fallback |
| return p |
|
|
| @staticmethod |
| def _normalize_sample_mode(raw: str | None, do_sample: bool = True) -> str: |
| s = str(raw or "").strip() |
| if s in {SAMPLE_MODE_GREEDY, SAMPLE_MODE_FIXED, SAMPLE_MODE_FULL}: |
| return s |
| if s == "mixed3": |
| return SAMPLE_MODE_FIXED if do_sample else SAMPLE_MODE_GREEDY |
| return SAMPLE_MODE_GREEDY if not do_sample else SAMPLE_MODE_FIXED |
|
|
| def _resolve_axmodel_path(self, key: str, axmodel_name: str) -> Path: |
| flat = self.axmodel_dir / axmodel_name |
| if key != "local_fixed_sampled_frame": |
| return flat |
| if flat.exists(): |
| return flat |
| nested = self.axmodel_dir / "build-tts_local_fixed_sampled_frame" / axmodel_name |
| return nested |
|
|
| def _create_sessions(self) -> dict: |
| sessions: dict[str, Any] = {} |
| self._session_input_names: dict[str, list[str]] = {} |
|
|
| if self.use_onnx_decode: |
| onnx_name = _ONNX_FILE_MAP["decode"] |
| onnx_path = self.onnx_dir / onnx_name |
| if not onnx_path.exists(): |
| raise FileNotFoundError(f"ONNX not found: {onnx_path}") |
| if not _HAVE_ORT: |
| raise RuntimeError("onnxruntime not installed; cannot load ONNX decode_step.") |
| sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) |
| sessions["decode"] = sess |
| self._model_display_names["decode"] = f"{onnx_path.name} [onnxruntime CPU]" |
| self._session_input_names["decode"] = [inp.name for inp in sess.get_inputs()] |
| logging.info("loaded decode → %s (onnxruntime CPU)", onnx_path.name) |
| else: |
| axmodel_name = _AXMODEL_FILE_MAP["decode"] |
| axmodel_path = self.axmodel_dir / axmodel_name |
| if not axmodel_path.exists(): |
| raise FileNotFoundError(f"axmodel not found: {axmodel_path}") |
| input_names, output_names = _FULL8_IO_MAP["decode"] |
| sess = AxeSession(axmodel_path, input_names, output_names) |
| sessions["decode"] = sess |
| self._model_display_names["decode"] = f"{axmodel_path.name} [axmodel]" |
| self._session_input_names["decode"] = [inp.name for inp in sess.get_inputs()] |
| logging.info("loaded decode → %s (axmodel)", axmodel_path.name) |
|
|
| for key in ("prefill", "local_decoder", "local_fixed_sampled_frame", "codec_decode", "codec_encode"): |
| if key not in _AXMODEL_FILE_MAP: |
| continue |
| axmodel_name = _AXMODEL_FILE_MAP[key] |
| axmodel_path = self._resolve_axmodel_path(key, axmodel_name) |
| if not axmodel_path.exists(): |
| logging.warning("axmodel not found, skipping session %s: %s", key, axmodel_path) |
| continue |
| input_names, output_names = _FULL8_IO_MAP[key] |
| sess = AxeSession(axmodel_path, input_names, output_names) |
| sessions[key] = sess |
| self._model_display_names[key] = f"{axmodel_path.name} [axmodel]" |
| self._session_input_names[key] = [inp.name for inp in sess.get_inputs()] |
| logging.info("loaded %s → %s (axmodel)", key, axmodel_path.name) |
|
|
| if "codec_encode" not in sessions: |
| logging.info("codec_encode axmodel not found; voice clone (reference audio) is disabled.") |
|
|
| return sessions |
|
|
| def _reset_timing_stats(self) -> None: |
| self._model_time_stats = defaultdict(float) |
| self._model_call_stats = defaultdict(int) |
| self._used_model_keys: list[str] = [] |
| self._used_model_key_set: set[str] = set() |
|
|
| def _mark_model_used(self, key: str) -> None: |
| if key in self._used_model_key_set: |
| return |
| self._used_model_key_set.add(key) |
| self._used_model_keys.append(key) |
| logging.info( |
| "[usage] first_use session=%s model=%s", |
| key, |
| self._model_display_names.get(key, key), |
| ) |
|
|
| def _log_used_model_summary(self) -> None: |
| if not getattr(self, "_used_model_keys", None): |
| logging.info("[usage] used_models=none") |
| return |
| logging.info("[usage] used_models_count=%d", len(self._used_model_keys)) |
| for idx, key in enumerate(self._used_model_keys, start=1): |
| logging.info( |
| "[usage] used_model[%d]=session=%s model=%s calls=%d", |
| idx, |
| key, |
| self._model_display_names.get(key, key), |
| self._model_call_stats.get(key, 0), |
| ) |
|
|
| def _run_session( |
| self, |
| key: str, |
| input_feed: dict[str, np.ndarray], |
| output_names: list[str] | None = None, |
| ) -> list[np.ndarray]: |
| sess = self.sessions[key] |
| actual_input_names = [inp.name for inp in sess.get_inputs()] |
| actual_input_name_set = set(actual_input_names) |
| filtered_input_feed = { |
| name: value for name, value in input_feed.items() if name in actual_input_name_set |
| } |
| missing_inputs = [name for name in actual_input_names if name not in filtered_input_feed] |
| if missing_inputs: |
| raise RuntimeError( |
| f"session={key} missing required inputs: {missing_inputs}; " |
| f"provided={sorted(input_feed.keys())}" |
| ) |
| extra_inputs = sorted(name for name in input_feed if name not in actual_input_name_set) |
| self._mark_model_used(key) |
| start = time.perf_counter() |
| try: |
| outputs = sess.run(output_names, filtered_input_feed) |
| except Exception as exc: |
| input_summary = { |
| name: {"shape": list(np.asarray(value).shape), "dtype": str(np.asarray(value).dtype)} |
| for name, value in filtered_input_feed.items() |
| } |
| raise RuntimeError( |
| f"session={key} model={self._model_display_names.get(key, key)} run failed; " |
| f"used_models_so_far={self._used_model_keys}; " |
| f"expected_inputs={actual_input_names}; extra_inputs={extra_inputs}; " |
| f"input_summary={input_summary}" |
| ) from exc |
| elapsed = time.perf_counter() - start |
| self._model_time_stats[key] += elapsed |
| self._model_call_stats[key] += 1 |
| return outputs |
|
|
|
|
| def encode_text(self, text: str) -> list[int]: |
| return [int(t) for t in self.sp_model.encode(str(text or ""), out_type=int)] |
|
|
| def count_text_tokens(self, text: str) -> int: |
| return len(self.encode_text(text)) |
|
|
| def split_text_by_token_budget(self, text: str, max_tokens: int) -> list[str]: |
| remaining = str(text or "").strip() |
| if not remaining: |
| return [] |
| pieces: list[str] = [] |
| preferred_boundary = set(CLAUSE_SPLIT_PUNCTUATION) | set(SENTENCE_END_PUNCTUATION) | {" "} |
| while remaining: |
| if self.count_text_tokens(remaining) <= max_tokens: |
| pieces.append(remaining) |
| break |
| lo, hi, best = 1, len(remaining), 1 |
| while lo <= hi: |
| mid = (lo + hi) // 2 |
| cand = remaining[:mid].strip() |
| if not cand: |
| lo = mid + 1 |
| continue |
| if self.count_text_tokens(cand) <= max_tokens: |
| best = mid |
| lo = mid + 1 |
| else: |
| hi = mid - 1 |
| cut = best |
| prefix = remaining[:best] |
| preferred = -1 |
| scan_min = max(-1, len(prefix) - 25) |
| for si in range(len(prefix) - 1, scan_min, -1): |
| if prefix[si] in preferred_boundary: |
| preferred = si + 1 |
| break |
| if preferred > 0: |
| cut = preferred |
| piece = remaining[:cut].strip() |
| if not piece: |
| piece = remaining[:best].strip() |
| cut = best |
| pieces.append(piece) |
| remaining = remaining[cut:].strip() |
| return pieces |
|
|
| def split_voice_clone_text(self, text: str, max_tokens: int = 75) -> list[str]: |
| t = str(text or "").strip() |
| if not t: |
| return [] |
| safe_max = max(1, int(max_tokens)) |
| prepared = _prepare_text_for_sentence_chunking(t) |
| sents = _split_text_by_punctuation(prepared, SENTENCE_END_PUNCTUATION) or [prepared.strip()] |
| slices: list[tuple[int, str]] = [] |
| for sent in sents: |
| s = sent.strip() |
| if not s: |
| continue |
| cnt = self.count_text_tokens(s) |
| if cnt <= safe_max: |
| slices.append((cnt, s)) |
| continue |
| clauses = _split_text_by_punctuation(s, CLAUSE_SPLIT_PUNCTUATION) |
| if len(clauses) <= 1: |
| clauses = [s] |
| for clause in clauses: |
| c = clause.strip() |
| if not c: |
| continue |
| ccnt = self.count_text_tokens(c) |
| if ccnt <= safe_max: |
| slices.append((ccnt, c)) |
| continue |
| for piece in self.split_text_by_token_budget(c, safe_max): |
| p = piece.strip() |
| if p: |
| slices.append((self.count_text_tokens(p), p)) |
| chunks: list[str] = [] |
| cur_chunk, cur_cnt = "", 0 |
| for cnt, s in slices: |
| if not cur_chunk: |
| cur_chunk, cur_cnt = s, cnt |
| continue |
| if cur_cnt + cnt > safe_max: |
| chunks.append(cur_chunk.strip()) |
| cur_chunk, cur_cnt = s, cnt |
| else: |
| cur_chunk = _join_sentence_parts(cur_chunk, s) |
| cur_cnt = self.count_text_tokens(cur_chunk) |
| if cur_chunk: |
| chunks.append(cur_chunk.strip()) |
| return chunks or [t] |
|
|
| def _split_text_by_sentence_punctuation(self, text: str) -> list[str]: |
| t = str(text or "").strip() |
| if not t: |
| return [] |
|
|
| codec_limit = self._static_decode_code_length() |
| if codec_limit is not None: |
| _safe_text_tokens = max(1, int(codec_limit / 3 * 0.9)) |
| _min_tokens = max(1, int(_safe_text_tokens * 0.25)) |
| else: |
| _safe_text_tokens = 80 |
| _min_tokens = 20 |
|
|
| prepared = _prepare_text_for_sentence_chunking(t) |
| sents = _split_text_by_punctuation(prepared, SENTENCE_END_PUNCTUATION) |
| if not sents: |
| sents = [prepared.strip()] |
|
|
| result: list[str] = [] |
| for sent in sents: |
| s = sent.strip() |
| if not s: |
| continue |
| cnt = self.count_text_tokens(s) |
| if cnt <= _safe_text_tokens: |
| result.append(s) |
| continue |
| clauses = _split_text_by_punctuation(s, CLAUSE_SPLIT_PUNCTUATION) |
| if len(clauses) <= 1: |
| clauses = [s] |
|
|
| merged: list[str] = [] |
| buf = "" |
| buf_tokens = 0 |
| for clause in clauses: |
| c = clause.strip() |
| if not c: |
| continue |
| ccnt = self.count_text_tokens(c) |
| if ccnt >= _safe_text_tokens: |
| if buf: |
| merged.append(buf.strip()) |
| buf = "" |
| buf_tokens = 0 |
| merged.extend( |
| p.strip() for p in self.split_text_by_token_budget(c, _safe_text_tokens) if p.strip() |
| ) |
| elif buf_tokens > 0 and buf_tokens + ccnt > _safe_text_tokens: |
| merged.append(buf.strip()) |
| buf = c |
| buf_tokens = ccnt |
| else: |
| buf = (buf + " " + c) if buf else c |
| buf_tokens += ccnt |
| if buf: |
| merged.append(buf.strip()) |
|
|
| for m in merged: |
| mcnt = self.count_text_tokens(m) |
| if mcnt <= _safe_text_tokens: |
| result.append(m) |
| else: |
| result.extend( |
| p.strip() for p in self.split_text_by_token_budget(m, _safe_text_tokens) if p.strip() |
| ) |
| return result or [t] |
|
|
|
|
| def build_text_rows(self, token_ids: list[int]) -> list[list[int]]: |
| row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 |
| rows: list[list[int]] = [] |
| for tid in token_ids: |
| row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width |
| row[0] = int(tid) |
| rows.append(row) |
| return rows |
|
|
| def build_audio_prefix_rows( |
| self, prompt_audio_codes: list[list[int]], slot_token_id: int | None = None |
| ) -> list[list[int]]: |
| row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 |
| slot = int( |
| self.manifest["tts_config"]["audio_user_slot_token_id"] |
| if slot_token_id is None else slot_token_id |
| ) |
| rows: list[list[int]] = [] |
| for code_row in prompt_audio_codes: |
| row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width |
| row[0] = slot |
| for i in range(min(len(code_row), int(self.manifest["tts_config"]["n_vq"]))): |
| row[i + 1] = int(code_row[i]) |
| rows.append(row) |
| return rows |
|
|
| def build_voice_clone_request_rows( |
| self, prompt_audio_codes: list[list[int]], text_token_ids: list[int] |
| ) -> dict[str, list[list[int]]]: |
| prefix_ids = [ |
| *self.manifest["prompt_templates"]["user_prompt_prefix_token_ids"], |
| int(self.manifest["tts_config"]["audio_start_token_id"]), |
| ] |
| suffix_ids = [ |
| int(self.manifest["tts_config"]["audio_end_token_id"]), |
| *self.manifest["prompt_templates"]["user_prompt_after_reference_token_ids"], |
| *text_token_ids, |
| *self.manifest["prompt_templates"]["assistant_prompt_prefix_token_ids"], |
| int(self.manifest["tts_config"]["audio_start_token_id"]), |
| ] |
| rows = [ |
| *self.build_text_rows(prefix_ids), |
| *self.build_audio_prefix_rows(prompt_audio_codes), |
| *self.build_text_rows(suffix_ids), |
| ] |
| return {"inputIds": rows, "attentionMask": [[1 for _ in rows]]} |
|
|
|
|
| def _static_prefill_length(self) -> int | None: |
| v = self.tts_static_shapes.get("prefill_seq") |
| return None if v is None else int(v) |
|
|
| def _static_decode_code_length(self) -> int | None: |
| v = self.codec_static_shapes.get("decode_code_length") |
| return None if v is None else int(v) |
|
|
| def _static_encode_waveform_length(self) -> int | None: |
| v = self.codec_static_shapes.get("waveform_length") |
| return None if v is None else int(v) |
|
|
| def _pad_request_rows( |
| self, request_rows: dict[str, list[list[int]]] |
| ) -> tuple[dict[str, list[list[int]]], int]: |
| static_len = self._static_prefill_length() |
| actual_len = len(request_rows["inputIds"]) |
| if static_len is None: |
| return request_rows, actual_len |
| if actual_len > static_len: |
| raise ValueError( |
| f"static prefill_seq={static_len} < request length={actual_len}. " |
| "Reduce prompt/text length or re-export with larger --sample-seq-len." |
| ) |
| if actual_len == static_len: |
| return request_rows, actual_len |
| row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 |
| pad_row = [int(self.manifest["tts_config"]["audio_pad_token_id"])] * row_width |
| pad_row[0] = int(self.manifest["tts_config"]["pad_token_id"]) |
| padded_rows = [list(r) for r in request_rows["inputIds"]] |
| padded_rows.extend([list(pad_row) for _ in range(static_len - actual_len)]) |
| padded_mask = [1] * actual_len + [0] * (static_len - actual_len) |
| return {"inputIds": padded_rows, "attentionMask": [padded_mask]}, actual_len |
|
|
|
|
| def _pad_kv( |
| self, kv_dict: dict[str, np.ndarray], target_len: int |
| ) -> dict[str, np.ndarray]: |
| result: dict[str, np.ndarray] = {} |
| for key, val in kv_dict.items(): |
| cur = val.shape[1] |
| if cur > target_len: |
| padded = val[:, :target_len, :, :].astype(np.float32) |
| elif cur < target_len: |
| pad = np.zeros((1, target_len - cur, *val.shape[2:]), dtype=np.float32) |
| padded = np.concatenate([val.astype(np.float32), pad], axis=1) |
| else: |
| padded = val.astype(np.float32) |
| result[key] = padded |
| return result |
|
|
|
|
| def run_local_fixed_sampled_frame( |
| self, |
| global_hidden: np.ndarray, |
| *, |
| previous_token_sets_by_channel: list[set[int]], |
| ) -> tuple[bool, list[int]]: |
| n_vq = int(self.manifest["tts_config"]["n_vq"]) |
| codebook_size = int(self.tts_meta["model_config"]["audio_codebook_sizes"][0]) |
| rep_mask = np.zeros((1, n_vq, codebook_size), dtype=np.int32) |
| for ci, token_set in enumerate(previous_token_sets_by_channel): |
| for tid in token_set: |
| if 0 <= tid < codebook_size: |
| rep_mask[0, ci, tid] = 1 |
| sess = self.sessions["local_fixed_sampled_frame"] |
| session_input_names = {inp.name for inp in sess.get_inputs()} |
| input_feed: dict = { |
| "global_hidden": global_hidden.astype(np.float32, copy=False), |
| "repetition_seen_mask": rep_mask, |
| } |
| asst_u: float | None = None |
| if "assistant_random_u" in session_input_names: |
| asst_u = float(self.rng.random()) |
| input_feed["assistant_random_u"] = np.array([asst_u], dtype=np.float32) |
| if "audio_random_u" in session_input_names: |
| input_feed["audio_random_u"] = np.asarray( |
| [[float(self.rng.random()) for _ in range(n_vq)]], |
| dtype=np.float32, |
| ) |
| outputs = self._run_session("local_fixed_sampled_frame", input_feed) |
| out_names = [o.name for o in sess.get_outputs()] |
| named = dict(zip(out_names, outputs, strict=True)) |
| frame_ids = np.asarray(named["frame_token_ids"]).reshape(-1).astype(np.int32).tolist() |
| should_continue = bool(int(np.asarray(named["should_continue"]).reshape(-1)[0])) |
| logging.debug( |
| "[fixed_frame] asst_u=%.4f should_continue=%s tokens=%s", |
| asst_u if asst_u is not None else -1.0, |
| should_continue, |
| frame_ids[:4], |
| ) |
| return should_continue, [int(x) for x in frame_ids] |
|
|
| def _can_use_local_fixed_sampled_frame(self) -> bool: |
| sess = self.sessions.get("local_fixed_sampled_frame") |
| if sess is None: |
| return False |
| session_input_names = {inp.name for inp in sess.get_inputs()} |
| return "audio_random_u" in session_input_names |
|
|
| def run_local_decoder( |
| self, |
| global_hidden: np.ndarray, |
| text_token_id: int, |
| frame_prefix: list[int], |
| ) -> tuple[np.ndarray, np.ndarray]: |
| n_vq = int(self.manifest["tts_config"]["n_vq"]) |
| audio_pad = int(self.manifest["tts_config"]["audio_pad_token_id"]) |
| padded_prefix = np.full((1, n_vq - 1), audio_pad, dtype=np.int32) |
| for i in range(min(len(frame_prefix), n_vq - 1)): |
| padded_prefix[0, i] = int(frame_prefix[i]) |
| outputs = self._run_session( |
| "local_decoder", |
| { |
| "global_hidden": global_hidden.astype(np.float32, copy=False), |
| "text_token_id": np.asarray([int(text_token_id)], dtype=np.int32), |
| "audio_prefix_token_ids": padded_prefix, |
| }, |
| ) |
| out_names = [o.name for o in self.sessions["local_decoder"].get_outputs()] |
| named = dict(zip(out_names, outputs, strict=True)) |
| return named["text_logits"].reshape(-1), named["audio_logits"] |
|
|
| def slice_audio_channel_logits(self, audio_logits: np.ndarray, channel_index: int) -> np.ndarray: |
| per_ch = int(audio_logits.shape[-1]) |
| flat = audio_logits.reshape(-1) |
| return flat[channel_index * per_ch : (channel_index + 1) * per_ch] |
|
|
|
|
| def _infer_global_kv_len(self) -> int | None: |
| if "decode" in self.sessions: |
| dec_sess = self.sessions["decode"] |
| if hasattr(dec_sess, "get_inputs"): |
| for inp in dec_sess.get_inputs(): |
| if inp.name == "past_key_0": |
| dims = getattr(inp, "shape", []) or [] |
| if len(dims) >= 2 and isinstance(dims[1], str): |
| return None |
| break |
| v = self.tts_static_shapes.get("kv_seq") or self.tts_static_shapes.get("global_kv_len") |
| if v is not None: |
| return int(v) |
| return 512 |
|
|
| def generate_audio_frames( |
| self, |
| request_rows: dict[str, list[list[int]]], |
| on_frame: Callable | None = None, |
| ) -> list[list[int]]: |
| padded_rows, actual_prefill_len = self._pad_request_rows(request_rows) |
| return self._generate_frames_with_prefill( |
| padded_rows, |
| actual_prefill_length=actual_prefill_len, |
| on_frame=on_frame, |
| ) |
|
|
| def _generate_frames_with_prefill( |
| self, |
| request_rows: dict[str, list[list[int]]], |
| *, |
| actual_prefill_length: int, |
| on_frame: Callable | None = None, |
| ) -> list[list[int]]: |
| global_kv_len = self._infer_global_kv_len() |
| generation_defaults = self.manifest["generation_defaults"] |
| row_width = int(self.manifest["tts_config"]["n_vq"]) + 1 |
|
|
| prefill_ids, prefill_dims = _flatten3d_int32([request_rows["inputIds"]]) |
| prefill_mask, prefill_mask_dims = _flatten2d_int32(request_rows["attentionMask"]) |
| prefill_outputs = self._run_session( |
| "prefill", |
| { |
| "input_ids": prefill_ids.reshape(prefill_dims), |
| "attention_mask": prefill_mask.reshape(prefill_mask_dims), |
| }, |
| ) |
| prefill_out_names = [o.name for o in self.sessions["prefill"].get_outputs()] |
| named_prefill = dict(zip(prefill_out_names, prefill_outputs, strict=True)) |
| global_hidden = named_prefill["global_hidden"][:, actual_prefill_length - 1, :].astype(np.float32) |
| past_valid_length = int(actual_prefill_length) |
|
|
| raw_past: dict[str, np.ndarray] = { |
| out.replace("present_", "past_"): named_prefill[out][:, :actual_prefill_length, :, :] |
| for out in self.tts_meta["onnx"]["prefill_output_names"][1:] |
| } |
|
|
| use_dynamic_kv = global_kv_len is None |
| if use_dynamic_kv: |
| past_by_name = raw_past |
| else: |
| past_by_name = self._pad_kv(raw_past, global_kv_len) |
|
|
| generated_frames: list[list[int]] = [] |
| prev_tokens = [[] for _ in range(int(self.manifest["tts_config"]["n_vq"]))] |
| prev_tok_sets = [set() for _ in range(int(self.manifest["tts_config"]["n_vq"]))] |
| audio_pad_id = int(self.manifest["tts_config"]["audio_pad_token_id"]) |
| consecutive_pad_frames = 0 |
| _MAX_CONSECUTIVE_PAD_FRAMES = 10 |
|
|
| for step in range(int(generation_defaults["max_new_frames"])): |
| frame: list[int] = [] |
|
|
| if ( |
| generation_defaults["sample_mode"] == SAMPLE_MODE_FIXED |
| and self._can_use_local_fixed_sampled_frame() |
| ): |
| should_continue, frame = self.run_local_fixed_sampled_frame( |
| global_hidden, previous_token_sets_by_channel=prev_tok_sets |
| ) |
| if not should_continue: |
| logging.info("[gen] frame %d: should_continue=False → 停止生成", step) |
| break |
| for ci, tok in enumerate(frame): |
| prev_tokens[ci].append(tok) |
| prev_tok_sets[ci].add(tok) |
| elif "local_decoder" in self.sessions: |
| if ( |
| generation_defaults["sample_mode"] == SAMPLE_MODE_FIXED |
| and "local_fixed_sampled_frame" in self.sessions |
| and step == 0 |
| ): |
| logging.warning( |
| "local_fixed_sampled_frame 缺少 audio_random_u 输入,回退到 local_decoder 随机采样路径" |
| ) |
| local_text_logits, _ = self.run_local_decoder(global_hidden, 0, []) |
| next_text = _sample_assistant_text_token( |
| local_text_logits, self.manifest, generation_defaults, self.rng |
| ) |
| if next_text != int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]): |
| break |
| for ci in range(int(self.manifest["tts_config"]["n_vq"])): |
| _, audio_logits = self.run_local_decoder(global_hidden, next_text, frame) |
| ch_logits = self.slice_audio_channel_logits(audio_logits, ci).astype(np.float32) |
| tok = _sample_audio_token( |
| ch_logits, prev_tokens[ci], prev_tok_sets[ci], generation_defaults, self.rng |
| ) |
| frame.append(tok) |
| prev_tokens[ci].append(tok) |
| prev_tok_sets[ci].add(tok) |
| else: |
| raise RuntimeError("Neither local_fixed_sampled_frame nor local_decoder session found") |
|
|
| if not frame: |
| break |
|
|
| if all(tok == audio_pad_id for tok in frame): |
| consecutive_pad_frames += 1 |
| if consecutive_pad_frames >= _MAX_CONSECUTIVE_PAD_FRAMES: |
| logging.info( |
| "[gen] %d consecutive all-pad frames → force stop (trim %d silent frames)", |
| consecutive_pad_frames, consecutive_pad_frames, |
| ) |
| generated_frames = generated_frames[:-(consecutive_pad_frames - 1)] if len(generated_frames) >= consecutive_pad_frames else [] |
| break |
| else: |
| consecutive_pad_frames = 0 |
|
|
| generated_frames.append(frame) |
|
|
| next_row = np.full((1, 1, row_width), int(self.manifest["tts_config"]["audio_pad_token_id"]), dtype=np.int32) |
| next_row[0, 0, 0] = int(self.manifest["tts_config"]["audio_assistant_slot_token_id"]) |
| for idx, tok in enumerate(frame): |
| next_row[0, 0, idx + 1] = int(tok) |
|
|
| decode_feeds: dict[str, np.ndarray] = { |
| "input_ids": next_row, |
| "past_valid_lengths": np.asarray([past_valid_length], dtype=np.int32), |
| } |
| for inp_name in self.tts_meta["onnx"]["decode_input_names"][2:]: |
| decode_feeds[inp_name] = past_by_name[inp_name] |
|
|
| dec_outputs = self._run_session("decode", decode_feeds) |
| dec_out_names = [o.name for o in self.sessions["decode"].get_outputs()] |
| named_dec = dict(zip(dec_out_names, dec_outputs, strict=True)) |
| global_hidden = _extract_last_hidden(named_dec["global_hidden"]) |
| past_valid_length += 1 |
|
|
| if use_dynamic_kv: |
| past_by_name = { |
| out_name.replace("present_", "past_"): named_dec[out_name] |
| for out_name in self.tts_meta["onnx"]["decode_output_names"][1:] |
| } |
| else: |
| new_pos = past_valid_length - 1 |
| new_padded: dict[str, np.ndarray] = {} |
| for out_name in self.tts_meta["onnx"]["decode_output_names"][1:]: |
| past_name = out_name.replace("present_", "past_") |
| v = named_dec[out_name] |
| buf = np.zeros((1, global_kv_len, 12, 64), dtype=np.float32) |
| if new_pos < global_kv_len: |
| buf[:, :new_pos, :, :] = v[:, :new_pos, :, :] |
| buf[:, new_pos, :, :] = v[:, global_kv_len, :, :] |
| else: |
| buf[:, :global_kv_len - 1, :, :] = v[:, 1:global_kv_len, :, :] |
| buf[:, global_kv_len - 1, :, :] = v[:, global_kv_len, :, :] |
| new_padded[past_name] = buf |
| past_by_name = new_padded |
| if past_valid_length > global_kv_len: |
| past_valid_length = global_kv_len |
|
|
| if on_frame is not None: |
| on_frame(generated_frames, step, frame) |
|
|
| logging.info("[gen] 生成完毕: 共 %d 帧 (max_new_frames=%s)", len(generated_frames), generation_defaults["max_new_frames"]) |
| return generated_frames |
|
|
|
|
| def decode_full_audio(self, generated_frames: list[list[int]]) -> tuple[list[np.ndarray], int]: |
| if not generated_frames: |
| return [], 0 |
| static_len = self._static_decode_code_length() |
| n_vq = int(self.manifest["tts_config"]["n_vq"]) |
| frame_count = len(generated_frames) |
|
|
| if static_len is None or frame_count <= static_len: |
| return self._decode_one_segment(generated_frames, frame_count, n_vq, static_len) |
|
|
| all_audio: list[np.ndarray] = [] |
| total_samples = 0 |
| for seg_start in range(0, frame_count, static_len): |
| seg_end = min(seg_start + static_len, frame_count) |
| seg_frames = generated_frames[seg_start:seg_end] |
| seg_count = seg_end - seg_start |
| channel_arrays, seg_samples = self._decode_one_segment(seg_frames, seg_count, n_vq, static_len) |
| all_audio.extend(channel_arrays) |
| total_samples += seg_samples |
| return all_audio, total_samples |
|
|
| def _decode_one_segment( |
| self, |
| frames: list[list[int]], |
| frame_count: int, |
| n_vq: int, |
| static_len: int | None, |
| ) -> tuple[list[np.ndarray], int]: |
| max_len = static_len if static_len is not None else frame_count |
| audio_codes = np.zeros((1, max_len, n_vq), dtype=np.int32) |
| for fi, frame in enumerate(frames): |
| for ci in range(n_vq): |
| audio_codes[0, fi, ci] = int(frame[ci] if ci < len(frame) else 0) |
| outputs = self._run_session( |
| "codec_decode", |
| { |
| "audio_codes": audio_codes, |
| "audio_code_lengths": np.asarray([frame_count], dtype=np.int32), |
| }, |
| ) |
| out_names = [o.name for o in self.sessions["codec_decode"].get_outputs()] |
| named = dict(zip(out_names, outputs, strict=True)) |
| audio_length = int(named["audio_lengths"].reshape(-1)[0]) |
| return _slice_channel_major_audio(named["audio"], 0, audio_length), audio_length |
|
|
|
|
| def encode_reference_audio(self, reference_audio_path: str | Path) -> list[list[int]]: |
| codec_cfg = self.codec_meta["codec_config"] |
| target_sr = int(codec_cfg["sample_rate"]) |
| target_ch = int(codec_cfg["channels"]) |
| waveform = _load_audio_numpy(reference_audio_path, target_sr, target_ch) |
| waveform_length = int(waveform.shape[-1]) |
|
|
| if bool(codec_cfg.get("pad_reference_audio_to_downsample_rate", False)): |
| ds_rate = int(codec_cfg["downsample_rate"]) |
| remainder = waveform_length % ds_rate |
| if remainder: |
| pad_len = ds_rate - remainder |
| waveform = np.concatenate( |
| [waveform, np.zeros((1, target_ch, pad_len), dtype=np.float32)], axis=2 |
| ) |
| waveform_length += pad_len |
|
|
| static_wl = self._static_encode_waveform_length() |
| num_quantizers = int(codec_cfg["num_quantizers"]) |
| prompt_audio_codes: list[list[int]] = [] |
| out_names = [o.name for o in self.sessions["codec_encode"].get_outputs()] |
|
|
| if static_wl and static_wl > 0: |
| for start in range(0, waveform_length, static_wl): |
| chunk = waveform[..., start : start + static_wl] |
| chunk_len = int(chunk.shape[-1]) |
| if chunk_len <= 0: |
| continue |
| if chunk_len < static_wl: |
| padded = np.zeros((1, target_ch, static_wl), dtype=np.float32) |
| padded[..., :chunk_len] = chunk |
| chunk = padded |
| outputs = self._run_session( |
| "codec_encode", |
| { |
| "waveform": chunk, |
| "input_lengths": np.asarray([chunk_len], dtype=np.int32), |
| }, |
| ) |
| named = dict(zip(out_names, outputs, strict=True)) |
| codes = np.asarray(named["audio_codes"], dtype=np.int32) |
| code_len = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0]) |
| for fi in range(code_len): |
| prompt_audio_codes.append([int(codes[0, fi, qi]) for qi in range(num_quantizers)]) |
| else: |
| outputs = self._run_session( |
| "codec_encode", |
| { |
| "waveform": waveform, |
| "input_lengths": np.asarray([waveform_length], dtype=np.int32), |
| }, |
| ) |
| named = dict(zip(out_names, outputs, strict=True)) |
| codes = np.asarray(named["audio_codes"], dtype=np.int32) |
| code_len = int(np.asarray(named["audio_code_lengths"]).reshape(-1)[0]) |
| for fi in range(code_len): |
| prompt_audio_codes.append([int(codes[0, fi, qi]) for qi in range(num_quantizers)]) |
|
|
| return prompt_audio_codes |
|
|
|
|
| def resolve_prompt_audio_codes( |
| self, *, voice: str | None, prompt_audio_path: str | Path | None |
| ) -> list[list[int]]: |
| if prompt_audio_path: |
| return self.encode_reference_audio(prompt_audio_path) |
| resolved_voice = str(voice or self.manifest["builtin_voices"][0]["voice"]) |
| for v in self.manifest["builtin_voices"]: |
| if v["voice"] == resolved_voice: |
| return list(v["prompt_audio_codes"]) |
| raise ValueError(f"Built-in voice not found: {resolved_voice}") |
|
|
| def list_builtin_voices(self) -> list[dict[str, Any]]: |
| return list(self.manifest["builtin_voices"]) |
|
|
|
|
| def synthesize( |
| self, |
| *, |
| text: str, |
| voice: str | None = None, |
| prompt_audio_path: str | Path | None = None, |
| output_audio_path: str | Path | None = None, |
| sample_mode: str | None = None, |
| do_sample: bool = True, |
| streaming: bool = True, |
| max_new_frames: int | None = None, |
| voice_clone_max_text_tokens: int = 75, |
| seed: int | None = None, |
| ) -> dict[str, Any]: |
| self._reset_timing_stats() |
| gen = self.manifest["generation_defaults"] |
| if max_new_frames is not None: |
| gen["max_new_frames"] = int(max_new_frames) |
|
|
| static_codec_len = self._static_decode_code_length() |
| if static_codec_len is not None and gen.get("max_new_frames", float("inf")) > static_codec_len: |
| gen["max_new_frames"] = static_codec_len |
| logging.info("capped max_new_frames=%d (static codec limit)", static_codec_len) |
|
|
| normalized_mode = self._normalize_sample_mode(sample_mode, do_sample) |
| gen["sample_mode"] = normalized_mode |
| gen["do_sample"] = normalized_mode != SAMPLE_MODE_GREEDY |
| if seed is not None: |
| self.rng = np.random.default_rng(int(seed)) |
|
|
| infer_start_time = time.perf_counter() |
|
|
| try: |
| text_chunks = self._split_text_by_sentence_punctuation(text) |
| prompt_audio_codes = self.resolve_prompt_audio_codes( |
| voice=voice, prompt_audio_path=prompt_audio_path |
| ) |
| sample_rate = int(self.codec_meta["codec_config"]["sample_rate"]) |
| channels = int(self.codec_meta["codec_config"]["channels"]) |
| static_decode_code_length = self._static_decode_code_length() |
| out_path = ( |
| Path(output_audio_path).expanduser().resolve() |
| if output_audio_path |
| else (Path("output.wav")).resolve() |
| ) |
|
|
| all_waveforms: list[np.ndarray] = [] |
| all_generated_frames: list[list[int]] = [] |
| final_text_chunks: list[str] = [] |
| pending_text_chunks = list(text_chunks) |
|
|
| wav_file = None |
| if streaming: |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| wav_file = wave.open(str(out_path), "wb") |
| wav_file.setnchannels(channels) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(sample_rate) |
|
|
| try: |
| while pending_text_chunks: |
| chunk_idx = len(final_text_chunks) + 1 |
| total_chunks = len(final_text_chunks) + len(pending_text_chunks) |
| chunk_text = pending_text_chunks.pop(0) |
| logging.info("chunk %d/%d: %r", chunk_idx, total_chunks, chunk_text[:40]) |
| text_token_ids = self.encode_text(chunk_text) |
|
|
| request_rows = self.build_voice_clone_request_rows(prompt_audio_codes, text_token_ids) |
| generated_frames = self.generate_audio_frames(request_rows) |
|
|
| channel_arrays, _ = self.decode_full_audio(generated_frames) |
| waveform = np.asarray(_merge_audio_channels(channel_arrays), dtype=np.float32) |
| all_waveforms.append(waveform) |
| all_generated_frames.extend(generated_frames) |
| final_text_chunks.append(chunk_text) |
|
|
| if wav_file is not None and waveform.size > 0: |
| clipped = np.clip(waveform, -1.0, 1.0) |
| pcm16 = np.round(clipped * 32767.0).astype(np.int16) |
| wav_file.writeframes(pcm16.tobytes()) |
| logging.info( |
| "stream emitted chunk %d: samples=%d duration=%.3fs", |
| len(final_text_chunks), |
| int(waveform.shape[0]), |
| float(waveform.shape[0]) / float(sample_rate), |
| ) |
|
|
| if pending_text_chunks: |
| pause_words = len([w for w in chunk_text.strip().split() if w]) |
| pause_sec = 0.40 if pause_words <= 4 else 0.24 |
| pause_samples = max(0, int(round(sample_rate * pause_sec))) |
| if pause_samples > 0: |
| pause_waveform = np.zeros((pause_samples, channels), dtype=np.float32) |
| all_waveforms.append(pause_waveform) |
| if wav_file is not None: |
| wav_file.writeframes(np.zeros((pause_samples, channels), dtype=np.int16).tobytes()) |
| finally: |
| if wav_file is not None: |
| wav_file.close() |
|
|
| final_waveform = _concat_waveforms(all_waveforms) |
| total_infer_time = time.perf_counter() - infer_start_time |
| audio_duration_sec = ( |
| float(final_waveform.shape[0]) / float(sample_rate) |
| if sample_rate > 0 and final_waveform.size > 0 |
| else 0.0 |
| ) |
| rtf = (total_infer_time / audio_duration_sec) if audio_duration_sec > 0 else float("inf") |
| model_time_total = sum(self._model_time_stats.values()) |
| audio_path = out_path if streaming else _write_waveform_to_wav(out_path, final_waveform, sample_rate) |
| logging.info("已保存 %s sample_rate=%s frames=%s", audio_path, sample_rate, len(all_generated_frames)) |
| logging.info( |
| "[timing] total_infer_time=%.3fs audio_duration=%.3fs rtf=%.4f model_time_total=%.3fs", |
| total_infer_time, |
| audio_duration_sec, |
| rtf, |
| model_time_total, |
| ) |
| for key in sorted(self._model_display_names): |
| model_name = self._model_display_names[key] |
| logging.info( |
| "[timing] session=%s model=%s calls=%d total=%.3fs", |
| key, |
| model_name, |
| self._model_call_stats.get(key, 0), |
| self._model_time_stats.get(key, 0.0), |
| ) |
| return { |
| "audio_path": str(audio_path), |
| "waveform": final_waveform, |
| "sample_rate": sample_rate, |
| "audio_token_ids": np.asarray(all_generated_frames, dtype=np.int32), |
| "text_chunks": final_text_chunks, |
| "sample_mode": normalized_mode, |
| "do_sample": normalized_mode != SAMPLE_MODE_GREEDY, |
| "streaming": bool(streaming), |
| "timing": { |
| "total_infer_time_sec": total_infer_time, |
| "audio_duration_sec": audio_duration_sec, |
| "rtf": rtf, |
| "model_time_total_sec": model_time_total, |
| "per_model_time_sec": dict(self._model_time_stats), |
| "per_model_calls": dict(self._model_call_stats), |
| "per_model_display_name": dict(self._model_display_names), |
| "used_model_keys": list(self._used_model_keys), |
| }, |
| } |
| finally: |
| self._log_used_model_summary() |
|
|