lglg666's picture
Update VoxCPM/src/voxcpm/model/voxcpm.py
c04851d verified
"""
VoxCPM: A Tokenizer-free speech generation model
This module contains the main VoxCPM model implementation, including configuration classes
and the core VoxCPMModel for text-to-speech generation.
Copyright 2025 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Tuple, Union, Generator, List
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
from pydantic import BaseModel
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE
from ..modules.layers import ScalarQuantizationLayer
from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
class VoxCPMEncoderConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
class VoxCPMDitConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
cfm_config: CfmConfig
class VoxCPMConfig(BaseModel):
lm_config: MiniCPM4Config
patch_size: int = 2
feat_dim: int = 64
residual_lm_num_layers: int = 6
scalar_quantization_latent_dim: int = 256
scalar_quantization_scale: int = 9
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
max_length: int = 4096
device: str = "cuda"
dtype: str = "bfloat16"
dit_mean_mode: bool = False
class LoRAConfig(BaseModel):
enable_lm: bool = False # 对 base_lm + residual_lm 加 LoRA
enable_dit: bool = False # 对 VoxCPMLocDiT 加 LoRA
enable_proj: bool = False # 对若干投影 Linear 加 LoRA
r: int = 8
alpha: int = 16
dropout: float = 0.0
# LM & DiT 目标线性层名(以属性名匹配)
target_modules_lm: list[str] = ["q_proj", "v_proj"]
target_modules_dit: list[str] = ["q_proj", "v_proj"]
# 投影层属性名,在 VoxCPMModel 上查找
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
VoxCPMConfig.model_rebuild()
class VoxCPMModel(nn.Module):
def __init__(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101
self.audio_end_token = 102
# Residual Acoustic LM
residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
# Local Encoder
encoder_config = config.lm_config.model_copy(deep=True)
encoder_config.hidden_size = config.encoder_config.hidden_dim
encoder_config.intermediate_size = config.encoder_config.ffn_dim
encoder_config.num_attention_heads = config.encoder_config.num_heads
encoder_config.num_hidden_layers = config.encoder_config.num_layers
encoder_config.kv_channels = config.encoder_config.kv_channels
encoder_config.vocab_size = 0
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
# Local DiT
decoder_config = config.lm_config.model_copy(deep=True)
decoder_config.hidden_size = config.dit_config.hidden_dim
decoder_config.intermediate_size = config.dit_config.ffn_dim
decoder_config.num_attention_heads = config.dit_config.num_heads
decoder_config.num_hidden_layers = config.dit_config.num_layers
decoder_config.kv_channels = config.dit_config.kv_channels
decoder_config.vocab_size = 0
self.feat_decoder = UnifiedCFM(
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
mean_mode=config.dit_mean_mode,
)
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
# Stop Predictor
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate
# ------------------------------------------------------------------ #
# 可选:在构造阶段就对 LM / DiT / 投影层注入 LoRA(仅结构,不含冻结逻辑)
# 是否真正只训练 LoRA 参数,由 LoRAConfig.train_only_lora 控制。
# ------------------------------------------------------------------ #
if self.lora_config is not None:
# LM: base_lm + residual_lm
if self.lora_config.enable_lm:
apply_lora_to_named_linear_modules(
self.base_lm,
target_submodule_names=self.lora_config.target_modules_lm,
r=self.lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
)
apply_lora_to_named_linear_modules(
self.residual_lm,
target_submodule_names=self.lora_config.target_modules_lm,
r=lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
)
# DiT: VoxCPMLocDiT(feat_decoder.estimator)
if self.lora_config.enable_dit:
apply_lora_to_named_linear_modules(
self.feat_decoder.estimator,
target_submodule_names=self.lora_config.target_modules_dit,
r=self.lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
)
# 投影层:在当前模型上按属性名查 Linear 并替换
if self.lora_config.enable_proj:
for attr_name in lora_config.target_proj_modules:
if hasattr(self, attr_name):
module = getattr(self, attr_name)
if isinstance(module, nn.Linear):
from ..modules.layers.lora import LoRALinear
setattr(
self,
attr_name,
LoRALinear(
base=module,
r=self.lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
),
)
def optimize(self, disable: bool = False):
# 无论是否 compile,都需要设置这些接口(generate 等方法依赖它们)
if not hasattr(self, 'feat_encoder_step'):
self.feat_encoder_step = self.feat_encoder
if disable:
# 不使用 torch.compile,但仍然设置必要的接口
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
return self
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
except:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"Error: {e}")
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
return self
def forward(
self,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
audio_feats: torch.Tensor,
audio_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
labels: torch.Tensor,
*,
progress: float = 0.0,
sample_generate: bool = False,
):
del position_ids # not used yet
text_tokens = text_tokens.to(self.device, dtype=torch.long)
text_mask = text_mask.to(self.device, dtype=self._dtype())
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
labels = labels.to(self.device, dtype=torch.long)
B, T, P, D = audio_feats.shape
feat_embed = self.feat_encoder(audio_feats)
feat_embed = self.enc_to_lm_proj(feat_embed)
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
if not getattr(self.config.lm_config, "use_mup", False):
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
enc_outputs = enc_outputs.to(self._dtype())
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
residual_outputs = residual_outputs.to(self._dtype())
residual_hidden = torch.cat(
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
dim=1,
)
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
target_dtype = self._dtype()
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
feat_cond = torch.cat(
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
dim=1,
)
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
diff_loss = self.feat_decoder.compute_loss(
feat_gt.transpose(1, 2).contiguous(),
dit_hidden,
cond=feat_cond.transpose(1, 2).contiguous(),
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
progress=progress,
)
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
denom = torch.clamp(loss_mask.sum(), min=1.0)
stop_loss = (stop_losses * loss_mask).sum() / denom
feat_pred = None
if sample_generate:
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
feat_pred_seq = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
return {
"loss/diff": diff_loss,
"loss/stop": stop_loss,
"feat_gt": feat_gt_tensor,
"feat_pred": feat_pred,
}
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
if len(prompt_wav_path) == 0:
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_length = text_token.shape[0]
audio_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
else:
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token])
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, _ in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
self,
prompt_text: str,
prompt_wav_path: str,
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with text tokens and audio features
"""
if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# extract audio features
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
# build prompt cache
prompt_cache = {
"text_token": text_token,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
new_text_token: torch.Tensor,
new_audio_feat: torch.Tensor,
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text_token: newly generated text tokens
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache
"""
if original_cache is None:
return {
"text_token": new_text_token,
"audio_feat": new_audio_feat,
}
original_text_token = original_cache["text_token"]
original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache
merged_cache = {
"text_token": merged_text_token,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
target_text: str,
prompt_cache: dict,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
min_len: Minimum audio length to avoid very short audio
max_len: Maximum audio length
inference_timesteps: Number of diffusion sampling steps
cfg_value: Classifier-free guidance value
retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
Returns:
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
- Tensor of new text tokens
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
"""
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
audio_pad_feat = torch.zeros(
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _inference(
self,
text: torch.Tensor,
text_mask: torch.Tensor,
feat: torch.Tensor,
feat_mask: torch.Tensor,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
feat: Input audio features
feat_mask: Mask for audio features
min_len: Minimum generation length
max_len: Maximum generation length
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
prefix_feat_cond = feat[:, -1, ...] # b, p, d
pred_feat_seq = [] # b, t, p, d
curr_embed = None
enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed,
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
)
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
pred_feat = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
n_timesteps=inference_timesteps,
cfg_value=cfg_value,
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone()
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae = AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load(
os.path.join(path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)["state_dict"]
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
# LoRALinear 直接持有 weight/bias,与 nn.Linear 的 state_dict key 一致,
# 无需做 key 转换。使用 strict=False 是因为预训练权重不含 lora_A/lora_B。
model.load_state_dict(model_state_dict, strict=False)
if training:
return model
return model.to(model.device).eval().optimize(disable=not optimize)
# ------------------------------------------------------------------ #
# LoRA 权重管理接口
# ------------------------------------------------------------------ #
def load_lora_weights(self, lora_path: str, device: str = None):
"""
从文件加载 LoRA 权重,支持在 torch.compile 之后调用。
实现说明:
使用 named_parameters() 而非 load_state_dict() 来加载权重。
原因是 torch.compile 会将模块包装成 OptimizedModule,导致
state_dict 的 key 路径发生变化(如 module.weight -> module._orig_mod.weight),
使得 load_state_dict() 无法匹配到正确的参数。
而 named_parameters() 返回的是参数对象的引用,不受 compile 包装影响,
通过 .data.copy_() 可以直接修改参数值,既不会触发重编译,
也支持在 compile 后热切换不同的 LoRA 权重。
Args:
lora_path: LoRA checkpoint 路径(目录,内含 generator.pth)或直接的 .pth 文件
device: 加载到的设备,默认为模型当前设备
Returns:
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
if device is None:
device = self.device
# 支持目录或文件
lora_path = Path(lora_path)
if lora_path.is_dir():
ckpt_file = lora_path / "generator.pth"
else:
ckpt_file = lora_path
if not ckpt_file.exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_file}")
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
# 通过 named_parameters() 加载,兼容 torch.compile
model_params = {name: param for name, param in self.named_parameters()}
# 构建 key 映射:处理 torch.compile 导致的 _orig_mod 前缀
# checkpoint key: feat_decoder.estimator.decoder.layers...
# compile 后 key: feat_decoder.estimator._orig_mod.decoder.layers...
# 需要建立双向映射以支持两种情况
key_mapping = {}
for model_key in model_params.keys():
# 去掉 _orig_mod 得到原始 key
normalized_key = model_key.replace("._orig_mod.", ".")
if normalized_key != model_key:
key_mapping[normalized_key] = model_key
loaded_keys = []
skipped_keys = []
for key, value in state_dict.items():
# 优先直接匹配
if key in model_params:
model_params[key].data.copy_(value.to(device))
loaded_keys.append(key)
# 尝试通过映射匹配(处理 _orig_mod)
elif key in key_mapping:
mapped_key = key_mapping[key]
model_params[mapped_key].data.copy_(value.to(device))
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
"""
动态启用/禁用所有 LoRA 层(通过 scaling 控制,兼容 torch.compile)。
Args:
enabled: True=启用 LoRA,False=禁用(仅使用基础权重)
"""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
module.set_enabled(enabled)
def reset_lora_weights(self):
"""
重置所有 LoRA 权重到初始状态(A: kaiming, B: zeros)。
B=0 时 LoRA 输出为 0,相当于"卸载" LoRA。
"""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
module.reset_lora_parameters()
def get_lora_state_dict(self) -> dict:
"""
获取当前模型中所有 LoRA 参数的 state_dict。
Returns:
dict: 仅包含 lora_A / lora_B 参数的字典
"""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_A" in name or "lora_B" in name}