happyme531 commited on
Commit
8b1176e
·
verified ·
1 Parent(s): aa3001b

Upload 36 files

Browse files
Files changed (36) hide show
  1. convert/MiniCPM4-0.5B/added_tokens.json +10 -0
  2. convert/MiniCPM4-0.5B/config.json +37 -0
  3. convert/MiniCPM4-0.5B/configuration_minicpm.py +203 -0
  4. convert/MiniCPM4-0.5B/generation_config.json +12 -0
  5. convert/MiniCPM4-0.5B/modeling_minicpm.py +1514 -0
  6. convert/MiniCPM4-0.5B/special_tokens_map.json +33 -0
  7. convert/MiniCPM4-0.5B/tokenizer.json +0 -0
  8. convert/MiniCPM4-0.5B/tokenizer.model +3 -0
  9. convert/MiniCPM4-0.5B/tokenizer_config.json +117 -0
  10. convert/README.md +53 -0
  11. convert/scripts/build_rk3588_pipeline.py +283 -0
  12. convert/scripts/convert_vox_minicpm_to_hf.py +115 -0
  13. convert/scripts/export_onnx.py +297 -0
  14. convert/scripts/export_rkllm.py +65 -0
  15. convert/src/voxcpm/__init__.py +5 -0
  16. convert/src/voxcpm/cli.py +299 -0
  17. convert/src/voxcpm/core.py +195 -0
  18. convert/src/voxcpm/model/__init__.py +3 -0
  19. convert/src/voxcpm/model/utils.py +122 -0
  20. convert/src/voxcpm/model/voxcpm.py +690 -0
  21. convert/src/voxcpm/modules/__init__.py +0 -0
  22. convert/src/voxcpm/modules/audiovae/__init__.py +1 -0
  23. convert/src/voxcpm/modules/audiovae/audio_vae.py +359 -0
  24. convert/src/voxcpm/modules/layers/__init__.py +1 -0
  25. convert/src/voxcpm/modules/layers/scalar_quantization_layer.py +26 -0
  26. convert/src/voxcpm/modules/locdit/__init__.py +2 -0
  27. convert/src/voxcpm/modules/locdit/local_dit.py +114 -0
  28. convert/src/voxcpm/modules/locdit/unified_cfm.py +137 -0
  29. convert/src/voxcpm/modules/locenc/__init__.py +1 -0
  30. convert/src/voxcpm/modules/locenc/local_encoder.py +30 -0
  31. convert/src/voxcpm/modules/minicpm4/__init__.py +3 -0
  32. convert/src/voxcpm/modules/minicpm4/cache.py +47 -0
  33. convert/src/voxcpm/modules/minicpm4/config.py +29 -0
  34. convert/src/voxcpm/modules/minicpm4/model.py +473 -0
  35. convert/src/voxcpm/utils/text_normalize.py +185 -0
  36. convert/src/voxcpm/zipenhancer.py +76 -0
convert/MiniCPM4-0.5B/added_tokens.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|execute_end|>": 73444,
3
+ "<|execute_start|>": 73443,
4
+ "<|fim_middle|>": 73446,
5
+ "<|fim_prefix|>": 73445,
6
+ "<|fim_suffix|>": 73447,
7
+ "<|im_end|>": 73440,
8
+ "<|im_start|>": 73441,
9
+ "<|tool_call|>": 73442
10
+ }
convert/MiniCPM4-0.5B/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openbmb/MiniCPM4-0.5B",
3
+ "architectures": [
4
+ "MiniCPMForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_minicpm.MiniCPMConfig",
8
+ "AutoModel": "modeling_minicpm.MiniCPMModel",
9
+ "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
10
+ "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
12
+ },
13
+ "bos_token_id": 1,
14
+ "eos_token_id": [2, 73440],
15
+ "hidden_act": "silu",
16
+ "hidden_size": 1024,
17
+ "initializer_range": 0.1,
18
+ "intermediate_size": 4096,
19
+ "max_position_embeddings": 32768,
20
+ "num_attention_heads": 16,
21
+ "num_hidden_layers": 24,
22
+ "num_key_value_heads": 2,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_scaling": {
25
+ "rope_type": "longrope",
26
+ "long_factor": [1.0004360675811768, 1.0668443441390991, 1.1631425619125366, 1.3025742769241333, 1.5040205717086792, 1.7941505908966064, 2.2101221084594727, 2.802666664123535, 3.6389970779418945, 4.804192543029785, 6.39855432510376, 8.527148246765137, 11.277542114257812, 14.684998512268066, 18.69317054748535, 23.13019371032715, 27.72362518310547, 32.1606559753418, 36.168827056884766, 39.57627868652344, 42.32667541503906, 44.45526885986328, 46.04962921142578, 47.21482849121094, 48.05115509033203, 48.64370346069336, 49.05967712402344, 49.34980392456055, 49.551246643066406, 49.69068145751953, 49.78697967529297, 49.85338592529297],
27
+ "short_factor": [1.0004360675811768, 1.0668443441390991, 1.1631425619125366, 1.3025742769241333, 1.5040205717086792, 1.7941505908966064, 2.2101221084594727, 2.802666664123535, 3.6389970779418945, 4.804192543029785, 6.39855432510376, 8.527148246765137, 11.277542114257812, 14.684998512268066, 18.69317054748535, 23.13019371032715, 27.72362518310547, 32.1606559753418, 36.168827056884766, 39.57627868652344, 42.32667541503906, 44.45526885986328, 46.04962921142578, 47.21482849121094, 48.05115509033203, 48.64370346069336, 49.05967712402344, 49.34980392456055, 49.551246643066406, 49.69068145751953, 49.78697967529297, 49.85338592529297],
28
+ "original_max_position_embeddings": 32768
29
+ },
30
+ "torch_dtype": "bfloat16",
31
+ "transformers_version": "4.46.3",
32
+ "use_cache": true,
33
+ "vocab_size": 73448,
34
+ "scale_emb": 12,
35
+ "dim_model_base": 256,
36
+ "scale_depth": 1.4
37
+ }
convert/MiniCPM4-0.5B/configuration_minicpm.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OpenBMB Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MiniCPM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
23
+
24
+
25
+ class MiniCPMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
28
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the MiniCPM-7B.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 32000):
37
+ Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`MiniCPMModel`]
39
+ hidden_size (`int`, *optional*, defaults to 4096):
40
+ Dimension of the hidden representations.
41
+ intermediate_size (`int`, *optional*, defaults to 11008):
42
+ Dimension of the MLP representations.
43
+ num_hidden_layers (`int`, *optional*, defaults to 32):
44
+ Number of hidden layers in the Transformer decoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 32):
46
+ Number of attention heads for each attention layer in the Transformer decoder.
47
+ num_key_value_heads (`int`, *optional*):
48
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
49
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
50
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
51
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
52
+ by meanpooling all the original heads within that group. For more details checkout [this
53
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
54
+ `num_attention_heads`.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
58
+ The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
59
+ MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
63
+ The epsilon used by the rms normalization layers.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
66
+ relevant if `config.is_decoder=True`.
67
+ pad_token_id (`int`, *optional*):
68
+ Padding token id.
69
+ bos_token_id (`int`, *optional*, defaults to 1):
70
+ Beginning of stream token id.
71
+ eos_token_id (`int`, *optional*, defaults to 2):
72
+ End of stream token id.
73
+ pretraining_tp (`int`, *optional*, defaults to 1):
74
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
75
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
76
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
77
+ issue](https://github.com/pytorch/pytorch/issues/76232).
78
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
79
+ Whether to tie weight embeddings
80
+ rope_theta (`float`, *optional*, defaults to 10000.0):
81
+ The base period of the RoPE embeddings.
82
+ rope_scaling (`Dict`, *optional*):
83
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
84
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
85
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
86
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
87
+ these scaling strategies behave:
88
+ https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
89
+ experimental feature, subject to breaking API changes in future versions.
90
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
91
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
92
+ attention_dropout (`float`, *optional*, defaults to 0.0):
93
+ The dropout ratio for the attention probabilities.
94
+
95
+ ```python
96
+ >>> from transformers import MiniCPMModel, MiniCPMConfig
97
+
98
+ >>> # Initializing a MiniCPM minicpm-7b style configuration
99
+ >>> configuration = MiniCPMConfig()
100
+
101
+ >>> # Initializing a model from the minicpm-7b style configuration
102
+ >>> model = MiniCPMModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+
108
+ model_type = 'minicpm'
109
+ keys_to_ignore_at_inference = ['past_key_values']
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=32000,
114
+ hidden_size=4096,
115
+ intermediate_size=11008,
116
+ num_hidden_layers=32,
117
+ num_attention_heads=32,
118
+ num_key_value_heads=None,
119
+ hidden_act='silu',
120
+ max_position_embeddings=2048,
121
+ initializer_range=0.02,
122
+ rms_norm_eps=1e-6,
123
+ use_cache=True,
124
+ pad_token_id=None,
125
+ bos_token_id=1,
126
+ eos_token_id=2,
127
+ pretraining_tp=1,
128
+ tie_word_embeddings=True,
129
+ rope_theta=10000.0,
130
+ rope_scaling=None,
131
+ attention_bias=False,
132
+ attention_dropout=0.0,
133
+ scale_emb=1,
134
+ dim_model_base=1,
135
+ scale_depth=1,
136
+ mup_denominator=None,
137
+ sparse_config=None,
138
+ **kwargs):
139
+
140
+ self.vocab_size = vocab_size
141
+ self.max_position_embeddings = max_position_embeddings
142
+ self.hidden_size = hidden_size
143
+ self.intermediate_size = intermediate_size
144
+ self.num_hidden_layers = num_hidden_layers
145
+ self.num_attention_heads = num_attention_heads
146
+
147
+ # for backward compatibility
148
+ if num_key_value_heads is None:
149
+ num_key_value_heads = num_attention_heads
150
+
151
+ self.num_key_value_heads = num_key_value_heads
152
+ self.hidden_act = hidden_act
153
+ self.initializer_range = initializer_range
154
+ self.rms_norm_eps = rms_norm_eps
155
+ self.pretraining_tp = pretraining_tp
156
+ self.use_cache = use_cache
157
+ self.rope_theta = rope_theta
158
+ self.rope_scaling = rope_scaling
159
+ # self._rope_scaling_validation()
160
+ self.attention_bias = attention_bias
161
+ self.attention_dropout = attention_dropout
162
+ self.scale_emb = scale_emb
163
+ self.dim_model_base = dim_model_base
164
+ self.scale_depth = scale_depth
165
+ # only used for Eagle Head
166
+ self.mup_denominator = mup_denominator
167
+
168
+ # sparse config
169
+ self.sparse_config = sparse_config
170
+
171
+ super().__init__(
172
+ pad_token_id=pad_token_id,
173
+ bos_token_id=bos_token_id,
174
+ eos_token_id=eos_token_id,
175
+ tie_word_embeddings=tie_word_embeddings,
176
+ **kwargs,
177
+ )
178
+ try:
179
+ import flash_attn
180
+ self._attn_implementation = 'flash_attention_2'
181
+ except:
182
+ pass
183
+
184
+ def _rope_scaling_validation(self):
185
+ """
186
+ Validate the `rope_scaling` configuration.
187
+ """
188
+ if self.rope_scaling is None:
189
+ return
190
+
191
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
192
+ raise ValueError(
193
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
194
+ f'got {self.rope_scaling}'
195
+ )
196
+ rope_scaling_type = self.rope_scaling.get('type', None)
197
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
198
+ if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
199
+ raise ValueError(
200
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
201
+ )
202
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
203
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
convert/MiniCPM4-0.5B/generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 2,
6
+ 73440
7
+ ],
8
+ "pad_token_id": 2,
9
+ "temperature": 0.8,
10
+ "top_p": 0.8,
11
+ "transformers_version": "4.46.1"
12
+ }
convert/MiniCPM4-0.5B/modeling_minicpm.py ADDED
@@ -0,0 +1,1514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OpenBMB Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch MiniCPM model."""
16
+ import math
17
+ import re
18
+ import warnings
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
28
+ from transformers.modeling_attn_mask_utils import (
29
+ AttentionMaskConverter,
30
+ _prepare_4d_attention_mask,
31
+ _prepare_4d_causal_attention_mask,
32
+ _prepare_4d_causal_attention_mask_for_sdpa,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ SequenceClassifierOutputWithPast,
38
+ )
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
41
+ from transformers.utils import (
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_greater_or_equal_2_10,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers.utils.import_utils import is_torch_fx_available
49
+
50
+ from .configuration_minicpm import MiniCPMConfig
51
+
52
+ try:
53
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
54
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55
+ except:
56
+ pass
57
+
58
+
59
+
60
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
61
+ # It means that the function will not be traced through and simply appear as a node in the graph.
62
+ if is_torch_fx_available():
63
+ if not is_torch_greater_or_equal_than_1_13:
64
+ import torch.fx
65
+
66
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
67
+
68
+
69
+ logger = logging.get_logger(__name__)
70
+
71
+ _CONFIG_FOR_DOC = 'MiniCPMConfig'
72
+
73
+
74
+ def _get_unpad_data(attention_mask):
75
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
76
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
77
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
78
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
79
+ return (
80
+ indices,
81
+ cu_seqlens,
82
+ max_seqlen_in_batch,
83
+ )
84
+
85
+
86
+
87
+
88
+ # @torch.jit.script # type: ignore
89
+ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
90
+ old_dtype = hidden.dtype
91
+ variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
92
+ hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
93
+ return hidden * weight
94
+
95
+
96
+ class MiniCPMRMSNorm(nn.Module):
97
+ def __init__(self, hidden_size, eps=1e-6):
98
+ """
99
+ MiniCPMRMSNorm is equivalent to T5LayerNorm
100
+ """
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(hidden_size))
103
+ self.variance_epsilon = eps
104
+
105
+ def forward(self, hidden_states):
106
+ return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
107
+
108
+
109
+ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
110
+
111
+
112
+ class MiniCPMRotaryEmbedding(nn.Module):
113
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
114
+ super().__init__()
115
+
116
+ self.dim = dim
117
+ self.max_position_embeddings = max_position_embeddings
118
+ self.base = base
119
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
120
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
121
+
122
+ # Build here to make `torch.jit.trace` work.
123
+ self._set_cos_sin_cache(
124
+ # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
125
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
126
+ )
127
+
128
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
129
+ self.max_seq_len_cached = seq_len
130
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+
135
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
136
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
137
+
138
+ def forward(self, x, seq_len=None):
139
+ # x: [bs, num_attention_heads, seq_len, head_size]
140
+ if seq_len > self.max_seq_len_cached:
141
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
142
+
143
+ return (
144
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
145
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
146
+ )
147
+
148
+
149
+ class MiniCPMLongRoPE(MiniCPMRotaryEmbedding):
150
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
151
+
152
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, short_factor=None, long_factor=None, original_max_position_embeddings=None):
153
+ self.short_factor = short_factor
154
+ self.long_factor = long_factor
155
+ self.original_max_position_embeddings = original_max_position_embeddings
156
+ scale = (max_position_embeddings / self.original_max_position_embeddings)
157
+ self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
158
+ super().__init__(dim, max_position_embeddings, base, device)
159
+
160
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
161
+ self.max_seq_len_cached = seq_len
162
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
163
+ if seq_len > self.original_max_position_embeddings:
164
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
165
+ else:
166
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
167
+
168
+ freqs = torch.mul(
169
+ torch.outer(t, 1.0 / ext_factors).to(device=device),
170
+ self.inv_freq.to(device=device).to(dtype)
171
+ )
172
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
173
+ emb = torch.cat((freqs, freqs), dim=-1)
174
+ self.register_buffer('cos_cached', emb.cos().to(dtype) * self.scaling_factor, persistent=False)
175
+ self.register_buffer('sin_cached', emb.sin().to(dtype) * self.scaling_factor, persistent=False)
176
+
177
+
178
+ class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
179
+ """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
180
+
181
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
182
+ self.scaling_factor = scaling_factor
183
+ super().__init__(dim, max_position_embeddings, base, device)
184
+
185
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
186
+ self.max_seq_len_cached = seq_len
187
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
188
+ t = t / self.scaling_factor
189
+
190
+ freqs = torch.outer(t, self.inv_freq)
191
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
192
+ emb = torch.cat((freqs, freqs), dim=-1)
193
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
194
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
195
+
196
+
197
+ class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
198
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
199
+
200
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
201
+ self.scaling_factor = scaling_factor
202
+ super().__init__(dim, max_position_embeddings, base, device)
203
+
204
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
205
+ self.max_seq_len_cached = seq_len
206
+
207
+ if seq_len > self.max_position_embeddings:
208
+ base = self.base * (
209
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
210
+ ) ** (self.dim / (self.dim - 2))
211
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
212
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
213
+
214
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
215
+
216
+ freqs = torch.outer(t, self.inv_freq)
217
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
218
+ emb = torch.cat((freqs, freqs), dim=-1)
219
+
220
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
221
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
222
+
223
+
224
+ def rotate_half(x):
225
+ """Rotates half the hidden dims of the input."""
226
+ x1 = x[..., : x.shape[-1] // 2]
227
+ x2 = x[..., x.shape[-1] // 2:]
228
+ return torch.cat((-x2, x1), dim=-1)
229
+
230
+
231
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
232
+ """Applies Rotary Position Embedding to the query and key tensors.
233
+
234
+ Args:
235
+ q (`torch.Tensor`): The query tensor.
236
+ k (`torch.Tensor`): The key tensor.
237
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
238
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
239
+ position_ids (`torch.Tensor`):
240
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
241
+ used to pass offsetted position ids when working with a KV-cache.
242
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
243
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
244
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
245
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
246
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
247
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
248
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
249
+ Returns:
250
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
251
+ """
252
+ # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
253
+ # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
254
+ # q_embed = (q * cos) + (rotate_half(q) * sin)
255
+ # k_embed = (k * cos) + (rotate_half(k) * sin)
256
+ orig_dtype = k.dtype
257
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
258
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
259
+ q_fp32 = q.to(dtype=torch.float32, device=q.device)
260
+ k_fp32 = k.to(dtype=torch.float32, device=k.device)
261
+ q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
262
+ k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
263
+ return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
264
+
265
+
266
+ class MiniCPMMLP(nn.Module):
267
+ def __init__(self, config):
268
+ super().__init__()
269
+ self.config = config
270
+ self.hidden_size = config.hidden_size
271
+ self.intermediate_size = config.intermediate_size
272
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
273
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
274
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
275
+ self.act_fn = ACT2FN[config.hidden_act]
276
+
277
+ def forward(self, x):
278
+ if self.config.pretraining_tp > 1:
279
+ slice = self.intermediate_size // self.config.pretraining_tp
280
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
281
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
282
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
283
+
284
+ gate_proj = torch.cat(
285
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
286
+ )
287
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
288
+
289
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
290
+ down_proj = [
291
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
292
+ ]
293
+ down_proj = sum(down_proj)
294
+ else:
295
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
296
+
297
+ return down_proj
298
+
299
+ def _unpad_one_tensor(hidden_states, attention_mask):
300
+ # Unpad the hidden states using the indices
301
+ indices, cu_seqlens, max_seqlen_in_batch = _get_unpad_data(attention_mask)
302
+ batch_size, seq_len = hidden_states.shape[:2]
303
+
304
+ # Get the remaining dimensions
305
+ remaining_dims = hidden_states.shape[2:]
306
+
307
+ # Reshape to (batch_size * seq_len, *remaining_dims)
308
+ reshaped_states = hidden_states.reshape(batch_size * seq_len, *remaining_dims)
309
+
310
+ # Apply unpadding using indices
311
+ unpadded_states = index_first_axis(reshaped_states, indices)
312
+
313
+ return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
314
+
315
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
316
+ """
317
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
318
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
319
+ """
320
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
321
+ if n_rep == 1:
322
+ return hidden_states
323
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
324
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
325
+
326
+
327
+ class MiniCPMAttention(nn.Module):
328
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
329
+
330
+ def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
331
+ super().__init__()
332
+ self.config = config
333
+ self.layer_idx = layer_idx
334
+ if layer_idx is None:
335
+ logger.warning_once(
336
+ f'Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will '
337
+ 'to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` '
338
+ 'when creating this class.'
339
+ )
340
+
341
+ self.attention_dropout = config.attention_dropout
342
+ self.hidden_size = config.hidden_size
343
+ self.num_heads = config.num_attention_heads
344
+ self.head_dim = self.hidden_size // self.num_heads
345
+ self.num_key_value_heads = config.num_key_value_heads
346
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
347
+ self.max_position_embeddings = config.max_position_embeddings
348
+ self.rope_theta = config.rope_theta
349
+ self.is_causal = True
350
+
351
+ if (self.head_dim * self.num_heads) != self.hidden_size:
352
+ raise ValueError(
353
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
354
+ f' and `num_heads`: {self.num_heads}).'
355
+ )
356
+
357
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
358
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
359
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
360
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
361
+ self._init_rope()
362
+
363
+ def _init_rope(self):
364
+ if self.config.rope_scaling is None:
365
+ self.rotary_emb = MiniCPMRotaryEmbedding(
366
+ self.head_dim,
367
+ max_position_embeddings=self.max_position_embeddings,
368
+ base=self.rope_theta,
369
+ )
370
+ else:
371
+ scaling_type = self.config.rope_scaling['rope_type']
372
+ scaling_factor = self.config.rope_scaling.get('factor', None)
373
+ if scaling_type == 'linear':
374
+ self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
375
+ self.head_dim,
376
+ max_position_embeddings=self.max_position_embeddings,
377
+ scaling_factor=scaling_factor,
378
+ base=self.rope_theta,
379
+ )
380
+ elif scaling_type == 'dynamic':
381
+ self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding(
382
+ self.head_dim,
383
+ max_position_embeddings=self.max_position_embeddings,
384
+ scaling_factor=scaling_factor,
385
+ base=self.rope_theta,
386
+ )
387
+ elif scaling_type == 'longrope':
388
+ self.rotary_emb = MiniCPMLongRoPE(
389
+ self.head_dim,
390
+ max_position_embeddings=self.max_position_embeddings,
391
+ short_factor=self.config.rope_scaling['short_factor'],
392
+ long_factor=self.config.rope_scaling['long_factor'],
393
+ base=self.rope_theta,
394
+ original_max_position_embeddings=self.config.rope_scaling['original_max_position_embeddings']
395
+ )
396
+ else:
397
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
398
+
399
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
400
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states: torch.Tensor,
405
+ attention_mask: Optional[torch.Tensor] = None,
406
+ position_ids: Optional[torch.LongTensor] = None,
407
+ past_key_value: Optional[Cache] = None,
408
+ output_attentions: bool = False,
409
+ use_cache: bool = False,
410
+ **kwargs,
411
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
412
+ if 'padding_mask' in kwargs:
413
+ warnings.warn(
414
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
415
+ )
416
+
417
+ bsz, q_len, _ = hidden_states.size()
418
+
419
+ if self.config.pretraining_tp > 1:
420
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
421
+ query_slices = self.q_proj.weight.split(
422
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
423
+ )
424
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
425
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
426
+
427
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
428
+ query_states = torch.cat(query_states, dim=-1)
429
+
430
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
431
+ key_states = torch.cat(key_states, dim=-1)
432
+
433
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
434
+ value_states = torch.cat(value_states, dim=-1)
435
+
436
+ else:
437
+ query_states = self.q_proj(hidden_states)
438
+ key_states = self.k_proj(hidden_states)
439
+ value_states = self.v_proj(hidden_states)
440
+
441
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
442
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
443
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
444
+
445
+ kv_seq_len = position_ids.max().item() + 1
446
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
447
+
448
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
449
+
450
+ if past_key_value is not None:
451
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
452
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
453
+
454
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
455
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
456
+
457
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
458
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
459
+ raise ValueError(
460
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
461
+ f' {attn_weights.size()}'
462
+ )
463
+
464
+ if attention_mask is not None:
465
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
466
+ raise ValueError(
467
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
468
+ )
469
+ attn_weights = attn_weights + attention_mask
470
+
471
+ # upcast attention to fp32
472
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
473
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
474
+ attn_output = torch.matmul(attn_weights, value_states)
475
+
476
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
477
+ raise ValueError(
478
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
479
+ f' {attn_output.size()}'
480
+ )
481
+
482
+ attn_output = attn_output.transpose(1, 2).contiguous()
483
+
484
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
485
+
486
+ if self.config.pretraining_tp > 1:
487
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
488
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
489
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
490
+ else:
491
+ attn_output = self.o_proj(attn_output)
492
+
493
+ if not output_attentions:
494
+ attn_weights = None
495
+
496
+ return attn_output, attn_weights, past_key_value
497
+
498
+
499
+ class MiniCPMFlashAttention2(MiniCPMAttention):
500
+ """
501
+ MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays
502
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
503
+ flash attention and deal with padding tokens in case the input contains any of them.
504
+ """
505
+
506
+ def __init__(self, *args, **kwargs):
507
+ super().__init__(*args, **kwargs)
508
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
509
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
510
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
511
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
512
+
513
+ def forward(
514
+ self,
515
+ hidden_states: torch.Tensor,
516
+ attention_mask: Optional[torch.LongTensor] = None,
517
+ position_ids: Optional[torch.LongTensor] = None,
518
+ past_key_value: Optional[Cache] = None,
519
+ output_attentions: bool = False,
520
+ use_cache: bool = False,
521
+ **kwargs,
522
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
523
+ # MiniCPMFlashAttention2 attention does not support output_attentions
524
+ if 'padding_mask' in kwargs:
525
+ warnings.warn(
526
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
527
+ )
528
+
529
+ # overwrite attention_mask with padding_mask
530
+ attention_mask = kwargs.pop('padding_mask')
531
+
532
+ output_attentions = False
533
+
534
+ bsz, q_len, _ = hidden_states.size()
535
+
536
+ query_states = self.q_proj(hidden_states)
537
+ key_states = self.k_proj(hidden_states)
538
+ value_states = self.v_proj(hidden_states)
539
+
540
+ # Flash attention requires the input to have the shape
541
+ # batch_size x seq_length x head_dim x hidden_dim
542
+ # therefore we just need to keep the original shape
543
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
544
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
545
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
546
+
547
+ kv_seq_len = position_ids.max().item() + 1
548
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
549
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
550
+
551
+ if past_key_value is not None:
552
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
553
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
554
+
555
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
556
+ # to be able to avoid many of these transpose/reshape/view.
557
+ query_states = query_states.transpose(1, 2)
558
+ key_states = key_states.transpose(1, 2)
559
+ value_states = value_states.transpose(1, 2)
560
+
561
+ dropout_rate = self.attention_dropout if self.training else 0.0
562
+
563
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
564
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
565
+ # cast them back in the correct dtype just to be sure everything works as expected.
566
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
567
+ # in fp32. (MiniCPMRMSNorm handles it correctly)
568
+
569
+ input_dtype = query_states.dtype
570
+ if input_dtype == torch.float32:
571
+ # Handle the case where the model is quantized
572
+ if hasattr(self.config, '_pre_quantization_dtype'):
573
+ target_dtype = self.config._pre_quantization_dtype
574
+ else:
575
+ target_dtype = self.q_proj.weight.dtype
576
+
577
+ logger.warning_once(
578
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
579
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
580
+ f' {target_dtype}.'
581
+ )
582
+
583
+ query_states = query_states.to(target_dtype)
584
+ key_states = key_states.to(target_dtype)
585
+ value_states = value_states.to(target_dtype)
586
+
587
+ attn_output = self._flash_attention_forward(
588
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
589
+ )
590
+
591
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
592
+ attn_output = self.o_proj(attn_output)
593
+
594
+ if not output_attentions:
595
+ attn_weights = None
596
+
597
+ return attn_output, attn_weights, past_key_value
598
+
599
+ def _flash_attention_forward(
600
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
601
+ ):
602
+ """
603
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
604
+ first unpad the input, then computes the attention scores and pad the final attention scores.
605
+
606
+ Args:
607
+ query_states (`torch.Tensor`):
608
+ Input query states to be passed to Flash Attention API
609
+ key_states (`torch.Tensor`):
610
+ Input key states to be passed to Flash Attention API
611
+ value_states (`torch.Tensor`):
612
+ Input value states to be passed to Flash Attention API
613
+ attention_mask (`torch.Tensor`):
614
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
615
+ position of padding tokens and 1 for the position of non-padding tokens.
616
+ dropout (`int`, *optional*):
617
+ Attention dropout
618
+ softmax_scale (`float`, *optional*):
619
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
620
+ """
621
+ if not self._flash_attn_uses_top_left_mask:
622
+ causal = self.is_causal
623
+ else:
624
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
625
+ causal = self.is_causal and query_length != 1
626
+ # Contains at least one padding token in the sequence
627
+ if attention_mask is not None:
628
+ batch_size = query_states.shape[0]
629
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
630
+ query_states, key_states, value_states, attention_mask, query_length
631
+ )
632
+
633
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
634
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
635
+ attn_output_unpad = flash_attn_varlen_func(
636
+ query_states,
637
+ key_states,
638
+ value_states,
639
+ cu_seqlens_q=cu_seqlens_q,
640
+ cu_seqlens_k=cu_seqlens_k,
641
+ max_seqlen_q=max_seqlen_in_batch_q,
642
+ max_seqlen_k=max_seqlen_in_batch_k,
643
+ dropout_p=dropout,
644
+ softmax_scale=softmax_scale,
645
+ causal=causal,
646
+ )
647
+
648
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
649
+ else:
650
+ attn_output = flash_attn_func(
651
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
652
+ )
653
+
654
+ return attn_output
655
+
656
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
657
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
658
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
659
+
660
+ key_layer = index_first_axis(
661
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
662
+ )
663
+ value_layer = index_first_axis(
664
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
665
+ )
666
+ if query_length == kv_seq_len:
667
+ query_layer = index_first_axis(
668
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
669
+ )
670
+ cu_seqlens_q = cu_seqlens_k
671
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
672
+ indices_q = indices_k
673
+ elif query_length == 1:
674
+ max_seqlen_in_batch_q = 1
675
+ cu_seqlens_q = torch.arange(
676
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
677
+ ) # There is a memcpy here, that is very bad.
678
+ indices_q = cu_seqlens_q[:-1]
679
+ query_layer = query_layer.squeeze(1)
680
+ else:
681
+ # The -q_len: slice assumes left padding.
682
+ attention_mask = attention_mask[:, -query_length:]
683
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
684
+
685
+ return (
686
+ query_layer,
687
+ key_layer,
688
+ value_layer,
689
+ indices_q,
690
+ (cu_seqlens_q, cu_seqlens_k),
691
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
692
+ )
693
+
694
+
695
+ class MiniCPMSdpaAttention(MiniCPMAttention):
696
+ """
697
+ MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
698
+ `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
699
+ SDPA API.
700
+ """
701
+
702
+ # Adapted from MiniCPMAttention.forward
703
+ def forward(
704
+ self,
705
+ hidden_states: torch.Tensor,
706
+ attention_mask: Optional[torch.Tensor] = None,
707
+ position_ids: Optional[torch.LongTensor] = None,
708
+ past_key_value: Optional[Cache] = None,
709
+ output_attentions: bool = False,
710
+ use_cache: bool = False,
711
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
712
+ if output_attentions:
713
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
714
+ logger.warning_once(
715
+ 'MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
716
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
717
+ )
718
+ return super().forward(
719
+ hidden_states=hidden_states,
720
+ attention_mask=attention_mask,
721
+ position_ids=position_ids,
722
+ past_key_value=past_key_value,
723
+ output_attentions=output_attentions,
724
+ use_cache=use_cache,
725
+ )
726
+
727
+ bsz, q_len, _ = hidden_states.size()
728
+
729
+ query_states = self.q_proj(hidden_states)
730
+ key_states = self.k_proj(hidden_states)
731
+ value_states = self.v_proj(hidden_states)
732
+
733
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
734
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
735
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
736
+
737
+ kv_seq_len = position_ids.max().item() + 1
738
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
739
+
740
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
741
+
742
+ if past_key_value is not None:
743
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
744
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
745
+
746
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
747
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
748
+
749
+ if attention_mask is not None:
750
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
751
+ raise ValueError(
752
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
753
+ )
754
+
755
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
756
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
757
+ if query_states.device.type == 'cuda' and attention_mask is not None:
758
+ query_states = query_states.contiguous()
759
+ key_states = key_states.contiguous()
760
+ value_states = value_states.contiguous()
761
+
762
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
763
+ query_states,
764
+ key_states,
765
+ value_states,
766
+ attn_mask=attention_mask,
767
+ dropout_p=self.attention_dropout if self.training else 0.0,
768
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
769
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
770
+ )
771
+
772
+ attn_output = attn_output.transpose(1, 2).contiguous()
773
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
774
+
775
+ attn_output = self.o_proj(attn_output)
776
+
777
+ return attn_output, None, past_key_value
778
+
779
+
780
+ MINICPM_ATTENTION_CLASSES = {
781
+ 'eager': MiniCPMAttention,
782
+ 'flash_attention_2': MiniCPMFlashAttention2,
783
+ 'sdpa': MiniCPMSdpaAttention,
784
+ }
785
+
786
+
787
+ class MiniCPMDecoderLayer(nn.Module):
788
+ def __init__(self, config: MiniCPMConfig, layer_idx: int):
789
+ super().__init__()
790
+ self.hidden_size = config.hidden_size
791
+ self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
792
+
793
+ self.mlp = MiniCPMMLP(config)
794
+ self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
795
+ self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
796
+
797
+ self.scale_depth = config.scale_depth
798
+ self.num_hidden_layers = config.num_hidden_layers
799
+
800
+ def forward(
801
+ self,
802
+ hidden_states: torch.Tensor,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ position_ids: Optional[torch.LongTensor] = None,
805
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
806
+ output_attentions: Optional[bool] = False,
807
+ use_cache: Optional[bool] = False,
808
+ **kwargs,
809
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
810
+ """
811
+ Args:
812
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
813
+ attention_mask (`torch.FloatTensor`, *optional*):
814
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
815
+ query_sequence_length, key_sequence_length)` if default attention is used.
816
+ output_attentions (`bool`, *optional*):
817
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
818
+ returned tensors for more detail.
819
+ use_cache (`bool`, *optional*):
820
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
821
+ (see `past_key_values`).
822
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
823
+ """
824
+ if 'padding_mask' in kwargs:
825
+ warnings.warn(
826
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
827
+ )
828
+
829
+ residual = hidden_states
830
+ hidden_states = self.input_layernorm(hidden_states)
831
+ # Self Attention
832
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
833
+ hidden_states=hidden_states,
834
+ attention_mask=attention_mask,
835
+ position_ids=position_ids,
836
+ past_key_value=past_key_value,
837
+ output_attentions=output_attentions,
838
+ use_cache=use_cache,
839
+ **kwargs,
840
+ )
841
+
842
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
843
+
844
+ # Fully Connected
845
+ residual = hidden_states
846
+ hidden_states = self.post_attention_layernorm(hidden_states)
847
+
848
+ hidden_states = self.mlp(hidden_states)
849
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
850
+
851
+ outputs = (hidden_states,)
852
+
853
+ if output_attentions:
854
+ outputs += (self_attn_weights,)
855
+
856
+ if use_cache:
857
+ outputs += (present_key_value,)
858
+
859
+ return outputs
860
+
861
+
862
+ MINICPM_START_DOCSTRING = r"""
863
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
864
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
865
+ etc.)
866
+
867
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
868
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
869
+ and behavior.
870
+
871
+ Parameters:
872
+ config ([`MiniCPMConfig`]):
873
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
874
+ load the weights associated with the model, only the configuration. Check out the
875
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
876
+ """
877
+
878
+
879
+ @add_start_docstrings(
880
+ 'The bare MiniCPM Model outputting raw hidden-states without any specific head on top.',
881
+ MINICPM_START_DOCSTRING,
882
+ )
883
+ class MiniCPMPreTrainedModel(PreTrainedModel):
884
+ config_class = MiniCPMConfig
885
+ base_model_prefix = 'model'
886
+ supports_gradient_checkpointing = True
887
+ _no_split_modules = ['MiniCPMDecoderLayer']
888
+ _skip_keys_device_placement = 'past_key_values'
889
+ _supports_flash_attn_2 = True
890
+ _supports_sdpa = True
891
+ _supports_cache_class = True
892
+
893
+ def _init_weights(self, module):
894
+ std = self.config.initializer_range
895
+ if isinstance(module, nn.Linear):
896
+ module.weight.data.normal_(mean=0.0, std=std)
897
+ if module.bias is not None:
898
+ module.bias.data.zero_()
899
+ elif isinstance(module, nn.Embedding):
900
+ module.weight.data.normal_(mean=0.0, std=std)
901
+ if module.padding_idx is not None:
902
+ module.weight.data[module.padding_idx].zero_()
903
+
904
+
905
+ MINICPM_INPUTS_DOCSTRING = r"""
906
+ Args:
907
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
908
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
909
+ it.
910
+
911
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
912
+ [`PreTrainedTokenizer.__call__`] for details.
913
+
914
+ [What are input IDs?](../glossary#input-ids)
915
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
916
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
917
+
918
+ - 1 for tokens that are **not masked**,
919
+ - 0 for tokens that are **masked**.
920
+
921
+ [What are attention masks?](../glossary#attention-mask)
922
+
923
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
924
+ [`PreTrainedTokenizer.__call__`] for details.
925
+
926
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
927
+ `past_key_values`).
928
+
929
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
930
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
931
+ information on the default strategy.
932
+
933
+ - 1 indicates the head is **not masked**,
934
+ - 0 indicates the head is **masked**.
935
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
936
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
937
+ config.n_positions - 1]`.
938
+
939
+ [What are position IDs?](../glossary#position-ids)
940
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
941
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
942
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
943
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
944
+
945
+ Two formats are allowed:
946
+ - a [`~cache_utils.Cache`] instance;
947
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
948
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
949
+ cache format.
950
+
951
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
952
+ legacy cache format will be returned.
953
+
954
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
955
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
956
+ of shape `(batch_size, sequence_length)`.
957
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
958
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
959
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
960
+ model's internal embedding lookup matrix.
961
+ use_cache (`bool`, *optional*):
962
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
963
+ `past_key_values`).
964
+ output_attentions (`bool`, *optional*):
965
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
966
+ tensors for more detail.
967
+ output_hidden_states (`bool`, *optional*):
968
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
969
+ more detail.
970
+ return_dict (`bool`, *optional*):
971
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
972
+ """
973
+
974
+
975
+ @add_start_docstrings(
976
+ 'The bare MiniCPM Model outputting raw hidden-states without any specific head on top.',
977
+ MINICPM_START_DOCSTRING,
978
+ )
979
+ class MiniCPMModel(MiniCPMPreTrainedModel):
980
+ """
981
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
982
+
983
+ Args:
984
+ config: MiniCPMConfig
985
+ """
986
+
987
+ def __init__(self, config: MiniCPMConfig):
988
+ super().__init__(config)
989
+ self.padding_idx = config.pad_token_id
990
+ self.vocab_size = config.vocab_size
991
+
992
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
993
+ self.layers = nn.ModuleList(
994
+ [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
995
+ )
996
+ self._use_sdpa = config._attn_implementation == 'sdpa'
997
+ self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
998
+
999
+ self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1000
+
1001
+ self.gradient_checkpointing = False
1002
+ # Initialize weights and apply final processing
1003
+ self.post_init()
1004
+
1005
+ def get_input_embeddings(self):
1006
+ return self.embed_tokens
1007
+
1008
+ def set_input_embeddings(self, value):
1009
+ self.embed_tokens = value
1010
+
1011
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1012
+ def forward(
1013
+ self,
1014
+ input_ids: torch.LongTensor = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
1016
+ position_ids: Optional[torch.LongTensor] = None,
1017
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1019
+ use_cache: Optional[bool] = None,
1020
+ output_attentions: Optional[bool] = None,
1021
+ output_hidden_states: Optional[bool] = None,
1022
+ return_dict: Optional[bool] = None,
1023
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1024
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1025
+ output_hidden_states = (
1026
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1027
+ )
1028
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1029
+
1030
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1031
+
1032
+ # retrieve input_ids and inputs_embeds
1033
+ if input_ids is not None and inputs_embeds is not None:
1034
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
1035
+ elif input_ids is not None:
1036
+ batch_size, seq_length = input_ids.shape[:2]
1037
+ elif inputs_embeds is not None:
1038
+ batch_size, seq_length = inputs_embeds.shape[:2]
1039
+ else:
1040
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
1041
+
1042
+ if self.gradient_checkpointing and self.training:
1043
+ if use_cache:
1044
+ logger.warning_once(
1045
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
1046
+ )
1047
+ use_cache = False
1048
+
1049
+ past_key_values_length = 0
1050
+
1051
+ if use_cache:
1052
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1053
+ if use_legacy_cache:
1054
+ raise ValueError(
1055
+ 'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1056
+ )
1057
+
1058
+ # Calculate the usable length of past key values
1059
+ past_key_values_length = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0
1060
+
1061
+
1062
+ if position_ids is None:
1063
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1064
+ position_ids = torch.arange(
1065
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1066
+ )
1067
+ position_ids = position_ids.unsqueeze(0)
1068
+
1069
+ if inputs_embeds is None:
1070
+ inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1071
+
1072
+ if self._use_flash_attention_2:
1073
+ # 2d mask is passed through the layers
1074
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1075
+ if attention_mask is None:
1076
+ raise ValueError(
1077
+ f'need attention_mask for flash attention, but got {attention_mask}.'
1078
+ )
1079
+ elif self._use_sdpa and not output_attentions:
1080
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1081
+ # the manual implementation that requires a 4D causal mask in all cases.
1082
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1083
+ attention_mask,
1084
+ (batch_size, seq_length),
1085
+ inputs_embeds,
1086
+ past_key_values_length,
1087
+ )
1088
+ else:
1089
+ # 4d mask is passed through the layers
1090
+ attention_mask = _prepare_4d_causal_attention_mask(
1091
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1092
+ )
1093
+
1094
+ # embed positions
1095
+ hidden_states = inputs_embeds
1096
+
1097
+ # decoder layers
1098
+ all_hidden_states = () if output_hidden_states else None
1099
+ all_self_attns = () if output_attentions else None
1100
+ next_decoder_cache = None
1101
+
1102
+ for decoder_layer in self.layers:
1103
+ if output_hidden_states:
1104
+ all_hidden_states += (hidden_states,)
1105
+
1106
+ if self.gradient_checkpointing and self.training:
1107
+ layer_outputs = self._gradient_checkpointing_func(
1108
+ decoder_layer.__call__,
1109
+ hidden_states,
1110
+ attention_mask,
1111
+ position_ids,
1112
+ past_key_values,
1113
+ output_attentions,
1114
+ use_cache,
1115
+ )
1116
+ else:
1117
+ layer_outputs = decoder_layer(
1118
+ hidden_states,
1119
+ attention_mask=attention_mask,
1120
+ position_ids=position_ids,
1121
+ past_key_value=past_key_values,
1122
+ output_attentions=output_attentions,
1123
+ use_cache=use_cache,
1124
+ )
1125
+
1126
+ hidden_states = layer_outputs[0]
1127
+
1128
+ if use_cache:
1129
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1130
+
1131
+ if output_attentions:
1132
+ all_self_attns += (layer_outputs[1],)
1133
+
1134
+ hidden_states = self.norm(hidden_states)
1135
+
1136
+ # add hidden states from the last decoder layer
1137
+ if output_hidden_states:
1138
+ all_hidden_states += (hidden_states,)
1139
+
1140
+ next_cache = None
1141
+ if use_cache:
1142
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1143
+ if not return_dict:
1144
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1145
+ return BaseModelOutputWithPast(
1146
+ last_hidden_state=hidden_states,
1147
+ past_key_values=next_cache,
1148
+ hidden_states=all_hidden_states,
1149
+ attentions=all_self_attns,
1150
+ )
1151
+
1152
+
1153
+ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1154
+ _tied_weights_keys = ['lm_head.weight']
1155
+
1156
+ def __init__(self, config):
1157
+ super().__init__(config)
1158
+ self.model = MiniCPMModel(config)
1159
+ self.vocab_size = config.vocab_size
1160
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1161
+
1162
+ # Initialize weights and apply final processing
1163
+ self.post_init()
1164
+
1165
+ def get_input_embeddings(self):
1166
+ return self.model.embed_tokens
1167
+
1168
+ def set_input_embeddings(self, value):
1169
+ self.model.embed_tokens = value
1170
+
1171
+ def get_output_embeddings(self):
1172
+ return self.lm_head
1173
+
1174
+ def set_output_embeddings(self, new_embeddings):
1175
+ self.lm_head = new_embeddings
1176
+
1177
+ def set_decoder(self, decoder):
1178
+ self.model = decoder
1179
+
1180
+ def get_decoder(self):
1181
+ return self.model
1182
+
1183
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1184
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1185
+ def forward(
1186
+ self,
1187
+ input_ids: torch.LongTensor = None,
1188
+ attention_mask: Optional[torch.Tensor] = None,
1189
+ position_ids: Optional[torch.LongTensor] = None,
1190
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1191
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1192
+ labels: Optional[torch.LongTensor] = None,
1193
+ use_cache: Optional[bool] = None,
1194
+ output_attentions: Optional[bool] = None,
1195
+ output_hidden_states: Optional[bool] = None,
1196
+ return_dict: Optional[bool] = None,
1197
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1198
+ **kwargs,
1199
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1200
+ r"""
1201
+ Args:
1202
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1203
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1204
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1205
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1206
+
1207
+ Returns:
1208
+
1209
+ Example:
1210
+
1211
+ ```python
1212
+ >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
1213
+
1214
+ >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1215
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1216
+
1217
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1218
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1219
+
1220
+ >>> # Generate
1221
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1222
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1223
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1224
+ ```"""
1225
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1226
+ output_hidden_states = (
1227
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1228
+ )
1229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1230
+
1231
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1232
+ outputs = self.model(
1233
+ input_ids=input_ids,
1234
+ attention_mask=attention_mask,
1235
+ position_ids=position_ids,
1236
+ past_key_values=past_key_values,
1237
+ inputs_embeds=inputs_embeds,
1238
+ use_cache=use_cache,
1239
+ output_attentions=output_attentions,
1240
+ output_hidden_states=output_hidden_states,
1241
+ return_dict=return_dict,
1242
+ )
1243
+
1244
+ hidden_states = outputs[0]
1245
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1246
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1247
+ hidden_states = hidden_states[:, slice_indices, :].contiguous()
1248
+ if self.config.pretraining_tp > 1:
1249
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1250
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1251
+ logits = torch.cat(logits, dim=-1)
1252
+ else:
1253
+ logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
1254
+ logits = logits.float()
1255
+
1256
+ loss = None
1257
+ if labels is not None:
1258
+ # Shift so that tokens < n predict n
1259
+ shift_logits = logits[..., :-1, :].contiguous()
1260
+ shift_labels = labels[..., 1:].contiguous()
1261
+ # Flatten the tokens
1262
+ loss_fct = CrossEntropyLoss()
1263
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1264
+ shift_labels = shift_labels.view(-1)
1265
+ # Enable model parallelism
1266
+ shift_labels = shift_labels.to(shift_logits.device)
1267
+ loss = loss_fct(shift_logits, shift_labels)
1268
+
1269
+ if not return_dict:
1270
+ output = (logits,) + outputs[1:]
1271
+ return (loss,) + output if loss is not None else output
1272
+
1273
+ return CausalLMOutputWithPast(
1274
+ loss=loss,
1275
+ logits=logits,
1276
+ past_key_values=outputs.past_key_values,
1277
+ hidden_states=outputs.hidden_states,
1278
+ attentions=outputs.attentions,
1279
+ )
1280
+
1281
+ def prepare_inputs_for_generation(
1282
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1283
+ ):
1284
+ if past_key_values is not None:
1285
+ if isinstance(past_key_values, Cache):
1286
+ # Use the new Cache class methods
1287
+ cache_length = past_key_values.get_seq_length()
1288
+
1289
+
1290
+ past_length = cache_length
1291
+ max_cache_length = None
1292
+ else:
1293
+ raise ValueError(
1294
+ 'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1295
+ )
1296
+
1297
+ # Keep only the unprocessed tokens:
1298
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1299
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1300
+ # input)
1301
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1302
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1303
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1304
+ # input_ids based on the past_length.
1305
+ elif past_length < input_ids.shape[1]:
1306
+ input_ids = input_ids[:, past_length:]
1307
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1308
+
1309
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1310
+ if (
1311
+ max_cache_length is not None
1312
+ and attention_mask is not None
1313
+ and cache_length + input_ids.shape[1] > max_cache_length
1314
+ ):
1315
+ attention_mask = attention_mask[:, -max_cache_length:]
1316
+
1317
+ position_ids = kwargs.get('position_ids', None)
1318
+ if attention_mask is not None and position_ids is None:
1319
+ # create position_ids on the fly for batch generation
1320
+ position_ids = attention_mask.long().cumsum(-1) - 1
1321
+ position_ids.masked_fill_(attention_mask == 0, 1)
1322
+ if past_key_values:
1323
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1324
+
1325
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1326
+ if inputs_embeds is not None and past_key_values is None:
1327
+ model_inputs = {'inputs_embeds': inputs_embeds}
1328
+ else:
1329
+ model_inputs = {'input_ids': input_ids}
1330
+
1331
+ model_inputs.update(
1332
+ {
1333
+ 'position_ids': position_ids,
1334
+ 'past_key_values': past_key_values,
1335
+ 'use_cache': kwargs.get('use_cache'),
1336
+ 'attention_mask': attention_mask,
1337
+ }
1338
+ )
1339
+ # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
1340
+ for key, value in kwargs.items():
1341
+ if key not in model_inputs:
1342
+ model_inputs[key] = value
1343
+ return model_inputs
1344
+
1345
+ @staticmethod
1346
+ def _reorder_cache(past_key_values, beam_idx):
1347
+ reordered_past = ()
1348
+ for layer_past in past_key_values:
1349
+ reordered_past += (
1350
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1351
+ )
1352
+ return reordered_past
1353
+
1354
+ @torch.inference_mode()
1355
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = 'user',
1356
+ max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1357
+ **kwargs):
1358
+ if history is None:
1359
+ history = []
1360
+ if logits_processor:
1361
+ gen_kwargs = {
1362
+ 'max_length': max_length,
1363
+ 'num_beams': num_beams,
1364
+ 'do_sample': do_sample,
1365
+ 'top_p': top_p,
1366
+ 'temperature': temperature,
1367
+ 'logits_processor': logits_processor,
1368
+ **kwargs
1369
+ }
1370
+ else:
1371
+ gen_kwargs = {
1372
+ 'max_length': max_length,
1373
+ 'num_beams': num_beams,
1374
+ 'do_sample': do_sample,
1375
+ 'top_p': top_p,
1376
+ 'temperature': temperature,
1377
+ 'logits_processor': logits_processor,
1378
+ **kwargs
1379
+ }
1380
+
1381
+ history.append({'role': role, 'content': query})
1382
+ history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1383
+ inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1384
+ outputs = self.generate(**inputs, **gen_kwargs)
1385
+ outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):-1]
1386
+ response = tokenizer.decode(outputs)
1387
+ pattern = re.compile(r'.*?(?=<AI>|<用户>)', re.DOTALL)
1388
+ matches = pattern.findall(response)
1389
+ if len(matches) > 0:
1390
+ response = matches[0]
1391
+ history.append({'role': 'assistant', 'content': response})
1392
+ return response, history
1393
+
1394
+
1395
+ @add_start_docstrings(
1396
+ """
1397
+ The MiniCPM Model transformer with a sequence classification head on top (linear layer).
1398
+
1399
+ [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1400
+ (e.g. GPT-2) do.
1401
+
1402
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1403
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1404
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1405
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1406
+ each row of the batch).
1407
+ """,
1408
+ MINICPM_START_DOCSTRING,
1409
+ )
1410
+ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1411
+ def __init__(self, config):
1412
+ super().__init__(config)
1413
+ self.num_labels = config.num_labels
1414
+ self.model = MiniCPMModel(config)
1415
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1416
+
1417
+ # Initialize weights and apply final processing
1418
+ self.post_init()
1419
+
1420
+ def get_input_embeddings(self):
1421
+ return self.model.embed_tokens
1422
+
1423
+ def set_input_embeddings(self, value):
1424
+ self.model.embed_tokens = value
1425
+
1426
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1427
+ def forward(
1428
+ self,
1429
+ input_ids: torch.LongTensor = None,
1430
+ attention_mask: Optional[torch.Tensor] = None,
1431
+ position_ids: Optional[torch.LongTensor] = None,
1432
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1433
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1434
+ labels: Optional[torch.LongTensor] = None,
1435
+ use_cache: Optional[bool] = None,
1436
+ output_attentions: Optional[bool] = None,
1437
+ output_hidden_states: Optional[bool] = None,
1438
+ return_dict: Optional[bool] = None,
1439
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1440
+ r"""
1441
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1442
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1443
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1444
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1445
+ """
1446
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1447
+
1448
+ transformer_outputs = self.model(
1449
+ input_ids,
1450
+ attention_mask=attention_mask,
1451
+ position_ids=position_ids,
1452
+ past_key_values=past_key_values,
1453
+ inputs_embeds=inputs_embeds,
1454
+ use_cache=use_cache,
1455
+ output_attentions=output_attentions,
1456
+ output_hidden_states=output_hidden_states,
1457
+ return_dict=return_dict,
1458
+ )
1459
+ hidden_states = transformer_outputs[0]
1460
+ logits = self.score(hidden_states)
1461
+
1462
+ if input_ids is not None:
1463
+ batch_size = input_ids.shape[0]
1464
+ else:
1465
+ batch_size = inputs_embeds.shape[0]
1466
+
1467
+ if self.config.pad_token_id is None and batch_size != 1:
1468
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
1469
+ if self.config.pad_token_id is None:
1470
+ sequence_lengths = -1
1471
+ else:
1472
+ if input_ids is not None:
1473
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1474
+ logits.device
1475
+ )
1476
+ else:
1477
+ sequence_lengths = -1
1478
+
1479
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1480
+
1481
+ loss = None
1482
+ if labels is not None:
1483
+ labels = labels.to(logits.device)
1484
+ if self.config.problem_type is None:
1485
+ if self.num_labels == 1:
1486
+ self.config.problem_type = 'regression'
1487
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1488
+ self.config.problem_type = 'single_label_classification'
1489
+ else:
1490
+ self.config.problem_type = 'multi_label_classification'
1491
+
1492
+ if self.config.problem_type == 'regression':
1493
+ loss_fct = MSELoss()
1494
+ if self.num_labels == 1:
1495
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1496
+ else:
1497
+ loss = loss_fct(pooled_logits, labels)
1498
+ elif self.config.problem_type == 'single_label_classification':
1499
+ loss_fct = CrossEntropyLoss()
1500
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1501
+ elif self.config.problem_type == 'multi_label_classification':
1502
+ loss_fct = BCEWithLogitsLoss()
1503
+ loss = loss_fct(pooled_logits, labels)
1504
+ if not return_dict:
1505
+ output = (pooled_logits,) + transformer_outputs[1:]
1506
+ return ((loss,) + output) if loss is not None else output
1507
+
1508
+ return SequenceClassifierOutputWithPast(
1509
+ loss=loss,
1510
+ logits=pooled_logits,
1511
+ past_key_values=transformer_outputs.past_key_values,
1512
+ hidden_states=transformer_outputs.hidden_states,
1513
+ attentions=transformer_outputs.attentions,
1514
+ )
convert/MiniCPM4-0.5B/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_end|>",
4
+ "<|im_start|>",
5
+ "<|tool_call|>",
6
+ "<|execute_start|>",
7
+ "<|execute_end|>",
8
+ "<|fim_prefix|>",
9
+ "<|fim_middle|>",
10
+ "<|fim_suffix|>"
11
+ ],
12
+ "bos_token": {
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
convert/MiniCPM4-0.5B/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
convert/MiniCPM4-0.5B/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
3
+ size 1181204
convert/MiniCPM4-0.5B/tokenizer_config.json ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "73440": {
31
+ "content": "<|im_end|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "73441": {
39
+ "content": "<|im_start|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "73442": {
47
+ "content": "<|tool_call|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "73443": {
55
+ "content": "<|execute_start|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "73444": {
63
+ "content": "<|execute_end|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "73445": {
71
+ "content": "<|fim_prefix|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "73446": {
79
+ "content": "<|fim_middle|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "73447": {
87
+ "content": "<|fim_suffix|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ }
94
+ },
95
+ "additional_special_tokens": [
96
+ "<|im_end|>",
97
+ "<|im_start|>",
98
+ "<|tool_call|>",
99
+ "<|execute_start|>",
100
+ "<|execute_end|>",
101
+ "<|fim_prefix|>",
102
+ "<|fim_middle|>",
103
+ "<|fim_suffix|>"
104
+ ],
105
+ "bos_token": "<s>",
106
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
107
+ "clean_up_tokenization_spaces": false,
108
+ "eos_token": "<|im_end|>",
109
+ "legacy": true,
110
+ "model_max_length": 1000000000000000019884624838656,
111
+ "pad_token": null,
112
+ "sp_model_kwargs": {},
113
+ "spaces_between_special_tokens": false,
114
+ "tokenizer_class": "LlamaTokenizer",
115
+ "unk_token": "<unk>",
116
+ "use_default_system_prompt": false
117
+ }
convert/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模型转换
2
+
3
+ 1. 测试可用的依赖版本如下:
4
+
5
+ ```
6
+ torch==2.10.0
7
+ transformers==4.57.6
8
+ onnx==1.18.0
9
+ onnxruntime==1.22.0
10
+ einops==0.8.2
11
+ rknn-toolkit2==2.3.2
12
+ rkllm-toolkit==1.2.3
13
+ ```
14
+
15
+ 2. 下载模型
16
+
17
+ 从`https://huggingface.co/openbmb/VoxCPM-0.5B`下载模型,保存到`./VoxCPM-0.5B`文件夹。
18
+
19
+ 3. 转换模型
20
+
21
+ ```bash
22
+ python scripts/build_rk3588_pipeline.py
23
+ ```
24
+
25
+ 转换后的模型会放置在`build/rk3588/final_models/`.
26
+
27
+ ---
28
+
29
+ # Model Conversion
30
+
31
+ 1. Tested dependency versions:
32
+
33
+ ```
34
+ torch==2.10.0
35
+ transformers==4.57.6
36
+ onnx==1.18.0
37
+ onnxruntime==1.22.0
38
+ einops==0.8.2
39
+ rknn-toolkit2==2.3.2
40
+ rkllm-toolkit==1.2.3
41
+ ```
42
+
43
+ 2. Download the model
44
+
45
+ Download the model from `https://huggingface.co/openbmb/VoxCPM-0.5B` and save it to the `./VoxCPM-0.5B` directory.
46
+
47
+ 3. Convert the model
48
+
49
+ ```bash
50
+ python scripts/build_rk3588_pipeline.py
51
+ ```
52
+
53
+ The converted models will be placed in `build/rk3588/final_models/`.
convert/scripts/build_rk3588_pipeline.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+ import subprocess
8
+
9
+ from rknn.api import RKNN
10
+
11
+
12
+ REPO_ROOT = Path(__file__).resolve().parent.parent
13
+ SRC_DIR = REPO_ROOT / "src"
14
+
15
+ TOKENIZER_SUPPORT_FILES = [
16
+ "tokenizer.json",
17
+ "tokenizer_config.json",
18
+ "tokenizer.model",
19
+ "special_tokens_map.json",
20
+ "added_tokens.json",
21
+ "generation_config.json",
22
+ "README.md",
23
+ "modeling_minicpm.py",
24
+ "configuration_minicpm.py",
25
+ ]
26
+
27
+ RKNN_SPECS = [
28
+ ("audio_vae_encode.onnx", "audio_vae_encode.rknn", ["audio_wave"], [[1, 1, 40960]], None),
29
+ ("audio_vae_decode.onnx", "audio_vae_decode.rknn", ["latent"], [[1, 64, 64]], None),
30
+ ("locenc.onnx", "locenc_64.rknn", ["x"], [[1, 64, 2, 64]], None),
31
+ ("locenc.onnx", "locenc_1.rknn", ["x"], [[1, 1, 2, 64]], None),
32
+ ("fsq_layer.onnx", "fsq_layer.rknn", ["hidden"], [[1, 64, 1024]], [[[1, 64, 1024]], [[1, 1, 1024]]]),
33
+ ("stop_head.onnx", "stop_head.rknn", ["hidden"], [[1, 1024]], None),
34
+ ("lm_to_dit_proj.onnx", "lm_to_dit_proj.rknn", ["input"], [[1, 1024]], None),
35
+ ("res_to_dit_proj.onnx", "res_to_dit_proj.rknn", ["input"], [[1, 1024]], None),
36
+ ("dit_step.onnx", "dit_step.rknn", ["x", "mu", "t", "cond", "dt"], [[1, 64, 2], [1, 1024], [1], [1, 64, 2], [1]], None),
37
+ ]
38
+
39
+
40
+ def run(cmd: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None):
41
+ print("+", " ".join(cmd))
42
+ subprocess.run(cmd, cwd=cwd, env=env, check=True)
43
+
44
+
45
+ @contextlib.contextmanager
46
+ def pushd(path: Path):
47
+ prev = Path.cwd()
48
+ os.chdir(path)
49
+ try:
50
+ yield
51
+ finally:
52
+ os.chdir(prev)
53
+
54
+
55
+ def ensure_dir(path: Path):
56
+ path.mkdir(parents=True, exist_ok=True)
57
+
58
+
59
+ def copy_if_exists(src: Path, dst: Path):
60
+ if src.exists():
61
+ shutil.copy2(src, dst)
62
+
63
+
64
+ def sync_hf_support_files(minicpm_dir: Path, target_dir: Path):
65
+ ensure_dir(target_dir)
66
+ metadata_json = target_dir / "configuration.json"
67
+ if metadata_json.exists():
68
+ metadata_json.unlink()
69
+ for name in TOKENIZER_SUPPORT_FILES:
70
+ copy_if_exists(minicpm_dir / name, target_dir / name)
71
+
72
+
73
+ def patch_hf_config(reference_config_path: Path, target_config_path: Path, architecture: str):
74
+ reference = json.loads(reference_config_path.read_text())
75
+ target = json.loads(target_config_path.read_text())
76
+ if "auto_map" in reference:
77
+ target["auto_map"] = reference["auto_map"]
78
+ target["architectures"] = [architecture]
79
+ target_config_path.write_text(json.dumps(target, indent=2, ensure_ascii=False) + "\n")
80
+
81
+
82
+ def export_onnx(model_dir: Path, onnx_dir: Path):
83
+ ensure_dir(onnx_dir)
84
+ env = os.environ.copy()
85
+ env["PYTHONPATH"] = str(SRC_DIR) + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
86
+ run(
87
+ [
88
+ "python",
89
+ str(REPO_ROOT / "scripts" / "export_onnx.py"),
90
+ "--model-dir",
91
+ str(model_dir),
92
+ "--out-dir",
93
+ str(onnx_dir),
94
+ "--dump-embeddings",
95
+ ],
96
+ cwd=REPO_ROOT,
97
+ env=env,
98
+ )
99
+
100
+
101
+ def convert_one_rknn(
102
+ onnx_dir: Path,
103
+ rknn_dir: Path,
104
+ spec: tuple[str, str, list[str], list[list[int]] | None, list[list[list[int]]] | None],
105
+ target_platform: str,
106
+ ):
107
+ onnx_name, rknn_name, inputs, input_size_list, dynamic_input = spec
108
+ onnx_path = onnx_dir / onnx_name
109
+ out_path = rknn_dir / rknn_name
110
+ ensure_dir(rknn_dir)
111
+
112
+ if not onnx_path.exists():
113
+ raise FileNotFoundError(f"Missing ONNX file: {onnx_path}")
114
+
115
+ rknn = RKNN(verbose=False)
116
+ ret = rknn.config(target_platform=target_platform, dynamic_input=dynamic_input)
117
+ if ret != 0:
118
+ raise RuntimeError(f"RKNN config failed for {onnx_name}, ret={ret}")
119
+
120
+ load_kwargs = {"model": str(onnx_path)}
121
+ if input_size_list is not None:
122
+ load_kwargs["inputs"] = inputs
123
+ load_kwargs["input_size_list"] = input_size_list
124
+
125
+ ret = rknn.load_onnx(**load_kwargs)
126
+ if ret != 0:
127
+ raise RuntimeError(f"RKNN load_onnx failed for {onnx_name}, ret={ret}")
128
+
129
+ ret = rknn.build(do_quantization=False)
130
+ if ret != 0:
131
+ raise RuntimeError(f"RKNN build failed for {onnx_name}, ret={ret}")
132
+
133
+ ret = rknn.export_rknn(str(out_path))
134
+ if ret != 0:
135
+ raise RuntimeError(f"RKNN export failed for {out_path}, ret={ret}")
136
+ rknn.release()
137
+
138
+
139
+ def export_rknn(onnx_dir: Path, rknn_dir: Path, target_platform: str):
140
+ ensure_dir(rknn_dir)
141
+ copy_if_exists(onnx_dir / "embed_tokens.npy", rknn_dir / "embed_tokens.npy")
142
+ with pushd(rknn_dir):
143
+ for spec in RKNN_SPECS:
144
+ convert_one_rknn(onnx_dir, rknn_dir, spec, target_platform)
145
+
146
+
147
+ def collect_final_models(build_dir: Path):
148
+ final_dir = build_dir / "final_models"
149
+ ensure_dir(final_dir)
150
+
151
+ for name in [
152
+ "audio_vae_encode.rknn",
153
+ "audio_vae_decode.rknn",
154
+ "locenc_64.rknn",
155
+ "locenc_1.rknn",
156
+ "fsq_layer.rknn",
157
+ "stop_head.rknn",
158
+ "lm_to_dit_proj.rknn",
159
+ "res_to_dit_proj.rknn",
160
+ "dit_step.rknn",
161
+ "embed_tokens.npy",
162
+ ]:
163
+ copy_if_exists(build_dir / "rknn" / name, final_dir / name)
164
+
165
+ copy_if_exists(build_dir / "rkllm" / "base" / "language_model.rkllm", final_dir / "base_lm.rkllm")
166
+ copy_if_exists(build_dir / "rkllm" / "residual" / "language_model.rkllm", final_dir / "residual_lm.rkllm")
167
+
168
+
169
+ def convert_vox_to_hf(vox_config: Path, vox_state: Path, minicpm_dir: Path, base_out: Path, residual_out: Path):
170
+ ensure_dir(base_out)
171
+ ensure_dir(residual_out)
172
+ run(
173
+ [
174
+ "python",
175
+ str(REPO_ROOT / "scripts" / "convert_vox_minicpm_to_hf.py"),
176
+ "--vox-config",
177
+ str(vox_config),
178
+ "--vox-state",
179
+ str(vox_state),
180
+ "--minicpm-dir",
181
+ str(minicpm_dir),
182
+ "--out-dir",
183
+ str(base_out),
184
+ "--out-residual-dir",
185
+ str(residual_out),
186
+ ],
187
+ cwd=REPO_ROOT,
188
+ )
189
+ sync_hf_support_files(minicpm_dir, base_out)
190
+ sync_hf_support_files(minicpm_dir, residual_out)
191
+ patch_hf_config(minicpm_dir / "config.json", base_out / "config.json", "MiniCPMForCausalLM")
192
+ patch_hf_config(minicpm_dir / "config.json", residual_out / "config.json", "MiniCPMModel")
193
+
194
+
195
+ def export_rkllm(hf_dir: Path, out_path: Path, target_platform: str, num_npu_core: int):
196
+ hf_home = out_path.parent.parent.parent / "cache" / "huggingface"
197
+ ensure_dir(hf_home)
198
+ env = os.environ.copy()
199
+ env["HF_HOME"] = str(hf_home)
200
+ env["HUGGINGFACE_HUB_CACHE"] = str(hf_home / "hub")
201
+ env["TRANSFORMERS_CACHE"] = str(hf_home / "transformers")
202
+ run(
203
+ [
204
+ "python",
205
+ str(REPO_ROOT / "scripts" / "export_rkllm.py"),
206
+ "--model-dir",
207
+ str(hf_dir),
208
+ "--output",
209
+ str(out_path),
210
+ "--target-platform",
211
+ target_platform,
212
+ "--num-npu-core",
213
+ str(num_npu_core),
214
+ "--hf-home",
215
+ str(hf_home),
216
+ ],
217
+ cwd=REPO_ROOT,
218
+ env=env,
219
+ )
220
+
221
+
222
+ def write_manifest(build_dir: Path, model_dir: Path, minicpm_dir: Path):
223
+ manifest = {
224
+ "model_dir": str(model_dir),
225
+ "minicpm_dir": str(minicpm_dir),
226
+ "onnx_dir": str(build_dir / "onnx"),
227
+ "rknn_dir": str(build_dir / "rknn"),
228
+ "hf_base_dir": str(build_dir / "hf" / "base"),
229
+ "hf_residual_dir": str(build_dir / "hf" / "residual"),
230
+ "rkllm_base_model": str(build_dir / "rkllm" / "base" / "language_model.rkllm"),
231
+ "rkllm_residual_model": str(build_dir / "rkllm" / "residual" / "language_model.rkllm"),
232
+ "output_dir": str(build_dir / "output"),
233
+ }
234
+ ensure_dir(build_dir)
235
+ (build_dir / "build_manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False) + "\n")
236
+
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser(description="Rebuild the VoxCPM RK3588 deployment artifacts from scratch.")
240
+ parser.add_argument("--model-dir", default="VoxCPM-0.5B", help="Path to the original VoxCPM-0.5B model directory.")
241
+ parser.add_argument("--minicpm-dir", default="MiniCPM4-0.5B", help="Path to the reference MiniCPM4-0.5B directory.")
242
+ parser.add_argument("--build-dir", default="build/rk3588", help="Output root for rebuilt artifacts.")
243
+ parser.add_argument("--target-platform", default="rk3588", help="RK target platform.")
244
+ parser.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export.")
245
+ parser.add_argument("--skip-rknn", action="store_true", help="Skip RKNN conversion.")
246
+ parser.add_argument("--skip-hf", action="store_true", help="Skip Vox->HF conversion.")
247
+ parser.add_argument("--skip-rkllm", action="store_true", help="Skip RKLLM export.")
248
+ args = parser.parse_args()
249
+
250
+ model_dir = (REPO_ROOT / args.model_dir).resolve()
251
+ minicpm_dir = (REPO_ROOT / args.minicpm_dir).resolve()
252
+ build_dir = (REPO_ROOT / args.build_dir).resolve()
253
+ onnx_dir = build_dir / "onnx"
254
+ rknn_dir = build_dir / "rknn"
255
+ hf_base_dir = build_dir / "hf" / "base"
256
+ hf_residual_dir = build_dir / "hf" / "residual"
257
+ rkllm_base_path = build_dir / "rkllm" / "base" / "language_model.rkllm"
258
+ rkllm_residual_path = build_dir / "rkllm" / "residual" / "language_model.rkllm"
259
+ ensure_dir(build_dir / "output")
260
+
261
+ if not args.skip_onnx:
262
+ export_onnx(model_dir, onnx_dir)
263
+ if not args.skip_rknn:
264
+ export_rknn(onnx_dir, rknn_dir, args.target_platform)
265
+ if not args.skip_hf:
266
+ convert_vox_to_hf(
267
+ vox_config=model_dir / "config.json",
268
+ vox_state=model_dir / "pytorch_model.bin",
269
+ minicpm_dir=minicpm_dir,
270
+ base_out=hf_base_dir,
271
+ residual_out=hf_residual_dir,
272
+ )
273
+ if not args.skip_rkllm:
274
+ export_rkllm(hf_base_dir, rkllm_base_path, args.target_platform, num_npu_core=1)
275
+ export_rkllm(hf_residual_dir, rkllm_residual_path, args.target_platform, num_npu_core=3)
276
+
277
+ collect_final_models(build_dir)
278
+ write_manifest(build_dir, model_dir, minicpm_dir)
279
+ print(f"Saved: {build_dir / 'build_manifest.json'}")
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()
convert/scripts/convert_vox_minicpm_to_hf.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ import torch
6
+ import math
7
+
8
+
9
+ def load_vox_configs(vox_config_path: str) -> tuple[dict, dict]:
10
+ """Return (base_lm_cfg, residual_cfg)."""
11
+ with open(vox_config_path, "r") as f:
12
+ data = json.load(f)
13
+
14
+ base = data["lm_config"]
15
+ rope = base.get("rope_scaling")
16
+ if rope:
17
+ rope = dict(rope)
18
+ # Vox config uses "type", transformers expects "rope_type"
19
+ if "type" in rope and "rope_type" not in rope:
20
+ rope["rope_type"] = rope.pop("type")
21
+ base["rope_scaling"] = rope
22
+
23
+ residual = dict(base)
24
+ residual["num_hidden_layers"] = data.get("residual_lm_num_layers", residual["num_hidden_layers"])
25
+ # keep vocab_size for easier loading; Vox sets 0 because inputs_embeds are provided
26
+ residual.setdefault("vocab_size", base.get("vocab_size"))
27
+
28
+ # Align transformers residual scaling with Vox (no scaling when use_mup=False)
29
+ if not base.get("use_mup", True):
30
+ base["scale_depth"] = math.sqrt(base["num_hidden_layers"])
31
+ residual["scale_depth"] = math.sqrt(residual["num_hidden_layers"])
32
+ return base, residual
33
+
34
+
35
+ def build_hf_config(lm_cfg: dict, minicpm_dir: str):
36
+ sys.path.insert(0, minicpm_dir)
37
+ from configuration_minicpm import MiniCPMConfig
38
+
39
+ return MiniCPMConfig(**lm_cfg)
40
+
41
+
42
+ def convert_state_dict(vox_state_path: str, lm_prefix: str) -> dict:
43
+ raw = torch.load(vox_state_path, map_location="cpu")
44
+ sd = raw["state_dict"] if isinstance(raw, dict) and "state_dict" in raw else raw
45
+
46
+ out = {}
47
+ prefix = f"{lm_prefix}."
48
+ for k, v in sd.items():
49
+ if not k.startswith(prefix):
50
+ continue
51
+ new_k = "model." + k[len(prefix) :]
52
+ out[new_k] = v
53
+
54
+ # Tie lm_head to embeddings for MiniCPMForCausalLM
55
+ if "model.embed_tokens.weight" in out:
56
+ out["lm_head.weight"] = out["model.embed_tokens.weight"]
57
+ return out
58
+
59
+
60
+ def main():
61
+ parser = argparse.ArgumentParser(description="Convert VoxCPM MiniCPM weights to transformers format")
62
+ parser.add_argument(
63
+ "--vox-config",
64
+ default="VoxCPM-0.5B/config.json",
65
+ help="Path to VoxCPM config.json (used to read lm_config)",
66
+ )
67
+ parser.add_argument(
68
+ "--vox-state",
69
+ default="VoxCPM-0.5B/pytorch_model.bin",
70
+ help="Path to VoxCPM checkpoint containing base_lm weights",
71
+ )
72
+ parser.add_argument(
73
+ "--minicpm-dir",
74
+ default="MiniCPM4-0.5B",
75
+ help="Path to local MiniCPM4-0.5B directory (provides configuration_minicpm.py)",
76
+ )
77
+ parser.add_argument(
78
+ "--out-dir",
79
+ default="converted-minicpm-hf",
80
+ help="Output directory for base LM transformers-style checkpoint",
81
+ )
82
+ parser.add_argument(
83
+ "--out-residual-dir",
84
+ default="converted-minicpm-residual-hf",
85
+ help="Output directory for residual LM checkpoint",
86
+ )
87
+ args = parser.parse_args()
88
+
89
+ os.makedirs(args.out_dir, exist_ok=True)
90
+ os.makedirs(args.out_residual_dir, exist_ok=True)
91
+
92
+ base_cfg, residual_cfg = load_vox_configs(args.vox_config)
93
+
94
+ hf_config = build_hf_config(base_cfg, args.minicpm_dir)
95
+ hf_config.save_pretrained(args.out_dir)
96
+
97
+ print("Loaded Vox lm_config and wrote transformers config to", args.out_dir)
98
+
99
+ hf_state = convert_state_dict(args.vox_state, lm_prefix="base_lm")
100
+ out_path = os.path.join(args.out_dir, "pytorch_model.bin")
101
+ torch.save(hf_state, out_path)
102
+ print("Saved base LM weights to", out_path)
103
+
104
+ residual_hf_config = build_hf_config(residual_cfg, args.minicpm_dir)
105
+ residual_hf_config.save_pretrained(args.out_residual_dir)
106
+ residual_state = convert_state_dict(args.vox_state, lm_prefix="residual_lm")
107
+ residual_out_path = os.path.join(args.out_residual_dir, "pytorch_model.bin")
108
+ torch.save(residual_state, residual_out_path)
109
+ print("Saved residual LM weights to", residual_out_path)
110
+
111
+ print("Load with MiniCPMForCausalLM.from_pretrained(...) or MiniCPMModel.from_pretrained(...).")
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
convert/scripts/export_onnx.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import torch
5
+ from torch import nn
6
+
7
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ SRC_DIR = os.path.join(REPO_ROOT, "src")
9
+ if SRC_DIR not in sys.path:
10
+ sys.path.insert(0, SRC_DIR)
11
+
12
+ from voxcpm.model.voxcpm import VoxCPMModel
13
+
14
+
15
+ def remove_weight_norm(module: nn.Module):
16
+ """Strip weight_norm wrappers for cleaner ONNX graphs."""
17
+ for name, child in module.named_children():
18
+ remove_weight_norm(child)
19
+ if isinstance(child, (nn.Conv1d, nn.ConvTranspose1d)):
20
+ try:
21
+ torch.nn.utils.remove_weight_norm(child)
22
+ except ValueError:
23
+ # not wrapped, skip
24
+ pass
25
+
26
+
27
+ class VAEEncodeWrapper(nn.Module):
28
+ def __init__(self, audio_vae: nn.Module):
29
+ super().__init__()
30
+ self.audio_vae = audio_vae
31
+
32
+ def forward(self, audio_wave: torch.Tensor):
33
+ return self.audio_vae.encode(audio_wave, self.audio_vae.sample_rate)
34
+
35
+
36
+ class VAEDecodeWrapper(nn.Module):
37
+ def __init__(self, audio_vae: nn.Module):
38
+ super().__init__()
39
+ self.audio_vae = audio_vae
40
+
41
+ def forward(self, latent: torch.Tensor):
42
+ return self.audio_vae.decode(latent)
43
+
44
+
45
+ class LocEncWrapper(nn.Module):
46
+ def __init__(self, locenc: nn.Module):
47
+ super().__init__()
48
+ self.locenc = locenc
49
+
50
+ def forward(self, x: torch.Tensor):
51
+ # x: [B, T, P, D]
52
+ return self.locenc(x)
53
+
54
+
55
+ class LocEncLmWrapper(nn.Module):
56
+ """LocEnc with enc_to_lm projection fused in a single graph."""
57
+
58
+ def __init__(self, locenc: nn.Module, proj: nn.Module):
59
+ super().__init__()
60
+ self.locenc = locenc
61
+ self.proj = proj
62
+
63
+ def forward(self, x: torch.Tensor):
64
+ # x: [B, T, P, D]
65
+ hidden = self.locenc(x)
66
+ return self.proj(hidden)
67
+
68
+
69
+ class FSQWrapper(nn.Module):
70
+ def __init__(self, fsq: nn.Module):
71
+ super().__init__()
72
+ self.fsq = fsq
73
+
74
+ def forward(self, hidden: torch.Tensor):
75
+ return self.fsq(hidden)
76
+
77
+
78
+ class StopHeadWrapper(nn.Module):
79
+ def __init__(self, stop_proj: nn.Linear, stop_actn: nn.Module, stop_head: nn.Linear):
80
+ super().__init__()
81
+ self.stop_proj = stop_proj
82
+ self.stop_actn = stop_actn
83
+ self.stop_head = stop_head
84
+
85
+ def forward(self, hidden: torch.Tensor):
86
+ hidden = self.stop_proj(hidden)
87
+ hidden = self.stop_actn(hidden)
88
+ return self.stop_head(hidden)
89
+
90
+
91
+ class CFMWrapper(nn.Module):
92
+ """
93
+ Wrapper for one diffusion step block.
94
+
95
+ Note: the number of diffusion steps (n_timesteps) is fixed at export time.
96
+ """
97
+
98
+ def __init__(self, cfm: nn.Module, patch_size: int, n_timesteps: int, cfg_value: float):
99
+ super().__init__()
100
+ self.cfm = cfm
101
+ self.patch_size = patch_size
102
+ self.n_timesteps = n_timesteps
103
+ self.cfg_value = cfg_value
104
+
105
+ def forward(self, mu: torch.Tensor, cond: torch.Tensor):
106
+ # mu: [B, H_dit], cond: [B, D_feat, P]
107
+ return self.cfm(
108
+ mu=mu,
109
+ n_timesteps=self.n_timesteps,
110
+ patch_size=self.patch_size,
111
+ cond=cond,
112
+ cfg_value=self.cfg_value,
113
+ )
114
+
115
+
116
+ class DiTStepWrapper(nn.Module):
117
+ """
118
+ Wrapper for a single VoxCPMLocDiT forward (one diffusion score estimation step).
119
+ Inputs match VoxCPMLocDiT.forward: x, mu, t, cond, dt.
120
+ """
121
+
122
+ def __init__(self, dit: nn.Module):
123
+ super().__init__()
124
+ self.dit = dit
125
+
126
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, dt: torch.Tensor):
127
+ return self.dit(x, mu, t, cond, dt)
128
+
129
+
130
+ def export(model: nn.Module, inputs, path: str, dynamic_axes: dict, opset: int):
131
+ os.makedirs(os.path.dirname(path), exist_ok=True)
132
+ torch.onnx.export(
133
+ model,
134
+ inputs,
135
+ path,
136
+ opset_version=opset,
137
+ dynamo=True,
138
+ do_constant_folding=True,
139
+ input_names=list(dynamic_axes.keys()),
140
+ output_names=["output"],
141
+ dynamic_axes=dynamic_axes,
142
+ )
143
+ print(f"Saved: {path}")
144
+
145
+
146
+ def main():
147
+ parser = argparse.ArgumentParser(description="Export VoxCPM submodules to ONNX (LLM excluded).")
148
+ parser.add_argument("--model-dir", required=True, help="Path to VoxCPM model directory (config/weights).")
149
+ parser.add_argument("--out-dir", default="onnx_exports", help="Output directory for ONNX files.")
150
+ parser.add_argument("--opset", type=int, default=18, help="ONNX opset version.")
151
+ parser.add_argument("--audio-samples", type=int, default=1280, help="Dummy audio length for encoder export.")
152
+ parser.add_argument("--latent-steps", type=int, default=6, help="Dummy latent steps for decoder export.")
153
+ parser.add_argument("--seq-len", type=int, default=4, help="Dummy sequence length for LocEnc/FSQ export.")
154
+ parser.add_argument("--dit-step-t", type=float, default=0.5, help="Dummy diffusion time for DiT step export.")
155
+ parser.add_argument("--force-fp32", action="store_true", help="Force submodules to float32 for ONNX export.")
156
+ parser.add_argument("--dump-embeddings", action="store_true", help="Dump base_lm.embed_tokens weights to npy.")
157
+ args = parser.parse_args()
158
+
159
+ device = torch.device("cpu")
160
+ # Load full model once, then peel submodules; keep optimize disabled.
161
+ full_model = VoxCPMModel.from_local(args.model_dir, optimize=False).to(device).eval()
162
+ if args.force_fp32 or full_model.config.dtype != "float32":
163
+ full_model.config.dtype = "float32"
164
+ full_model = full_model.to(torch.float32)
165
+ full_model.audio_vae = full_model.audio_vae.to(torch.float32)
166
+ remove_weight_norm(full_model)
167
+
168
+ # Audio VAE encode
169
+ vae_enc = VAEEncodeWrapper(full_model.audio_vae).to(device).eval()
170
+ dummy_audio = torch.randn(1, 1, args.audio_samples, device=device)
171
+ export(
172
+ vae_enc,
173
+ dummy_audio,
174
+ os.path.join(args.out_dir, "audio_vae_encode.onnx"),
175
+ dynamic_axes={"audio_wave": {0: "batch", 2: "samples"}},
176
+ opset=args.opset,
177
+ )
178
+
179
+ # Audio VAE decode
180
+ vae_dec = VAEDecodeWrapper(full_model.audio_vae).to(device).eval()
181
+ dummy_latent = torch.randn(1, full_model.audio_vae.latent_dim, args.latent_steps, device=device)
182
+ export(
183
+ vae_dec,
184
+ dummy_latent,
185
+ os.path.join(args.out_dir, "audio_vae_decode.onnx"),
186
+ dynamic_axes={"latent": {0: "batch", 2: "latent_steps"}},
187
+ opset=args.opset,
188
+ )
189
+
190
+ # LocEnc with enc_to_lm projection fused
191
+ locenc = LocEncLmWrapper(full_model.feat_encoder, full_model.enc_to_lm_proj).to(device).eval()
192
+ dummy_seq = torch.randn(1, args.seq_len, full_model.patch_size, full_model.feat_dim, device=device)
193
+ export(
194
+ locenc,
195
+ dummy_seq,
196
+ os.path.join(args.out_dir, "locenc.onnx"),
197
+ dynamic_axes={"x": {0: "batch", 1: "seq_len"}},
198
+ opset=args.opset,
199
+ )
200
+
201
+ # FSQ layer
202
+ fsq = FSQWrapper(full_model.fsq_layer).to(device).eval()
203
+ hidden_size = full_model.config.lm_config.hidden_size
204
+ dummy_hidden = torch.randn(1, args.seq_len, hidden_size, device=device)
205
+ export(
206
+ fsq,
207
+ dummy_hidden,
208
+ os.path.join(args.out_dir, "fsq_layer.onnx"),
209
+ dynamic_axes={"hidden": {0: "batch", 1: "seq_len"}},
210
+ opset=args.opset,
211
+ )
212
+
213
+ # Stop head
214
+ stop = StopHeadWrapper(full_model.stop_proj, full_model.stop_actn, full_model.stop_head).to(device).eval()
215
+ dummy_stop_inp = torch.randn(1, hidden_size, device=device)
216
+ export(
217
+ stop,
218
+ dummy_stop_inp,
219
+ os.path.join(args.out_dir, "stop_head.onnx"),
220
+ dynamic_axes={"hidden": {0: "batch"}},
221
+ opset=args.opset,
222
+ )
223
+
224
+ # Projection layers
225
+ # export(
226
+ # full_model.enc_to_lm_proj,
227
+ # dummy_hidden,
228
+ # os.path.join(args.out_dir, "enc_to_lm_proj.onnx"),
229
+ # dynamic_axes={"input": {0: "batch", 1: "seq_len"}},
230
+ # opset=args.opset,
231
+ # )
232
+ lm_hidden = torch.randn(1, full_model.config.lm_config.hidden_size, device=device)
233
+ export(
234
+ full_model.lm_to_dit_proj,
235
+ lm_hidden,
236
+ os.path.join(args.out_dir, "lm_to_dit_proj.onnx"),
237
+ dynamic_axes={"input": {0: "batch"}},
238
+ opset=args.opset,
239
+ )
240
+ export(
241
+ full_model.res_to_dit_proj,
242
+ lm_hidden,
243
+ os.path.join(args.out_dir, "res_to_dit_proj.onnx"),
244
+ dynamic_axes={"input": {0: "batch"}},
245
+ opset=args.opset,
246
+ )
247
+
248
+ # VoxCPMLocDiT single step (score function)
249
+ dit_step = DiTStepWrapper(full_model.feat_decoder.estimator).to(device).eval()
250
+ dummy_x = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
251
+ dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device)
252
+ dummy_t = torch.full((1,), args.dit_step_t, device=device)
253
+ dummy_dt = torch.full((1,), 0.0, device=device)
254
+ dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
255
+ export(
256
+ dit_step,
257
+ (dummy_x, dummy_mu, dummy_t, dummy_cond, dummy_dt),
258
+ os.path.join(args.out_dir, "dit_step.onnx"),
259
+ dynamic_axes={
260
+ "x": {0: "batch"},
261
+ "mu": {0: "batch"},
262
+ "t": {0: "batch"},
263
+ "cond": {0: "batch"},
264
+ "dt": {0: "batch"},
265
+ },
266
+ opset=args.opset,
267
+ )
268
+
269
+ # # UnifiedCFM + VoxCPMLocDiT (single-step sampler unrolled with fixed n_timesteps)
270
+ # cfm = CFMWrapper(
271
+ # full_model.feat_decoder,
272
+ # patch_size=full_model.patch_size,
273
+ # n_timesteps=args.cfm_steps,
274
+ # cfg_value=args.cfg_value,
275
+ # ).to(device).eval()
276
+ # dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device)
277
+ # dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
278
+ # export(
279
+ # cfm,
280
+ # (dummy_mu, dummy_cond),
281
+ # os.path.join(args.out_dir, "cfm_step.onnx"),
282
+ # dynamic_axes={"mu": {0: "batch"}, "cond": {0: "batch"}},
283
+ # opset=args.opset,
284
+ # )
285
+
286
+ if args.dump_embeddings and hasattr(full_model.base_lm, "embed_tokens"):
287
+ import numpy as np
288
+ emb = full_model.base_lm.embed_tokens.weight.detach().cpu().numpy()
289
+ os.makedirs(args.out_dir, exist_ok=True)
290
+ np.save(os.path.join(args.out_dir, "embed_tokens.npy"), emb)
291
+ print(f"Saved: {os.path.join(args.out_dir, 'embed_tokens.npy')}")
292
+
293
+ print("Done.")
294
+
295
+
296
+ if __name__ == "__main__":
297
+ main()
convert/scripts/export_rkllm.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from rkllm.api import RKLLM
6
+
7
+
8
+ def export_rkllm(
9
+ model_dir: Path,
10
+ output_path: Path,
11
+ target_platform: str,
12
+ num_npu_core: int,
13
+ optimization_level: int,
14
+ ):
15
+ llm = RKLLM()
16
+ ret = llm.load_huggingface(model=str(model_dir), model_lora=None, device="cpu")
17
+ if ret != 0:
18
+ raise RuntimeError(f"load_huggingface failed for {model_dir}, ret={ret}")
19
+
20
+ ret = llm.build(
21
+ do_quantization=False,
22
+ optimization_level=optimization_level,
23
+ quantized_dtype="w8a8",
24
+ quantized_algorithm="normal",
25
+ target_platform=target_platform,
26
+ num_npu_core=num_npu_core,
27
+ extra_qparams=None,
28
+ )
29
+ if ret != 0:
30
+ raise RuntimeError(f"RKLLM build failed for {model_dir}, ret={ret}")
31
+
32
+ output_path.parent.mkdir(parents=True, exist_ok=True)
33
+ ret = llm.export_rkllm(str(output_path))
34
+ if ret != 0:
35
+ raise RuntimeError(f"export_rkllm failed for {output_path}, ret={ret}")
36
+
37
+
38
+ def main():
39
+ parser = argparse.ArgumentParser(description="Export a HuggingFace-format MiniCPM model to RKLLM.")
40
+ parser.add_argument("--model-dir", required=True, help="Input HuggingFace model directory.")
41
+ parser.add_argument("--output", required=True, help="Output .rkllm path.")
42
+ parser.add_argument("--target-platform", default="rk3588", help="RK target platform.")
43
+ parser.add_argument("--num-npu-core", type=int, default=1, help="NPU cores for RKLLM build.")
44
+ parser.add_argument("--optimization-level", type=int, default=1, help="RKLLM optimization level.")
45
+ parser.add_argument("--hf-home", default=None, help="Optional writable Hugging Face cache root.")
46
+ args = parser.parse_args()
47
+
48
+ if args.hf_home:
49
+ hf_home = str(Path(args.hf_home).resolve())
50
+ os.environ["HF_HOME"] = hf_home
51
+ os.environ["HUGGINGFACE_HUB_CACHE"] = str(Path(hf_home) / "hub")
52
+ os.environ["TRANSFORMERS_CACHE"] = str(Path(hf_home) / "transformers")
53
+
54
+ export_rkllm(
55
+ model_dir=Path(args.model_dir),
56
+ output_path=Path(args.output),
57
+ target_platform=args.target_platform,
58
+ num_npu_core=args.num_npu_core,
59
+ optimization_level=args.optimization_level,
60
+ )
61
+ print(f"Saved: {args.output}")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()
convert/src/voxcpm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .core import VoxCPM
2
+
3
+ __all__ = [
4
+ "VoxCPM",
5
+ ]
convert/src/voxcpm/cli.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VoxCPM Command Line Interface
4
+
5
+ Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
6
+
7
+ Usage examples:
8
+ # Direct synthesis (single sample)
9
+ voxcpm --text "Hello world" --output output.wav
10
+
11
+ # Voice cloning (with reference audio and text)
12
+ voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
13
+
14
+ # Batch processing (each line in the file is one sample)
15
+ voxcpm --input texts.txt --output-dir ./outputs/
16
+ """
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Optional, List
23
+ import soundfile as sf
24
+
25
+ from voxcpm.core import VoxCPM
26
+
27
+
28
+ def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
29
+ """Validate that a file exists."""
30
+ path = Path(file_path)
31
+ if not path.exists():
32
+ raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
33
+ return path
34
+
35
+
36
+ def validate_output_path(output_path: str) -> Path:
37
+ """Validate the output path and create parent directories if needed."""
38
+ path = Path(output_path)
39
+ path.parent.mkdir(parents=True, exist_ok=True)
40
+ return path
41
+
42
+
43
+ def load_model(args) -> VoxCPM:
44
+ """Load VoxCPM model.
45
+
46
+ Prefer --model-path if provided; otherwise use from_pretrained (Hub).
47
+ """
48
+ print("Loading VoxCPM model...")
49
+
50
+ # 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
51
+ zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
52
+ "ZIPENHANCER_MODEL_PATH", None
53
+ )
54
+
55
+ # Load from local path if provided
56
+ if getattr(args, "model_path", None):
57
+ try:
58
+ model = VoxCPM(
59
+ voxcpm_model_path=args.model_path,
60
+ zipenhancer_model_path=zipenhancer_path,
61
+ enable_denoiser=not getattr(args, "no_denoiser", False),
62
+ )
63
+ print("Model loaded (local).")
64
+ return model
65
+ except Exception as e:
66
+ print(f"Failed to load model (local): {e}")
67
+ sys.exit(1)
68
+
69
+ # Otherwise, try from_pretrained (Hub); exit on failure
70
+ try:
71
+ model = VoxCPM.from_pretrained(
72
+ hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
73
+ load_denoiser=not getattr(args, "no_denoiser", False),
74
+ zipenhancer_model_id=zipenhancer_path,
75
+ cache_dir=getattr(args, "cache_dir", None),
76
+ local_files_only=getattr(args, "local_files_only", False),
77
+ )
78
+ print("Model loaded (from_pretrained).")
79
+ return model
80
+ except Exception as e:
81
+ print(f"Failed to load model (from_pretrained): {e}")
82
+ sys.exit(1)
83
+
84
+
85
+ def cmd_clone(args):
86
+ """Voice cloning command."""
87
+ # Validate inputs
88
+ if not args.text:
89
+ print("Error: Please provide text to synthesize (--text)")
90
+ sys.exit(1)
91
+
92
+ if not args.prompt_audio:
93
+ print("Error: Voice cloning requires a reference audio (--prompt-audio)")
94
+ sys.exit(1)
95
+
96
+ if not args.prompt_text:
97
+ print("Error: Voice cloning requires a reference text (--prompt-text)")
98
+ sys.exit(1)
99
+
100
+ # Validate files
101
+ prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
102
+ output_path = validate_output_path(args.output)
103
+
104
+ # Load model
105
+ model = load_model(args)
106
+
107
+ # Generate audio
108
+ print(f"Synthesizing text: {args.text}")
109
+ print(f"Reference audio: {prompt_audio_path}")
110
+ print(f"Reference text: {args.prompt_text}")
111
+
112
+ audio_array = model.generate(
113
+ text=args.text,
114
+ prompt_wav_path=str(prompt_audio_path),
115
+ prompt_text=args.prompt_text,
116
+ cfg_value=args.cfg_value,
117
+ inference_timesteps=args.inference_timesteps,
118
+ normalize=args.normalize,
119
+ denoise=args.denoise
120
+ )
121
+
122
+ # Save audio
123
+ sf.write(str(output_path), audio_array, 16000)
124
+ print(f"Saved audio to: {output_path}")
125
+
126
+ # Stats
127
+ duration = len(audio_array) / 16000
128
+ print(f"Duration: {duration:.2f}s")
129
+
130
+
131
+ def cmd_synthesize(args):
132
+ """Direct TTS synthesis command."""
133
+ # Validate inputs
134
+ if not args.text:
135
+ print("Error: Please provide text to synthesize (--text)")
136
+ sys.exit(1)
137
+ # Validate output path
138
+ output_path = validate_output_path(args.output)
139
+ # Load model
140
+ model = load_model(args)
141
+ # Generate audio
142
+ print(f"Synthesizing text: {args.text}")
143
+
144
+ audio_array = model.generate(
145
+ text=args.text,
146
+ prompt_wav_path=None,
147
+ prompt_text=None,
148
+ cfg_value=args.cfg_value,
149
+ inference_timesteps=args.inference_timesteps,
150
+ normalize=args.normalize,
151
+ denoise=False # 无参考音频时不需要降噪
152
+ )
153
+
154
+ # Save audio
155
+ sf.write(str(output_path), audio_array, 16000)
156
+ print(f"Saved audio to: {output_path}")
157
+
158
+ # Stats
159
+ duration = len(audio_array) / 16000
160
+ print(f"Duration: {duration:.2f}s")
161
+
162
+
163
+ def cmd_batch(args):
164
+ """Batch synthesis command."""
165
+ # Validate input file
166
+ input_file = validate_file_exists(args.input, "input file")
167
+ output_dir = Path(args.output_dir)
168
+ output_dir.mkdir(parents=True, exist_ok=True)
169
+
170
+ try:
171
+ with open(input_file, 'r', encoding='utf-8') as f:
172
+ texts = [line.strip() for line in f if line.strip()]
173
+ except Exception as e:
174
+ print(f"Failed to read input file: {e}")
175
+ sys.exit(1)
176
+ if not texts:
177
+ print("Error: Input file is empty or contains no valid lines")
178
+ sys.exit(1)
179
+ print(f"Found {len(texts)} lines to process")
180
+
181
+ model = load_model(args)
182
+ prompt_audio_path = None
183
+ if args.prompt_audio:
184
+ prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
185
+
186
+ success_count = 0
187
+ for i, text in enumerate(texts, 1):
188
+ print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
189
+
190
+ try:
191
+ audio_array = model.generate(
192
+ text=text,
193
+ prompt_wav_path=prompt_audio_path,
194
+ prompt_text=args.prompt_text,
195
+ cfg_value=args.cfg_value,
196
+ inference_timesteps=args.inference_timesteps,
197
+ normalize=args.normalize,
198
+ denoise=args.denoise and prompt_audio_path is not None
199
+ )
200
+ output_file = output_dir / f"output_{i:03d}.wav"
201
+ sf.write(str(output_file), audio_array, 16000)
202
+
203
+ duration = len(audio_array) / 16000
204
+ print(f" Saved: {output_file} ({duration:.2f}s)")
205
+ success_count += 1
206
+
207
+ except Exception as e:
208
+ print(f" Failed: {e}")
209
+ continue
210
+
211
+ print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
212
+
213
+ def _build_unified_parser():
214
+ """Build unified argument parser (no subcommands, route by args)."""
215
+ parser = argparse.ArgumentParser(
216
+ description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing",
217
+ formatter_class=argparse.RawDescriptionHelpFormatter,
218
+ epilog="""
219
+ Examples:
220
+ # Direct synthesis (single sample)
221
+ voxcpm --text "Hello world" --output out.wav
222
+
223
+ # Voice cloning (reference audio + text)
224
+ voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
225
+
226
+ # Batch processing
227
+ voxcpm --input texts.txt --output-dir ./outs
228
+
229
+ # Select model (from Hub)
230
+ voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
231
+ """
232
+ )
233
+
234
+ # Task selection (automatic routing by presence of args)
235
+ parser.add_argument("--input", "-i", help="Input text file (one line per sample)")
236
+ parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)")
237
+ parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)")
238
+ parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)")
239
+
240
+ # Prompt audio (for voice cloning)
241
+ parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
242
+ parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
243
+ parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
244
+ parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
245
+
246
+ # Generation parameters
247
+ parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)")
248
+ parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)")
249
+ parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
250
+
251
+ # Model loading parameters
252
+ parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
253
+ parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
254
+ parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
255
+ parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
256
+ parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
257
+ parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
258
+
259
+ return parser
260
+
261
+
262
+ def main():
263
+ """Unified CLI entrypoint: route by provided arguments."""
264
+ parser = _build_unified_parser()
265
+ args = parser.parse_args()
266
+
267
+ # Routing: prefer batch → single (clone/direct)
268
+ if args.input:
269
+ if not args.output_dir:
270
+ print("Error: Batch mode requires --output-dir")
271
+ parser.print_help()
272
+ sys.exit(1)
273
+ return cmd_batch(args)
274
+
275
+ # Single-sample mode
276
+ if not args.text or not args.output:
277
+ print("Error: Single-sample mode requires --text and --output")
278
+ parser.print_help()
279
+ sys.exit(1)
280
+
281
+ # If prompt audio+text provided → voice cloning
282
+ if args.prompt_audio or args.prompt_text:
283
+ if not args.prompt_text and args.prompt_file:
284
+ assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
285
+
286
+ with open(args.prompt_file, 'r', encoding='utf-8') as f:
287
+ args.prompt_text = f.read()
288
+
289
+ if not args.prompt_audio or not args.prompt_text:
290
+ print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
291
+ sys.exit(1)
292
+ return cmd_clone(args)
293
+
294
+ # Otherwise → direct synthesis
295
+ return cmd_synthesize(args)
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
convert/src/voxcpm/core.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import tempfile
4
+ import numpy as np
5
+ from typing import Generator
6
+ from huggingface_hub import snapshot_download
7
+ from .model.voxcpm import VoxCPMModel
8
+
9
+ class VoxCPM:
10
+ def __init__(self,
11
+ voxcpm_model_path : str,
12
+ zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
13
+ enable_denoiser : bool = True,
14
+ optimize: bool = True,
15
+ ):
16
+ """Initialize VoxCPM TTS pipeline.
17
+
18
+ Args:
19
+ voxcpm_model_path: Local filesystem path to the VoxCPM model assets
20
+ (weights, configs, etc.). Typically the directory returned by
21
+ a prior download step.
22
+ zipenhancer_model_path: ModelScope acoustic noise suppression model
23
+ id or local path. If None, denoiser will not be initialized.
24
+ enable_denoiser: Whether to initialize the denoiser pipeline.
25
+ optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
26
+ """
27
+ print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
28
+ self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
29
+ self.text_normalizer = None
30
+ if enable_denoiser and zipenhancer_model_path is not None:
31
+ from .zipenhancer import ZipEnhancer
32
+ self.denoiser = ZipEnhancer(zipenhancer_model_path)
33
+ else:
34
+ self.denoiser = None
35
+ print("Warm up VoxCPMModel...")
36
+ self.tts_model.generate(
37
+ target_text="Hello, this is the first test sentence.",
38
+ max_len=10,
39
+ )
40
+
41
+ @classmethod
42
+ def from_pretrained(cls,
43
+ hf_model_id: str = "openbmb/VoxCPM-0.5B",
44
+ load_denoiser: bool = True,
45
+ zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
46
+ cache_dir: str = None,
47
+ local_files_only: bool = False,
48
+ **kwargs,
49
+ ):
50
+ """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
51
+
52
+ Args:
53
+ hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
54
+ load_denoiser: Whether to initialize the denoiser pipeline.
55
+ zipenhancer_model_id: Denoiser model id or path for ModelScope
56
+ acoustic noise suppression.
57
+ cache_dir: Custom cache directory for the snapshot.
58
+ local_files_only: If True, only use local files and do not attempt
59
+ to download.
60
+ Kwargs:
61
+ Additional keyword arguments passed to the ``VoxCPM`` constructor.
62
+
63
+ Returns:
64
+ VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
65
+ the downloaded snapshot directory.
66
+
67
+ Raises:
68
+ ValueError: If neither a valid ``hf_model_id`` nor a resolvable
69
+ ``hf_model_id`` is provided.
70
+ """
71
+ repo_id = hf_model_id
72
+ if not repo_id:
73
+ raise ValueError("You must provide hf_model_id")
74
+
75
+ # Load from local path if provided
76
+ if os.path.isdir(repo_id):
77
+ local_path = repo_id
78
+ else:
79
+ # Otherwise, try from_pretrained (Hub); exit on failure
80
+ local_path = snapshot_download(
81
+ repo_id=repo_id,
82
+ cache_dir=cache_dir,
83
+ local_files_only=local_files_only,
84
+ )
85
+
86
+ return cls(
87
+ voxcpm_model_path=local_path,
88
+ zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
89
+ enable_denoiser=load_denoiser,
90
+ **kwargs,
91
+ )
92
+
93
+ def generate(self, *args, **kwargs) -> np.ndarray:
94
+ return next(self._generate(*args, streaming=False, **kwargs))
95
+
96
+ def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
97
+ return self._generate(*args, streaming=True, **kwargs)
98
+
99
+ def _generate(self,
100
+ text : str,
101
+ prompt_wav_path : str = None,
102
+ prompt_text : str = None,
103
+ cfg_value : float = 2.0,
104
+ inference_timesteps : int = 10,
105
+ max_length : int = 4096,
106
+ normalize : bool = True,
107
+ denoise : bool = True,
108
+ retry_badcase : bool = True,
109
+ retry_badcase_max_times : int = 3,
110
+ retry_badcase_ratio_threshold : float = 6.0,
111
+ streaming: bool = False,
112
+ ) -> Generator[np.ndarray, None, None]:
113
+ """Synthesize speech for the given text and return a single waveform.
114
+
115
+ This method optionally builds and reuses a prompt cache. If an external
116
+ prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
117
+ used for all sub-sentences. Otherwise, the prompt cache is built from
118
+ the first generated result and reused for the remaining text chunks.
119
+
120
+ Args:
121
+ text: Input text. Can include newlines; each non-empty line is
122
+ treated as a sub-sentence.
123
+ prompt_wav_path: Path to a reference audio file for prompting.
124
+ prompt_text: Text content corresponding to the prompt audio.
125
+ cfg_value: Guidance scale for the generation model.
126
+ inference_timesteps: Number of inference steps.
127
+ max_length: Maximum token length during generation.
128
+ normalize: Whether to run text normalization before generation.
129
+ denoise: Whether to denoise the prompt audio if a denoiser is
130
+ available.
131
+ retry_badcase: Whether to retry badcase.
132
+ retry_badcase_max_times: Maximum number of times to retry badcase.
133
+ retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
134
+ streaming: Whether to return a generator of audio chunks.
135
+ Returns:
136
+ Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
137
+ Yields audio chunks for each generations step if ``streaming=True``,
138
+ otherwise yields a single array containing the final audio.
139
+ """
140
+ if not text.strip() or not isinstance(text, str):
141
+ raise ValueError("target text must be a non-empty string")
142
+
143
+ if prompt_wav_path is not None:
144
+ if not os.path.exists(prompt_wav_path):
145
+ raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
146
+
147
+ if (prompt_wav_path is None) != (prompt_text is None):
148
+ raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
149
+
150
+ text = text.replace("\n", " ")
151
+ text = re.sub(r'\s+', ' ', text)
152
+ temp_prompt_wav_path = None
153
+
154
+ try:
155
+ if prompt_wav_path is not None and prompt_text is not None:
156
+ if denoise and self.denoiser is not None:
157
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
158
+ temp_prompt_wav_path = tmp_file.name
159
+ self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
160
+ prompt_wav_path = temp_prompt_wav_path
161
+ fixed_prompt_cache = self.tts_model.build_prompt_cache(
162
+ prompt_wav_path=prompt_wav_path,
163
+ prompt_text=prompt_text
164
+ )
165
+ else:
166
+ fixed_prompt_cache = None # will be built from the first inference
167
+
168
+ if normalize:
169
+ if self.text_normalizer is None:
170
+ from .utils.text_normalize import TextNormalizer
171
+ self.text_normalizer = TextNormalizer()
172
+ text = self.text_normalizer.normalize(text)
173
+
174
+ generate_result = self.tts_model._generate_with_prompt_cache(
175
+ target_text=text,
176
+ prompt_cache=fixed_prompt_cache,
177
+ min_len=2,
178
+ max_len=max_length,
179
+ inference_timesteps=inference_timesteps,
180
+ cfg_value=cfg_value,
181
+ retry_badcase=retry_badcase,
182
+ retry_badcase_max_times=retry_badcase_max_times,
183
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
184
+ streaming=streaming,
185
+ )
186
+
187
+ for wav, _, _ in generate_result:
188
+ yield wav.squeeze(0).cpu().numpy()
189
+
190
+ finally:
191
+ if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
192
+ try:
193
+ os.unlink(temp_prompt_wav_path)
194
+ except OSError:
195
+ pass
convert/src/voxcpm/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .voxcpm import VoxCPMModel
2
+
3
+ __all__ = ["VoxCPMModel"]
convert/src/voxcpm/model/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from transformers import PreTrainedTokenizer
4
+
5
+
6
+ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
7
+ """Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
8
+
9
+ This function creates a wrapper around the provided tokenizer that automatically
10
+ splits multi-character Chinese tokens into individual characters. This is useful
11
+ for ensuring consistent tokenization of Chinese text.
12
+
13
+ Args:
14
+ tokenizer: The base tokenizer to wrap
15
+
16
+ Returns:
17
+ A CharTokenizerWrapper instance that handles multi-character Chinese tokens
18
+
19
+ Example:
20
+ >>> from transformers import LlamaTokenizerFast
21
+ >>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
22
+ >>> wrapped_tokenizer = mask_multichar_chinese_tokens(tokenizer)
23
+ >>> tokens = wrapped_tokenizer("你好世界")
24
+ """
25
+ # Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
26
+ multichar_tokens = {
27
+ token for token in tokenizer.vocab.keys()
28
+ if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
29
+ }
30
+
31
+ class CharTokenizerWrapper:
32
+ """Wrapper class for tokenizers that handles multi-character Chinese tokens.
33
+
34
+ This wrapper automatically splits multi-character Chinese tokens into
35
+ individual characters while preserving the original tokenizer's interface.
36
+ """
37
+
38
+ def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
39
+ """Initialize the wrapper with a base tokenizer.
40
+
41
+ Args:
42
+ base_tokenizer: The tokenizer to wrap
43
+ """
44
+ self.tokenizer = base_tokenizer
45
+ self.multichar_tokens = multichar_tokens
46
+
47
+ def tokenize(self, text: str, **kwargs) -> List[str]:
48
+ """Tokenize text and split multi-character Chinese tokens into single characters.
49
+
50
+ Args:
51
+ text: Input text to tokenize
52
+ **kwargs: Additional arguments passed to the base tokenizer
53
+
54
+ Returns:
55
+ List of processed tokens with multi-character Chinese tokens split
56
+
57
+ Example:
58
+ >>> wrapper = CharTokenizerWrapper(tokenizer)
59
+ >>> tokens = wrapper.tokenize("你好世界")
60
+ >>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
61
+ """
62
+ if not isinstance(text, str):
63
+ raise TypeError(f"Expected string input, got {type(text)}")
64
+
65
+ tokens = self.tokenizer.tokenize(text, **kwargs)
66
+ processed = []
67
+
68
+ for token in tokens:
69
+ # Remove possible subword prefix
70
+ clean_token = token.replace("▁", "")
71
+
72
+ if clean_token in self.multichar_tokens:
73
+ # Split multi-character token into single characters
74
+ chars = list(clean_token)
75
+ processed.extend(chars)
76
+ else:
77
+ processed.append(token)
78
+
79
+ return processed
80
+
81
+ def __call__(self, text: str, **kwargs) -> List[int]:
82
+ """Call the tokenizer and return token IDs.
83
+
84
+ This method provides the same interface as the original tokenizer
85
+ but with multi-character Chinese token handling.
86
+
87
+ Args:
88
+ text: Input text to tokenize
89
+ **kwargs: Additional arguments passed to the base tokenizer
90
+
91
+ Returns:
92
+ List of token IDs
93
+
94
+ Raises:
95
+ TypeError: If input is not a string
96
+ ValueError: If tokenization fails
97
+ """
98
+ try:
99
+ tokens = self.tokenize(text, **kwargs)
100
+ result = self.tokenizer.convert_tokens_to_ids(tokens)
101
+ return result
102
+ except Exception as e:
103
+ raise ValueError(f"Tokenization failed: {str(e)}") from e
104
+
105
+ return CharTokenizerWrapper(tokenizer)
106
+
107
+
108
+ def get_dtype(dtype: str):
109
+ if dtype == "bfloat16":
110
+ return torch.bfloat16
111
+ elif dtype == "bf16":
112
+ return torch.bfloat16
113
+ elif dtype == "float16":
114
+ return torch.float16
115
+ elif dtype == "fp16":
116
+ return torch.float16
117
+ elif dtype == "float32":
118
+ return torch.float32
119
+ elif dtype == "fp32":
120
+ return torch.float32
121
+ else:
122
+ raise ValueError(f"Unsupported dtype: {dtype}")
convert/src/voxcpm/model/voxcpm.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoxCPM: A Tokenizer-free speech generation model
3
+
4
+ This module contains the main VoxCPM model implementation, including configuration classes
5
+ and the core VoxCPMModel for text-to-speech generation.
6
+
7
+ Copyright 2025 OpenBMB
8
+ Licensed under the Apache License, Version 2.0 (the "License");
9
+ you may not use this file except in compliance with the License.
10
+ You may obtain a copy of the License at
11
+
12
+ http://www.apache.org/licenses/LICENSE-2.0
13
+
14
+ Unless required by applicable law or agreed to in writing, software
15
+ distributed under the License is distributed on an "AS IS" BASIS,
16
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ See the License for the specific language governing permissions and
18
+ limitations under the License.
19
+ """
20
+
21
+ import os
22
+ from typing import Tuple, Union, Generator, List
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torchaudio
27
+ import warnings
28
+ from einops import rearrange
29
+ from pydantic import BaseModel
30
+ from tqdm import tqdm
31
+ from transformers import LlamaTokenizerFast
32
+
33
+ from ..modules.audiovae import AudioVAE
34
+ from ..modules.layers import ScalarQuantizationLayer
35
+ from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
36
+ from ..modules.locenc import VoxCPMLocEnc
37
+ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
38
+ from .utils import get_dtype, mask_multichar_chinese_tokens
39
+
40
+
41
+ class VoxCPMEncoderConfig(BaseModel):
42
+ hidden_dim: int = 1024
43
+ ffn_dim: int = 4096
44
+ num_heads: int = 16
45
+ num_layers: int = 4
46
+ kv_channels: int = None
47
+
48
+
49
+ class VoxCPMDitConfig(BaseModel):
50
+ hidden_dim: int = 1024
51
+ ffn_dim: int = 4096
52
+ num_heads: int = 16
53
+ num_layers: int = 4
54
+ kv_channels: int = None
55
+
56
+ cfm_config: CfmConfig
57
+
58
+
59
+ class VoxCPMConfig(BaseModel):
60
+ lm_config: MiniCPM4Config
61
+ patch_size: int = 2
62
+ feat_dim: int = 64
63
+ residual_lm_num_layers: int = 6
64
+ scalar_quantization_latent_dim: int = 256
65
+ scalar_quantization_scale: int = 9
66
+
67
+ encoder_config: VoxCPMEncoderConfig
68
+ dit_config: VoxCPMDitConfig
69
+
70
+ max_length: int = 4096
71
+ device: str = "cuda"
72
+ dtype: str = "bfloat16"
73
+
74
+
75
+ class VoxCPMModel(nn.Module):
76
+ def __init__(
77
+ self,
78
+ config: VoxCPMConfig,
79
+ tokenizer: LlamaTokenizerFast,
80
+ audio_vae: AudioVAE,
81
+ ):
82
+ super().__init__()
83
+ self.config = config
84
+ self.feat_dim = config.feat_dim
85
+ self.patch_size = config.patch_size
86
+ self.device = config.device
87
+ if not torch.cuda.is_available():
88
+ if torch.backends.mps.is_available():
89
+ self.device = "mps"
90
+ else:
91
+ self.device = "cpu"
92
+ print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
93
+
94
+ # Text-Semantic LM
95
+ self.base_lm = MiniCPMModel(config.lm_config)
96
+ self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
97
+
98
+ self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
99
+ self.audio_start_token = 101
100
+ self.audio_end_token = 102
101
+
102
+ # Residual Acoustic LM
103
+ residual_lm_config = config.lm_config.model_copy(deep=True)
104
+ residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
105
+ residual_lm_config.vocab_size = 0
106
+ self.residual_lm = MiniCPMModel(residual_lm_config)
107
+ self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
108
+
109
+ # Local Encoder
110
+ encoder_config = config.lm_config.model_copy(deep=True)
111
+ encoder_config.hidden_size = config.encoder_config.hidden_dim
112
+ encoder_config.intermediate_size = config.encoder_config.ffn_dim
113
+ encoder_config.num_attention_heads = config.encoder_config.num_heads
114
+ encoder_config.num_hidden_layers = config.encoder_config.num_layers
115
+ encoder_config.kv_channels = config.encoder_config.kv_channels
116
+ encoder_config.vocab_size = 0
117
+ self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
118
+
119
+ # Local DiT
120
+ decoder_config = config.lm_config.model_copy(deep=True)
121
+ decoder_config.hidden_size = config.dit_config.hidden_dim
122
+ decoder_config.intermediate_size = config.dit_config.ffn_dim
123
+ decoder_config.num_attention_heads = config.dit_config.num_heads
124
+ decoder_config.num_hidden_layers = config.dit_config.num_layers
125
+ decoder_config.kv_channels = config.dit_config.kv_channels
126
+ decoder_config.vocab_size = 0
127
+ self.feat_decoder = UnifiedCFM(
128
+ in_channels=config.feat_dim,
129
+ cfm_params=config.dit_config.cfm_config,
130
+ estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
131
+ )
132
+
133
+ # Projection layers
134
+ self.fsq_layer = ScalarQuantizationLayer(
135
+ config.lm_config.hidden_size,
136
+ config.lm_config.hidden_size,
137
+ config.scalar_quantization_latent_dim,
138
+ config.scalar_quantization_scale
139
+ )
140
+ self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
141
+ self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
142
+ self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
143
+
144
+ # Stop Predictor
145
+ self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
146
+ self.stop_actn = nn.SiLU()
147
+ self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
148
+
149
+ # Audio VAE
150
+ self.audio_vae = audio_vae
151
+ self.chunk_size = audio_vae.chunk_size
152
+ self.sample_rate = audio_vae.sample_rate
153
+
154
+
155
+ def optimize(self, disable: bool = False):
156
+ try:
157
+ if disable:
158
+ raise ValueError("Optimization disabled by user")
159
+ if self.device != "cuda":
160
+ raise ValueError("VoxCPMModel can only be optimized on CUDA device")
161
+ try:
162
+ import triton
163
+ except:
164
+ raise ValueError("triton is not installed")
165
+ self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
166
+ self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
167
+ self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
168
+ self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
169
+ except Exception as e:
170
+ print(f"Error: {e}")
171
+ print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
172
+ self.base_lm.forward_step = self.base_lm.forward_step
173
+ self.residual_lm.forward_step = self.residual_lm.forward_step
174
+ self.feat_encoder_step = self.feat_encoder
175
+ self.feat_decoder.estimator = self.feat_decoder.estimator
176
+ return self
177
+
178
+
179
+ def generate(self, *args, **kwargs) -> torch.Tensor:
180
+ return next(self._generate(*args, streaming=False, **kwargs))
181
+
182
+ def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
183
+ return self._generate(*args, streaming=True, **kwargs)
184
+
185
+ @torch.inference_mode()
186
+ def _generate(
187
+ self,
188
+ target_text: str,
189
+ prompt_text: str = "",
190
+ prompt_wav_path: str = "",
191
+ min_len: int = 2,
192
+ max_len: int = 2000,
193
+ inference_timesteps: int = 10,
194
+ cfg_value: float = 2.0,
195
+ retry_badcase: bool = False,
196
+ retry_badcase_max_times: int = 3,
197
+ retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
198
+ streaming: bool = False,
199
+ ) -> Generator[torch.Tensor, None, None]:
200
+ if retry_badcase and streaming:
201
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
202
+ retry_badcase = False
203
+ if len(prompt_wav_path) == 0:
204
+ text = target_text
205
+ text_token = torch.LongTensor(self.text_tokenizer(text))
206
+ text_token = torch.cat(
207
+ [
208
+ text_token,
209
+ torch.tensor(
210
+ [self.audio_start_token],
211
+ dtype=torch.int32,
212
+ device=text_token.device,
213
+ ),
214
+ ],
215
+ dim=-1,
216
+ )
217
+ text_length = text_token.shape[0]
218
+
219
+ audio_feat = torch.zeros(
220
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
221
+ dtype=torch.float32,
222
+ device=text_token.device,
223
+ )
224
+ text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
225
+ audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
226
+
227
+ else:
228
+ text = prompt_text + target_text
229
+ text_token = torch.LongTensor(self.text_tokenizer(text))
230
+ text_token = torch.cat(
231
+ [
232
+ text_token,
233
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
234
+ ],
235
+ dim=-1,
236
+ )
237
+ text_length = text_token.shape[0]
238
+
239
+ audio, sr = torchaudio.load(prompt_wav_path)
240
+ if audio.size(0) > 1:
241
+ audio = audio.mean(dim=0, keepdim=True)
242
+
243
+ if sr != self.sample_rate:
244
+ audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
245
+
246
+ patch_len = self.patch_size * self.chunk_size
247
+
248
+ if audio.size(1) % patch_len != 0:
249
+ audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
250
+
251
+ # (B, D, T)
252
+ audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
253
+
254
+ audio_feat = audio_feat.view(
255
+ self.audio_vae.latent_dim,
256
+ -1,
257
+ self.patch_size,
258
+ ).permute(1, 2, 0)
259
+ audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
260
+ audio_length = audio_feat.size(0)
261
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
262
+ text_token = torch.cat([text_token, text_pad_token])
263
+ audio_pad_feat = torch.zeros(
264
+ (text_length, self.patch_size, self.audio_vae.latent_dim),
265
+ dtype=torch.float32,
266
+ device=text_token.device,
267
+ )
268
+ audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
269
+ text_mask = (
270
+ torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
271
+ )
272
+ audio_mask = (
273
+ torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
274
+ )
275
+
276
+ text_token = text_token.unsqueeze(0).to(self.device)
277
+ text_mask = text_mask.unsqueeze(0).to(self.device)
278
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
279
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
280
+
281
+ target_text_length = len(self.text_tokenizer(target_text))
282
+
283
+ retry_badcase_times = 0
284
+ while retry_badcase_times < retry_badcase_max_times:
285
+ inference_result = self._inference(
286
+ text_token,
287
+ text_mask,
288
+ audio_feat,
289
+ audio_mask,
290
+ min_len=min_len,
291
+ max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
292
+ inference_timesteps=inference_timesteps,
293
+ cfg_value=cfg_value,
294
+ streaming=streaming,
295
+ )
296
+ if streaming:
297
+ patch_len = self.patch_size * self.chunk_size
298
+ for latent_pred, _ in inference_result:
299
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
300
+ decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
301
+ yield decode_audio
302
+ break
303
+ else:
304
+ latent_pred, pred_audio_feat = next(inference_result)
305
+ if retry_badcase:
306
+ if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
307
+ print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
308
+ retry_badcase_times += 1
309
+ continue
310
+ else:
311
+ break
312
+ else:
313
+ break
314
+
315
+ if not streaming:
316
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
317
+ decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
318
+ yield decode_audio
319
+
320
+ @torch.inference_mode()
321
+ def build_prompt_cache(
322
+ self,
323
+ prompt_text: str,
324
+ prompt_wav_path: str,
325
+ ):
326
+ """
327
+ Build prompt cache for subsequent fast generation.
328
+
329
+ Args:
330
+ prompt_text: prompt text (required)
331
+ prompt_wav_path: prompt audio path (required)
332
+
333
+ Returns:
334
+ prompt_cache: dict with text tokens and audio features
335
+ """
336
+ if not prompt_text or not prompt_wav_path:
337
+ raise ValueError("prompt_text and prompt_wav_path are required")
338
+
339
+ # build text tokens
340
+ text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
341
+
342
+ # load audio
343
+ audio, sr = torchaudio.load(prompt_wav_path)
344
+ if audio.size(0) > 1:
345
+ audio = audio.mean(dim=0, keepdim=True)
346
+
347
+ if sr != self.sample_rate:
348
+ audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
349
+
350
+ patch_len = self.patch_size * self.chunk_size
351
+
352
+ if audio.size(1) % patch_len != 0:
353
+ audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
354
+
355
+ # extract audio features
356
+ audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
357
+
358
+ audio_feat = audio_feat.view(
359
+ self.audio_vae.latent_dim,
360
+ -1,
361
+ self.patch_size,
362
+ ).permute(1, 2, 0) # (D, T, P)
363
+ audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
364
+ # build prompt cache
365
+ prompt_cache = {
366
+ "text_token": text_token,
367
+ "audio_feat": audio_feat,
368
+ }
369
+
370
+ return prompt_cache
371
+
372
+
373
+ def merge_prompt_cache(
374
+ self,
375
+ original_cache: dict,
376
+ new_text_token: torch.Tensor,
377
+ new_audio_feat: torch.Tensor,
378
+ ):
379
+ """
380
+ Merge original prompt cache with newly generated content to stabilize voice.
381
+
382
+ Args:
383
+ original_cache: original prompt cache
384
+ new_text_token: newly generated text tokens
385
+ new_audio_feat: newly generated audio features
386
+
387
+ Returns:
388
+ merged_cache: merged cache
389
+ """
390
+ if original_cache is None:
391
+ return {
392
+ "text_token": new_text_token,
393
+ "audio_feat": new_audio_feat,
394
+ }
395
+ original_text_token = original_cache["text_token"]
396
+ original_audio_feat = original_cache["audio_feat"]
397
+ merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
398
+ merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
399
+
400
+ # build new cache
401
+ merged_cache = {
402
+ "text_token": merged_text_token,
403
+ "audio_feat": merged_audio_feat,
404
+ }
405
+
406
+ return merged_cache
407
+
408
+ def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
409
+ return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
410
+
411
+ def generate_with_prompt_cache_streaming(
412
+ self, *args, **kwargs
413
+ ) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
414
+ return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
415
+
416
+ @torch.inference_mode()
417
+ def _generate_with_prompt_cache(
418
+ self,
419
+ target_text: str,
420
+ prompt_cache: dict,
421
+ min_len: int = 2,
422
+ max_len: int = 2000,
423
+ inference_timesteps: int = 10,
424
+ cfg_value: float = 2.0,
425
+ retry_badcase: bool = False,
426
+ retry_badcase_max_times: int = 3,
427
+ retry_badcase_ratio_threshold: float = 6.0,
428
+ streaming: bool = False,
429
+ ) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
430
+ """
431
+ Generate audio using pre-built prompt cache.
432
+
433
+ Args:
434
+ target_text: Text to convert to speech
435
+ prompt_cache: Cache built by build_prompt_cache (can be None)
436
+ min_len: Minimum audio length to avoid very short audio
437
+ max_len: Maximum audio length
438
+ inference_timesteps: Number of diffusion sampling steps
439
+ cfg_value: Classifier-free guidance value
440
+ retry_badcase: Whether to retry on bad cases
441
+ retry_badcase_max_times: Maximum retry attempts
442
+ retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
443
+ streaming: Whether to return a generator of audio chunks
444
+
445
+ Returns:
446
+ Generator of Tuple containing:
447
+ - Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
448
+ - Tensor of new text tokens
449
+ - New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
450
+ """
451
+ if retry_badcase and streaming:
452
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
453
+ retry_badcase = False
454
+ # get prompt from cache
455
+ if prompt_cache is None:
456
+ prompt_text_token = torch.empty(0, dtype=torch.int32)
457
+ prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
458
+ else:
459
+ prompt_text_token = prompt_cache["text_token"]
460
+ prompt_audio_feat = prompt_cache["audio_feat"]
461
+ # build target text tokens
462
+ target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
463
+ text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
464
+ text_token = torch.cat(
465
+ [
466
+ text_token,
467
+ torch.tensor(
468
+ [self.audio_start_token],
469
+ dtype=torch.int32,
470
+ device=text_token.device,
471
+ ),
472
+ ],
473
+ dim=-1,
474
+ )
475
+
476
+ audio_length = prompt_audio_feat.size(0)
477
+ text_length = text_token.shape[0]
478
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
479
+ audio_pad_feat = torch.zeros(
480
+ (text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
481
+ dtype=torch.float32,
482
+ device=text_token.device,
483
+ )
484
+ text_token = torch.cat([text_token, text_pad_token])
485
+ audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
486
+ text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
487
+ audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
488
+
489
+ text_token = text_token.unsqueeze(0).to(self.device)
490
+ text_mask = text_mask.unsqueeze(0).to(self.device)
491
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
492
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
493
+
494
+ # run inference
495
+ target_text_length = len(self.text_tokenizer(target_text))
496
+ retry_badcase_times = 0
497
+ while retry_badcase_times < retry_badcase_max_times:
498
+ inference_result = self._inference(
499
+ text_token,
500
+ text_mask,
501
+ audio_feat,
502
+ audio_mask,
503
+ min_len=min_len,
504
+ max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
505
+ inference_timesteps=inference_timesteps,
506
+ cfg_value=cfg_value,
507
+ streaming=streaming,
508
+ )
509
+ if streaming:
510
+ patch_len = self.patch_size * self.chunk_size
511
+ for latent_pred, pred_audio_feat in inference_result:
512
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
513
+ decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
514
+ yield (
515
+ decode_audio,
516
+ target_text_token,
517
+ pred_audio_feat
518
+ )
519
+ break
520
+ else:
521
+ latent_pred, pred_audio_feat = next(inference_result)
522
+ if retry_badcase:
523
+ if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
524
+ print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
525
+ retry_badcase_times += 1
526
+ continue
527
+ else:
528
+ break
529
+ else:
530
+ break
531
+ if not streaming:
532
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
533
+ decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
534
+
535
+ yield (
536
+ decode_audio,
537
+ target_text_token,
538
+ pred_audio_feat
539
+ )
540
+
541
+ def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
542
+ return next(self._inference(*args, streaming=False, **kwargs))
543
+
544
+ def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
545
+ return self._inference(*args, streaming=True, **kwargs)
546
+
547
+ @torch.inference_mode()
548
+ def _inference(
549
+ self,
550
+ text: torch.Tensor,
551
+ text_mask: torch.Tensor,
552
+ feat: torch.Tensor,
553
+ feat_mask: torch.Tensor,
554
+ min_len: int = 2,
555
+ max_len: int = 2000,
556
+ inference_timesteps: int = 10,
557
+ cfg_value: float = 2.0,
558
+ streaming: bool = False,
559
+ ) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
560
+ """Core inference method for audio generation.
561
+
562
+ This is the main inference loop that generates audio features
563
+ using the language model and diffusion transformer.
564
+
565
+ Args:
566
+ text: Input text tokens
567
+ text_mask: Mask for text tokens
568
+ feat: Input audio features
569
+ feat_mask: Mask for audio features
570
+ min_len: Minimum generation length
571
+ max_len: Maximum generation length
572
+ inference_timesteps: Number of diffusion steps
573
+ cfg_value: Classifier-free guidance value
574
+ streaming: Whether to yield each step latent feature or just the final result
575
+
576
+ Returns:
577
+ Generator of Tuple containing:
578
+ - Predicted latent feature at the current step if ``streaming=True``, else final latent features
579
+ - Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
580
+ """
581
+ B, T, P, D = feat.shape
582
+
583
+ feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
584
+ feat_embed = self.enc_to_lm_proj(feat_embed)
585
+
586
+ if self.config.lm_config.use_mup:
587
+ scale_emb = self.config.lm_config.scale_emb
588
+ else:
589
+ scale_emb = 1.0
590
+
591
+ text_embed = self.base_lm.embed_tokens(text) * scale_emb
592
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
593
+
594
+ prefix_feat_cond = feat[:, -1, ...] # b, p, d
595
+ pred_feat_seq = [] # b, t, p, d
596
+ curr_embed = None
597
+
598
+ enc_outputs, kv_cache_tuple = self.base_lm(
599
+ inputs_embeds=combined_embed,
600
+ is_causal=True,
601
+ )
602
+ self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
603
+
604
+ enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
605
+ lm_hidden = enc_outputs[:, -1, :]
606
+
607
+
608
+ residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
609
+ inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
610
+ is_causal=True,
611
+ )
612
+ self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
613
+ residual_hidden = residual_enc_outputs[:, -1, :]
614
+
615
+
616
+ for i in tqdm(range(max_len)):
617
+ dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
618
+ dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
619
+ dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
620
+
621
+ pred_feat = self.feat_decoder(
622
+ mu=dit_hidden,
623
+ patch_size=self.patch_size,
624
+ cond=prefix_feat_cond.transpose(1, 2).contiguous(),
625
+ n_timesteps=inference_timesteps,
626
+ cfg_value=cfg_value,
627
+ ).transpose(
628
+ 1, 2
629
+ ) # [b, p, d]
630
+
631
+ curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
632
+ curr_embed = self.enc_to_lm_proj(curr_embed)
633
+
634
+ pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
635
+ prefix_feat_cond = pred_feat
636
+
637
+ if streaming:
638
+ # return the last three predicted latent features to provide enough context for smooth decoding
639
+ pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
640
+ feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
641
+ yield feat_pred, pred_feat_seq
642
+
643
+ stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
644
+ if i > min_len and stop_flag == 1:
645
+ break
646
+
647
+ lm_hidden = self.base_lm.forward_step(
648
+ curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
649
+ ).clone()
650
+
651
+
652
+ lm_hidden = self.fsq_layer(lm_hidden)
653
+ residual_hidden = self.residual_lm.forward_step(
654
+ lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
655
+ ).clone()
656
+
657
+ if not streaming:
658
+ pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
659
+
660
+ feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
661
+ yield feat_pred, pred_feat_seq.squeeze(0).cpu()
662
+
663
+ @classmethod
664
+ def from_local(cls, path: str, optimize: bool = True):
665
+ config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
666
+
667
+ tokenizer = LlamaTokenizerFast.from_pretrained(path)
668
+
669
+ audio_vae = AudioVAE()
670
+ vae_state_dict = torch.load(
671
+ os.path.join(path, "audiovae.pth"),
672
+ map_location="cpu",
673
+ weights_only=True,
674
+ )["state_dict"]
675
+
676
+ model = cls(config, tokenizer, audio_vae)
677
+ lm_dtype = get_dtype(model.config.dtype)
678
+ model = model.to(lm_dtype)
679
+ model.audio_vae = model.audio_vae.to(torch.float32)
680
+
681
+ model_state_dict = torch.load(
682
+ os.path.join(path, "pytorch_model.bin"),
683
+ map_location="cpu",
684
+ weights_only=True,
685
+ )["state_dict"]
686
+
687
+ for kw, val in vae_state_dict.items():
688
+ model_state_dict[f"audio_vae.{kw}"] = val
689
+ model.load_state_dict(model_state_dict, strict=True)
690
+ return model.to(model.device).eval().optimize(disable=not optimize)
convert/src/voxcpm/modules/__init__.py ADDED
File without changes
convert/src/voxcpm/modules/audiovae/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .audio_vae import AudioVAE
convert/src/voxcpm/modules/audiovae/audio_vae.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ return weight_norm(nn.Conv1d(*args, **kwargs))
13
+
14
+
15
+ def WNConvTranspose1d(*args, **kwargs):
16
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
17
+
18
+
19
+ class CausalConv1d(nn.Conv1d):
20
+ def __init__(self, *args, padding: int = 0, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.__padding = padding
23
+
24
+ def forward(self, x):
25
+ x_pad = F.pad(x, (self.__padding * 2, 0))
26
+ return super().forward(x_pad)
27
+
28
+
29
+ class CausalTransposeConv1d(nn.ConvTranspose1d):
30
+ def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+ self.__padding = padding
33
+ self.__output_padding = output_padding
34
+
35
+ def forward(self, x):
36
+ return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
37
+
38
+
39
+ def WNCausalConv1d(*args, **kwargs):
40
+ return weight_norm(CausalConv1d(*args, **kwargs))
41
+
42
+
43
+ def WNCausalTransposeConv1d(*args, **kwargs):
44
+ return weight_norm(CausalTransposeConv1d(*args, **kwargs))
45
+
46
+
47
+ # Scripting this brings model speed up 1.4x
48
+ @torch.jit.script
49
+ def snake(x, alpha):
50
+ shape = x.shape
51
+ x = x.reshape(shape[0], shape[1], -1)
52
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
53
+ x = x.reshape(shape)
54
+ return x
55
+
56
+
57
+ class Snake1d(nn.Module):
58
+ def __init__(self, channels):
59
+ super().__init__()
60
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
61
+
62
+ def forward(self, x):
63
+ return snake(x, self.alpha)
64
+
65
+
66
+ def init_weights(m):
67
+ if isinstance(m, nn.Conv1d):
68
+ nn.init.trunc_normal_(m.weight, std=0.02)
69
+ if m.bias is not None:
70
+ nn.init.constant_(m.bias, 0)
71
+
72
+
73
+ class CausalResidualUnit(nn.Module):
74
+ def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
75
+ super().__init__()
76
+ pad = ((7 - 1) * dilation) // 2
77
+ self.block = nn.Sequential(
78
+ Snake1d(dim),
79
+ WNCausalConv1d(
80
+ dim,
81
+ dim,
82
+ kernel_size=kernel,
83
+ dilation=dilation,
84
+ padding=pad,
85
+ groups=groups,
86
+ ),
87
+ Snake1d(dim),
88
+ WNCausalConv1d(dim, dim, kernel_size=1),
89
+ )
90
+
91
+ def forward(self, x):
92
+ y = self.block(x)
93
+ pad = (x.shape[-1] - y.shape[-1]) // 2
94
+ assert pad == 0
95
+ if pad > 0:
96
+ x = x[..., pad:-pad]
97
+ return x + y
98
+
99
+
100
+ class CausalEncoderBlock(nn.Module):
101
+ def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
102
+ super().__init__()
103
+ input_dim = input_dim or output_dim // 2
104
+ self.block = nn.Sequential(
105
+ CausalResidualUnit(input_dim, dilation=1, groups=groups),
106
+ CausalResidualUnit(input_dim, dilation=3, groups=groups),
107
+ CausalResidualUnit(input_dim, dilation=9, groups=groups),
108
+ Snake1d(input_dim),
109
+ WNCausalConv1d(
110
+ input_dim,
111
+ output_dim,
112
+ kernel_size=2 * stride,
113
+ stride=stride,
114
+ padding=math.ceil(stride / 2),
115
+ ),
116
+ )
117
+
118
+ def forward(self, x):
119
+ return self.block(x)
120
+
121
+
122
+ class CausalEncoder(nn.Module):
123
+ def __init__(
124
+ self,
125
+ d_model: int = 64,
126
+ latent_dim: int = 32,
127
+ strides: list = [2, 4, 8, 8],
128
+ depthwise: bool = False,
129
+ ):
130
+ super().__init__()
131
+ # Create first convolution
132
+ self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
133
+
134
+ # Create EncoderBlocks that double channels as they downsample by `stride`
135
+ for stride in strides:
136
+ d_model *= 2
137
+ groups = d_model // 2 if depthwise else 1
138
+ self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
139
+
140
+ groups = d_model if depthwise else 1
141
+
142
+ # Create two convolution, for mu and logvar
143
+ self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
144
+ self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
145
+
146
+ # Wrap black into nn.Sequential
147
+ self.block = nn.Sequential(*self.block)
148
+ self.enc_dim = d_model
149
+
150
+ def forward(self, x):
151
+ hidden_state = self.block(x)
152
+ return {
153
+ "hidden_state": hidden_state,
154
+ "mu": self.fc_mu(hidden_state),
155
+ "logvar": self.fc_logvar(hidden_state),
156
+ }
157
+
158
+
159
+ class NoiseBlock(nn.Module):
160
+ def __init__(self, dim):
161
+ super().__init__()
162
+ self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
163
+
164
+ def forward(self, x):
165
+ B, C, T = x.shape
166
+ noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
167
+ h = self.linear(x)
168
+ n = noise * h
169
+ x = x + n
170
+ return x
171
+
172
+
173
+ class CausalDecoderBlock(nn.Module):
174
+ def __init__(
175
+ self,
176
+ input_dim: int = 16,
177
+ output_dim: int = 8,
178
+ stride: int = 1,
179
+ groups=1,
180
+ use_noise_block: bool = False,
181
+ ):
182
+ super().__init__()
183
+ layers = [
184
+ Snake1d(input_dim),
185
+ WNCausalTransposeConv1d(
186
+ input_dim,
187
+ output_dim,
188
+ kernel_size=2 * stride,
189
+ stride=stride,
190
+ padding=math.ceil(stride / 2),
191
+ output_padding=stride % 2,
192
+ ),
193
+ ]
194
+ if use_noise_block:
195
+ layers.append(NoiseBlock(output_dim))
196
+ layers.extend(
197
+ [
198
+ CausalResidualUnit(output_dim, dilation=1, groups=groups),
199
+ CausalResidualUnit(output_dim, dilation=3, groups=groups),
200
+ CausalResidualUnit(output_dim, dilation=9, groups=groups),
201
+ ]
202
+ )
203
+ self.block = nn.Sequential(*layers)
204
+
205
+ def forward(self, x):
206
+ return self.block(x)
207
+
208
+
209
+ class TransposeLastTwoDim(torch.nn.Module):
210
+ def forward(self, x):
211
+ return torch.transpose(x, -1, -2)
212
+
213
+
214
+ class CausalDecoder(nn.Module):
215
+ def __init__(
216
+ self,
217
+ input_channel,
218
+ channels,
219
+ rates,
220
+ depthwise: bool = False,
221
+ d_out: int = 1,
222
+ use_noise_block: bool = False,
223
+ ):
224
+ super().__init__()
225
+
226
+ # Add first conv layer
227
+ if depthwise:
228
+ layers = [
229
+ WNCausalConv1d(
230
+ input_channel,
231
+ input_channel,
232
+ kernel_size=7,
233
+ padding=3,
234
+ groups=input_channel,
235
+ ),
236
+ WNCausalConv1d(input_channel, channels, kernel_size=1),
237
+ ]
238
+ else:
239
+ layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
240
+
241
+ # Add upsampling + MRF blocks
242
+ for i, stride in enumerate(rates):
243
+ input_dim = channels // 2**i
244
+ output_dim = channels // 2 ** (i + 1)
245
+ groups = output_dim if depthwise else 1
246
+ layers += [
247
+ CausalDecoderBlock(
248
+ input_dim,
249
+ output_dim,
250
+ stride,
251
+ groups=groups,
252
+ use_noise_block=use_noise_block,
253
+ )
254
+ ]
255
+
256
+ # Add final conv layer
257
+ layers += [
258
+ Snake1d(output_dim),
259
+ WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
260
+ nn.Tanh(),
261
+ ]
262
+
263
+ self.model = nn.Sequential(*layers)
264
+
265
+ def forward(self, x):
266
+ return self.model(x)
267
+
268
+
269
+ class AudioVAE(nn.Module):
270
+ """
271
+ Args:
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ encoder_dim: int = 128,
277
+ encoder_rates: List[int] = [2, 5, 8, 8],
278
+ latent_dim: int = 64,
279
+ decoder_dim: int = 1536,
280
+ decoder_rates: List[int] = [8, 8, 5, 2],
281
+ depthwise: bool = True,
282
+ sample_rate: int = 16000,
283
+ use_noise_block: bool = False,
284
+ ):
285
+ super().__init__()
286
+
287
+ self.encoder_dim = encoder_dim
288
+ self.encoder_rates = encoder_rates
289
+ self.decoder_dim = decoder_dim
290
+ self.decoder_rates = decoder_rates
291
+ self.depthwise = depthwise
292
+
293
+ self.use_noise_block = use_noise_block
294
+
295
+ if latent_dim is None:
296
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
297
+
298
+ self.latent_dim = latent_dim
299
+ self.hop_length = np.prod(encoder_rates)
300
+ self.encoder = CausalEncoder(
301
+ encoder_dim,
302
+ latent_dim,
303
+ encoder_rates,
304
+ depthwise=depthwise,
305
+ )
306
+
307
+ self.decoder = CausalDecoder(
308
+ latent_dim,
309
+ decoder_dim,
310
+ decoder_rates,
311
+ depthwise=depthwise,
312
+ use_noise_block=use_noise_block,
313
+ )
314
+ self.sample_rate = sample_rate
315
+ self.chunk_size = math.prod(encoder_rates)
316
+
317
+ def preprocess(self, audio_data, sample_rate):
318
+ if sample_rate is None:
319
+ sample_rate = self.sample_rate
320
+ assert sample_rate == self.sample_rate
321
+ pad_to = self.hop_length
322
+ length = audio_data.shape[-1]
323
+ right_pad = math.ceil(length / pad_to) * pad_to - length
324
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
325
+
326
+ return audio_data
327
+
328
+ def decode(self, z: torch.Tensor):
329
+ """Decode given latent codes and return audio data
330
+
331
+ Parameters
332
+ ----------
333
+ z : Tensor[B x D x T]
334
+ Quantized continuous representation of input
335
+ length : int, optional
336
+ Number of samples in output audio, by default None
337
+
338
+ Returns
339
+ -------
340
+ dict
341
+ A dictionary with the following keys:
342
+ "audio" : Tensor[B x 1 x length]
343
+ Decoded audio data.
344
+ """
345
+ return self.decoder(z)
346
+
347
+ def encode(self, audio_data: torch.Tensor, sample_rate: int):
348
+ """
349
+ Args:
350
+ audio_data: Tensor[B x 1 x T]
351
+ sample_rate: int
352
+ Returns:
353
+ z: Tensor[B x D x T]
354
+ """
355
+ if audio_data.ndim == 2:
356
+ audio_data = audio_data.unsqueeze(1)
357
+
358
+ audio_data = self.preprocess(audio_data, sample_rate)
359
+ return self.encoder(audio_data)["mu"]
convert/src/voxcpm/modules/layers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scalar_quantization_layer import ScalarQuantizationLayer
convert/src/voxcpm/modules/layers/scalar_quantization_layer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ScalarQuantizationLayer(nn.Module):
6
+ def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
7
+ super().__init__()
8
+ self.in_dim = in_dim
9
+ self.out_dim = out_dim
10
+ self.latent_dim = latent_dim
11
+ self.scale = scale
12
+
13
+ self.in_proj = nn.Linear(in_dim, latent_dim)
14
+ self.out_proj = nn.Linear(latent_dim, out_dim)
15
+
16
+ def forward(self, hidden):
17
+ hidden = self.in_proj(hidden)
18
+ hidden = torch.tanh(hidden)
19
+
20
+ if self.training:
21
+ quantized = torch.round(hidden * self.scale) / self.scale
22
+ hidden = hidden + (quantized - hidden).detach()
23
+ else:
24
+ hidden = torch.round(hidden * self.scale) / self.scale
25
+
26
+ return self.out_proj(hidden)
convert/src/voxcpm/modules/locdit/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .unified_cfm import UnifiedCFM, CfmConfig
2
+ from .local_dit import VoxCPMLocDiT
convert/src/voxcpm/modules/locdit/local_dit.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..minicpm4 import MiniCPMModel, MiniCPM4Config
3
+ import torch.nn as nn
4
+ import math
5
+
6
+
7
+ class SinusoidalPosEmb(torch.nn.Module):
8
+ def __init__(self, dim):
9
+ super().__init__()
10
+ self.dim = dim
11
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
12
+
13
+ def forward(self, x, scale=1000):
14
+ if x.ndim < 1:
15
+ x = x.unsqueeze(0)
16
+ device = x.device
17
+ half_dim = self.dim // 2
18
+ emb = math.log(10000) / (half_dim - 1)
19
+ emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
20
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
21
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
22
+ return emb
23
+
24
+
25
+ class TimestepEmbedding(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_channels: int,
29
+ time_embed_dim: int,
30
+ out_dim: int = None,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
35
+ self.act = nn.SiLU()
36
+ if out_dim is not None:
37
+ time_embed_dim_out = out_dim
38
+ else:
39
+ time_embed_dim_out = time_embed_dim
40
+
41
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
42
+
43
+ def forward(self, sample):
44
+ sample = self.linear_1(sample)
45
+ sample = self.act(sample)
46
+ sample = self.linear_2(sample)
47
+ return sample
48
+
49
+
50
+ class VoxCPMLocDiT(nn.Module):
51
+ """
52
+ Diffusion model with a Transformer backbone.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ config: MiniCPM4Config,
58
+ in_channels: int = 64,
59
+ ):
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ self.out_channels = in_channels
63
+ self.config = config
64
+
65
+ self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
66
+ self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
67
+ self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
68
+
69
+ self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
70
+ self.time_mlp = TimestepEmbedding(
71
+ in_channels=config.hidden_size,
72
+ time_embed_dim=config.hidden_size,
73
+ )
74
+ self.delta_time_mlp = TimestepEmbedding(
75
+ in_channels=config.hidden_size,
76
+ time_embed_dim=config.hidden_size,
77
+ )
78
+
79
+ assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
80
+ self.decoder = MiniCPMModel(config)
81
+
82
+ def forward(
83
+ self,
84
+ x: torch.Tensor,
85
+ mu: torch.Tensor,
86
+ t: torch.Tensor,
87
+ cond: torch.Tensor,
88
+ dt: torch.Tensor,
89
+ ):
90
+ """
91
+ Forward pass of DiT.
92
+ x: (N, C, T) tensor of inputs
93
+ mu: (N, C) tensor of hidden embedding
94
+ t: (N,) tensor of diffusion timesteps
95
+ cond: (N, C, T') tensor of prefix conditions
96
+ dt: (N,) used for mean velocity (may be supported in the future...)
97
+ """
98
+ x = self.in_proj(x.transpose(1, 2).contiguous())
99
+
100
+ cond = self.cond_proj(cond.transpose(1, 2).contiguous())
101
+ prefix = cond.size(1)
102
+
103
+ t = self.time_embeddings(t).to(x.dtype)
104
+ t = self.time_mlp(t)
105
+ dt = self.time_embeddings(dt).to(x.dtype)
106
+ dt = self.delta_time_mlp(dt)
107
+ t = t + dt
108
+
109
+ x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
110
+ hidden, _ = self.decoder(x, is_causal=False)
111
+ hidden = hidden[:, prefix + 1 :, :]
112
+ hidden = self.out_proj(hidden)
113
+
114
+ return hidden.transpose(1, 2).contiguous()
convert/src/voxcpm/modules/locdit/unified_cfm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+ from .local_dit import VoxCPMLocDiT
4
+ import math
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class CfmConfig(BaseModel):
9
+ sigma_min: float = 1e-06
10
+ solver: str = "euler"
11
+ t_scheduler: str = "log-norm"
12
+
13
+
14
+ class UnifiedCFM(torch.nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels,
18
+ cfm_params: CfmConfig,
19
+ estimator: VoxCPMLocDiT,
20
+ mean_mode: bool = False,
21
+ ):
22
+ super().__init__()
23
+ self.solver = cfm_params.solver
24
+ self.sigma_min = cfm_params.sigma_min
25
+ self.t_scheduler = cfm_params.t_scheduler
26
+ self.in_channels = in_channels
27
+ self.mean_mode = mean_mode
28
+
29
+ # Just change the architecture of the estimator here
30
+ self.estimator = estimator
31
+
32
+ @torch.inference_mode()
33
+ def forward(
34
+ self,
35
+ mu: torch.Tensor,
36
+ n_timesteps: int,
37
+ patch_size: int,
38
+ cond: torch.Tensor,
39
+ temperature: float = 1.0,
40
+ cfg_value: float = 1.0,
41
+ sway_sampling_coef: float = 1.0,
42
+ use_cfg_zero_star: bool = True,
43
+ ):
44
+ """Forward diffusion
45
+
46
+ Args:
47
+ mu (torch.Tensor): output of encoder
48
+ shape: (batch_size, n_feats)
49
+ n_timesteps (int): number of diffusion steps
50
+ cond: Not used but kept for future purposes
51
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
52
+
53
+ Returns:
54
+ sample: generated mel-spectrogram
55
+ shape: (batch_size, n_feats, mel_timesteps)
56
+ """
57
+ b, c = mu.shape
58
+ t = patch_size
59
+ z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
60
+
61
+ t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
62
+ # Sway sampling strategy
63
+ t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
64
+
65
+ return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
66
+
67
+ def optimized_scale(self, positive_flat, negative_flat):
68
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
69
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
70
+
71
+ st_star = dot_product / squared_norm
72
+ return st_star
73
+
74
+ def solve_euler(
75
+ self,
76
+ x: torch.Tensor,
77
+ t_span: torch.Tensor,
78
+ mu: torch.Tensor,
79
+ cond: torch.Tensor,
80
+ cfg_value: float = 1.0,
81
+ use_cfg_zero_star: bool = True,
82
+ ):
83
+ """
84
+ Fixed euler solver for ODEs.
85
+ Args:
86
+ x (torch.Tensor): random noise
87
+ t_span (torch.Tensor): n_timesteps interpolated
88
+ shape: (n_timesteps + 1,)
89
+ mu (torch.Tensor): output of encoder
90
+ shape: (batch_size, n_feats)
91
+ cond: condition -- prefix prompt
92
+ cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
93
+ """
94
+ t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
95
+
96
+ sol = []
97
+ zero_init_steps = max(1, int(len(t_span) * 0.04))
98
+ for step in range(1, len(t_span)):
99
+ if use_cfg_zero_star and step <= zero_init_steps:
100
+ dphi_dt = 0.
101
+ else:
102
+ # Classifier-Free Guidance inference introduced in VoiceBox
103
+ b = x.size(0)
104
+ x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
105
+ mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
106
+ t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
107
+ dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
108
+ cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
109
+ x_in[:b], x_in[b:] = x, x
110
+ mu_in[:b] = mu
111
+ t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
112
+ dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
113
+ # not used now
114
+ if not self.mean_mode:
115
+ dt_in = torch.zeros_like(dt_in)
116
+ cond_in[:b], cond_in[b:] = cond, cond
117
+
118
+ dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
119
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
120
+
121
+ if use_cfg_zero_star:
122
+ positive_flat = dphi_dt.view(b, -1)
123
+ negative_flat = cfg_dphi_dt.view(b, -1)
124
+ st_star = self.optimized_scale(positive_flat, negative_flat)
125
+ st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
126
+ else:
127
+ st_star = 1.0
128
+
129
+ dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
130
+
131
+ x = x - dt * dphi_dt
132
+ t = t - dt
133
+ sol.append(x)
134
+ if step < len(t_span) - 1:
135
+ dt = t - t_span[step + 1]
136
+
137
+ return sol[-1]
convert/src/voxcpm/modules/locenc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .local_encoder import VoxCPMLocEnc
convert/src/voxcpm/modules/locenc/local_encoder.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..minicpm4 import MiniCPMModel, MiniCPM4Config
4
+ from einops import rearrange
5
+
6
+
7
+ class VoxCPMLocEnc(nn.Module):
8
+ def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
9
+ super().__init__()
10
+ self.config = config
11
+ self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
12
+ self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
13
+
14
+ assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
15
+ self.encoder = MiniCPMModel(config)
16
+
17
+ def forward(self, x):
18
+ """
19
+ x: [B, T, P, D]
20
+ """
21
+ B, T, P, D = x.shape
22
+
23
+ x = self.in_proj(x)
24
+ special_tokens = self.special_token.expand(B, T, 1, -1)
25
+ x = torch.cat([special_tokens, x], dim=2)
26
+ x = rearrange(x, "b t p c -> (b t) p c")
27
+ outputs, _ = self.encoder(x, is_causal=False)
28
+ cls_output = outputs[:, 0, :]
29
+
30
+ return rearrange(cls_output, "(b t) c -> b t c", b=B)
convert/src/voxcpm/modules/minicpm4/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config import MiniCPM4Config
2
+ from .model import MiniCPMModel
3
+ from .cache import StaticKVCache
convert/src/voxcpm/modules/minicpm4/cache.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import torch
3
+
4
+
5
+ class StaticKVCache:
6
+ def __init__(
7
+ self,
8
+ num_layers: int,
9
+ num_kv_heads: int,
10
+ dim_kv_head: int,
11
+ batch_size: int,
12
+ device: torch.device,
13
+ dtype: torch.dtype,
14
+ max_length: int = 8192,
15
+ ):
16
+ self.max_length = max_length
17
+ self.num_layers = num_layers
18
+
19
+ self.kv_cache = torch.zeros(
20
+ 2,
21
+ num_layers,
22
+ batch_size,
23
+ num_kv_heads,
24
+ max_length,
25
+ dim_kv_head,
26
+ device=device,
27
+ dtype=dtype,
28
+ )
29
+ self.current_length = 0
30
+
31
+ def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
33
+
34
+ def step(self) -> int:
35
+ if self.current_length >= self.max_length:
36
+ raise ValueError("KV cache is full")
37
+
38
+ ret = self.current_length
39
+ self.current_length += 1
40
+ return ret
41
+
42
+ def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
43
+ self.current_length = kv_caches[0][0].size(2)
44
+ self.kv_cache.zero_()
45
+ for i in range(self.num_layers):
46
+ self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
47
+ self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]
convert/src/voxcpm/modules/minicpm4/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+
5
+ class RopeScalingConfig(BaseModel):
6
+ type: str
7
+ long_factor: List[float]
8
+ short_factor: List[float]
9
+ original_max_position_embeddings: int
10
+
11
+
12
+ class MiniCPM4Config(BaseModel):
13
+ bos_token_id: int
14
+ eos_token_id: int
15
+ hidden_size: int
16
+ intermediate_size: int
17
+ max_position_embeddings: int
18
+ num_attention_heads: int
19
+ num_hidden_layers: int
20
+ num_key_value_heads: int
21
+ rms_norm_eps: float
22
+ rope_scaling: RopeScalingConfig
23
+ vocab_size: int
24
+ use_mup: bool = True
25
+ scale_emb: float
26
+ dim_model_base: int
27
+ scale_depth: float
28
+ rope_theta: float
29
+ kv_channels: int = None
convert/src/voxcpm/modules/minicpm4/model.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import MiniCPM4Config
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import List, Tuple
5
+ import math
6
+ from .cache import StaticKVCache
7
+
8
+
9
+ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
10
+ old_dtype = hidden.dtype
11
+ variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
12
+ hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
13
+ return hidden * weight
14
+
15
+
16
+ class MiniCPMRMSNorm(nn.Module):
17
+ def __init__(self, hidden_size, eps=1e-6):
18
+ """
19
+ MiniCPMRMSNorm is equivalent to T5LayerNorm
20
+ """
21
+ super().__init__()
22
+ self.weight = nn.Parameter(torch.ones(hidden_size))
23
+ self.variance_epsilon = eps
24
+
25
+ def forward(self, hidden_states):
26
+ return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
27
+
28
+
29
+ def rotate_half(x):
30
+ """Rotates half the hidden dims of the input."""
31
+ x1, x2 = x.chunk(2, dim=-1)
32
+ return torch.cat((-x2, x1), dim=-1)
33
+
34
+
35
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
36
+ """
37
+ Args:
38
+ q: Tensor(batch_size, num_heads, seq_len, head_dim)
39
+ k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
40
+ cos: Tensor(seq_len, head_dim)
41
+ sin: Tensor(seq_len, head_dim)
42
+ Returns:
43
+ Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
44
+ """
45
+ orig_dtype = q.dtype
46
+ q = q.to(torch.float32)
47
+ k = k.to(torch.float32)
48
+ q_embed = (q * cos) + (rotate_half(q) * sin)
49
+ k_embed = (k * cos) + (rotate_half(k) * sin)
50
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
51
+
52
+
53
+ def scaled_dot_product_attention_gqa_compat(
54
+ query: torch.Tensor,
55
+ key: torch.Tensor,
56
+ value: torch.Tensor,
57
+ *,
58
+ attn_mask: torch.Tensor | None = None,
59
+ is_causal: bool = False,
60
+ enable_gqa: bool = False,
61
+ ) -> torch.Tensor:
62
+ """ONNX-export friendly fallback for scaled_dot_product_attention(enable_gqa=True)."""
63
+ orig_dtype = query.dtype
64
+ query = query.to(torch.float32)
65
+ key = key.to(torch.float32)
66
+ value = value.to(torch.float32)
67
+
68
+ if enable_gqa and query.shape[-3] != key.shape[-3]:
69
+ repeat_factor = query.shape[-3] // key.shape[-3]
70
+ key = key.repeat_interleave(repeat_factor, dim=-3)
71
+ value = value.repeat_interleave(repeat_factor, dim=-3)
72
+
73
+ scale = 1.0 / math.sqrt(query.size(-1))
74
+ attn_scores = torch.matmul(query, key.transpose(-2, -1)) * scale
75
+
76
+ if is_causal:
77
+ q_len = query.size(-2)
78
+ k_len = key.size(-2)
79
+ q_pos = torch.arange(q_len, device=query.device).unsqueeze(-1)
80
+ k_pos = torch.arange(k_len, device=query.device).unsqueeze(0)
81
+ causal_mask = k_pos <= (q_pos + k_len - q_len)
82
+ attn_scores = attn_scores.masked_fill(~causal_mask, torch.finfo(attn_scores.dtype).min)
83
+
84
+ if attn_mask is not None:
85
+ if attn_mask.dtype == torch.bool:
86
+ while attn_mask.ndim < attn_scores.ndim:
87
+ attn_mask = attn_mask.unsqueeze(0)
88
+ attn_scores = attn_scores.masked_fill(~attn_mask, torch.finfo(attn_scores.dtype).min)
89
+ else:
90
+ attn_scores = attn_scores + attn_mask.to(attn_scores.dtype)
91
+
92
+ attn_probs = torch.softmax(attn_scores, dim=-1)
93
+ return torch.matmul(attn_probs, value).to(orig_dtype)
94
+
95
+
96
+ class MiniCPMLongRoPE(nn.Module):
97
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
98
+
99
+ def __init__(self, config: MiniCPM4Config):
100
+ super().__init__()
101
+ self.config = config
102
+ self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
103
+ self.base = config.rope_theta
104
+ self.max_position_embeddings = config.max_position_embeddings
105
+
106
+ self.short_factor = config.rope_scaling.short_factor
107
+ self.long_factor = config.rope_scaling.long_factor
108
+ self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
109
+
110
+ scale = (self.max_position_embeddings / self.original_max_position_embeddings)
111
+ self.scaling_factor = math.sqrt(
112
+ 1 + math.log(scale) / math.log(self.original_max_position_embeddings)
113
+ )
114
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
115
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
116
+
117
+ self.max_seq_len_cached = 0
118
+
119
+ self.register_buffer("cos_cached", torch.empty(0), persistent=False)
120
+ self.register_buffer("sin_cached", torch.empty(0), persistent=False)
121
+
122
+ self._set_cos_sin_cache(
123
+ seq_len=self.max_position_embeddings,
124
+ device=self.inv_freq.device,
125
+ dtype=torch.float32
126
+ )
127
+
128
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
129
+ """设置cos和sin缓存"""
130
+ self.max_seq_len_cached = seq_len
131
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
132
+
133
+ if seq_len > self.original_max_position_embeddings:
134
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
135
+ else:
136
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
137
+
138
+ freqs = torch.mul(
139
+ torch.outer(t, 1.0 / ext_factors).to(device=device),
140
+ self.inv_freq.to(device=device).to(dtype)
141
+ )
142
+
143
+ # 创建embeddings
144
+ emb = torch.cat((freqs, freqs), dim=-1)
145
+
146
+ self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
147
+ self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
148
+
149
+ def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
150
+ """
151
+ Args:
152
+ position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
153
+ Returns:
154
+ Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
155
+ """
156
+ cos = self.cos_cached[position_ids]
157
+ sin = self.sin_cached[position_ids]
158
+
159
+ return cos, sin
160
+
161
+
162
+ class MiniCPMAttention(nn.Module):
163
+ def __init__(self, config: MiniCPM4Config, layer_idx: int):
164
+ super().__init__()
165
+ self.config = config
166
+ self.layer_idx = layer_idx
167
+ self.hidden_size = config.hidden_size
168
+ self.num_heads = config.num_attention_heads
169
+ self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
170
+ self.num_key_value_heads = config.num_key_value_heads
171
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
172
+ self.max_position_embeddings = config.max_position_embeddings
173
+ self.rope_theta = 10000.0
174
+
175
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
176
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
177
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
178
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
184
+ is_causal: bool,
185
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
186
+ bsz, q_len, _ = hidden_states.size()
187
+
188
+ query_states = self.q_proj(hidden_states)
189
+ key_states = self.k_proj(hidden_states)
190
+ value_states = self.v_proj(hidden_states)
191
+
192
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
193
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
194
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
195
+
196
+ cos, sin = position_emb
197
+
198
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
199
+
200
+ # ref: https://github.com/pytorch/pytorch/issues/163597
201
+ # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
202
+ query_states = query_states.contiguous()
203
+ key_states = key_states.contiguous()
204
+ value_states = value_states.contiguous()
205
+ if torch.onnx.is_in_onnx_export():
206
+ attn_output = scaled_dot_product_attention_gqa_compat(
207
+ query_states,
208
+ key_states,
209
+ value_states,
210
+ is_causal=is_causal,
211
+ enable_gqa=True,
212
+ )
213
+ else:
214
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
215
+ query_states,
216
+ key_states,
217
+ value_states,
218
+ is_causal=is_causal,
219
+ enable_gqa=True,
220
+ )
221
+
222
+ attn_output = attn_output.transpose(1, 2).contiguous()
223
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
224
+
225
+ attn_output = self.o_proj(attn_output)
226
+
227
+ past_key_value = (key_states, value_states)
228
+ return attn_output, past_key_value
229
+
230
+ def forward_step(
231
+ self,
232
+ hidden_states: torch.Tensor,
233
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
234
+ position_id: int,
235
+ kv_cache: Tuple[torch.Tensor, torch.Tensor],
236
+ ) -> torch.Tensor:
237
+ bsz, _ = hidden_states.size()
238
+
239
+ query_states = self.q_proj(hidden_states)
240
+ key_states = self.k_proj(hidden_states)
241
+ value_states = self.v_proj(hidden_states)
242
+
243
+ query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
244
+ key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
245
+ value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
246
+
247
+ cos, sin = position_emb
248
+
249
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
250
+
251
+ key_cache, value_cache = kv_cache
252
+
253
+ key_cache[:, :, position_id, :] = key_states
254
+ value_cache[:, :, position_id, :] = value_states
255
+
256
+ attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
257
+
258
+ # ref: https://github.com/pytorch/pytorch/issues/163597
259
+ # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
260
+ query_states = query_states.unsqueeze(0)
261
+ key_cache = key_cache.unsqueeze(0)
262
+ value_cache = value_cache.unsqueeze(0)
263
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
264
+ query_states,
265
+ key_cache,
266
+ value_cache,
267
+ attn_mask=attn_mask,
268
+ enable_gqa=True,
269
+ )
270
+
271
+ attn_output = attn_output.transpose(1, 2).contiguous()
272
+ attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
273
+ attn_output = self.o_proj(attn_output)
274
+
275
+ return attn_output
276
+
277
+
278
+ class MiniCPMMLP(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.config = config
282
+ self.hidden_size = config.hidden_size
283
+ self.intermediate_size = config.intermediate_size
284
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
285
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
286
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
287
+ self.act_fn = nn.SiLU()
288
+
289
+ def forward(self, x):
290
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
291
+
292
+
293
+ class MiniCPMDecoderLayer(nn.Module):
294
+ def __init__(self, config: MiniCPM4Config, layer_idx: int):
295
+ super().__init__()
296
+ self.hidden_size = config.hidden_size
297
+ self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
298
+
299
+ self.mlp = MiniCPMMLP(config)
300
+ self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
+ self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302
+
303
+ self.scale_depth = config.scale_depth
304
+ self.num_hidden_layers = config.num_hidden_layers
305
+ self.use_mup = config.use_mup
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
311
+ is_causal: bool,
312
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
313
+ """
314
+ Args:
315
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
316
+ position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
317
+ is_causal (`bool`): whether the attention mask is causal
318
+ """
319
+ residual = hidden_states
320
+ hidden_states = self.input_layernorm(hidden_states)
321
+ # Self Attention
322
+ hidden_states, present_key_value = self.self_attn(
323
+ hidden_states=hidden_states,
324
+ position_emb=position_emb,
325
+ is_causal=is_causal,
326
+ )
327
+
328
+ if self.use_mup:
329
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
330
+ else:
331
+ hidden_states = residual + hidden_states
332
+
333
+ # Fully Connected
334
+ residual = hidden_states
335
+ hidden_states = self.post_attention_layernorm(hidden_states)
336
+
337
+ hidden_states = self.mlp(hidden_states)
338
+ if self.use_mup:
339
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
340
+ else:
341
+ hidden_states = residual + hidden_states
342
+
343
+ return hidden_states, present_key_value
344
+
345
+ def forward_step(
346
+ self,
347
+ hidden_states: torch.Tensor,
348
+ position_emb: Tuple[torch.Tensor, torch.Tensor],
349
+ position_id: torch.Tensor,
350
+ kv_cache: Tuple[torch.Tensor, torch.Tensor],
351
+ ) -> torch.Tensor:
352
+ residual = hidden_states
353
+ hidden_states = self.input_layernorm(hidden_states)
354
+ # Self Attention
355
+ hidden_states = self.self_attn.forward_step(
356
+ hidden_states=hidden_states,
357
+ position_emb=position_emb,
358
+ position_id=position_id,
359
+ kv_cache=kv_cache,
360
+ )
361
+
362
+ if self.use_mup:
363
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
364
+ else:
365
+ hidden_states = residual + hidden_states
366
+
367
+ # Fully Connected
368
+ residual = hidden_states
369
+ hidden_states = self.post_attention_layernorm(hidden_states)
370
+
371
+ hidden_states = self.mlp(hidden_states)
372
+ if self.use_mup:
373
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
374
+ else:
375
+ hidden_states = residual + hidden_states
376
+
377
+ return hidden_states
378
+
379
+
380
+ class MiniCPMModel(nn.Module):
381
+ """
382
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
383
+
384
+ Args:
385
+ config: MiniCPMConfig
386
+ """
387
+
388
+ def __init__(self, config: MiniCPM4Config):
389
+ super().__init__()
390
+ self.vocab_size = config.vocab_size
391
+ self.config = config
392
+
393
+ if config.vocab_size > 0:
394
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
395
+ else:
396
+ self.embed_tokens = nn.Identity()
397
+
398
+ self.layers = nn.ModuleList(
399
+ [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
400
+ )
401
+
402
+ self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
403
+ self.rope_emb = MiniCPMLongRoPE(config)
404
+
405
+ self.kv_cache = None
406
+
407
+ def forward(
408
+ self,
409
+ inputs_embeds: torch.Tensor,
410
+ is_causal: bool = True,
411
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
412
+ """
413
+ Args:
414
+ inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
415
+ is_causal: bool, whether the attention mask is causal
416
+ Returns:
417
+ hidden_states: Tensor(batch_size, seq_length, hidden_size)
418
+ next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
419
+ """
420
+ position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
421
+ position_emb = self.rope_emb(position_ids)
422
+ hidden_states = inputs_embeds
423
+
424
+ next_decoder_cache = []
425
+
426
+ for decoder_layer in self.layers:
427
+
428
+ hidden_states, this_cache = decoder_layer(
429
+ hidden_states,
430
+ position_emb,
431
+ is_causal,
432
+ )
433
+ next_decoder_cache.append(this_cache)
434
+ hidden_states = self.norm(hidden_states)
435
+ return hidden_states, next_decoder_cache
436
+
437
+ def forward_step(
438
+ self,
439
+ inputs_embeds: torch.Tensor,
440
+ position_id: torch.Tensor,
441
+ ) -> torch.Tensor:
442
+ """
443
+ Args:
444
+ inputs_embeds: Tensor(batch_size, hidden_size)
445
+ Returns:
446
+ hidden_states: Tensor(batch_size, hidden_size)
447
+ """
448
+ assert self.kv_cache is not None, "KV cache is not setup"
449
+
450
+ position_emb = self.rope_emb(position_id)
451
+ hidden_states = inputs_embeds
452
+
453
+ for i, decoder_layer in enumerate(self.layers):
454
+ hidden_states = decoder_layer.forward_step(
455
+ hidden_states,
456
+ position_emb,
457
+ position_id,
458
+ self.kv_cache.get_layer_cache(i),
459
+ )
460
+
461
+ hidden_states = self.norm(hidden_states)
462
+ return hidden_states
463
+
464
+ def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
465
+ self.kv_cache = StaticKVCache(
466
+ num_layers=self.config.num_hidden_layers,
467
+ num_kv_heads=self.config.num_key_value_heads,
468
+ dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
469
+ batch_size=batch_size,
470
+ device=device,
471
+ dtype=dtype,
472
+ max_length=max_length,
473
+ )
convert/src/voxcpm/utils/text_normalize.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some functions are copied from https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/utils/frontend_utils.py
2
+ import re
3
+ import regex
4
+ import inflect
5
+ from functools import partial
6
+ from wetext import Normalizer
7
+
8
+ chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
9
+
10
+ # whether contain chinese character
11
+ def contains_chinese(text):
12
+ return bool(chinese_char_pattern.search(text))
13
+
14
+
15
+ # replace special symbol
16
+ def replace_corner_mark(text):
17
+ text = text.replace('²', '平方')
18
+ text = text.replace('³', '立方')
19
+ text = text.replace('√', '根号')
20
+ text = text.replace('≈', '约等于')
21
+ text = text.replace('<', '小于')
22
+ return text
23
+
24
+
25
+ # remove meaningless symbol
26
+ def remove_bracket(text):
27
+ text = text.replace('(', ' ').replace(')', ' ')
28
+ text = text.replace('【', ' ').replace('】', ' ')
29
+ text = text.replace('`', '').replace('`', '')
30
+ text = text.replace("——", " ")
31
+ return text
32
+
33
+
34
+ # spell Arabic numerals
35
+ def spell_out_number(text: str, inflect_parser):
36
+ new_text = []
37
+ st = None
38
+ for i, c in enumerate(text):
39
+ if not c.isdigit():
40
+ if st is not None:
41
+ num_str = inflect_parser.number_to_words(text[st: i])
42
+ new_text.append(num_str)
43
+ st = None
44
+ new_text.append(c)
45
+ else:
46
+ if st is None:
47
+ st = i
48
+ if st is not None and st < len(text):
49
+ num_str = inflect_parser.number_to_words(text[st:])
50
+ new_text.append(num_str)
51
+ return ''.join(new_text)
52
+
53
+
54
+ # split paragrah logic:
55
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
56
+ # 2. cal sentence len according to lang
57
+ # 3. split sentence according to puncatation
58
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
59
+ def calc_utt_length(_text: str):
60
+ if lang == "zh":
61
+ return len(_text)
62
+ else:
63
+ return len(tokenize(_text))
64
+
65
+ def should_merge(_text: str):
66
+ if lang == "zh":
67
+ return len(_text) < merge_len
68
+ else:
69
+ return len(tokenize(_text)) < merge_len
70
+
71
+ if lang == "zh":
72
+ pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
73
+ else:
74
+ pounc = ['.', '?', '!', ';', ':']
75
+ if comma_split:
76
+ pounc.extend([',', ','])
77
+ st = 0
78
+ utts = []
79
+ for i, c in enumerate(text):
80
+ if c in pounc:
81
+ if len(text[st: i]) > 0:
82
+ utts.append(text[st: i] + c)
83
+ if i + 1 < len(text) and text[i + 1] in ['"', '”']:
84
+ tmp = utts.pop(-1)
85
+ utts.append(tmp + text[i + 1])
86
+ st = i + 2
87
+ else:
88
+ st = i + 1
89
+ if len(utts) == 0:
90
+ if lang == "zh":
91
+ utts.append(text + '。')
92
+ else:
93
+ utts.append(text + '.')
94
+ final_utts = []
95
+ cur_utt = ""
96
+ for utt in utts:
97
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
98
+ final_utts.append(cur_utt)
99
+ cur_utt = ""
100
+ cur_utt = cur_utt + utt
101
+ if len(cur_utt) > 0:
102
+ if should_merge(cur_utt) and len(final_utts) != 0:
103
+ final_utts[-1] = final_utts[-1] + cur_utt
104
+ else:
105
+ final_utts.append(cur_utt)
106
+
107
+ return final_utts
108
+
109
+
110
+ # remove blank between chinese character
111
+ def replace_blank(text: str):
112
+ out_str = []
113
+ for i, c in enumerate(text):
114
+ if c == " ":
115
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
116
+ (text[i - 1].isascii() and text[i - 1] != " ")):
117
+ out_str.append(c)
118
+ else:
119
+ out_str.append(c)
120
+ return "".join(out_str)
121
+
122
+ def clean_markdown(md_text: str) -> str:
123
+ # 去除代码块 ``` ```(包括多行)
124
+ md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
125
+
126
+ # 去除内联代码 `code`
127
+ md_text = re.sub(r"`[^`]*`", "", md_text)
128
+
129
+ # 去除图片语法 ![alt](url)
130
+ md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text)
131
+
132
+ # 去除链接但保留文本 [text](url) -> text
133
+ md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
134
+
135
+ # 替换无序列表符号
136
+ md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
137
+
138
+ # 去除HTML标签
139
+ md_text = re.sub(r"<[^>]+>", "", md_text)
140
+
141
+ # 去除标题符号(#)
142
+ md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE)
143
+
144
+ # 去除多余空格和空行
145
+ md_text = re.sub(r"\n\s*\n", "\n", md_text) # 多余空行
146
+ md_text = md_text.strip()
147
+
148
+ return md_text
149
+
150
+
151
+ def clean_text(text):
152
+ # 去除 Markdown 语法
153
+ text = clean_markdown(text)
154
+ # 匹配并移除表情符号
155
+ text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
156
+ # 去除换行符
157
+ text = text.replace("\n", " ")
158
+ text = text.replace("\t", " ")
159
+ text = text.replace('"', "\“")
160
+ return text
161
+
162
+ class TextNormalizer:
163
+ def __init__(self, tokenizer=None):
164
+ self.tokenizer = tokenizer
165
+ self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
166
+ self.en_tn_model = Normalizer(lang="en", operator="tn")
167
+ self.inflect_parser = inflect.engine()
168
+
169
+ def normalize(self, text, split=False):
170
+ # 去除 Markdown 语法,去除表情符号,去除换行符
171
+ lang = "zh" if contains_chinese(text) else "en"
172
+ text = clean_text(text)
173
+ if lang == "zh":
174
+ text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
175
+ if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
176
+ text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
177
+ text = self.zh_tn_model.normalize(text)
178
+ text = replace_blank(text)
179
+ text = replace_corner_mark(text)
180
+ text = remove_bracket(text)
181
+ else:
182
+ text = self.en_tn_model.normalize(text)
183
+ text = spell_out_number(text, self.inflect_parser)
184
+ if split is False:
185
+ return text
convert/src/voxcpm/zipenhancer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ZipEnhancer Module - Audio Denoising Enhancer
3
+
4
+ Provides on-demand import ZipEnhancer functionality for audio denoising processing.
5
+ Related dependencies are imported only when denoising functionality is needed.
6
+ """
7
+
8
+ import os
9
+ import tempfile
10
+ from typing import Optional, Union
11
+ import torchaudio
12
+ import torch
13
+ from modelscope.pipelines import pipeline
14
+ from modelscope.utils.constant import Tasks
15
+
16
+
17
+ class ZipEnhancer:
18
+ """ZipEnhancer Audio Denoising Enhancer"""
19
+ def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
20
+ """
21
+ Initialize ZipEnhancer
22
+ Args:
23
+ model_path: ModelScope model path or local path
24
+ """
25
+ self.model_path = model_path
26
+ self._pipeline = pipeline(
27
+ Tasks.acoustic_noise_suppression,
28
+ model=self.model_path
29
+ )
30
+
31
+ def _normalize_loudness(self, wav_path: str):
32
+ """
33
+ Audio loudness normalization
34
+
35
+ Args:
36
+ wav_path: Audio file path
37
+ """
38
+ audio, sr = torchaudio.load(wav_path)
39
+ loudness = torchaudio.functional.loudness(audio, sr)
40
+ normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
41
+ torchaudio.save(wav_path, normalized_audio, sr)
42
+
43
+ def enhance(self, input_path: str, output_path: Optional[str] = None,
44
+ normalize_loudness: bool = True) -> str:
45
+ """
46
+ Audio denoising enhancement
47
+ Args:
48
+ input_path: Input audio file path
49
+ output_path: Output audio file path (optional, creates temp file by default)
50
+ normalize_loudness: Whether to perform loudness normalization
51
+ Returns:
52
+ str: Output audio file path
53
+ Raises:
54
+ RuntimeError: If pipeline is not initialized or processing fails
55
+ """
56
+ if not os.path.exists(input_path):
57
+ raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
58
+ # Create temporary file if no output path is specified
59
+ if output_path is None:
60
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
61
+ output_path = tmp_file.name
62
+ try:
63
+ # Perform denoising processing
64
+ self._pipeline(input_path, output_path=output_path)
65
+ # Loudness normalization
66
+ if normalize_loudness:
67
+ self._normalize_loudness(output_path)
68
+ return output_path
69
+ except Exception as e:
70
+ # Clean up possibly created temporary files
71
+ if output_path and os.path.exists(output_path):
72
+ try:
73
+ os.unlink(output_path)
74
+ except OSError:
75
+ pass
76
+ raise RuntimeError(f"Audio denoising processing failed: {e}")