ZipVoice.AXERA / scripts /zipvoice_decoder4_runtime.py
HY-2012's picture
Add ZipVoice_distill model
a17e6b8 verified
Raw
History Blame Contribute Delete
14.4 kB
#!/usr/bin/env python3
from __future__ import annotations
import json
import logging
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List
import numpy as np
from scripts.zipvoice_runtime import AxeSession
class Decoder4ZipVoiceBoardRuntime:
"""Runs encoder_core_nolog.axmodel and fm_decoder_part0..part3.axmodel."""
def __init__(
self,
config_dir: str | Path,
models_dir: str | Path,
max_feat_len: int = 1024,
max_tokens: int = 384,
num_step: int = 16,
t_shift: float = 0.5,
) -> None:
self.config_dir = Path(config_dir)
self.models_dir = Path(models_dir)
self.max_feat_len = int(max_feat_len)
self.max_tokens = int(max_tokens)
self.num_step = int(num_step)
self.t_shift = float(t_shift)
self._load_config()
self._load_manifest()
self.sessions: Dict[str, AxeSession] = {}
self._load_models()
self._load_decoder_input_metadata()
def _load_config(self) -> None:
config_path = self.models_dir / "runtime_config.json"
if not config_path.exists():
config_path = self.config_dir / "runtime_config.json"
config = json.loads(config_path.read_text()) if config_path.exists() else {}
self.feat_dim = int(config.get("feat_dim", 100))
self.sampling_rate = int(config.get("sampling_rate", 24000))
self.hop_length = int(config.get("hop_length", 256))
self.model_type = str(config.get("model_type", "zipvoice_decoder4"))
logging.debug(
"Decoder4 runtime: max_tokens=%d, max_feat_len=%d, feat_dim=%d, num_step=%d",
self.max_tokens,
self.max_feat_len,
self.feat_dim,
self.num_step,
)
def _load_manifest(self) -> None:
manifest_path = self.models_dir / "decoder4_split_manifest.json"
if not manifest_path.exists():
manifest_path = self.config_dir / "decoder4_split_manifest.json"
if not manifest_path.exists():
raise FileNotFoundError(f"decoder4_split_manifest.json not found: {manifest_path}")
self.manifest = json.loads(manifest_path.read_text())
self.model_type = str(self.manifest.get("model_type", self.model_type))
self.encoder_info = self.manifest["encoder"]
self.decoder_parts = self.manifest["decoder_parts"]
def _load_models(self) -> None:
model_infos = [self.encoder_info, *self.decoder_parts]
for info in model_infos:
name = info["name"]
path = self.models_dir / info["file"]
logging.debug("Loading %s from %s", name, path)
self.sessions[name] = AxeSession(path)
logging.debug("Loaded encoder + %d decoder4 parts", len(self.decoder_parts))
def _load_decoder_input_metadata(self) -> None:
part0 = self.decoder_parts[0]
sess = self.sessions[part0["name"]]
input_names = sess.input_names
self.decoder_has_padding_mask = "padding_mask" in input_names
self.decoder_seq_len = self.max_feat_len
if "x" in input_names:
index = input_names.index("x")
input_info = sess._inputs[index] if index < len(sess._inputs) else None
shape = getattr(input_info, "shape", None) if input_info is not None else None
if shape is not None and len(shape) >= 2 and isinstance(shape[1], (int, np.integer)):
self.decoder_seq_len = int(shape[1])
if self.decoder_seq_len != self.max_feat_len:
logging.debug(
"decoder x seq_len=%d differs from configured max_feat_len=%d; "
"using model seq_len for decoder feeds",
self.decoder_seq_len,
self.max_feat_len,
)
logging.debug(
"Decoder4 model metadata: seq_len=%d, has_padding_mask=%s",
self.decoder_seq_len,
self.decoder_has_padding_mask,
)
@staticmethod
def _coerce_input_dtype(value: np.ndarray, input_info: Any | None) -> np.ndarray:
if input_info is None:
return value
expected_dtype = getattr(input_info, "dtype", None)
if expected_dtype is None:
expected_dtype = getattr(input_info, "type", None)
if expected_dtype is None:
return value
dtype_text = str(expected_dtype).lower()
if "float32" in dtype_text:
return np.ascontiguousarray(value, dtype=np.float32)
if "int32" in dtype_text:
return np.ascontiguousarray(value, dtype=np.int32)
if "int64" in dtype_text:
return np.ascontiguousarray(value, dtype=np.int64)
if "uint8" in dtype_text:
return np.ascontiguousarray(value, dtype=np.uint8)
if "bool" in dtype_text:
return np.ascontiguousarray(value, dtype=np.bool_)
return np.ascontiguousarray(value)
def _run_model(
self,
name: str,
expected_inputs: Iterable[str],
expected_outputs: Iterable[str],
values: Dict[str, np.ndarray],
) -> Dict[str, np.ndarray]:
sess = self.sessions[name]
expected_inputs = list(expected_inputs)
expected_outputs = list(expected_outputs)
feed: Dict[str, np.ndarray] = {}
for index, actual_name in enumerate(sess.input_names):
input_info = sess._inputs[index] if index < len(sess._inputs) else None
if actual_name in values:
feed[actual_name] = self._coerce_input_dtype(values[actual_name], input_info)
continue
if index < len(expected_inputs) and expected_inputs[index] in values:
feed[actual_name] = self._coerce_input_dtype(
values[expected_inputs[index]], input_info
)
continue
expected = expected_inputs[index] if index < len(expected_inputs) else None
raise KeyError(
f"Missing input for {name}: actual={actual_name!r}, expected={expected!r}"
)
raw_outputs = sess.run(feed)
mapped: Dict[str, np.ndarray] = {}
for index, expected_name in enumerate(expected_outputs):
if expected_name in raw_outputs:
mapped[expected_name] = raw_outputs[expected_name]
continue
if index < len(sess.output_names) and sess.output_names[index] in raw_outputs:
mapped[expected_name] = raw_outputs[sess.output_names[index]]
continue
raise KeyError(f"Missing output for {name}: {expected_name!r}")
return mapped
def run_encoder(self, cat_tokens: np.ndarray) -> np.ndarray:
# Pulsar2/AXEngine exposes the quantized encoder token input as int32
# even though the reference ONNX path uses int64 token IDs.
cat_tokens = np.asarray(cat_tokens, dtype=np.int32)
outputs = self._run_model(
self.encoder_info["name"],
self.encoder_info["inputs"],
self.encoder_info["outputs"],
{"cat_tokens": cat_tokens},
)
return outputs[self.encoder_info["outputs"][0]].astype(np.float32)
def run_decoder(
self,
t: np.ndarray,
x: np.ndarray,
text_condition: np.ndarray,
speech_condition: np.ndarray,
guidance_scale: np.ndarray,
padding_mask: np.ndarray | None = None,
) -> np.ndarray:
seq_len = x.shape[1]
values: Dict[str, np.ndarray] = {
"t": np.asarray(t, dtype=np.float32).reshape(1),
"x": x.astype(np.float32),
"text_condition": text_condition.astype(np.float32),
"speech_condition": speech_condition.astype(np.float32),
"guidance_scale": np.asarray(guidance_scale, dtype=np.float32).reshape(1),
"padding_mask": padding_mask.astype(np.bool_)
if padding_mask is not None
else np.zeros((1, seq_len), dtype=np.bool_),
}
for part in self.decoder_parts:
outputs = self._run_model(
part["name"],
part["inputs"],
part["outputs"],
values,
)
values.update(outputs)
final_output = self.decoder_parts[-1]["outputs"][0]
return values[final_output].astype(np.float32)
def duration_expand(
self,
encoded: np.ndarray,
prompt_tokens_len: int,
text_tokens_len: int,
prompt_features_len: int,
speed: float,
) -> tuple[np.ndarray, int]:
total_tokens_len = prompt_tokens_len + text_tokens_len
features_len = int(
np.ceil(prompt_features_len / prompt_tokens_len * total_tokens_len / speed)
)
if features_len > self.max_feat_len:
logging.debug(
"features_len=%d > max_feat_len=%d, clamping",
features_len,
self.max_feat_len,
)
features_len = self.max_feat_len
token_dur = features_len // total_tokens_len
embed_no_pad = encoded[0, :total_tokens_len, :]
text_condition = np.repeat(embed_no_pad, token_dur, axis=0)
residual = features_len - text_condition.shape[0]
if residual > 0:
last_embed = encoded[0, total_tokens_len : total_tokens_len + 1, :]
text_condition = np.concatenate(
[text_condition, np.repeat(last_embed, residual, axis=0)],
axis=0,
)
text_condition = text_condition[:features_len, :]
return text_condition[np.newaxis, :, :].astype(np.float32), features_len
def _get_time_steps(self) -> np.ndarray:
t = np.linspace(0.0, 1.0, self.num_step + 1, dtype=np.float32)
ts = self.t_shift
return ts * t / (1.0 + (ts - 1.0) * t)
def sample(
self,
cat_tokens: np.ndarray,
prompt_tokens_len: int,
text_tokens_len: int,
prompt_features: np.ndarray,
prompt_features_len: int,
speed: float = 1.0,
guidance_scale: float = 1.0,
seed: int = 666,
) -> tuple[np.ndarray, Dict[str, Any]]:
logging.debug(
"sample: prompt_tokens=%d, text_tokens=%d, prompt_frames=%d, "
"speed=%.2f, guidance_scale=%.2f, seed=%d",
prompt_tokens_len,
text_tokens_len,
prompt_features_len,
speed,
guidance_scale,
seed,
)
t_total_start = time.perf_counter()
t_start = time.perf_counter()
encoded = self.run_encoder(cat_tokens)
t_enc = time.perf_counter() - t_start
logging.debug(" encoder: %.3f s (output shape=%s)", t_enc, encoded.shape)
t_start = time.perf_counter()
text_condition, features_len = self.duration_expand(
encoded,
prompt_tokens_len,
text_tokens_len,
prompt_features_len,
speed,
)
t_dur = time.perf_counter() - t_start
logging.debug(" duration_expand: %.3f s (features_len=%d)", t_dur, features_len)
seq_len = self.decoder_seq_len or self.max_feat_len
if features_len > seq_len:
raise ValueError(
f"features_len={features_len} exceeds decoder sequence length {seq_len}"
)
if (
self.decoder_seq_len is not None
and not self.decoder_has_padding_mask
and features_len != seq_len
):
raise ValueError(
"Fixed no-mask decoder requires exact feature length: "
f"features_len={features_len}, decoder_seq_len={seq_len}"
)
if prompt_features.shape[1] > seq_len:
raise ValueError(
f"prompt feature length {prompt_features.shape[1]} exceeds "
f"decoder sequence length {seq_len}"
)
text_cond_padded = np.zeros((1, seq_len, self.feat_dim), dtype=np.float32)
text_cond_padded[0, :features_len] = text_condition[0, :features_len]
speech_cond_padded = np.zeros((1, seq_len, self.feat_dim), dtype=np.float32)
prompt_actual_len = prompt_features.shape[1]
speech_cond_padded[0, :prompt_actual_len] = prompt_features[0].astype(np.float32)
padding_mask = np.zeros((1, seq_len), dtype=np.bool_)
padding_mask[:, features_len:] = True
rng = np.random.RandomState(seed)
x = rng.randn(1, seq_len, self.feat_dim).astype(np.float32)
x[:, features_len:, :] = 0.0
timesteps = self._get_time_steps()
gs = np.array([guidance_scale], dtype=np.float32)
t_dec_total = 0.0
for step in range(self.num_step):
t_val = np.array([float(timesteps[step])], dtype=np.float32)
t_start = time.perf_counter()
v = self.run_decoder(
t_val,
x,
text_cond_padded,
speech_cond_padded,
gs,
padding_mask,
)
t_dec_total += time.perf_counter() - t_start
dt = float(timesteps[step + 1] - timesteps[step])
x = (x + v * dt).astype(np.float32)
x[:, features_len:, :] = 0.0
logging.debug(
" %s (NPU x%d): %.3f s total (avg %.3f ms/step)",
getattr(self, "decoder_label", "decoder4"),
self.num_step,
t_dec_total,
t_dec_total / self.num_step * 1000,
)
generated_frames = features_len - prompt_features_len
if generated_frames <= 0:
generated_frames = features_len
pred_features = x[0, :features_len, :]
else:
pred_features = x[0, prompt_features_len:features_len, :]
t_total = time.perf_counter() - t_total_start
timing = {
"encoder_time_sec": round(t_enc, 3),
"duration_expand_time_sec": round(t_dur, 3),
"decoder_time_sec": round(t_dec_total, 3),
"total_time_sec": round(t_total, 3),
"generated_frames": int(generated_frames),
"features_len": int(features_len),
}
logging.debug(" total: %.3f s", t_total)
return pred_features[np.newaxis, :, :].astype(np.float32), timing