#!/usr/bin/env python3 from __future__ import annotations import json import logging import time from pathlib import Path from typing import Any, Dict, Iterable, List import numpy as np from scripts.zipvoice_runtime import AxeSession class Decoder4ZipVoiceBoardRuntime: """Runs encoder_core_nolog.axmodel and fm_decoder_part0..part3.axmodel.""" def __init__( self, config_dir: str | Path, models_dir: str | Path, max_feat_len: int = 1024, max_tokens: int = 384, num_step: int = 16, t_shift: float = 0.5, ) -> None: self.config_dir = Path(config_dir) self.models_dir = Path(models_dir) self.max_feat_len = int(max_feat_len) self.max_tokens = int(max_tokens) self.num_step = int(num_step) self.t_shift = float(t_shift) self._load_config() self._load_manifest() self.sessions: Dict[str, AxeSession] = {} self._load_models() self._load_decoder_input_metadata() def _load_config(self) -> None: config_path = self.models_dir / "runtime_config.json" if not config_path.exists(): config_path = self.config_dir / "runtime_config.json" config = json.loads(config_path.read_text()) if config_path.exists() else {} self.feat_dim = int(config.get("feat_dim", 100)) self.sampling_rate = int(config.get("sampling_rate", 24000)) self.hop_length = int(config.get("hop_length", 256)) self.model_type = str(config.get("model_type", "zipvoice_decoder4")) logging.debug( "Decoder4 runtime: max_tokens=%d, max_feat_len=%d, feat_dim=%d, num_step=%d", self.max_tokens, self.max_feat_len, self.feat_dim, self.num_step, ) def _load_manifest(self) -> None: manifest_path = self.models_dir / "decoder4_split_manifest.json" if not manifest_path.exists(): manifest_path = self.config_dir / "decoder4_split_manifest.json" if not manifest_path.exists(): raise FileNotFoundError(f"decoder4_split_manifest.json not found: {manifest_path}") self.manifest = json.loads(manifest_path.read_text()) self.model_type = str(self.manifest.get("model_type", self.model_type)) self.encoder_info = self.manifest["encoder"] self.decoder_parts = self.manifest["decoder_parts"] def _load_models(self) -> None: model_infos = [self.encoder_info, *self.decoder_parts] for info in model_infos: name = info["name"] path = self.models_dir / info["file"] logging.debug("Loading %s from %s", name, path) self.sessions[name] = AxeSession(path) logging.debug("Loaded encoder + %d decoder4 parts", len(self.decoder_parts)) def _load_decoder_input_metadata(self) -> None: part0 = self.decoder_parts[0] sess = self.sessions[part0["name"]] input_names = sess.input_names self.decoder_has_padding_mask = "padding_mask" in input_names self.decoder_seq_len = self.max_feat_len if "x" in input_names: index = input_names.index("x") input_info = sess._inputs[index] if index < len(sess._inputs) else None shape = getattr(input_info, "shape", None) if input_info is not None else None if shape is not None and len(shape) >= 2 and isinstance(shape[1], (int, np.integer)): self.decoder_seq_len = int(shape[1]) if self.decoder_seq_len != self.max_feat_len: logging.debug( "decoder x seq_len=%d differs from configured max_feat_len=%d; " "using model seq_len for decoder feeds", self.decoder_seq_len, self.max_feat_len, ) logging.debug( "Decoder4 model metadata: seq_len=%d, has_padding_mask=%s", self.decoder_seq_len, self.decoder_has_padding_mask, ) @staticmethod def _coerce_input_dtype(value: np.ndarray, input_info: Any | None) -> np.ndarray: if input_info is None: return value expected_dtype = getattr(input_info, "dtype", None) if expected_dtype is None: expected_dtype = getattr(input_info, "type", None) if expected_dtype is None: return value dtype_text = str(expected_dtype).lower() if "float32" in dtype_text: return np.ascontiguousarray(value, dtype=np.float32) if "int32" in dtype_text: return np.ascontiguousarray(value, dtype=np.int32) if "int64" in dtype_text: return np.ascontiguousarray(value, dtype=np.int64) if "uint8" in dtype_text: return np.ascontiguousarray(value, dtype=np.uint8) if "bool" in dtype_text: return np.ascontiguousarray(value, dtype=np.bool_) return np.ascontiguousarray(value) def _run_model( self, name: str, expected_inputs: Iterable[str], expected_outputs: Iterable[str], values: Dict[str, np.ndarray], ) -> Dict[str, np.ndarray]: sess = self.sessions[name] expected_inputs = list(expected_inputs) expected_outputs = list(expected_outputs) feed: Dict[str, np.ndarray] = {} for index, actual_name in enumerate(sess.input_names): input_info = sess._inputs[index] if index < len(sess._inputs) else None if actual_name in values: feed[actual_name] = self._coerce_input_dtype(values[actual_name], input_info) continue if index < len(expected_inputs) and expected_inputs[index] in values: feed[actual_name] = self._coerce_input_dtype( values[expected_inputs[index]], input_info ) continue expected = expected_inputs[index] if index < len(expected_inputs) else None raise KeyError( f"Missing input for {name}: actual={actual_name!r}, expected={expected!r}" ) raw_outputs = sess.run(feed) mapped: Dict[str, np.ndarray] = {} for index, expected_name in enumerate(expected_outputs): if expected_name in raw_outputs: mapped[expected_name] = raw_outputs[expected_name] continue if index < len(sess.output_names) and sess.output_names[index] in raw_outputs: mapped[expected_name] = raw_outputs[sess.output_names[index]] continue raise KeyError(f"Missing output for {name}: {expected_name!r}") return mapped def run_encoder(self, cat_tokens: np.ndarray) -> np.ndarray: # Pulsar2/AXEngine exposes the quantized encoder token input as int32 # even though the reference ONNX path uses int64 token IDs. cat_tokens = np.asarray(cat_tokens, dtype=np.int32) outputs = self._run_model( self.encoder_info["name"], self.encoder_info["inputs"], self.encoder_info["outputs"], {"cat_tokens": cat_tokens}, ) return outputs[self.encoder_info["outputs"][0]].astype(np.float32) def run_decoder( self, t: np.ndarray, x: np.ndarray, text_condition: np.ndarray, speech_condition: np.ndarray, guidance_scale: np.ndarray, padding_mask: np.ndarray | None = None, ) -> np.ndarray: seq_len = x.shape[1] values: Dict[str, np.ndarray] = { "t": np.asarray(t, dtype=np.float32).reshape(1), "x": x.astype(np.float32), "text_condition": text_condition.astype(np.float32), "speech_condition": speech_condition.astype(np.float32), "guidance_scale": np.asarray(guidance_scale, dtype=np.float32).reshape(1), "padding_mask": padding_mask.astype(np.bool_) if padding_mask is not None else np.zeros((1, seq_len), dtype=np.bool_), } for part in self.decoder_parts: outputs = self._run_model( part["name"], part["inputs"], part["outputs"], values, ) values.update(outputs) final_output = self.decoder_parts[-1]["outputs"][0] return values[final_output].astype(np.float32) def duration_expand( self, encoded: np.ndarray, prompt_tokens_len: int, text_tokens_len: int, prompt_features_len: int, speed: float, ) -> tuple[np.ndarray, int]: total_tokens_len = prompt_tokens_len + text_tokens_len features_len = int( np.ceil(prompt_features_len / prompt_tokens_len * total_tokens_len / speed) ) if features_len > self.max_feat_len: logging.debug( "features_len=%d > max_feat_len=%d, clamping", features_len, self.max_feat_len, ) features_len = self.max_feat_len token_dur = features_len // total_tokens_len embed_no_pad = encoded[0, :total_tokens_len, :] text_condition = np.repeat(embed_no_pad, token_dur, axis=0) residual = features_len - text_condition.shape[0] if residual > 0: last_embed = encoded[0, total_tokens_len : total_tokens_len + 1, :] text_condition = np.concatenate( [text_condition, np.repeat(last_embed, residual, axis=0)], axis=0, ) text_condition = text_condition[:features_len, :] return text_condition[np.newaxis, :, :].astype(np.float32), features_len def _get_time_steps(self) -> np.ndarray: t = np.linspace(0.0, 1.0, self.num_step + 1, dtype=np.float32) ts = self.t_shift return ts * t / (1.0 + (ts - 1.0) * t) def sample( self, cat_tokens: np.ndarray, prompt_tokens_len: int, text_tokens_len: int, prompt_features: np.ndarray, prompt_features_len: int, speed: float = 1.0, guidance_scale: float = 1.0, seed: int = 666, ) -> tuple[np.ndarray, Dict[str, Any]]: logging.debug( "sample: prompt_tokens=%d, text_tokens=%d, prompt_frames=%d, " "speed=%.2f, guidance_scale=%.2f, seed=%d", prompt_tokens_len, text_tokens_len, prompt_features_len, speed, guidance_scale, seed, ) t_total_start = time.perf_counter() t_start = time.perf_counter() encoded = self.run_encoder(cat_tokens) t_enc = time.perf_counter() - t_start logging.debug(" encoder: %.3f s (output shape=%s)", t_enc, encoded.shape) t_start = time.perf_counter() text_condition, features_len = self.duration_expand( encoded, prompt_tokens_len, text_tokens_len, prompt_features_len, speed, ) t_dur = time.perf_counter() - t_start logging.debug(" duration_expand: %.3f s (features_len=%d)", t_dur, features_len) seq_len = self.decoder_seq_len or self.max_feat_len if features_len > seq_len: raise ValueError( f"features_len={features_len} exceeds decoder sequence length {seq_len}" ) if ( self.decoder_seq_len is not None and not self.decoder_has_padding_mask and features_len != seq_len ): raise ValueError( "Fixed no-mask decoder requires exact feature length: " f"features_len={features_len}, decoder_seq_len={seq_len}" ) if prompt_features.shape[1] > seq_len: raise ValueError( f"prompt feature length {prompt_features.shape[1]} exceeds " f"decoder sequence length {seq_len}" ) text_cond_padded = np.zeros((1, seq_len, self.feat_dim), dtype=np.float32) text_cond_padded[0, :features_len] = text_condition[0, :features_len] speech_cond_padded = np.zeros((1, seq_len, self.feat_dim), dtype=np.float32) prompt_actual_len = prompt_features.shape[1] speech_cond_padded[0, :prompt_actual_len] = prompt_features[0].astype(np.float32) padding_mask = np.zeros((1, seq_len), dtype=np.bool_) padding_mask[:, features_len:] = True rng = np.random.RandomState(seed) x = rng.randn(1, seq_len, self.feat_dim).astype(np.float32) x[:, features_len:, :] = 0.0 timesteps = self._get_time_steps() gs = np.array([guidance_scale], dtype=np.float32) t_dec_total = 0.0 for step in range(self.num_step): t_val = np.array([float(timesteps[step])], dtype=np.float32) t_start = time.perf_counter() v = self.run_decoder( t_val, x, text_cond_padded, speech_cond_padded, gs, padding_mask, ) t_dec_total += time.perf_counter() - t_start dt = float(timesteps[step + 1] - timesteps[step]) x = (x + v * dt).astype(np.float32) x[:, features_len:, :] = 0.0 logging.debug( " %s (NPU x%d): %.3f s total (avg %.3f ms/step)", getattr(self, "decoder_label", "decoder4"), self.num_step, t_dec_total, t_dec_total / self.num_step * 1000, ) generated_frames = features_len - prompt_features_len if generated_frames <= 0: generated_frames = features_len pred_features = x[0, :features_len, :] else: pred_features = x[0, prompt_features_len:features_len, :] t_total = time.perf_counter() - t_total_start timing = { "encoder_time_sec": round(t_enc, 3), "duration_expand_time_sec": round(t_dur, 3), "decoder_time_sec": round(t_dec_total, 3), "total_time_sec": round(t_total, 3), "generated_frames": int(generated_frames), "features_len": int(features_len), } logging.debug(" total: %.3f s", t_total) return pred_features[np.newaxis, :, :].astype(np.float32), timing