| import argparse |
| import os |
| import random |
| import numpy as np |
| import torch |
| import torchaudio |
| import sys |
| import pathlib |
| import onnxruntime as ort |
| import time |
| from tqdm import tqdm |
|
|
| from transformers import AutoTokenizer |
| from transformers.cache_utils import DynamicCache |
|
|
| from modeling_minicpm import MiniCPMModel |
|
|
| def mask_multichar_chinese_tokens(tokenizer): |
| |
| multichar_tokens = { |
| token for token in tokenizer.vocab.keys() |
| if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token) |
| } |
|
|
| class CharTokenizerWrapper: |
| """Wrapper class for tokenizers that handles multi-character Chinese tokens. |
| |
| This wrapper automatically splits multi-character Chinese tokens into |
| individual characters while preserving the original tokenizer's interface. |
| """ |
| |
| def __init__(self, base_tokenizer) -> None: |
| """Initialize the wrapper with a base tokenizer. |
| |
| Args: |
| base_tokenizer: The tokenizer to wrap |
| """ |
| self.tokenizer = base_tokenizer |
| self.multichar_tokens = multichar_tokens |
|
|
| def tokenize(self, text: str, **kwargs): |
| """Tokenize text and split multi-character Chinese tokens into single characters. |
| |
| Args: |
| text: Input text to tokenize |
| **kwargs: Additional arguments passed to the base tokenizer |
| |
| Returns: |
| List of processed tokens with multi-character Chinese tokens split |
| |
| Example: |
| >>> wrapper = CharTokenizerWrapper(tokenizer) |
| >>> tokens = wrapper.tokenize("你好世界") |
| >>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"] |
| """ |
| if not isinstance(text, str): |
| raise TypeError(f"Expected string input, got {type(text)}") |
| |
| tokens = self.tokenizer.tokenize(text, **kwargs) |
| processed = [] |
| |
| for token in tokens: |
| |
| clean_token = token.replace("▁", "") |
|
|
| if clean_token in self.multichar_tokens: |
| |
| chars = list(clean_token) |
| processed.extend(chars) |
| else: |
| processed.append(token) |
| |
| return processed |
|
|
| def __call__(self, text: str, **kwargs): |
| """Call the tokenizer and return token IDs. |
| |
| This method provides the same interface as the original tokenizer |
| but with multi-character Chinese token handling. |
| |
| Args: |
| text: Input text to tokenize |
| **kwargs: Additional arguments passed to the base tokenizer |
| |
| Returns: |
| List of token IDs |
| |
| Raises: |
| TypeError: If input is not a string |
| ValueError: If tokenization fails |
| """ |
| try: |
| tokens = self.tokenize(text, **kwargs) |
| result = self.tokenizer.convert_tokens_to_ids(tokens) |
| return result |
| except Exception as e: |
| raise ValueError(f"Tokenization failed: {str(e)}") from e |
|
|
| return CharTokenizerWrapper(tokenizer) |
|
|
|
|
| def load_onnx(path: str, providers): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"ONNX file not found: {path}") |
| return ort.InferenceSession(path, providers=providers) |
|
|
|
|
| def to_numpy(t: torch.Tensor): |
| return t.detach().cpu().numpy() |
|
|
|
|
| def run_ort(session: ort.InferenceSession, inputs: dict, name: str = None): |
| start = time.perf_counter() |
| ort_inputs = {k: (v if isinstance(v, np.ndarray) else to_numpy(v)) for k, v in inputs.items()} |
| outputs = session.run(None, ort_inputs) |
| if name: |
| elapsed_ms = (time.perf_counter() - start) * 1000 |
| print(f"[time] {name}: {elapsed_ms:.2f} ms") |
| return outputs[0] |
|
|
|
|
| def cfm_euler_with_onnx_step( |
| dit_sess: ort.InferenceSession, |
| x: torch.Tensor, |
| mu: torch.Tensor, |
| cond: torch.Tensor, |
| n_timesteps: int, |
| cfg_value: float, |
| use_cfg_zero_star: bool = True, |
| mean_mode: bool = False, |
| ): |
| """ |
| Re-implementation of UnifiedCFM.solve_euler using ONNX DiT single step. |
| Shapes: |
| x: [B, C, P], mu: [B, H_dit], cond: [B, C, P] |
| """ |
| device = x.device |
| dtype = x.dtype |
| t_span = torch.linspace(1.0, 0.0, n_timesteps + 1, device=device, dtype=dtype) |
| t_span = t_span + 1.0 * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) |
|
|
| t = t_span[0] |
| dt = t_span[0] - t_span[1] |
| zero_init_steps = max(1, int(len(t_span) * 0.04)) |
|
|
| for step in range(1, len(t_span)): |
| if use_cfg_zero_star and step <= zero_init_steps: |
| dphi_dt = torch.zeros_like(x) |
| else: |
| b = x.size(0) |
| t_in = t.expand(b) |
| dt_in = dt.expand(b) |
| if not mean_mode: |
| dt_in = torch.zeros_like(dt_in) |
|
|
| x_batch = torch.cat([x, x], dim=0) |
| mu_batch = torch.cat([mu, torch.zeros_like(mu)], dim=0) |
| t_batch = torch.cat([t_in, t_in], dim=0) |
| cond_batch = torch.cat([cond, torch.zeros_like(cond)], dim=0) |
| dt_batch = torch.cat([dt_in, torch.zeros_like(dt_in)], dim=0) |
|
|
| dphi_dt_batch = run_ort( |
| dit_sess, |
| { |
| "x": x_batch, |
| "mu": mu_batch, |
| "t": t_batch, |
| "cond": cond_batch, |
| "dt": dt_batch, |
| }, |
| name=f"dit_step_b2_{step}", |
| ) |
| dphi_dt_batch = torch.from_numpy(dphi_dt_batch).to(device=device, dtype=dtype) |
| dphi_dt_pos, dphi_dt_neg = torch.split(dphi_dt_batch, [b, b], dim=0) |
|
|
| if use_cfg_zero_star: |
| positive_flat = dphi_dt_pos.view(b, -1) |
| negative_flat = dphi_dt_neg.view(b, -1) |
| st_star = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) / ( |
| torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 |
| ) |
| st_star = st_star.view(b, *([1] * (dphi_dt_pos.ndim - 1))) |
| else: |
| st_star = 1.0 |
|
|
| dphi_dt = dphi_dt_neg * st_star + cfg_value * (dphi_dt_pos - dphi_dt_neg * st_star) |
|
|
| x = x - dt * dphi_dt |
| t = t - dt |
| if step < len(t_span) - 1: |
| dt = t - t_span[step + 1] |
|
|
| return x |
|
|
|
|
| def prepare_audio_features( |
| audio_path: str, |
| sample_rate: int, |
| patch_size: int, |
| chunk_size: int, |
| vae_encode_sess: ort.InferenceSession, |
| ): |
| audio, sr = torchaudio.load(audio_path) |
| if audio.size(0) > 1: |
| audio = audio.mean(dim=0, keepdim=True) |
| if sr != sample_rate: |
| audio = torchaudio.functional.resample(audio, sr, sample_rate) |
|
|
| |
| if audio.ndim == 2: |
| audio = audio.unsqueeze(0) |
|
|
| patch_len = patch_size * chunk_size |
| t = audio.size(-1) |
| pad_right = (patch_len - t % patch_len) % patch_len |
| if pad_right > 0: |
| audio = torch.nn.functional.pad(audio, (0, pad_right)) |
|
|
| latent = run_ort(vae_encode_sess, {"audio_wave": audio}, name="vae_encode") |
| latent = torch.from_numpy(latent) |
| latent_dim = latent.shape[1] |
| t_latent = latent.shape[2] |
| if t_latent % patch_size != 0: |
| raise ValueError(f"Encoded latent length {t_latent} not divisible by patch_size={patch_size}") |
|
|
| audio_feat = latent.view(latent_dim, -1, patch_size).permute(1, 2, 0) |
| audio_feat = audio_feat[:-1, ...] |
| return audio_feat |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Hybrid ONNX/PyTorch inference for VoxCPM (non-streaming).") |
| parser.add_argument("--tokenizer-dir", required=True, help="Path to tokenizer (e.g., VoxCPM-0.5B).") |
| parser.add_argument("--base-hf-dir", required=True, help="Path to transformers-formatted base MiniCPM.") |
| parser.add_argument("--residual-hf-dir", required=True, help="Path to transformers-formatted residual MiniCPM.") |
| parser.add_argument("--onnx-dir", required=True, help="Directory containing exported ONNX files.") |
| parser.add_argument("--text", required=True, help="Target text to synthesize.") |
| parser.add_argument("--prompt-audio", default=None, help="Optional prompt audio path.") |
| parser.add_argument("--prompt-text", default=None, help="Text transcript of prompt audio (required if prompt-audio).") |
| parser.add_argument("--output", default="onnx_output.wav", help="Output wav path.") |
| parser.add_argument("--device", default=None, help="torch device; default auto (cuda if available else cpu).") |
| parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG value for diffusion.") |
| parser.add_argument("--inference-timesteps", type=int, default=10, help="Diffusion steps.") |
| parser.add_argument("--min-len", type=int, default=2, help="Minimum generated patch count before stop allowed.") |
| parser.add_argument("--max-len", type=int, default=2000, help="Maximum generated patch count.") |
| parser.add_argument("--force-fp32", action="store_true", help="Force model dtype to float32 for consistency with ONNX.") |
| parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") |
| parser.add_argument( |
| "--providers", |
| nargs="+", |
| default=None, |
| help="ONNX Runtime providers (e.g., CUDAExecutionProvider CPUExecutionProvider).", |
| ) |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device) if args.device else torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| providers = args.providers or ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
|
| |
| if args.seed is not None: |
| random.seed(args.seed) |
| np.random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(args.seed) |
|
|
| |
| with torch.inference_mode(): |
| |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) |
| tokenizer = mask_multichar_chinese_tokens(tokenizer) |
| base_model = MiniCPMModel.from_pretrained(args.base_hf_dir).to(device).eval() |
| residual_model = MiniCPMModel.from_pretrained(args.residual_hf_dir).to(device).eval() |
| if args.force_fp32: |
| dtype = torch.float32 |
| else: |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| base_model = base_model.to(dtype) |
| residual_model = residual_model.to(dtype) |
|
|
| |
| patch_size = 2 |
| feat_dim = 64 |
| latent_dim = 64 |
| chunk_size = 640 |
| sample_rate = 16000 |
| audio_start_token = 101 |
|
|
| |
| vae_encode_sess = load_onnx(os.path.join(args.onnx_dir, "audio_vae_encode.onnx"), providers) |
| vae_decode_sess = load_onnx(os.path.join(args.onnx_dir, "audio_vae_decode.onnx"), providers) |
| locenc_sess = load_onnx(os.path.join(args.onnx_dir, "locenc.onnx"), providers) |
| fsq_sess = load_onnx(os.path.join(args.onnx_dir, "fsq_layer.onnx"), providers) |
| stop_sess = load_onnx(os.path.join(args.onnx_dir, "stop_head.onnx"), providers) |
| dit_step_sess = load_onnx(os.path.join(args.onnx_dir, "dit_step.onnx"), providers) |
| enc_to_lm_sess = load_onnx(os.path.join(args.onnx_dir, "enc_to_lm_proj.onnx"), providers) |
| lm_to_dit_sess = load_onnx(os.path.join(args.onnx_dir, "lm_to_dit_proj.onnx"), providers) |
| res_to_dit_sess = load_onnx(os.path.join(args.onnx_dir, "res_to_dit_proj.onnx"), providers) |
|
|
| |
| if args.prompt_audio: |
| if not args.prompt_text: |
| raise ValueError("prompt-text is required when prompt-audio is provided.") |
| text = args.prompt_text + args.text |
| else: |
| text = args.text |
|
|
| tokenized = tokenizer(text, add_special_tokens=False) |
| if isinstance(tokenized, dict): |
| text_token = tokenized["input_ids"] |
| else: |
| text_token = tokenized |
| text_token = torch.LongTensor(text_token) |
| text_token = torch.cat([text_token, torch.tensor([audio_start_token], dtype=torch.int64)], dim=-1) |
| text_length = text_token.shape[0] |
|
|
| if args.prompt_audio: |
| audio_feat = prepare_audio_features( |
| args.prompt_audio, sample_rate, patch_size, chunk_size, vae_encode_sess |
| ) |
| audio_length = audio_feat.size(0) |
|
|
| text_pad_token = torch.zeros(audio_length, dtype=torch.int64) |
| text_token = torch.cat([text_token, text_pad_token]) |
|
|
| audio_pad_feat = torch.zeros( |
| (text_length, patch_size, latent_dim), |
| dtype=torch.float32, |
| ) |
| audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0) |
|
|
| text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32) |
| audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32) |
| else: |
| audio_feat = torch.zeros( |
| (text_length, patch_size, latent_dim), |
| dtype=torch.float32, |
| ) |
| text_mask = torch.ones(text_length).type(torch.int32) |
| audio_mask = torch.zeros(text_length).type(torch.int32) |
|
|
| text_token = text_token.unsqueeze(0).to(device) |
| text_mask = text_mask.unsqueeze(0).to(device) |
| audio_feat = audio_feat.unsqueeze(0).to(device) |
| audio_mask = audio_mask.unsqueeze(0).to(device) |
|
|
| |
| feat_embed_np = run_ort(locenc_sess, {"x": audio_feat.float()}, name="locenc") |
| feat_embed = torch.from_numpy(feat_embed_np).to(device=device, dtype=dtype) |
| feat_embed = run_ort(enc_to_lm_sess, {"input": feat_embed.float()}, name="enc_to_lm_init") |
| feat_embed = torch.from_numpy(feat_embed).to(device=device, dtype=dtype) |
|
|
| |
| scale_emb = 1.0 |
| text_embed = base_model.embed_tokens(text_token) * scale_emb |
| np.save("text_embed_ref.npy", text_embed.cpu().numpy()) |
|
|
| combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed |
| attn_mask = torch.ones((combined_embed.size(0), combined_embed.size(1)), device=device, dtype=torch.long) |
| np.save("combined_embed_ref.npy", combined_embed.cpu().numpy()) |
|
|
| |
| start = time.perf_counter() |
| base_cache = DynamicCache() |
| base_outputs = base_model( |
| inputs_embeds=combined_embed, |
| attention_mask=attn_mask, |
| past_key_values=base_cache, |
| use_cache=True, |
| return_dict=True, |
| ) |
| enc_outputs = base_outputs.last_hidden_state |
| base_cache = base_outputs.past_key_values |
| print(f"[time] base_lm initial: {(time.perf_counter() - start)*1000:.2f} ms") |
| np.save("enc_outputs_ref.npy", enc_outputs.cpu().numpy()) |
|
|
| |
| enc_outputs_fsq_np = run_ort(fsq_sess, {"hidden": enc_outputs.float()}, name="fsq_init") |
| enc_outputs_fsq = torch.from_numpy(enc_outputs_fsq_np).to(device=device, dtype=dtype) |
| np.save("enc_outputs_fsq_ref.npy", enc_outputs_fsq.cpu().numpy()) |
|
|
| enc_outputs = enc_outputs_fsq * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1) |
|
|
| |
| np.save("audio_mask_ref.npy", audio_mask.cpu().numpy()) |
| np.save("feat_embed_ref.npy", feat_embed.cpu().numpy()) |
| residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed |
| start = time.perf_counter() |
| residual_cache = DynamicCache() |
| res_outputs = residual_model( |
| inputs_embeds=residual_inputs, |
| attention_mask=attn_mask, |
| past_key_values=residual_cache, |
| use_cache=True, |
| return_dict=True, |
| ) |
| residual_outputs = res_outputs.last_hidden_state |
| residual_cache = res_outputs.past_key_values |
| print(f"[time] residual_lm initial: {(time.perf_counter() - start)*1000:.2f} ms") |
| np.save("residual_inputs_ref.npy", residual_inputs.cpu().numpy()) |
| np.save("residual_outputs_ref.npy", residual_outputs.cpu().numpy()) |
| np.save("residual_inputs_ref.npy", residual_inputs.cpu().numpy()) |
| np.save("residual_outputs_ref.npy", residual_outputs.cpu().numpy()) |
|
|
| lm_hidden = enc_outputs[:, -1, :].to(dtype) |
| res_hidden = residual_outputs[:, -1, :].to(dtype) |
|
|
| prefix_feat_cond = audio_feat[:, -1, :, :] |
| pred_feat_seq = [] |
|
|
| |
| for step_idx in tqdm(range(args.max_len), desc="gen_loop"): |
| dit_hidden_lm = run_ort(lm_to_dit_sess, {"input": lm_hidden.float()}, name="lm_to_dit") |
| dit_hidden_res = run_ort(res_to_dit_sess, {"input": res_hidden.float()}, name="res_to_dit") |
| dit_hidden = torch.from_numpy(dit_hidden_lm + dit_hidden_res).to(device=device, dtype=dtype) |
| cond = prefix_feat_cond.transpose(1, 2).contiguous() |
|
|
| |
| x0 = torch.randn_like(prefix_feat_cond.transpose(1, 2)) |
| pred_feat = cfm_euler_with_onnx_step( |
| dit_step_sess, |
| x0, |
| dit_hidden, |
| cond, |
| n_timesteps=args.inference_timesteps, |
| cfg_value=args.cfg_value, |
| use_cfg_zero_star=True, |
| ).transpose(1, 2) |
|
|
| pred_feat_seq.append(pred_feat.unsqueeze(1)) |
| prefix_feat_cond = pred_feat |
|
|
| |
| locenc_step_np = run_ort(locenc_sess, {"x": pred_feat.unsqueeze(1).float()}, name="locenc_step") |
| curr_embed = torch.from_numpy(locenc_step_np).to(device=device, dtype=dtype) |
| curr_embed = run_ort(enc_to_lm_sess, {"input": curr_embed.float()}, name="enc_to_lm_step") |
| curr_embed = torch.from_numpy(curr_embed).to(device=device, dtype=dtype) |
|
|
| |
| stop_logits_np = run_ort(stop_sess, {"hidden": lm_hidden.float()}) |
| stop_logits = torch.from_numpy(stop_logits_np) |
| stop_flag = stop_logits.argmax(dim=-1)[0].item() |
| if step_idx > args.min_len and stop_flag == 1: |
| break |
|
|
| |
| attn_mask = torch.cat( |
| [attn_mask, torch.ones((attn_mask.size(0), 1), device=device, dtype=torch.long)], dim=1 |
| ) |
| np.save("curr_embed_ref.npy", curr_embed.cpu().numpy()) |
| base_step = base_model( |
| inputs_embeds=curr_embed, |
| attention_mask=attn_mask, |
| past_key_values=base_cache, |
| use_cache=True, |
| return_dict=True, |
| ) |
| lm_hidden_step = base_step.last_hidden_state |
| base_cache = base_step.past_key_values |
| lm_hidden = lm_hidden_step.squeeze(1).to(dtype) |
|
|
| |
| lm_hidden_step = lm_hidden.unsqueeze(1) |
| lm_hidden_fsq_np = run_ort(fsq_sess, {"hidden": lm_hidden_step.float()}) |
| lm_hidden_fsq = torch.from_numpy(lm_hidden_fsq_np).to(device=device, dtype=dtype).squeeze(1) |
|
|
| res_step_inputs = (lm_hidden_fsq + curr_embed[:, 0, :]).unsqueeze(1) |
| res_step = residual_model( |
| inputs_embeds=res_step_inputs, |
| attention_mask=attn_mask, |
| past_key_values=residual_cache, |
| use_cache=True, |
| return_dict=True, |
| ) |
| residual_cache = res_step.past_key_values |
| res_hidden = res_step.last_hidden_state.squeeze(1).to(dtype) |
| lm_hidden = lm_hidden_fsq |
|
|
| if len(pred_feat_seq) == 0: |
| raise RuntimeError("Generation produced zero patches.") |
|
|
| pred_feat_seq = torch.cat(pred_feat_seq, dim=1) |
| feat_pred = pred_feat_seq.permute(0, 3, 1, 2).reshape( |
| pred_feat_seq.size(0), pred_feat_seq.size(3), -1 |
| ) |
|
|
| |
| audio_np = run_ort(vae_decode_sess, {"latent": feat_pred.float()}, name="vae_decode") |
| audio = torch.from_numpy(audio_np) |
| audio = audio[..., 640:-640] |
|
|
| wav = audio.squeeze(0).squeeze(0).cpu().numpy() |
| torchaudio.save(args.output, torch.from_numpy(wav).unsqueeze(0), sample_rate) |
| print(f"Saved: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|