| |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Iterable, List |
|
|
| import numpy as np |
|
|
| from scripts.zipvoice_runtime import AxeSession |
|
|
|
|
| class Decoder4ZipVoiceBoardRuntime: |
| """Runs encoder_core_nolog.axmodel and fm_decoder_part0..part3.axmodel.""" |
|
|
| def __init__( |
| self, |
| config_dir: str | Path, |
| models_dir: str | Path, |
| max_feat_len: int = 1024, |
| max_tokens: int = 384, |
| num_step: int = 16, |
| t_shift: float = 0.5, |
| ) -> None: |
| self.config_dir = Path(config_dir) |
| self.models_dir = Path(models_dir) |
| self.max_feat_len = int(max_feat_len) |
| self.max_tokens = int(max_tokens) |
| self.num_step = int(num_step) |
| self.t_shift = float(t_shift) |
|
|
| self._load_config() |
| self._load_manifest() |
| self.sessions: Dict[str, AxeSession] = {} |
| self._load_models() |
| self._load_decoder_input_metadata() |
|
|
| def _load_config(self) -> None: |
| config_path = self.models_dir / "runtime_config.json" |
| if not config_path.exists(): |
| config_path = self.config_dir / "runtime_config.json" |
| config = json.loads(config_path.read_text()) if config_path.exists() else {} |
| self.feat_dim = int(config.get("feat_dim", 100)) |
| self.sampling_rate = int(config.get("sampling_rate", 24000)) |
| self.hop_length = int(config.get("hop_length", 256)) |
| self.model_type = str(config.get("model_type", "zipvoice_decoder4")) |
| logging.debug( |
| "Decoder4 runtime: max_tokens=%d, max_feat_len=%d, feat_dim=%d, num_step=%d", |
| self.max_tokens, |
| self.max_feat_len, |
| self.feat_dim, |
| self.num_step, |
| ) |
|
|
| def _load_manifest(self) -> None: |
| manifest_path = self.models_dir / "decoder4_split_manifest.json" |
| if not manifest_path.exists(): |
| manifest_path = self.config_dir / "decoder4_split_manifest.json" |
| if not manifest_path.exists(): |
| raise FileNotFoundError(f"decoder4_split_manifest.json not found: {manifest_path}") |
| self.manifest = json.loads(manifest_path.read_text()) |
| self.model_type = str(self.manifest.get("model_type", self.model_type)) |
| self.encoder_info = self.manifest["encoder"] |
| self.decoder_parts = self.manifest["decoder_parts"] |
|
|
| def _load_models(self) -> None: |
| model_infos = [self.encoder_info, *self.decoder_parts] |
| for info in model_infos: |
| name = info["name"] |
| path = self.models_dir / info["file"] |
| logging.debug("Loading %s from %s", name, path) |
| self.sessions[name] = AxeSession(path) |
| logging.debug("Loaded encoder + %d decoder4 parts", len(self.decoder_parts)) |
|
|
| def _load_decoder_input_metadata(self) -> None: |
| part0 = self.decoder_parts[0] |
| sess = self.sessions[part0["name"]] |
| input_names = sess.input_names |
| self.decoder_has_padding_mask = "padding_mask" in input_names |
| self.decoder_seq_len = self.max_feat_len |
| if "x" in input_names: |
| index = input_names.index("x") |
| input_info = sess._inputs[index] if index < len(sess._inputs) else None |
| shape = getattr(input_info, "shape", None) if input_info is not None else None |
| if shape is not None and len(shape) >= 2 and isinstance(shape[1], (int, np.integer)): |
| self.decoder_seq_len = int(shape[1]) |
|
|
| if self.decoder_seq_len != self.max_feat_len: |
| logging.debug( |
| "decoder x seq_len=%d differs from configured max_feat_len=%d; " |
| "using model seq_len for decoder feeds", |
| self.decoder_seq_len, |
| self.max_feat_len, |
| ) |
| logging.debug( |
| "Decoder4 model metadata: seq_len=%d, has_padding_mask=%s", |
| self.decoder_seq_len, |
| self.decoder_has_padding_mask, |
| ) |
|
|
| @staticmethod |
| def _coerce_input_dtype(value: np.ndarray, input_info: Any | None) -> np.ndarray: |
| if input_info is None: |
| return value |
|
|
| expected_dtype = getattr(input_info, "dtype", None) |
| if expected_dtype is None: |
| expected_dtype = getattr(input_info, "type", None) |
| if expected_dtype is None: |
| return value |
|
|
| dtype_text = str(expected_dtype).lower() |
| if "float32" in dtype_text: |
| return np.ascontiguousarray(value, dtype=np.float32) |
| if "int32" in dtype_text: |
| return np.ascontiguousarray(value, dtype=np.int32) |
| if "int64" in dtype_text: |
| return np.ascontiguousarray(value, dtype=np.int64) |
| if "uint8" in dtype_text: |
| return np.ascontiguousarray(value, dtype=np.uint8) |
| if "bool" in dtype_text: |
| return np.ascontiguousarray(value, dtype=np.bool_) |
| return np.ascontiguousarray(value) |
|
|
| def _run_model( |
| self, |
| name: str, |
| expected_inputs: Iterable[str], |
| expected_outputs: Iterable[str], |
| values: Dict[str, np.ndarray], |
| ) -> Dict[str, np.ndarray]: |
| sess = self.sessions[name] |
| expected_inputs = list(expected_inputs) |
| expected_outputs = list(expected_outputs) |
|
|
| feed: Dict[str, np.ndarray] = {} |
| for index, actual_name in enumerate(sess.input_names): |
| input_info = sess._inputs[index] if index < len(sess._inputs) else None |
| if actual_name in values: |
| feed[actual_name] = self._coerce_input_dtype(values[actual_name], input_info) |
| continue |
| if index < len(expected_inputs) and expected_inputs[index] in values: |
| feed[actual_name] = self._coerce_input_dtype( |
| values[expected_inputs[index]], input_info |
| ) |
| continue |
| expected = expected_inputs[index] if index < len(expected_inputs) else None |
| raise KeyError( |
| f"Missing input for {name}: actual={actual_name!r}, expected={expected!r}" |
| ) |
|
|
| raw_outputs = sess.run(feed) |
| mapped: Dict[str, np.ndarray] = {} |
| for index, expected_name in enumerate(expected_outputs): |
| if expected_name in raw_outputs: |
| mapped[expected_name] = raw_outputs[expected_name] |
| continue |
| if index < len(sess.output_names) and sess.output_names[index] in raw_outputs: |
| mapped[expected_name] = raw_outputs[sess.output_names[index]] |
| continue |
| raise KeyError(f"Missing output for {name}: {expected_name!r}") |
| return mapped |
|
|
| def run_encoder(self, cat_tokens: np.ndarray) -> np.ndarray: |
| |
| |
| cat_tokens = np.asarray(cat_tokens, dtype=np.int32) |
| outputs = self._run_model( |
| self.encoder_info["name"], |
| self.encoder_info["inputs"], |
| self.encoder_info["outputs"], |
| {"cat_tokens": cat_tokens}, |
| ) |
| return outputs[self.encoder_info["outputs"][0]].astype(np.float32) |
|
|
| def run_decoder( |
| self, |
| t: np.ndarray, |
| x: np.ndarray, |
| text_condition: np.ndarray, |
| speech_condition: np.ndarray, |
| guidance_scale: np.ndarray, |
| padding_mask: np.ndarray | None = None, |
| ) -> np.ndarray: |
| seq_len = x.shape[1] |
| values: Dict[str, np.ndarray] = { |
| "t": np.asarray(t, dtype=np.float32).reshape(1), |
| "x": x.astype(np.float32), |
| "text_condition": text_condition.astype(np.float32), |
| "speech_condition": speech_condition.astype(np.float32), |
| "guidance_scale": np.asarray(guidance_scale, dtype=np.float32).reshape(1), |
| "padding_mask": padding_mask.astype(np.bool_) |
| if padding_mask is not None |
| else np.zeros((1, seq_len), dtype=np.bool_), |
| } |
|
|
| for part in self.decoder_parts: |
| outputs = self._run_model( |
| part["name"], |
| part["inputs"], |
| part["outputs"], |
| values, |
| ) |
| values.update(outputs) |
|
|
| final_output = self.decoder_parts[-1]["outputs"][0] |
| return values[final_output].astype(np.float32) |
|
|
| def duration_expand( |
| self, |
| encoded: np.ndarray, |
| prompt_tokens_len: int, |
| text_tokens_len: int, |
| prompt_features_len: int, |
| speed: float, |
| ) -> tuple[np.ndarray, int]: |
| total_tokens_len = prompt_tokens_len + text_tokens_len |
| features_len = int( |
| np.ceil(prompt_features_len / prompt_tokens_len * total_tokens_len / speed) |
| ) |
| if features_len > self.max_feat_len: |
| logging.debug( |
| "features_len=%d > max_feat_len=%d, clamping", |
| features_len, |
| self.max_feat_len, |
| ) |
| features_len = self.max_feat_len |
|
|
| token_dur = features_len // total_tokens_len |
| embed_no_pad = encoded[0, :total_tokens_len, :] |
| text_condition = np.repeat(embed_no_pad, token_dur, axis=0) |
|
|
| residual = features_len - text_condition.shape[0] |
| if residual > 0: |
| last_embed = encoded[0, total_tokens_len : total_tokens_len + 1, :] |
| text_condition = np.concatenate( |
| [text_condition, np.repeat(last_embed, residual, axis=0)], |
| axis=0, |
| ) |
|
|
| text_condition = text_condition[:features_len, :] |
| return text_condition[np.newaxis, :, :].astype(np.float32), features_len |
|
|
| def _get_time_steps(self) -> np.ndarray: |
| t = np.linspace(0.0, 1.0, self.num_step + 1, dtype=np.float32) |
| ts = self.t_shift |
| return ts * t / (1.0 + (ts - 1.0) * t) |
|
|
| def sample( |
| self, |
| cat_tokens: np.ndarray, |
| prompt_tokens_len: int, |
| text_tokens_len: int, |
| prompt_features: np.ndarray, |
| prompt_features_len: int, |
| speed: float = 1.0, |
| guidance_scale: float = 1.0, |
| seed: int = 666, |
| ) -> tuple[np.ndarray, Dict[str, Any]]: |
| logging.debug( |
| "sample: prompt_tokens=%d, text_tokens=%d, prompt_frames=%d, " |
| "speed=%.2f, guidance_scale=%.2f, seed=%d", |
| prompt_tokens_len, |
| text_tokens_len, |
| prompt_features_len, |
| speed, |
| guidance_scale, |
| seed, |
| ) |
| t_total_start = time.perf_counter() |
|
|
| t_start = time.perf_counter() |
| encoded = self.run_encoder(cat_tokens) |
| t_enc = time.perf_counter() - t_start |
| logging.debug(" encoder: %.3f s (output shape=%s)", t_enc, encoded.shape) |
|
|
| t_start = time.perf_counter() |
| text_condition, features_len = self.duration_expand( |
| encoded, |
| prompt_tokens_len, |
| text_tokens_len, |
| prompt_features_len, |
| speed, |
| ) |
| t_dur = time.perf_counter() - t_start |
| logging.debug(" duration_expand: %.3f s (features_len=%d)", t_dur, features_len) |
|
|
| seq_len = self.decoder_seq_len or self.max_feat_len |
| if features_len > seq_len: |
| raise ValueError( |
| f"features_len={features_len} exceeds decoder sequence length {seq_len}" |
| ) |
| if ( |
| self.decoder_seq_len is not None |
| and not self.decoder_has_padding_mask |
| and features_len != seq_len |
| ): |
| raise ValueError( |
| "Fixed no-mask decoder requires exact feature length: " |
| f"features_len={features_len}, decoder_seq_len={seq_len}" |
| ) |
| if prompt_features.shape[1] > seq_len: |
| raise ValueError( |
| f"prompt feature length {prompt_features.shape[1]} exceeds " |
| f"decoder sequence length {seq_len}" |
| ) |
|
|
| text_cond_padded = np.zeros((1, seq_len, self.feat_dim), dtype=np.float32) |
| text_cond_padded[0, :features_len] = text_condition[0, :features_len] |
|
|
| speech_cond_padded = np.zeros((1, seq_len, self.feat_dim), dtype=np.float32) |
| prompt_actual_len = prompt_features.shape[1] |
| speech_cond_padded[0, :prompt_actual_len] = prompt_features[0].astype(np.float32) |
|
|
| padding_mask = np.zeros((1, seq_len), dtype=np.bool_) |
| padding_mask[:, features_len:] = True |
|
|
| rng = np.random.RandomState(seed) |
| x = rng.randn(1, seq_len, self.feat_dim).astype(np.float32) |
| x[:, features_len:, :] = 0.0 |
|
|
| timesteps = self._get_time_steps() |
| gs = np.array([guidance_scale], dtype=np.float32) |
|
|
| t_dec_total = 0.0 |
| for step in range(self.num_step): |
| t_val = np.array([float(timesteps[step])], dtype=np.float32) |
| t_start = time.perf_counter() |
| v = self.run_decoder( |
| t_val, |
| x, |
| text_cond_padded, |
| speech_cond_padded, |
| gs, |
| padding_mask, |
| ) |
| t_dec_total += time.perf_counter() - t_start |
| dt = float(timesteps[step + 1] - timesteps[step]) |
| x = (x + v * dt).astype(np.float32) |
| x[:, features_len:, :] = 0.0 |
|
|
| logging.debug( |
| " %s (NPU x%d): %.3f s total (avg %.3f ms/step)", |
| getattr(self, "decoder_label", "decoder4"), |
| self.num_step, |
| t_dec_total, |
| t_dec_total / self.num_step * 1000, |
| ) |
|
|
| generated_frames = features_len - prompt_features_len |
| if generated_frames <= 0: |
| generated_frames = features_len |
| pred_features = x[0, :features_len, :] |
| else: |
| pred_features = x[0, prompt_features_len:features_len, :] |
|
|
| t_total = time.perf_counter() - t_total_start |
| timing = { |
| "encoder_time_sec": round(t_enc, 3), |
| "duration_expand_time_sec": round(t_dur, 3), |
| "decoder_time_sec": round(t_dec_total, 3), |
| "total_time_sec": round(t_total, 3), |
| "generated_frames": int(generated_frames), |
| "features_len": int(features_len), |
| } |
| logging.debug(" total: %.3f s", t_total) |
| return pred_features[np.newaxis, :, :].astype(np.float32), timing |
|
|