|
|
""" |
|
|
VoxCPM: A Tokenizer-free speech generation model |
|
|
|
|
|
This module contains the main VoxCPM model implementation, including configuration classes |
|
|
and the core VoxCPMModel for text-to-speech generation. |
|
|
|
|
|
Copyright 2025 OpenBMB |
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
you may not use this file except in compliance with the License. |
|
|
You may obtain a copy of the License at |
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
See the License for the specific language governing permissions and |
|
|
limitations under the License. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Tuple, Union, Generator, List |
|
|
import re |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchaudio |
|
|
import warnings |
|
|
from einops import rearrange |
|
|
from pydantic import BaseModel |
|
|
from tqdm import tqdm |
|
|
from transformers import LlamaTokenizerFast |
|
|
|
|
|
from ..modules.audiovae import AudioVAE |
|
|
from ..modules.layers import ScalarQuantizationLayer |
|
|
from ..modules.layers.lora import apply_lora_to_named_linear_modules |
|
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT |
|
|
from ..modules.locenc import VoxCPMLocEnc |
|
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel |
|
|
from .utils import get_dtype, mask_multichar_chinese_tokens |
|
|
|
|
|
|
|
|
class VoxCPMEncoderConfig(BaseModel): |
|
|
hidden_dim: int = 1024 |
|
|
ffn_dim: int = 4096 |
|
|
num_heads: int = 16 |
|
|
num_layers: int = 4 |
|
|
kv_channels: int = None |
|
|
|
|
|
|
|
|
class VoxCPMDitConfig(BaseModel): |
|
|
hidden_dim: int = 1024 |
|
|
ffn_dim: int = 4096 |
|
|
num_heads: int = 16 |
|
|
num_layers: int = 4 |
|
|
kv_channels: int = None |
|
|
|
|
|
cfm_config: CfmConfig |
|
|
|
|
|
|
|
|
class VoxCPMConfig(BaseModel): |
|
|
lm_config: MiniCPM4Config |
|
|
patch_size: int = 2 |
|
|
feat_dim: int = 64 |
|
|
residual_lm_num_layers: int = 6 |
|
|
scalar_quantization_latent_dim: int = 256 |
|
|
scalar_quantization_scale: int = 9 |
|
|
|
|
|
encoder_config: VoxCPMEncoderConfig |
|
|
dit_config: VoxCPMDitConfig |
|
|
|
|
|
max_length: int = 4096 |
|
|
device: str = "cuda" |
|
|
dtype: str = "bfloat16" |
|
|
dit_mean_mode: bool = False |
|
|
|
|
|
|
|
|
class LoRAConfig(BaseModel): |
|
|
enable_lm: bool = False |
|
|
enable_dit: bool = False |
|
|
enable_proj: bool = False |
|
|
|
|
|
r: int = 8 |
|
|
alpha: int = 16 |
|
|
dropout: float = 0.0 |
|
|
|
|
|
|
|
|
target_modules_lm: list[str] = ["q_proj", "v_proj"] |
|
|
target_modules_dit: list[str] = ["q_proj", "v_proj"] |
|
|
|
|
|
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"] |
|
|
|
|
|
|
|
|
VoxCPMConfig.model_rebuild() |
|
|
|
|
|
|
|
|
class VoxCPMModel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
config: VoxCPMConfig, |
|
|
tokenizer: LlamaTokenizerFast, |
|
|
audio_vae: AudioVAE, |
|
|
lora_config: LoRAConfig = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.lora_config = lora_config |
|
|
self.feat_dim = config.feat_dim |
|
|
self.patch_size = config.patch_size |
|
|
self.device = config.device |
|
|
if not torch.cuda.is_available(): |
|
|
if torch.backends.mps.is_available(): |
|
|
self.device = "mps" |
|
|
else: |
|
|
self.device = "cpu" |
|
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}") |
|
|
|
|
|
|
|
|
self.base_lm = MiniCPMModel(config.lm_config) |
|
|
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) |
|
|
|
|
|
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer) |
|
|
self.audio_start_token = 101 |
|
|
self.audio_end_token = 102 |
|
|
|
|
|
|
|
|
residual_lm_config = config.lm_config.model_copy(deep=True) |
|
|
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers |
|
|
residual_lm_config.vocab_size = 0 |
|
|
self.residual_lm = MiniCPMModel(residual_lm_config) |
|
|
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) |
|
|
|
|
|
|
|
|
encoder_config = config.lm_config.model_copy(deep=True) |
|
|
encoder_config.hidden_size = config.encoder_config.hidden_dim |
|
|
encoder_config.intermediate_size = config.encoder_config.ffn_dim |
|
|
encoder_config.num_attention_heads = config.encoder_config.num_heads |
|
|
encoder_config.num_hidden_layers = config.encoder_config.num_layers |
|
|
encoder_config.kv_channels = config.encoder_config.kv_channels |
|
|
encoder_config.vocab_size = 0 |
|
|
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim) |
|
|
|
|
|
|
|
|
decoder_config = config.lm_config.model_copy(deep=True) |
|
|
decoder_config.hidden_size = config.dit_config.hidden_dim |
|
|
decoder_config.intermediate_size = config.dit_config.ffn_dim |
|
|
decoder_config.num_attention_heads = config.dit_config.num_heads |
|
|
decoder_config.num_hidden_layers = config.dit_config.num_layers |
|
|
decoder_config.kv_channels = config.dit_config.kv_channels |
|
|
decoder_config.vocab_size = 0 |
|
|
self.feat_decoder = UnifiedCFM( |
|
|
in_channels=config.feat_dim, |
|
|
cfm_params=config.dit_config.cfm_config, |
|
|
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim), |
|
|
mean_mode=config.dit_mean_mode, |
|
|
) |
|
|
|
|
|
|
|
|
self.fsq_layer = ScalarQuantizationLayer( |
|
|
config.lm_config.hidden_size, |
|
|
config.lm_config.hidden_size, |
|
|
config.scalar_quantization_latent_dim, |
|
|
config.scalar_quantization_scale |
|
|
) |
|
|
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size) |
|
|
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) |
|
|
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) |
|
|
|
|
|
|
|
|
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size) |
|
|
self.stop_actn = nn.SiLU() |
|
|
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False) |
|
|
self.stop_loss = nn.CrossEntropyLoss(reduction="none") |
|
|
|
|
|
|
|
|
self.audio_vae = audio_vae |
|
|
self.chunk_size = audio_vae.chunk_size |
|
|
self.sample_rate = audio_vae.sample_rate |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.lora_config is not None: |
|
|
|
|
|
if self.lora_config.enable_lm: |
|
|
apply_lora_to_named_linear_modules( |
|
|
self.base_lm, |
|
|
target_submodule_names=self.lora_config.target_modules_lm, |
|
|
r=self.lora_config.r, |
|
|
alpha=self.lora_config.alpha, |
|
|
dropout=self.lora_config.dropout, |
|
|
) |
|
|
apply_lora_to_named_linear_modules( |
|
|
self.residual_lm, |
|
|
target_submodule_names=self.lora_config.target_modules_lm, |
|
|
r=lora_config.r, |
|
|
alpha=self.lora_config.alpha, |
|
|
dropout=self.lora_config.dropout, |
|
|
) |
|
|
|
|
|
|
|
|
if self.lora_config.enable_dit: |
|
|
apply_lora_to_named_linear_modules( |
|
|
self.feat_decoder.estimator, |
|
|
target_submodule_names=self.lora_config.target_modules_dit, |
|
|
r=self.lora_config.r, |
|
|
alpha=self.lora_config.alpha, |
|
|
dropout=self.lora_config.dropout, |
|
|
) |
|
|
|
|
|
|
|
|
if self.lora_config.enable_proj: |
|
|
for attr_name in lora_config.target_proj_modules: |
|
|
if hasattr(self, attr_name): |
|
|
module = getattr(self, attr_name) |
|
|
if isinstance(module, nn.Linear): |
|
|
from ..modules.layers.lora import LoRALinear |
|
|
|
|
|
setattr( |
|
|
self, |
|
|
attr_name, |
|
|
LoRALinear( |
|
|
base=module, |
|
|
r=self.lora_config.r, |
|
|
alpha=self.lora_config.alpha, |
|
|
dropout=self.lora_config.dropout, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def optimize(self, disable: bool = False): |
|
|
|
|
|
if not hasattr(self, 'feat_encoder_step'): |
|
|
self.feat_encoder_step = self.feat_encoder |
|
|
|
|
|
if disable: |
|
|
|
|
|
self.base_lm.forward_step = self.base_lm.forward_step |
|
|
self.residual_lm.forward_step = self.residual_lm.forward_step |
|
|
return self |
|
|
|
|
|
try: |
|
|
if self.device != "cuda": |
|
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device") |
|
|
try: |
|
|
import triton |
|
|
except: |
|
|
raise ValueError("triton is not installed") |
|
|
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True) |
|
|
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True) |
|
|
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) |
|
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True) |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions") |
|
|
self.base_lm.forward_step = self.base_lm.forward_step |
|
|
self.residual_lm.forward_step = self.residual_lm.forward_step |
|
|
self.feat_encoder_step = self.feat_encoder |
|
|
self.feat_decoder.estimator = self.feat_decoder.estimator |
|
|
return self |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text_tokens: torch.Tensor, |
|
|
text_mask: torch.Tensor, |
|
|
audio_feats: torch.Tensor, |
|
|
audio_mask: torch.Tensor, |
|
|
loss_mask: torch.Tensor, |
|
|
position_ids: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
*, |
|
|
progress: float = 0.0, |
|
|
sample_generate: bool = False, |
|
|
): |
|
|
del position_ids |
|
|
|
|
|
text_tokens = text_tokens.to(self.device, dtype=torch.long) |
|
|
text_mask = text_mask.to(self.device, dtype=self._dtype()) |
|
|
audio_feats = audio_feats.to(self.device, dtype=self._dtype()) |
|
|
audio_mask = audio_mask.to(self.device, dtype=self._dtype()) |
|
|
loss_mask = loss_mask.to(self.device, dtype=self._dtype()) |
|
|
labels = labels.to(self.device, dtype=torch.long) |
|
|
|
|
|
B, T, P, D = audio_feats.shape |
|
|
|
|
|
feat_embed = self.feat_encoder(audio_feats) |
|
|
feat_embed = self.enc_to_lm_proj(feat_embed) |
|
|
|
|
|
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0) |
|
|
if not getattr(self.config.lm_config, "use_mup", False): |
|
|
scale_emb = 1.0 |
|
|
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb |
|
|
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed |
|
|
|
|
|
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True) |
|
|
enc_outputs = enc_outputs.to(self._dtype()) |
|
|
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1) |
|
|
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1) |
|
|
|
|
|
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed |
|
|
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True) |
|
|
residual_outputs = residual_outputs.to(self._dtype()) |
|
|
residual_hidden = torch.cat( |
|
|
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]), |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden) |
|
|
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c") |
|
|
|
|
|
|
|
|
target_dtype = self._dtype() |
|
|
|
|
|
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d") |
|
|
feat_cond = torch.cat( |
|
|
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]), |
|
|
dim=1, |
|
|
) |
|
|
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d") |
|
|
|
|
|
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size) |
|
|
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype) |
|
|
|
|
|
diff_loss = self.feat_decoder.compute_loss( |
|
|
feat_gt.transpose(1, 2).contiguous(), |
|
|
dit_hidden, |
|
|
cond=feat_cond.transpose(1, 2).contiguous(), |
|
|
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(), |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))) |
|
|
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels) |
|
|
denom = torch.clamp(loss_mask.sum(), min=1.0) |
|
|
stop_loss = (stop_losses * loss_mask).sum() / denom |
|
|
|
|
|
feat_pred = None |
|
|
if sample_generate: |
|
|
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous() |
|
|
feat_pred_seq = self.feat_decoder( |
|
|
mu=dit_hidden, |
|
|
patch_size=self.patch_size, |
|
|
cond=feat_cond_for_sample, |
|
|
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate |
|
|
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate") |
|
|
else 10, |
|
|
) |
|
|
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size) |
|
|
|
|
|
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size) |
|
|
|
|
|
return { |
|
|
"loss/diff": diff_loss, |
|
|
"loss/stop": stop_loss, |
|
|
"feat_gt": feat_gt_tensor, |
|
|
"feat_pred": feat_pred, |
|
|
} |
|
|
|
|
|
def _dtype(self): |
|
|
return get_dtype(self.config.dtype) |
|
|
|
|
|
|
|
|
def generate(self, *args, **kwargs) -> torch.Tensor: |
|
|
return next(self._generate(*args, streaming=False, **kwargs)) |
|
|
|
|
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: |
|
|
return self._generate(*args, streaming=True, **kwargs) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def _generate( |
|
|
self, |
|
|
target_text: str, |
|
|
prompt_text: str = "", |
|
|
prompt_wav_path: str = "", |
|
|
min_len: int = 2, |
|
|
max_len: int = 2000, |
|
|
inference_timesteps: int = 10, |
|
|
cfg_value: float = 2.0, |
|
|
retry_badcase: bool = False, |
|
|
retry_badcase_max_times: int = 3, |
|
|
retry_badcase_ratio_threshold: float = 6.0, |
|
|
streaming: bool = False, |
|
|
) -> Generator[torch.Tensor, None, None]: |
|
|
if retry_badcase and streaming: |
|
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") |
|
|
retry_badcase = False |
|
|
|
|
|
if len(prompt_wav_path) == 0: |
|
|
text = target_text |
|
|
text_token = torch.LongTensor(self.text_tokenizer(text)) |
|
|
text_token = torch.cat( |
|
|
[ |
|
|
text_token, |
|
|
torch.tensor( |
|
|
[self.audio_start_token], |
|
|
dtype=torch.int32, |
|
|
device=text_token.device, |
|
|
), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
text_length = text_token.shape[0] |
|
|
|
|
|
audio_feat = torch.zeros( |
|
|
(text_length, self.patch_size, self.audio_vae.latent_dim), |
|
|
dtype=torch.float32, |
|
|
device=text_token.device, |
|
|
) |
|
|
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device) |
|
|
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device) |
|
|
|
|
|
else: |
|
|
text = prompt_text + target_text |
|
|
text_token = torch.LongTensor(self.text_tokenizer(text)) |
|
|
text_token = torch.cat( |
|
|
[ |
|
|
text_token, |
|
|
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
text_length = text_token.shape[0] |
|
|
|
|
|
audio, sr = torchaudio.load(prompt_wav_path) |
|
|
if audio.size(0) > 1: |
|
|
audio = audio.mean(dim=0, keepdim=True) |
|
|
|
|
|
if sr != self.sample_rate: |
|
|
audio = torchaudio.functional.resample(audio, sr, self.sample_rate) |
|
|
|
|
|
patch_len = self.patch_size * self.chunk_size |
|
|
|
|
|
if audio.size(1) % patch_len != 0: |
|
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len)) |
|
|
|
|
|
|
|
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu() |
|
|
|
|
|
audio_feat = audio_feat.view( |
|
|
self.audio_vae.latent_dim, |
|
|
-1, |
|
|
self.patch_size, |
|
|
).permute(1, 2, 0) |
|
|
audio_feat = audio_feat[:-1, ...] |
|
|
audio_length = audio_feat.size(0) |
|
|
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device) |
|
|
text_token = torch.cat([text_token, text_pad_token]) |
|
|
audio_pad_feat = torch.zeros( |
|
|
(text_length, self.patch_size, self.audio_vae.latent_dim), |
|
|
dtype=torch.float32, |
|
|
device=text_token.device, |
|
|
) |
|
|
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0) |
|
|
text_mask = ( |
|
|
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device) |
|
|
) |
|
|
audio_mask = ( |
|
|
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device) |
|
|
) |
|
|
|
|
|
text_token = text_token.unsqueeze(0).to(self.device) |
|
|
text_mask = text_mask.unsqueeze(0).to(self.device) |
|
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype)) |
|
|
audio_mask = audio_mask.unsqueeze(0).to(self.device) |
|
|
|
|
|
target_text_length = len(self.text_tokenizer(target_text)) |
|
|
|
|
|
retry_badcase_times = 0 |
|
|
while retry_badcase_times < retry_badcase_max_times: |
|
|
inference_result = self._inference( |
|
|
text_token, |
|
|
text_mask, |
|
|
audio_feat, |
|
|
audio_mask, |
|
|
min_len=min_len, |
|
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, |
|
|
inference_timesteps=inference_timesteps, |
|
|
cfg_value=cfg_value, |
|
|
streaming=streaming, |
|
|
) |
|
|
if streaming: |
|
|
patch_len = self.patch_size * self.chunk_size |
|
|
for latent_pred, _ in inference_result: |
|
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) |
|
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() |
|
|
yield decode_audio |
|
|
break |
|
|
else: |
|
|
latent_pred, pred_audio_feat = next(inference_result) |
|
|
if retry_badcase: |
|
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: |
|
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") |
|
|
retry_badcase_times += 1 |
|
|
continue |
|
|
else: |
|
|
break |
|
|
else: |
|
|
break |
|
|
|
|
|
if not streaming: |
|
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() |
|
|
decode_audio = decode_audio[..., 640:-640] |
|
|
yield decode_audio |
|
|
|
|
|
@torch.inference_mode() |
|
|
def build_prompt_cache( |
|
|
self, |
|
|
prompt_text: str, |
|
|
prompt_wav_path: str, |
|
|
): |
|
|
""" |
|
|
Build prompt cache for subsequent fast generation. |
|
|
|
|
|
Args: |
|
|
prompt_text: prompt text (required) |
|
|
prompt_wav_path: prompt audio path (required) |
|
|
|
|
|
Returns: |
|
|
prompt_cache: dict with text tokens and audio features |
|
|
""" |
|
|
if not prompt_text or not prompt_wav_path: |
|
|
raise ValueError("prompt_text and prompt_wav_path are required") |
|
|
|
|
|
|
|
|
text_token = torch.LongTensor(self.text_tokenizer(prompt_text)) |
|
|
|
|
|
|
|
|
audio, sr = torchaudio.load(prompt_wav_path) |
|
|
if audio.size(0) > 1: |
|
|
audio = audio.mean(dim=0, keepdim=True) |
|
|
|
|
|
if sr != self.sample_rate: |
|
|
audio = torchaudio.functional.resample(audio, sr, self.sample_rate) |
|
|
|
|
|
patch_len = self.patch_size * self.chunk_size |
|
|
|
|
|
if audio.size(1) % patch_len != 0: |
|
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len)) |
|
|
|
|
|
|
|
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu() |
|
|
|
|
|
audio_feat = audio_feat.view( |
|
|
self.audio_vae.latent_dim, |
|
|
-1, |
|
|
self.patch_size, |
|
|
).permute(1, 2, 0) |
|
|
audio_feat = audio_feat[:-1, ...] |
|
|
|
|
|
prompt_cache = { |
|
|
"text_token": text_token, |
|
|
"audio_feat": audio_feat, |
|
|
} |
|
|
|
|
|
return prompt_cache |
|
|
|
|
|
|
|
|
def merge_prompt_cache( |
|
|
self, |
|
|
original_cache: dict, |
|
|
new_text_token: torch.Tensor, |
|
|
new_audio_feat: torch.Tensor, |
|
|
): |
|
|
""" |
|
|
Merge original prompt cache with newly generated content to stabilize voice. |
|
|
|
|
|
Args: |
|
|
original_cache: original prompt cache |
|
|
new_text_token: newly generated text tokens |
|
|
new_audio_feat: newly generated audio features |
|
|
|
|
|
Returns: |
|
|
merged_cache: merged cache |
|
|
""" |
|
|
if original_cache is None: |
|
|
return { |
|
|
"text_token": new_text_token, |
|
|
"audio_feat": new_audio_feat, |
|
|
} |
|
|
original_text_token = original_cache["text_token"] |
|
|
original_audio_feat = original_cache["audio_feat"] |
|
|
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0) |
|
|
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0) |
|
|
|
|
|
|
|
|
merged_cache = { |
|
|
"text_token": merged_text_token, |
|
|
"audio_feat": merged_audio_feat, |
|
|
} |
|
|
|
|
|
return merged_cache |
|
|
|
|
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) |
|
|
|
|
|
def generate_with_prompt_cache_streaming( |
|
|
self, *args, **kwargs |
|
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]: |
|
|
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def _generate_with_prompt_cache( |
|
|
self, |
|
|
target_text: str, |
|
|
prompt_cache: dict, |
|
|
min_len: int = 2, |
|
|
max_len: int = 2000, |
|
|
inference_timesteps: int = 10, |
|
|
cfg_value: float = 2.0, |
|
|
retry_badcase: bool = False, |
|
|
retry_badcase_max_times: int = 3, |
|
|
retry_badcase_ratio_threshold: float = 6.0, |
|
|
streaming: bool = False, |
|
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]: |
|
|
""" |
|
|
Generate audio using pre-built prompt cache. |
|
|
|
|
|
Args: |
|
|
target_text: Text to convert to speech |
|
|
prompt_cache: Cache built by build_prompt_cache (can be None) |
|
|
min_len: Minimum audio length to avoid very short audio |
|
|
max_len: Maximum audio length |
|
|
inference_timesteps: Number of diffusion sampling steps |
|
|
cfg_value: Classifier-free guidance value |
|
|
retry_badcase: Whether to retry on bad cases |
|
|
retry_badcase_max_times: Maximum retry attempts |
|
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio |
|
|
streaming: Whether to return a generator of audio chunks |
|
|
|
|
|
Returns: |
|
|
Generator of Tuple containing: |
|
|
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor |
|
|
- Tensor of new text tokens |
|
|
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor |
|
|
""" |
|
|
if retry_badcase and streaming: |
|
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") |
|
|
retry_badcase = False |
|
|
|
|
|
if prompt_cache is None: |
|
|
prompt_text_token = torch.empty(0, dtype=torch.int32) |
|
|
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32) |
|
|
else: |
|
|
prompt_text_token = prompt_cache["text_token"] |
|
|
prompt_audio_feat = prompt_cache["audio_feat"] |
|
|
|
|
|
target_text_token = torch.LongTensor(self.text_tokenizer(target_text)) |
|
|
text_token = torch.cat([prompt_text_token, target_text_token], dim=0) |
|
|
text_token = torch.cat( |
|
|
[ |
|
|
text_token, |
|
|
torch.tensor( |
|
|
[self.audio_start_token], |
|
|
dtype=torch.int32, |
|
|
device=text_token.device, |
|
|
), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
audio_length = prompt_audio_feat.size(0) |
|
|
text_length = text_token.shape[0] |
|
|
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device) |
|
|
audio_pad_feat = torch.zeros( |
|
|
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim), |
|
|
dtype=torch.float32, |
|
|
device=text_token.device, |
|
|
) |
|
|
text_token = torch.cat([text_token, text_pad_token]) |
|
|
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0) |
|
|
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device) |
|
|
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device) |
|
|
|
|
|
text_token = text_token.unsqueeze(0).to(self.device) |
|
|
text_mask = text_mask.unsqueeze(0).to(self.device) |
|
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype)) |
|
|
audio_mask = audio_mask.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
target_text_length = len(self.text_tokenizer(target_text)) |
|
|
retry_badcase_times = 0 |
|
|
while retry_badcase_times < retry_badcase_max_times: |
|
|
inference_result = self._inference( |
|
|
text_token, |
|
|
text_mask, |
|
|
audio_feat, |
|
|
audio_mask, |
|
|
min_len=min_len, |
|
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, |
|
|
inference_timesteps=inference_timesteps, |
|
|
cfg_value=cfg_value, |
|
|
streaming=streaming, |
|
|
) |
|
|
if streaming: |
|
|
patch_len = self.patch_size * self.chunk_size |
|
|
for latent_pred, pred_audio_feat in inference_result: |
|
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) |
|
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() |
|
|
yield ( |
|
|
decode_audio, |
|
|
target_text_token, |
|
|
pred_audio_feat |
|
|
) |
|
|
break |
|
|
else: |
|
|
latent_pred, pred_audio_feat = next(inference_result) |
|
|
if retry_badcase: |
|
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: |
|
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") |
|
|
retry_badcase_times += 1 |
|
|
continue |
|
|
else: |
|
|
break |
|
|
else: |
|
|
break |
|
|
if not streaming: |
|
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() |
|
|
decode_audio = decode_audio[..., 640:-640] |
|
|
|
|
|
yield ( |
|
|
decode_audio, |
|
|
target_text_token, |
|
|
pred_audio_feat |
|
|
) |
|
|
|
|
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
return next(self._inference(*args, streaming=False, **kwargs)) |
|
|
|
|
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: |
|
|
return self._inference(*args, streaming=True, **kwargs) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def _inference( |
|
|
self, |
|
|
text: torch.Tensor, |
|
|
text_mask: torch.Tensor, |
|
|
feat: torch.Tensor, |
|
|
feat_mask: torch.Tensor, |
|
|
min_len: int = 2, |
|
|
max_len: int = 2000, |
|
|
inference_timesteps: int = 10, |
|
|
cfg_value: float = 2.0, |
|
|
streaming: bool = False, |
|
|
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]: |
|
|
"""Core inference method for audio generation. |
|
|
|
|
|
This is the main inference loop that generates audio features |
|
|
using the language model and diffusion transformer. |
|
|
|
|
|
Args: |
|
|
text: Input text tokens |
|
|
text_mask: Mask for text tokens |
|
|
feat: Input audio features |
|
|
feat_mask: Mask for audio features |
|
|
min_len: Minimum generation length |
|
|
max_len: Maximum generation length |
|
|
inference_timesteps: Number of diffusion steps |
|
|
cfg_value: Classifier-free guidance value |
|
|
streaming: Whether to yield each step latent feature or just the final result |
|
|
|
|
|
Returns: |
|
|
Generator of Tuple containing: |
|
|
- Predicted latent feature at the current step if ``streaming=True``, else final latent features |
|
|
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor |
|
|
""" |
|
|
B, T, P, D = feat.shape |
|
|
|
|
|
feat_embed = self.feat_encoder(feat) |
|
|
feat_embed = self.enc_to_lm_proj(feat_embed) |
|
|
|
|
|
if self.config.lm_config.use_mup: |
|
|
scale_emb = self.config.lm_config.scale_emb |
|
|
else: |
|
|
scale_emb = 1.0 |
|
|
|
|
|
text_embed = self.base_lm.embed_tokens(text) * scale_emb |
|
|
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed |
|
|
|
|
|
prefix_feat_cond = feat[:, -1, ...] |
|
|
pred_feat_seq = [] |
|
|
curr_embed = None |
|
|
|
|
|
enc_outputs, kv_cache_tuple = self.base_lm( |
|
|
inputs_embeds=combined_embed, |
|
|
is_causal=True, |
|
|
) |
|
|
self.base_lm.kv_cache.fill_caches(kv_cache_tuple) |
|
|
|
|
|
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1) |
|
|
lm_hidden = enc_outputs[:, -1, :] |
|
|
|
|
|
|
|
|
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm( |
|
|
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed, |
|
|
is_causal=True, |
|
|
) |
|
|
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple) |
|
|
residual_hidden = residual_enc_outputs[:, -1, :] |
|
|
|
|
|
|
|
|
for i in tqdm(range(max_len)): |
|
|
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) |
|
|
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) |
|
|
dit_hidden = dit_hidden_1 + dit_hidden_2 |
|
|
|
|
|
pred_feat = self.feat_decoder( |
|
|
mu=dit_hidden, |
|
|
patch_size=self.patch_size, |
|
|
cond=prefix_feat_cond.transpose(1, 2).contiguous(), |
|
|
n_timesteps=inference_timesteps, |
|
|
cfg_value=cfg_value, |
|
|
).transpose( |
|
|
1, 2 |
|
|
) |
|
|
|
|
|
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) |
|
|
curr_embed = self.enc_to_lm_proj(curr_embed) |
|
|
|
|
|
pred_feat_seq.append(pred_feat.unsqueeze(1)) |
|
|
prefix_feat_cond = pred_feat |
|
|
|
|
|
if streaming: |
|
|
|
|
|
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1) |
|
|
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) |
|
|
yield feat_pred, pred_feat_seq |
|
|
|
|
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() |
|
|
if i > min_len and stop_flag == 1: |
|
|
break |
|
|
|
|
|
lm_hidden = self.base_lm.forward_step( |
|
|
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device) |
|
|
).clone() |
|
|
|
|
|
|
|
|
lm_hidden = self.fsq_layer(lm_hidden) |
|
|
residual_hidden = self.residual_lm.forward_step( |
|
|
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device) |
|
|
).clone() |
|
|
|
|
|
if not streaming: |
|
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) |
|
|
|
|
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) |
|
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu() |
|
|
|
|
|
@classmethod |
|
|
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): |
|
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) |
|
|
tokenizer = LlamaTokenizerFast.from_pretrained(path) |
|
|
audio_vae = AudioVAE() |
|
|
vae_state_dict = torch.load( |
|
|
os.path.join(path, "audiovae.pth"), |
|
|
map_location="cpu", |
|
|
weights_only=True, |
|
|
)["state_dict"] |
|
|
model = cls(config, tokenizer, audio_vae, lora_config) |
|
|
if not training: |
|
|
lm_dtype = get_dtype(model.config.dtype) |
|
|
model = model.to(lm_dtype) |
|
|
else: |
|
|
for name, param in model.named_parameters(): |
|
|
if "audio_vae" in name: |
|
|
param.requires_grad = False |
|
|
continue |
|
|
|
|
|
if lora_config is not None: |
|
|
if "lora" not in name: |
|
|
param.requires_grad = False |
|
|
model.audio_vae = model.audio_vae.to(torch.float32) |
|
|
model_state_dict = torch.load( |
|
|
os.path.join(path, "pytorch_model.bin"), |
|
|
map_location="cpu", |
|
|
weights_only=True, |
|
|
)["state_dict"] |
|
|
for kw, val in vae_state_dict.items(): |
|
|
model_state_dict[f"audio_vae.{kw}"] = val |
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(model_state_dict, strict=False) |
|
|
if training: |
|
|
return model |
|
|
return model.to(model.device).eval().optimize(disable=not optimize) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_lora_weights(self, lora_path: str, device: str = None): |
|
|
""" |
|
|
从文件加载 LoRA 权重,支持在 torch.compile 之后调用。 |
|
|
|
|
|
实现说明: |
|
|
使用 named_parameters() 而非 load_state_dict() 来加载权重。 |
|
|
原因是 torch.compile 会将模块包装成 OptimizedModule,导致 |
|
|
state_dict 的 key 路径发生变化(如 module.weight -> module._orig_mod.weight), |
|
|
使得 load_state_dict() 无法匹配到正确的参数。 |
|
|
|
|
|
而 named_parameters() 返回的是参数对象的引用,不受 compile 包装影响, |
|
|
通过 .data.copy_() 可以直接修改参数值,既不会触发重编译, |
|
|
也支持在 compile 后热切换不同的 LoRA 权重。 |
|
|
|
|
|
Args: |
|
|
lora_path: LoRA checkpoint 路径(目录,内含 generator.pth)或直接的 .pth 文件 |
|
|
device: 加载到的设备,默认为模型当前设备 |
|
|
|
|
|
Returns: |
|
|
tuple: (loaded_keys, skipped_keys) |
|
|
""" |
|
|
from pathlib import Path |
|
|
|
|
|
if device is None: |
|
|
device = self.device |
|
|
|
|
|
|
|
|
lora_path = Path(lora_path) |
|
|
if lora_path.is_dir(): |
|
|
ckpt_file = lora_path / "generator.pth" |
|
|
else: |
|
|
ckpt_file = lora_path |
|
|
|
|
|
if not ckpt_file.exists(): |
|
|
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_file}") |
|
|
|
|
|
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False) |
|
|
state_dict = ckpt.get("state_dict", ckpt) |
|
|
|
|
|
|
|
|
model_params = {name: param for name, param in self.named_parameters()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key_mapping = {} |
|
|
for model_key in model_params.keys(): |
|
|
|
|
|
normalized_key = model_key.replace("._orig_mod.", ".") |
|
|
if normalized_key != model_key: |
|
|
key_mapping[normalized_key] = model_key |
|
|
|
|
|
loaded_keys = [] |
|
|
skipped_keys = [] |
|
|
for key, value in state_dict.items(): |
|
|
|
|
|
if key in model_params: |
|
|
model_params[key].data.copy_(value.to(device)) |
|
|
loaded_keys.append(key) |
|
|
|
|
|
elif key in key_mapping: |
|
|
mapped_key = key_mapping[key] |
|
|
model_params[mapped_key].data.copy_(value.to(device)) |
|
|
loaded_keys.append(key) |
|
|
else: |
|
|
skipped_keys.append(key) |
|
|
|
|
|
return loaded_keys, skipped_keys |
|
|
|
|
|
def set_lora_enabled(self, enabled: bool): |
|
|
""" |
|
|
动态启用/禁用所有 LoRA 层(通过 scaling 控制,兼容 torch.compile)。 |
|
|
|
|
|
Args: |
|
|
enabled: True=启用 LoRA,False=禁用(仅使用基础权重) |
|
|
""" |
|
|
from ..modules.layers.lora import LoRALinear |
|
|
|
|
|
for module in self.modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
module.set_enabled(enabled) |
|
|
|
|
|
def reset_lora_weights(self): |
|
|
""" |
|
|
重置所有 LoRA 权重到初始状态(A: kaiming, B: zeros)。 |
|
|
|
|
|
B=0 时 LoRA 输出为 0,相当于"卸载" LoRA。 |
|
|
""" |
|
|
from ..modules.layers.lora import LoRALinear |
|
|
|
|
|
for module in self.modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
module.reset_lora_parameters() |
|
|
|
|
|
def get_lora_state_dict(self) -> dict: |
|
|
""" |
|
|
获取当前模型中所有 LoRA 参数的 state_dict。 |
|
|
|
|
|
Returns: |
|
|
dict: 仅包含 lora_A / lora_B 参数的字典 |
|
|
""" |
|
|
return {name: param.data.clone() |
|
|
for name, param in self.named_parameters() |
|
|
if "lora_A" in name or "lora_B" in name} |
|
|
|