VoxCPM-0.5B-RKNN2 / onnx_infer.py
happyme531's picture
Add CFG parallel inference with new library
921aee2 verified
import argparse
import os
import random
import numpy as np
import torch
import torchaudio
import sys
import pathlib
import onnxruntime as ort
import time
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.cache_utils import DynamicCache
from modeling_minicpm import MiniCPMModel # noqa: E402
def mask_multichar_chinese_tokens(tokenizer):
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
multichar_tokens = {
token for token in tokenizer.vocab.keys()
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
}
class CharTokenizerWrapper:
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
This wrapper automatically splits multi-character Chinese tokens into
individual characters while preserving the original tokenizer's interface.
"""
def __init__(self, base_tokenizer) -> None:
"""Initialize the wrapper with a base tokenizer.
Args:
base_tokenizer: The tokenizer to wrap
"""
self.tokenizer = base_tokenizer
self.multichar_tokens = multichar_tokens
def tokenize(self, text: str, **kwargs):
"""Tokenize text and split multi-character Chinese tokens into single characters.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of processed tokens with multi-character Chinese tokens split
Example:
>>> wrapper = CharTokenizerWrapper(tokenizer)
>>> tokens = wrapper.tokenize("你好世界")
>>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
"""
if not isinstance(text, str):
raise TypeError(f"Expected string input, got {type(text)}")
tokens = self.tokenizer.tokenize(text, **kwargs)
processed = []
for token in tokens:
# Remove possible subword prefix
clean_token = token.replace("▁", "")
if clean_token in self.multichar_tokens:
# Split multi-character token into single characters
chars = list(clean_token)
processed.extend(chars)
else:
processed.append(token)
return processed
def __call__(self, text: str, **kwargs):
"""Call the tokenizer and return token IDs.
This method provides the same interface as the original tokenizer
but with multi-character Chinese token handling.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of token IDs
Raises:
TypeError: If input is not a string
ValueError: If tokenization fails
"""
try:
tokens = self.tokenize(text, **kwargs)
result = self.tokenizer.convert_tokens_to_ids(tokens)
return result
except Exception as e:
raise ValueError(f"Tokenization failed: {str(e)}") from e
return CharTokenizerWrapper(tokenizer)
def load_onnx(path: str, providers):
if not os.path.exists(path):
raise FileNotFoundError(f"ONNX file not found: {path}")
return ort.InferenceSession(path, providers=providers)
def to_numpy(t: torch.Tensor):
return t.detach().cpu().numpy()
def run_ort(session: ort.InferenceSession, inputs: dict, name: str = None):
start = time.perf_counter()
ort_inputs = {k: (v if isinstance(v, np.ndarray) else to_numpy(v)) for k, v in inputs.items()}
outputs = session.run(None, ort_inputs)
if name:
elapsed_ms = (time.perf_counter() - start) * 1000
print(f"[time] {name}: {elapsed_ms:.2f} ms")
return outputs[0]
def cfm_euler_with_onnx_step(
dit_sess: ort.InferenceSession,
x: torch.Tensor,
mu: torch.Tensor,
cond: torch.Tensor,
n_timesteps: int,
cfg_value: float,
use_cfg_zero_star: bool = True,
mean_mode: bool = False,
):
"""
Re-implementation of UnifiedCFM.solve_euler using ONNX DiT single step.
Shapes:
x: [B, C, P], mu: [B, H_dit], cond: [B, C, P]
"""
device = x.device
dtype = x.dtype
t_span = torch.linspace(1.0, 0.0, n_timesteps + 1, device=device, dtype=dtype)
t_span = t_span + 1.0 * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) # sway sampling
t = t_span[0]
dt = t_span[0] - t_span[1]
zero_init_steps = max(1, int(len(t_span) * 0.04))
for step in range(1, len(t_span)):
if use_cfg_zero_star and step <= zero_init_steps:
dphi_dt = torch.zeros_like(x)
else:
b = x.size(0)
t_in = t.expand(b)
dt_in = dt.expand(b)
if not mean_mode:
dt_in = torch.zeros_like(dt_in)
x_batch = torch.cat([x, x], dim=0)
mu_batch = torch.cat([mu, torch.zeros_like(mu)], dim=0)
t_batch = torch.cat([t_in, t_in], dim=0)
cond_batch = torch.cat([cond, torch.zeros_like(cond)], dim=0)
dt_batch = torch.cat([dt_in, torch.zeros_like(dt_in)], dim=0)
dphi_dt_batch = run_ort(
dit_sess,
{
"x": x_batch,
"mu": mu_batch,
"t": t_batch,
"cond": cond_batch,
"dt": dt_batch,
},
name=f"dit_step_b2_{step}",
)
dphi_dt_batch = torch.from_numpy(dphi_dt_batch).to(device=device, dtype=dtype)
dphi_dt_pos, dphi_dt_neg = torch.split(dphi_dt_batch, [b, b], dim=0)
if use_cfg_zero_star:
positive_flat = dphi_dt_pos.view(b, -1)
negative_flat = dphi_dt_neg.view(b, -1)
st_star = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) / (
torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
)
st_star = st_star.view(b, *([1] * (dphi_dt_pos.ndim - 1)))
else:
st_star = 1.0
dphi_dt = dphi_dt_neg * st_star + cfg_value * (dphi_dt_pos - dphi_dt_neg * st_star)
x = x - dt * dphi_dt
t = t - dt
if step < len(t_span) - 1:
dt = t - t_span[step + 1]
return x
def prepare_audio_features(
audio_path: str,
sample_rate: int,
patch_size: int,
chunk_size: int,
vae_encode_sess: ort.InferenceSession,
):
audio, sr = torchaudio.load(audio_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != sample_rate:
audio = torchaudio.functional.resample(audio, sr, sample_rate)
# Expect shape [B, 1, T] for ONNX VAE encoder
if audio.ndim == 2:
audio = audio.unsqueeze(0)
patch_len = patch_size * chunk_size
t = audio.size(-1)
pad_right = (patch_len - t % patch_len) % patch_len
if pad_right > 0:
audio = torch.nn.functional.pad(audio, (0, pad_right))
latent = run_ort(vae_encode_sess, {"audio_wave": audio}, name="vae_encode")
latent = torch.from_numpy(latent)
latent_dim = latent.shape[1]
t_latent = latent.shape[2]
if t_latent % patch_size != 0:
raise ValueError(f"Encoded latent length {t_latent} not divisible by patch_size={patch_size}")
audio_feat = latent.view(latent_dim, -1, patch_size).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # remove last padding token
return audio_feat
def main():
parser = argparse.ArgumentParser(description="Hybrid ONNX/PyTorch inference for VoxCPM (non-streaming).")
parser.add_argument("--tokenizer-dir", required=True, help="Path to tokenizer (e.g., VoxCPM-0.5B).")
parser.add_argument("--base-hf-dir", required=True, help="Path to transformers-formatted base MiniCPM.")
parser.add_argument("--residual-hf-dir", required=True, help="Path to transformers-formatted residual MiniCPM.")
parser.add_argument("--onnx-dir", required=True, help="Directory containing exported ONNX files.")
parser.add_argument("--text", required=True, help="Target text to synthesize.")
parser.add_argument("--prompt-audio", default=None, help="Optional prompt audio path.")
parser.add_argument("--prompt-text", default=None, help="Text transcript of prompt audio (required if prompt-audio).")
parser.add_argument("--output", default="onnx_output.wav", help="Output wav path.")
parser.add_argument("--device", default=None, help="torch device; default auto (cuda if available else cpu).")
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG value for diffusion.")
parser.add_argument("--inference-timesteps", type=int, default=10, help="Diffusion steps.")
parser.add_argument("--min-len", type=int, default=2, help="Minimum generated patch count before stop allowed.")
parser.add_argument("--max-len", type=int, default=2000, help="Maximum generated patch count.")
parser.add_argument("--force-fp32", action="store_true", help="Force model dtype to float32 for consistency with ONNX.")
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
parser.add_argument(
"--providers",
nargs="+",
default=None,
help="ONNX Runtime providers (e.g., CUDAExecutionProvider CPUExecutionProvider).",
)
args = parser.parse_args()
device = torch.device(args.device) if args.device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
providers = args.providers or ["CUDAExecutionProvider", "CPUExecutionProvider"]
# Seed
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# Inference with no grad to avoid graph retention / memory growth
with torch.inference_mode():
# Load tokenizer and HF MiniCPM models
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer = mask_multichar_chinese_tokens(tokenizer)
base_model = MiniCPMModel.from_pretrained(args.base_hf_dir).to(device).eval()
residual_model = MiniCPMModel.from_pretrained(args.residual_hf_dir).to(device).eval()
if args.force_fp32:
dtype = torch.float32
else:
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
base_model = base_model.to(dtype)
residual_model = residual_model.to(dtype)
# constants tied to exported ONNX
patch_size = 2
feat_dim = 64
latent_dim = 64
chunk_size = 640
sample_rate = 16000
audio_start_token = 101
# Load ONNX sessions
vae_encode_sess = load_onnx(os.path.join(args.onnx_dir, "audio_vae_encode.onnx"), providers)
vae_decode_sess = load_onnx(os.path.join(args.onnx_dir, "audio_vae_decode.onnx"), providers)
locenc_sess = load_onnx(os.path.join(args.onnx_dir, "locenc.onnx"), providers)
fsq_sess = load_onnx(os.path.join(args.onnx_dir, "fsq_layer.onnx"), providers)
stop_sess = load_onnx(os.path.join(args.onnx_dir, "stop_head.onnx"), providers)
dit_step_sess = load_onnx(os.path.join(args.onnx_dir, "dit_step.onnx"), providers)
enc_to_lm_sess = load_onnx(os.path.join(args.onnx_dir, "enc_to_lm_proj.onnx"), providers)
lm_to_dit_sess = load_onnx(os.path.join(args.onnx_dir, "lm_to_dit_proj.onnx"), providers)
res_to_dit_sess = load_onnx(os.path.join(args.onnx_dir, "res_to_dit_proj.onnx"), providers)
# Build text/audio features
if args.prompt_audio:
if not args.prompt_text:
raise ValueError("prompt-text is required when prompt-audio is provided.")
text = args.prompt_text + args.text
else:
text = args.text
tokenized = tokenizer(text, add_special_tokens=False)
if isinstance(tokenized, dict):
text_token = tokenized["input_ids"]
else:
text_token = tokenized
text_token = torch.LongTensor(text_token)
text_token = torch.cat([text_token, torch.tensor([audio_start_token], dtype=torch.int64)], dim=-1)
text_length = text_token.shape[0]
if args.prompt_audio:
audio_feat = prepare_audio_features(
args.prompt_audio, sample_rate, patch_size, chunk_size, vae_encode_sess
)
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int64)
text_token = torch.cat([text_token, text_pad_token])
audio_pad_feat = torch.zeros(
(text_length, patch_size, latent_dim),
dtype=torch.float32,
)
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32)
else:
audio_feat = torch.zeros(
(text_length, patch_size, latent_dim),
dtype=torch.float32,
)
text_mask = torch.ones(text_length).type(torch.int32)
audio_mask = torch.zeros(text_length).type(torch.int32)
text_token = text_token.unsqueeze(0).to(device)
text_mask = text_mask.unsqueeze(0).to(device)
audio_feat = audio_feat.unsqueeze(0).to(device)
audio_mask = audio_mask.unsqueeze(0).to(device)
# LocEnc (ONNX)
feat_embed_np = run_ort(locenc_sess, {"x": audio_feat.float()}, name="locenc")
feat_embed = torch.from_numpy(feat_embed_np).to(device=device, dtype=dtype)
feat_embed = run_ort(enc_to_lm_sess, {"input": feat_embed.float()}, name="enc_to_lm_init")
feat_embed = torch.from_numpy(feat_embed).to(device=device, dtype=dtype)
# Text embed
scale_emb = 1.0
text_embed = base_model.embed_tokens(text_token) * scale_emb
np.save("text_embed_ref.npy", text_embed.cpu().numpy())
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
attn_mask = torch.ones((combined_embed.size(0), combined_embed.size(1)), device=device, dtype=torch.long)
np.save("combined_embed_ref.npy", combined_embed.cpu().numpy())
# Base LM forward
start = time.perf_counter()
base_cache = DynamicCache()
base_outputs = base_model(
inputs_embeds=combined_embed,
attention_mask=attn_mask,
past_key_values=base_cache,
use_cache=True,
return_dict=True,
)
enc_outputs = base_outputs.last_hidden_state
base_cache = base_outputs.past_key_values
print(f"[time] base_lm initial: {(time.perf_counter() - start)*1000:.2f} ms")
np.save("enc_outputs_ref.npy", enc_outputs.cpu().numpy())
# FSQ on audio positions
enc_outputs_fsq_np = run_ort(fsq_sess, {"hidden": enc_outputs.float()}, name="fsq_init")
enc_outputs_fsq = torch.from_numpy(enc_outputs_fsq_np).to(device=device, dtype=dtype)
np.save("enc_outputs_fsq_ref.npy", enc_outputs_fsq.cpu().numpy())
enc_outputs = enc_outputs_fsq * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
# Residual LM forward
np.save("audio_mask_ref.npy", audio_mask.cpu().numpy())
np.save("feat_embed_ref.npy", feat_embed.cpu().numpy())
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
start = time.perf_counter()
residual_cache = DynamicCache()
res_outputs = residual_model(
inputs_embeds=residual_inputs,
attention_mask=attn_mask,
past_key_values=residual_cache,
use_cache=True,
return_dict=True,
)
residual_outputs = res_outputs.last_hidden_state
residual_cache = res_outputs.past_key_values
print(f"[time] residual_lm initial: {(time.perf_counter() - start)*1000:.2f} ms")
np.save("residual_inputs_ref.npy", residual_inputs.cpu().numpy())
np.save("residual_outputs_ref.npy", residual_outputs.cpu().numpy())
np.save("residual_inputs_ref.npy", residual_inputs.cpu().numpy())
np.save("residual_outputs_ref.npy", residual_outputs.cpu().numpy())
lm_hidden = enc_outputs[:, -1, :].to(dtype)
res_hidden = residual_outputs[:, -1, :].to(dtype)
prefix_feat_cond = audio_feat[:, -1, :, :] # [B, P, D]
pred_feat_seq = []
# Generation loop
for step_idx in tqdm(range(args.max_len), desc="gen_loop"):
dit_hidden_lm = run_ort(lm_to_dit_sess, {"input": lm_hidden.float()}, name="lm_to_dit")
dit_hidden_res = run_ort(res_to_dit_sess, {"input": res_hidden.float()}, name="res_to_dit")
dit_hidden = torch.from_numpy(dit_hidden_lm + dit_hidden_res).to(device=device, dtype=dtype)
cond = prefix_feat_cond.transpose(1, 2).contiguous() # [B, D, P]
# Sample next patch via ONNX DiT
x0 = torch.randn_like(prefix_feat_cond.transpose(1, 2)) # [B, D, P]
pred_feat = cfm_euler_with_onnx_step(
dit_step_sess,
x0,
dit_hidden,
cond,
n_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
use_cfg_zero_star=True,
).transpose(1, 2) # -> [B, P, D]
pred_feat_seq.append(pred_feat.unsqueeze(1)) # keep time dimension
prefix_feat_cond = pred_feat
# Encode new patch for next step (ONNX locenc)
locenc_step_np = run_ort(locenc_sess, {"x": pred_feat.unsqueeze(1).float()}, name="locenc_step")
curr_embed = torch.from_numpy(locenc_step_np).to(device=device, dtype=dtype)
curr_embed = run_ort(enc_to_lm_sess, {"input": curr_embed.float()}, name="enc_to_lm_step")
curr_embed = torch.from_numpy(curr_embed).to(device=device, dtype=dtype)
# Stop check (use lm_hidden BEFORE update, consistent with original)
stop_logits_np = run_ort(stop_sess, {"hidden": lm_hidden.float()})
stop_logits = torch.from_numpy(stop_logits_np)
stop_flag = stop_logits.argmax(dim=-1)[0].item()
if step_idx > args.min_len and stop_flag == 1:
break
# Update LMs using transformers cache API
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.size(0), 1), device=device, dtype=torch.long)], dim=1
)
np.save("curr_embed_ref.npy", curr_embed.cpu().numpy())
base_step = base_model(
inputs_embeds=curr_embed,
attention_mask=attn_mask,
past_key_values=base_cache,
use_cache=True,
return_dict=True,
)
lm_hidden_step = base_step.last_hidden_state # [B, 1, H]
base_cache = base_step.past_key_values
lm_hidden = lm_hidden_step.squeeze(1).to(dtype)
# FSQ expects [B, T, H]; expand step dimension then squeeze back
lm_hidden_step = lm_hidden.unsqueeze(1)
lm_hidden_fsq_np = run_ort(fsq_sess, {"hidden": lm_hidden_step.float()})
lm_hidden_fsq = torch.from_numpy(lm_hidden_fsq_np).to(device=device, dtype=dtype).squeeze(1)
res_step_inputs = (lm_hidden_fsq + curr_embed[:, 0, :]).unsqueeze(1)
res_step = residual_model(
inputs_embeds=res_step_inputs,
attention_mask=attn_mask,
past_key_values=residual_cache,
use_cache=True,
return_dict=True,
)
residual_cache = res_step.past_key_values
res_hidden = res_step.last_hidden_state.squeeze(1).to(dtype)
lm_hidden = lm_hidden_fsq
if len(pred_feat_seq) == 0:
raise RuntimeError("Generation produced zero patches.")
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # [B, T_gen, P, D]
feat_pred = pred_feat_seq.permute(0, 3, 1, 2).reshape(
pred_feat_seq.size(0), pred_feat_seq.size(3), -1
) # [B, D, T_gen*P]
# Decode audio via ONNX
audio_np = run_ort(vae_decode_sess, {"latent": feat_pred.float()}, name="vae_decode")
audio = torch.from_numpy(audio_np)
audio = audio[..., 640:-640] # trim start/end
wav = audio.squeeze(0).squeeze(0).cpu().numpy()
torchaudio.save(args.output, torch.from_numpy(wav).unsqueeze(0), sample_rate)
print(f"Saved: {args.output}")
if __name__ == "__main__":
main()