File size: 41,843 Bytes
6766eda c04851d 6766eda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 |
"""
VoxCPM: A Tokenizer-free speech generation model
This module contains the main VoxCPM model implementation, including configuration classes
and the core VoxCPMModel for text-to-speech generation.
Copyright 2025 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Tuple, Union, Generator, List
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
from pydantic import BaseModel
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE
from ..modules.layers import ScalarQuantizationLayer
from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
class VoxCPMEncoderConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
class VoxCPMDitConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
cfm_config: CfmConfig
class VoxCPMConfig(BaseModel):
lm_config: MiniCPM4Config
patch_size: int = 2
feat_dim: int = 64
residual_lm_num_layers: int = 6
scalar_quantization_latent_dim: int = 256
scalar_quantization_scale: int = 9
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
max_length: int = 4096
device: str = "cuda"
dtype: str = "bfloat16"
dit_mean_mode: bool = False
class LoRAConfig(BaseModel):
enable_lm: bool = False # 对 base_lm + residual_lm 加 LoRA
enable_dit: bool = False # 对 VoxCPMLocDiT 加 LoRA
enable_proj: bool = False # 对若干投影 Linear 加 LoRA
r: int = 8
alpha: int = 16
dropout: float = 0.0
# LM & DiT 目标线性层名(以属性名匹配)
target_modules_lm: list[str] = ["q_proj", "v_proj"]
target_modules_dit: list[str] = ["q_proj", "v_proj"]
# 投影层属性名,在 VoxCPMModel 上查找
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
VoxCPMConfig.model_rebuild()
class VoxCPMModel(nn.Module):
def __init__(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101
self.audio_end_token = 102
# Residual Acoustic LM
residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
# Local Encoder
encoder_config = config.lm_config.model_copy(deep=True)
encoder_config.hidden_size = config.encoder_config.hidden_dim
encoder_config.intermediate_size = config.encoder_config.ffn_dim
encoder_config.num_attention_heads = config.encoder_config.num_heads
encoder_config.num_hidden_layers = config.encoder_config.num_layers
encoder_config.kv_channels = config.encoder_config.kv_channels
encoder_config.vocab_size = 0
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
# Local DiT
decoder_config = config.lm_config.model_copy(deep=True)
decoder_config.hidden_size = config.dit_config.hidden_dim
decoder_config.intermediate_size = config.dit_config.ffn_dim
decoder_config.num_attention_heads = config.dit_config.num_heads
decoder_config.num_hidden_layers = config.dit_config.num_layers
decoder_config.kv_channels = config.dit_config.kv_channels
decoder_config.vocab_size = 0
self.feat_decoder = UnifiedCFM(
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
mean_mode=config.dit_mean_mode,
)
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
# Stop Predictor
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate
# ------------------------------------------------------------------ #
# 可选:在构造阶段就对 LM / DiT / 投影层注入 LoRA(仅结构,不含冻结逻辑)
# 是否真正只训练 LoRA 参数,由 LoRAConfig.train_only_lora 控制。
# ------------------------------------------------------------------ #
if self.lora_config is not None:
# LM: base_lm + residual_lm
if self.lora_config.enable_lm:
apply_lora_to_named_linear_modules(
self.base_lm,
target_submodule_names=self.lora_config.target_modules_lm,
r=self.lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
)
apply_lora_to_named_linear_modules(
self.residual_lm,
target_submodule_names=self.lora_config.target_modules_lm,
r=lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
)
# DiT: VoxCPMLocDiT(feat_decoder.estimator)
if self.lora_config.enable_dit:
apply_lora_to_named_linear_modules(
self.feat_decoder.estimator,
target_submodule_names=self.lora_config.target_modules_dit,
r=self.lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
)
# 投影层:在当前模型上按属性名查 Linear 并替换
if self.lora_config.enable_proj:
for attr_name in lora_config.target_proj_modules:
if hasattr(self, attr_name):
module = getattr(self, attr_name)
if isinstance(module, nn.Linear):
from ..modules.layers.lora import LoRALinear
setattr(
self,
attr_name,
LoRALinear(
base=module,
r=self.lora_config.r,
alpha=self.lora_config.alpha,
dropout=self.lora_config.dropout,
),
)
def optimize(self, disable: bool = False):
# 无论是否 compile,都需要设置这些接口(generate 等方法依赖它们)
if not hasattr(self, 'feat_encoder_step'):
self.feat_encoder_step = self.feat_encoder
if disable:
# 不使用 torch.compile,但仍然设置必要的接口
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
return self
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
except:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"Error: {e}")
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
return self
def forward(
self,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
audio_feats: torch.Tensor,
audio_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
labels: torch.Tensor,
*,
progress: float = 0.0,
sample_generate: bool = False,
):
del position_ids # not used yet
text_tokens = text_tokens.to(self.device, dtype=torch.long)
text_mask = text_mask.to(self.device, dtype=self._dtype())
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
labels = labels.to(self.device, dtype=torch.long)
B, T, P, D = audio_feats.shape
feat_embed = self.feat_encoder(audio_feats)
feat_embed = self.enc_to_lm_proj(feat_embed)
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
if not getattr(self.config.lm_config, "use_mup", False):
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
enc_outputs = enc_outputs.to(self._dtype())
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
residual_outputs = residual_outputs.to(self._dtype())
residual_hidden = torch.cat(
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
dim=1,
)
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
target_dtype = self._dtype()
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
feat_cond = torch.cat(
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
dim=1,
)
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
diff_loss = self.feat_decoder.compute_loss(
feat_gt.transpose(1, 2).contiguous(),
dit_hidden,
cond=feat_cond.transpose(1, 2).contiguous(),
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
progress=progress,
)
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
denom = torch.clamp(loss_mask.sum(), min=1.0)
stop_loss = (stop_losses * loss_mask).sum() / denom
feat_pred = None
if sample_generate:
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
feat_pred_seq = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
return {
"loss/diff": diff_loss,
"loss/stop": stop_loss,
"feat_gt": feat_gt_tensor,
"feat_pred": feat_pred,
}
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
if len(prompt_wav_path) == 0:
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_length = text_token.shape[0]
audio_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
else:
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token])
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, _ in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
self,
prompt_text: str,
prompt_wav_path: str,
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with text tokens and audio features
"""
if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# extract audio features
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
# build prompt cache
prompt_cache = {
"text_token": text_token,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
new_text_token: torch.Tensor,
new_audio_feat: torch.Tensor,
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text_token: newly generated text tokens
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache
"""
if original_cache is None:
return {
"text_token": new_text_token,
"audio_feat": new_audio_feat,
}
original_text_token = original_cache["text_token"]
original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache
merged_cache = {
"text_token": merged_text_token,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
target_text: str,
prompt_cache: dict,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
min_len: Minimum audio length to avoid very short audio
max_len: Maximum audio length
inference_timesteps: Number of diffusion sampling steps
cfg_value: Classifier-free guidance value
retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
Returns:
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
- Tensor of new text tokens
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
"""
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
audio_pad_feat = torch.zeros(
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _inference(
self,
text: torch.Tensor,
text_mask: torch.Tensor,
feat: torch.Tensor,
feat_mask: torch.Tensor,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
feat: Input audio features
feat_mask: Mask for audio features
min_len: Minimum generation length
max_len: Maximum generation length
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
prefix_feat_cond = feat[:, -1, ...] # b, p, d
pred_feat_seq = [] # b, t, p, d
curr_embed = None
enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed,
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
)
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
pred_feat = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
n_timesteps=inference_timesteps,
cfg_value=cfg_value,
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone()
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae = AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load(
os.path.join(path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)["state_dict"]
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
# LoRALinear 直接持有 weight/bias,与 nn.Linear 的 state_dict key 一致,
# 无需做 key 转换。使用 strict=False 是因为预训练权重不含 lora_A/lora_B。
model.load_state_dict(model_state_dict, strict=False)
if training:
return model
return model.to(model.device).eval().optimize(disable=not optimize)
# ------------------------------------------------------------------ #
# LoRA 权重管理接口
# ------------------------------------------------------------------ #
def load_lora_weights(self, lora_path: str, device: str = None):
"""
从文件加载 LoRA 权重,支持在 torch.compile 之后调用。
实现说明:
使用 named_parameters() 而非 load_state_dict() 来加载权重。
原因是 torch.compile 会将模块包装成 OptimizedModule,导致
state_dict 的 key 路径发生变化(如 module.weight -> module._orig_mod.weight),
使得 load_state_dict() 无法匹配到正确的参数。
而 named_parameters() 返回的是参数对象的引用,不受 compile 包装影响,
通过 .data.copy_() 可以直接修改参数值,既不会触发重编译,
也支持在 compile 后热切换不同的 LoRA 权重。
Args:
lora_path: LoRA checkpoint 路径(目录,内含 generator.pth)或直接的 .pth 文件
device: 加载到的设备,默认为模型当前设备
Returns:
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
if device is None:
device = self.device
# 支持目录或文件
lora_path = Path(lora_path)
if lora_path.is_dir():
ckpt_file = lora_path / "generator.pth"
else:
ckpt_file = lora_path
if not ckpt_file.exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_file}")
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
# 通过 named_parameters() 加载,兼容 torch.compile
model_params = {name: param for name, param in self.named_parameters()}
# 构建 key 映射:处理 torch.compile 导致的 _orig_mod 前缀
# checkpoint key: feat_decoder.estimator.decoder.layers...
# compile 后 key: feat_decoder.estimator._orig_mod.decoder.layers...
# 需要建立双向映射以支持两种情况
key_mapping = {}
for model_key in model_params.keys():
# 去掉 _orig_mod 得到原始 key
normalized_key = model_key.replace("._orig_mod.", ".")
if normalized_key != model_key:
key_mapping[normalized_key] = model_key
loaded_keys = []
skipped_keys = []
for key, value in state_dict.items():
# 优先直接匹配
if key in model_params:
model_params[key].data.copy_(value.to(device))
loaded_keys.append(key)
# 尝试通过映射匹配(处理 _orig_mod)
elif key in key_mapping:
mapped_key = key_mapping[key]
model_params[mapped_key].data.copy_(value.to(device))
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
"""
动态启用/禁用所有 LoRA 层(通过 scaling 控制,兼容 torch.compile)。
Args:
enabled: True=启用 LoRA,False=禁用(仅使用基础权重)
"""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
module.set_enabled(enabled)
def reset_lora_weights(self):
"""
重置所有 LoRA 权重到初始状态(A: kaiming, B: zeros)。
B=0 时 LoRA 输出为 0,相当于"卸载" LoRA。
"""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
module.reset_lora_parameters()
def get_lora_state_dict(self) -> dict:
"""
获取当前模型中所有 LoRA 参数的 state_dict。
Returns:
dict: 仅包含 lora_A / lora_B 参数的字典
"""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_A" in name or "lora_B" in name}
|