Add CFG parallel inference with new library
Browse files- README.md +72 -191
- onnx_infer-rknn2.py +26 -48
- onnx_infer.py +14 -22
README.md
CHANGED
|
@@ -25,7 +25,7 @@ VoxCPM 是一种创新的无分词器文本转语音(TTS)系统,重新定
|
|
| 25 |

|
| 26 |
|
| 27 |
|
| 28 |
-
- 推理速度(RKNN2):RK3588上RTF约
|
| 29 |
- 大致内存占用(RKNN2):约3.3GB
|
| 30 |
|
| 31 |
## 使用方法
|
|
@@ -35,7 +35,7 @@ VoxCPM 是一种创新的无分词器文本转语音(TTS)系统,重新定
|
|
| 35 |
2. 安装依赖
|
| 36 |
|
| 37 |
```bash
|
| 38 |
-
pip install
|
| 39 |
```
|
| 40 |
|
| 41 |
3. 运行
|
|
@@ -69,101 +69,42 @@ I rkllm: loading rkllm model from ./residual_lm.rkllm
|
|
| 69 |
I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: FP16
|
| 70 |
I rkllm: Enabled cpus: [4, 5, 6, 7]
|
| 71 |
I rkllm: Enabled cpus num: 4
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
I RKNN: [18:58:27.194] RKNN Driver Information, version: 0.9.8
|
| 93 |
-
I RKNN: [18:58:27.194] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 94 |
-
W RKNN: [18:58:27.317] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 95 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 96 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 97 |
-
I RKNN: [18:58:27.431] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 98 |
-
I RKNN: [18:58:27.431] RKNN Driver Information, version: 0.9.8
|
| 99 |
-
I RKNN: [18:58:27.431] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (@2025-04-03T08:26:16)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: dynamic_shape
|
| 100 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 101 |
-
I RKNN: [18:58:27.547] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 102 |
-
I RKNN: [18:58:27.547] RKNN Driver Information, version: 0.9.8
|
| 103 |
-
I RKNN: [18:58:27.547] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 104 |
-
W RKNN: [18:58:27.549] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 105 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 106 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 107 |
-
I RKNN: [18:58:27.728] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 108 |
-
I RKNN: [18:58:27.728] RKNN Driver Information, version: 0.9.8
|
| 109 |
-
I RKNN: [18:58:27.728] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 110 |
-
W RKNN: [18:58:27.819] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 111 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 112 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 113 |
-
I RKNN: [18:58:27.937] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 114 |
-
I RKNN: [18:58:27.937] RKNN Driver Information, version: 0.9.8
|
| 115 |
-
I RKNN: [18:58:27.937] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 116 |
-
W RKNN: [18:58:27.940] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 117 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 118 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 119 |
-
I RKNN: [18:58:28.058] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 120 |
-
I RKNN: [18:58:28.058] RKNN Driver Information, version: 0.9.8
|
| 121 |
-
I RKNN: [18:58:28.058] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 122 |
-
W RKNN: [18:58:28.060] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 123 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 124 |
-
[time] vae_encode_0: 1601.56 ms
|
| 125 |
-
[time] vae_encode_38400: 1605.46 ms
|
| 126 |
-
[time] vae_encode_76800: 1591.07 ms
|
| 127 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
|
| 128 |
-
[time] locenc_0: 819.49 ms
|
| 129 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
|
| 130 |
-
[time] locenc_64: 818.33 ms
|
| 131 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
|
| 132 |
-
[time] locenc_128: 819.09 ms
|
| 133 |
-
[time] base_lm initial: 579.08 ms
|
| 134 |
-
[time] fsq_init_0: 2.54 ms
|
| 135 |
-
[time] fsq_init_64: 1.86 ms
|
| 136 |
-
[time] fsq_init_128: 1.79 ms
|
| 137 |
-
[time] residual_lm initial: 139.10 ms
|
| 138 |
-
gen_loop: 0%| | 0/2000 [00:00<?, ?it/s][time] lm_to_dit: 0.82 ms
|
| 139 |
-
[time] res_to_dit: 0.56 ms
|
| 140 |
-
100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.32it/s]
|
| 141 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.4it/s]
|
| 142 |
-
[time] locenc_step: 16.32 ms
|
| 143 |
-
gen_loop: 0%| | 1/2000 [00:00<14:30, 2.30it/s][time] lm_to_dit: 0.57 ms
|
| 144 |
-
[time] res_to_dit: 0.44 ms
|
| 145 |
-
100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.10it/s]
|
| 146 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.1it/s]
|
| 147 |
-
[time] locenc_step: 15.84 ms
|
| 148 |
-
gen_loop: 0%| | 2/2000 [00:00<14:27, 2.30it/s][time] lm_to_dit: 0.56 ms
|
| 149 |
-
[time] res_to_dit: 0.50 ms
|
| 150 |
-
100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 31.93it/s]
|
| 151 |
|
| 152 |
...
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
[time]
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
[time] locenc_step: 15.
|
| 161 |
-
gen_loop: 6%|███
|
| 162 |
-
[time] vae_decode_0:
|
| 163 |
-
[time] vae_decode_60:
|
| 164 |
-
[time] vae_decode_120:
|
| 165 |
-
[time] vae_decode_180:
|
| 166 |
-
[time] vae_decode_240:
|
| 167 |
Saved: rknn_output.wav
|
| 168 |
```
|
| 169 |
|
|
@@ -176,7 +117,7 @@ Saved: rknn_output.wav
|
|
| 176 |
- 某些情况下语音生成可能陷入死循环,原项目似乎有检测死循环的机制,但我这里没有实现。
|
| 177 |
- 由于RKNN工具链的内部问题,locenc模型没有办法在一个模型里配置两种输入长度的两组shape,因此只能单独转换两个模型。
|
| 178 |
- 由于RKLLM工具链/运行时的内部问题,两个LLM的输出张量的数值都只有正确结果的四分之一,手动乘4之后可以得到正确结果。
|
| 179 |
-
- 由于RKNN工具链目前不支持非4维输入模型多batch使用多NPU核的数据并行推理,脚本中CFG是分两次单独进行的,速度较慢。
|
| 180 |
|
| 181 |
|
| 182 |
## 参考
|
|
@@ -190,7 +131,7 @@ VoxCPM is an innovative tokenizer-free Text-to-Speech (TTS) system that redefine
|
|
| 190 |
|
| 191 |
Unlike mainstream approaches that convert speech into discrete tokens, VoxCPM adopts an end-to-end diffusion autoregressive architecture that directly generates continuous speech representations from text. Built on the MiniCPM-4 backbone, it achieves implicit semantic-acoustic decoupling through hierarchical language modeling and FSQ constraints, greatly enhancing expressiveness and generation stability.
|
| 192 |
|
| 193 |
-
- Inference speed (RKNN2): RTF approximately
|
| 194 |
- Approximate memory usage (RKNN2): ~3.3GB
|
| 195 |
|
| 196 |
## Usage
|
|
@@ -200,7 +141,7 @@ Unlike mainstream approaches that convert speech into discrete tokens, VoxCPM ad
|
|
| 200 |
2. Install dependencies
|
| 201 |
|
| 202 |
```bash
|
| 203 |
-
pip install
|
| 204 |
```
|
| 205 |
|
| 206 |
3. Run
|
|
@@ -234,104 +175,44 @@ I rkllm: loading rkllm model from ./residual_lm.rkllm
|
|
| 234 |
I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: FP16
|
| 235 |
I rkllm: Enabled cpus: [4, 5, 6, 7]
|
| 236 |
I rkllm: Enabled cpus num: 4
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
I RKNN: [18:58:27.194] RKNN Driver Information, version: 0.9.8
|
| 258 |
-
I RKNN: [18:58:27.194] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 259 |
-
W RKNN: [18:58:27.317] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 260 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 261 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 262 |
-
I RKNN: [18:58:27.431] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 263 |
-
I RKNN: [18:58:27.431] RKNN Driver Information, version: 0.9.8
|
| 264 |
-
I RKNN: [18:58:27.431] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (@2025-04-03T08:26:16)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: dynamic_shape
|
| 265 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 266 |
-
I RKNN: [18:58:27.547] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 267 |
-
I RKNN: [18:58:27.547] RKNN Driver Information, version: 0.9.8
|
| 268 |
-
I RKNN: [18:58:27.547] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 269 |
-
W RKNN: [18:58:27.549] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 270 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 271 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 272 |
-
I RKNN: [18:58:27.728] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 273 |
-
I RKNN: [18:58:27.728] RKNN Driver Information, version: 0.9.8
|
| 274 |
-
I RKNN: [18:58:27.728] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 275 |
-
W RKNN: [18:58:27.819] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 276 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 277 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 278 |
-
I RKNN: [18:58:27.937] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 279 |
-
I RKNN: [18:58:27.937] RKNN Driver Information, version: 0.9.8
|
| 280 |
-
I RKNN: [18:58:27.937] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 281 |
-
W RKNN: [18:58:27.940] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 282 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 283 |
-
W rknn-toolkit-lite2 version: 2.3.2
|
| 284 |
-
I RKNN: [18:58:28.058] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
|
| 285 |
-
I RKNN: [18:58:28.058] RKNN Driver Information, version: 0.9.8
|
| 286 |
-
I RKNN: [18:58:28.058] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
|
| 287 |
-
W RKNN: [18:58:28.060] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
|
| 288 |
-
W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
|
| 289 |
-
[time] vae_encode_0: 1601.56 ms
|
| 290 |
-
[time] vae_encode_38400: 1605.46 ms
|
| 291 |
-
[time] vae_encode_76800: 1591.07 ms
|
| 292 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
|
| 293 |
-
[time] locenc_0: 819.49 ms
|
| 294 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
|
| 295 |
-
[time] locenc_64: 818.33 ms
|
| 296 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
|
| 297 |
-
[time] locenc_128: 819.09 ms
|
| 298 |
-
[time] base_lm initial: 579.08 ms
|
| 299 |
-
[time] fsq_init_0: 2.54 ms
|
| 300 |
-
[time] fsq_init_64: 1.86 ms
|
| 301 |
-
[time] fsq_init_128: 1.79 ms
|
| 302 |
-
[time] residual_lm initial: 139.10 ms
|
| 303 |
-
gen_loop: 0%| | 0/2000 [00:00<?, ?it/s][time] lm_to_dit: 0.82 ms
|
| 304 |
-
[time] res_to_dit: 0.56 ms
|
| 305 |
-
100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.32it/s]
|
| 306 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.4it/s]
|
| 307 |
-
[time] locenc_step: 16.32 ms
|
| 308 |
-
gen_loop: 0%| | 1/2000 [00:00<14:30, 2.30it/s][time] lm_to_dit: 0.57 ms
|
| 309 |
-
[time] res_to_dit: 0.44 ms
|
| 310 |
-
100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.10it/s]
|
| 311 |
-
W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.1it/s]
|
| 312 |
-
[time] locenc_step: 15.84 ms
|
| 313 |
-
gen_loop: 0%| | 2/2000 [00:00<14:27, 2.30it/s][time] lm_to_dit: 0.56 ms
|
| 314 |
-
[time] res_to_dit: 0.50 ms
|
| 315 |
-
100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 31.93it/s]
|
| 316 |
|
| 317 |
...
|
| 318 |
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
[time]
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
[time] locenc_step: 15.
|
| 326 |
-
gen_loop: 6%|███
|
| 327 |
-
[time] vae_decode_0:
|
| 328 |
-
[time] vae_decode_60:
|
| 329 |
-
[time] vae_decode_120:
|
| 330 |
-
[time] vae_decode_180:
|
| 331 |
-
[time] vae_decode_240:
|
| 332 |
Saved: rknn_output.wav
|
| 333 |
```
|
| 334 |
-
|
| 335 |
## Model Conversion
|
| 336 |
|
| 337 |
#### TODO: Documentation to be added
|
|
@@ -341,7 +222,7 @@ Saved: rknn_output.wav
|
|
| 341 |
- In some cases, speech generation may fall into an infinite loop. The original project seems to have a mechanism to detect infinite loops, but it is not implemented here.
|
| 342 |
- Due to internal issues with the RKNN toolchain, the locenc model cannot configure two sets of shapes for two different input lengths in a single model, so two separate models must be converted.
|
| 343 |
- Due to internal issues with the RKLLM toolchain/runtime, the output tensor values of both LLMs are only one-quarter of the correct result. Multiplying by 4 manually yields the correct result.
|
| 344 |
-
- Since the RKNN toolchain currently does not support data-parallel inference using multiple NPU cores for non-4D input models with multiple batches, CFG in the script is performed separately in two passes, which is relatively slow.
|
| 345 |
|
| 346 |
## References
|
| 347 |
- [openbmb/VoxCPM-0.5B](https://huggingface.co/openbmb/VoxCPM-0.5B)
|
|
|
|
| 25 |

|
| 26 |
|
| 27 |
|
| 28 |
+
- 推理速度(RKNN2):RK3588上RTF约4.5(生成10s音频需要推理45s)
|
| 29 |
- 大致内存占用(RKNN2):约3.3GB
|
| 30 |
|
| 31 |
## 使用方法
|
|
|
|
| 35 |
2. 安装依赖
|
| 36 |
|
| 37 |
```bash
|
| 38 |
+
pip install numpy scipy soundfile tqdm transformers sentencepiece ztu-somemodelruntime-ez-rknn-async
|
| 39 |
```
|
| 40 |
|
| 41 |
3. 运行
|
|
|
|
| 69 |
I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: FP16
|
| 70 |
I rkllm: Enabled cpus: [4, 5, 6, 7]
|
| 71 |
I rkllm: Enabled cpus num: 4
|
| 72 |
+
[time] vae_encode_0: 1502.91 ms
|
| 73 |
+
[time] vae_encode_38400: 1443.79 ms
|
| 74 |
+
[time] vae_encode_76800: 1418.36 ms
|
| 75 |
+
[time] locenc_0: 820.25 ms
|
| 76 |
+
[time] locenc_64: 814.78 ms
|
| 77 |
+
[time] locenc_128: 815.60 ms
|
| 78 |
+
[time] base_lm initial: 549.21 ms
|
| 79 |
+
[time] fsq_init_0: 5.34 ms
|
| 80 |
+
[time] fsq_init_64: 3.95 ms
|
| 81 |
+
[time] fsq_init_128: 4.17 ms
|
| 82 |
+
[time] residual_lm initial: 131.22 ms
|
| 83 |
+
gen_loop: 0%| | 0/2000 [00:00<?, ?it/s][time] lm_to_dit: 1.26 ms
|
| 84 |
+
[time] res_to_dit: 1.01 ms
|
| 85 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 60.13it/s]
|
| 86 |
+
[time] locenc_step: 16.43 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 61.20it/s]
|
| 87 |
+
gen_loop: 0%| | 1/2000 [00:00<09:43, 3.42it/s][time] lm_to_dit: 0.75 ms
|
| 88 |
+
[time] res_to_dit: 0.55 ms
|
| 89 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 63.99it/s]
|
| 90 |
+
[time] locenc_step: 15.93 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 67.27it/s]
|
| 91 |
+
gen_loop: 0%| | 2/2000 [00:00<09:25, 3.53it/s][time] lm_to_dit: 0.74 ms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
...
|
| 94 |
|
| 95 |
+
[time] res_to_dit: 0.59 ms
|
| 96 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 64.19it/s]
|
| 97 |
+
[time] locenc_step: 15.73 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 67.34it/s]
|
| 98 |
+
gen_loop: 6%|████▎ | 123/2000 [00:34<08:47, 3.56it/s][time] lm_to_dit: 0.76 ms
|
| 99 |
+
[time] res_to_dit: 0.56 ms
|
| 100 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 64.08it/s]
|
| 101 |
+
[time] locenc_step: 15.82 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 67.43it/s]
|
| 102 |
+
gen_loop: 6%|████▎ | 123/2000 [00:34<08:47, 3.56it/s]
|
| 103 |
+
[time] vae_decode_0: 1153.02 ms
|
| 104 |
+
[time] vae_decode_60: 1102.36 ms
|
| 105 |
+
[time] vae_decode_120: 1105.00 ms
|
| 106 |
+
[time] vae_decode_180: 1105.60 ms
|
| 107 |
+
[time] vae_decode_240: 1082.36 ms
|
| 108 |
Saved: rknn_output.wav
|
| 109 |
```
|
| 110 |
|
|
|
|
| 117 |
- 某些情况下语音生成可能陷入死循环,原项目似乎有检测死循环的机制,但我这里没有实现。
|
| 118 |
- 由于RKNN工具链的内部问题,locenc模型没有办法在一个模型里配置两种输入长度的两组shape,因此只能单独转换两个模型。
|
| 119 |
- 由于RKLLM工具链/运行时的内部问题,两个LLM的输出张量的数值都只有正确结果的四分之一,手动乘4之后可以得到正确结果。
|
| 120 |
+
- ~~由于RKNN工具链目前不支持非4维输入模型多batch使用多NPU核的数据并行推理,脚本中CFG是分两次单独进行的,速度较慢。~~(已修复)
|
| 121 |
|
| 122 |
|
| 123 |
## 参考
|
|
|
|
| 131 |
|
| 132 |
Unlike mainstream approaches that convert speech into discrete tokens, VoxCPM adopts an end-to-end diffusion autoregressive architecture that directly generates continuous speech representations from text. Built on the MiniCPM-4 backbone, it achieves implicit semantic-acoustic decoupling through hierarchical language modeling and FSQ constraints, greatly enhancing expressiveness and generation stability.
|
| 133 |
|
| 134 |
+
- Inference speed (RKNN2): RTF approximately 4.5 on RK3588 (45s inference time to generate 10s audio)
|
| 135 |
- Approximate memory usage (RKNN2): ~3.3GB
|
| 136 |
|
| 137 |
## Usage
|
|
|
|
| 141 |
2. Install dependencies
|
| 142 |
|
| 143 |
```bash
|
| 144 |
+
pip install numpy scipy soundfile tqdm transformers sentencepiece ztu-somemodelruntime-ez-rknn-async
|
| 145 |
```
|
| 146 |
|
| 147 |
3. Run
|
|
|
|
| 175 |
I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: FP16
|
| 176 |
I rkllm: Enabled cpus: [4, 5, 6, 7]
|
| 177 |
I rkllm: Enabled cpus num: 4
|
| 178 |
+
[time] vae_encode_0: 1502.91 ms
|
| 179 |
+
[time] vae_encode_38400: 1443.79 ms
|
| 180 |
+
[time] vae_encode_76800: 1418.36 ms
|
| 181 |
+
[time] locenc_0: 820.25 ms
|
| 182 |
+
[time] locenc_64: 814.78 ms
|
| 183 |
+
[time] locenc_128: 815.60 ms
|
| 184 |
+
[time] base_lm initial: 549.21 ms
|
| 185 |
+
[time] fsq_init_0: 5.34 ms
|
| 186 |
+
[time] fsq_init_64: 3.95 ms
|
| 187 |
+
[time] fsq_init_128: 4.17 ms
|
| 188 |
+
[time] residual_lm initial: 131.22 ms
|
| 189 |
+
gen_loop: 0%| | 0/2000 [00:00<?, ?it/s][time] lm_to_dit: 1.26 ms
|
| 190 |
+
[time] res_to_dit: 1.01 ms
|
| 191 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 60.13it/s]
|
| 192 |
+
[time] locenc_step: 16.43 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 61.20it/s]
|
| 193 |
+
gen_loop: 0%| | 1/2000 [00:00<09:43, 3.42it/s][time] lm_to_dit: 0.75 ms
|
| 194 |
+
[time] res_to_dit: 0.55 ms
|
| 195 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 63.99it/s]
|
| 196 |
+
[time] locenc_step: 15.93 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 67.27it/s]
|
| 197 |
+
gen_loop: 0%| | 2/2000 [00:00<09:25, 3.53it/s][time] lm_to_dit: 0.74 ms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
...
|
| 200 |
|
| 201 |
+
[time] res_to_dit: 0.59 ms
|
| 202 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 64.19it/s]
|
| 203 |
+
[time] locenc_step: 15.73 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 67.34it/s]
|
| 204 |
+
gen_loop: 6%|████▎ | 123/2000 [00:34<08:47, 3.56it/s][time] lm_to_dit: 0.76 ms
|
| 205 |
+
[time] res_to_dit: 0.56 ms
|
| 206 |
+
100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 64.08it/s]
|
| 207 |
+
[time] locenc_step: 15.82 ms████████████████████████████████████▍ | 7/10 [00:00<00:00, 67.43it/s]
|
| 208 |
+
gen_loop: 6%|████▎ | 123/2000 [00:34<08:47, 3.56it/s]
|
| 209 |
+
[time] vae_decode_0: 1153.02 ms
|
| 210 |
+
[time] vae_decode_60: 1102.36 ms
|
| 211 |
+
[time] vae_decode_120: 1105.00 ms
|
| 212 |
+
[time] vae_decode_180: 1105.60 ms
|
| 213 |
+
[time] vae_decode_240: 1082.36 ms
|
| 214 |
Saved: rknn_output.wav
|
| 215 |
```
|
|
|
|
| 216 |
## Model Conversion
|
| 217 |
|
| 218 |
#### TODO: Documentation to be added
|
|
|
|
| 222 |
- In some cases, speech generation may fall into an infinite loop. The original project seems to have a mechanism to detect infinite loops, but it is not implemented here.
|
| 223 |
- Due to internal issues with the RKNN toolchain, the locenc model cannot configure two sets of shapes for two different input lengths in a single model, so two separate models must be converted.
|
| 224 |
- Due to internal issues with the RKLLM toolchain/runtime, the output tensor values of both LLMs are only one-quarter of the correct result. Multiplying by 4 manually yields the correct result.
|
| 225 |
+
- ~~Since the RKNN toolchain currently does not support data-parallel inference using multiple NPU cores for non-4D input models with multiple batches, CFG in the script is performed separately in two passes, which is relatively slow.~~(Solved)
|
| 226 |
|
| 227 |
## References
|
| 228 |
- [openbmb/VoxCPM-0.5B](https://huggingface.co/openbmb/VoxCPM-0.5B)
|
onnx_infer-rknn2.py
CHANGED
|
@@ -12,7 +12,7 @@ from rkllm_binding import *
|
|
| 12 |
|
| 13 |
from transformers import AutoTokenizer
|
| 14 |
|
| 15 |
-
import
|
| 16 |
|
| 17 |
def mask_multichar_chinese_tokens(tokenizer):
|
| 18 |
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
|
@@ -97,12 +97,6 @@ def mask_multichar_chinese_tokens(tokenizer):
|
|
| 97 |
|
| 98 |
return CharTokenizerWrapper(tokenizer)
|
| 99 |
|
| 100 |
-
|
| 101 |
-
def load_rknn(path: str, providers):
|
| 102 |
-
if not os.path.exists(path):
|
| 103 |
-
raise FileNotFoundError(f"ONNX file not found: {path}")
|
| 104 |
-
return ort.InferenceSession(path, providers=providers)
|
| 105 |
-
|
| 106 |
def ensure_numpy(arr, dtype=None):
|
| 107 |
np_arr = np.asarray(arr)
|
| 108 |
if dtype is not None:
|
|
@@ -110,10 +104,10 @@ def ensure_numpy(arr, dtype=None):
|
|
| 110 |
return np_arr
|
| 111 |
|
| 112 |
|
| 113 |
-
def run_ort(session: ort.InferenceSession, inputs: dict, name: str = None):
|
| 114 |
start = time.perf_counter()
|
| 115 |
ort_inputs = {k: ensure_numpy(v) for k, v in inputs.items()}
|
| 116 |
-
outputs = session.run(None, ort_inputs) # noqa: SLF001
|
| 117 |
if name:
|
| 118 |
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 119 |
print(f"[time] {name}: {elapsed_ms:.2f} ms")
|
|
@@ -249,35 +243,27 @@ def cfm_euler_with_onnx_step(
|
|
| 249 |
t_in = np.full((b,), t, dtype=dtype)
|
| 250 |
dt_in = np.full((b,), dt if mean_mode else 0.0, dtype=dtype)
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
"x": x,
|
| 258 |
-
"mu": mu,
|
| 259 |
-
"t": t_in,
|
| 260 |
-
"cond": cond,
|
| 261 |
-
"dt": dt_in,
|
| 262 |
-
},
|
| 263 |
-
),
|
| 264 |
-
dtype=dtype,
|
| 265 |
-
)
|
| 266 |
|
| 267 |
-
|
| 268 |
-
dphi_dt_neg = np.asarray(
|
| 269 |
run_ort(
|
| 270 |
dit_sess,
|
| 271 |
{
|
| 272 |
-
"x":
|
| 273 |
-
"mu":
|
| 274 |
-
"t":
|
| 275 |
-
"cond":
|
| 276 |
-
"dt":
|
| 277 |
},
|
|
|
|
| 278 |
),
|
| 279 |
dtype=dtype,
|
| 280 |
)
|
|
|
|
| 281 |
|
| 282 |
if use_cfg_zero_star:
|
| 283 |
positive_flat = dphi_dt_pos.reshape(b, -1)
|
|
@@ -361,16 +347,8 @@ def main():
|
|
| 361 |
parser.add_argument("--min-len", type=int, default=2, help="Minimum generated patch count before stop allowed.")
|
| 362 |
parser.add_argument("--max-len", type=int, default=2000, help="Maximum generated patch count.")
|
| 363 |
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
| 364 |
-
parser.add_argument(
|
| 365 |
-
"--providers",
|
| 366 |
-
nargs="+",
|
| 367 |
-
default=None,
|
| 368 |
-
help="ONNX Runtime providers (e.g., CUDAExecutionProvider CPUExecutionProvider).",
|
| 369 |
-
)
|
| 370 |
-
args = parser.parse_args()
|
| 371 |
-
|
| 372 |
-
providers = args.providers or ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 373 |
|
|
|
|
| 374 |
# Seed
|
| 375 |
if args.seed is not None:
|
| 376 |
random.seed(args.seed)
|
|
@@ -401,15 +379,15 @@ def main():
|
|
| 401 |
fixed_seq_len = 64 # target platform prefers fixed seq len for locenc/fsq
|
| 402 |
|
| 403 |
# Load ONNX sessions
|
| 404 |
-
vae_encode_sess =
|
| 405 |
-
vae_decode_sess =
|
| 406 |
-
locenc_64_sess =
|
| 407 |
-
locenc_1_sess =
|
| 408 |
-
fsq_sess =
|
| 409 |
-
stop_sess =
|
| 410 |
-
dit_step_sess =
|
| 411 |
-
lm_to_dit_sess =
|
| 412 |
-
res_to_dit_sess =
|
| 413 |
|
| 414 |
# Build text/audio features
|
| 415 |
if args.prompt_audio:
|
|
|
|
| 12 |
|
| 13 |
from transformers import AutoTokenizer
|
| 14 |
|
| 15 |
+
import ztu_somemodelruntime_ez_rknn_async as ort
|
| 16 |
|
| 17 |
def mask_multichar_chinese_tokens(tokenizer):
|
| 18 |
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
|
|
|
| 97 |
|
| 98 |
return CharTokenizerWrapper(tokenizer)
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def ensure_numpy(arr, dtype=None):
|
| 101 |
np_arr = np.asarray(arr)
|
| 102 |
if dtype is not None:
|
|
|
|
| 104 |
return np_arr
|
| 105 |
|
| 106 |
|
| 107 |
+
def run_ort(session: ort.InferenceSession, inputs: dict, name: str = None, run_options=None):
|
| 108 |
start = time.perf_counter()
|
| 109 |
ort_inputs = {k: ensure_numpy(v) for k, v in inputs.items()}
|
| 110 |
+
outputs = session.run(None, ort_inputs, run_options=run_options) # noqa: SLF001
|
| 111 |
if name:
|
| 112 |
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 113 |
print(f"[time] {name}: {elapsed_ms:.2f} ms")
|
|
|
|
| 243 |
t_in = np.full((b,), t, dtype=dtype)
|
| 244 |
dt_in = np.full((b,), dt if mean_mode else 0.0, dtype=dtype)
|
| 245 |
|
| 246 |
+
x_batch = np.concatenate([x, x], axis=0)
|
| 247 |
+
mu_batch = np.concatenate([mu, np.zeros_like(mu)], axis=0)
|
| 248 |
+
t_batch = np.concatenate([t_in, t_in], axis=0)
|
| 249 |
+
cond_batch = np.concatenate([cond, np.zeros_like(cond)], axis=0)
|
| 250 |
+
dt_batch = np.concatenate([dt_in, np.zeros_like(dt_in)], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
dphi_dt_batch = np.asarray(
|
|
|
|
| 253 |
run_ort(
|
| 254 |
dit_sess,
|
| 255 |
{
|
| 256 |
+
"x": x_batch,
|
| 257 |
+
"mu": mu_batch,
|
| 258 |
+
"t": t_batch,
|
| 259 |
+
"cond": cond_batch,
|
| 260 |
+
"dt": dt_batch,
|
| 261 |
},
|
| 262 |
+
run_options={"ztu_modelrt_dispatch_batch": True}
|
| 263 |
),
|
| 264 |
dtype=dtype,
|
| 265 |
)
|
| 266 |
+
dphi_dt_pos, dphi_dt_neg = np.split(dphi_dt_batch, [b], axis=0)
|
| 267 |
|
| 268 |
if use_cfg_zero_star:
|
| 269 |
positive_flat = dphi_dt_pos.reshape(b, -1)
|
|
|
|
| 347 |
parser.add_argument("--min-len", type=int, default=2, help="Minimum generated patch count before stop allowed.")
|
| 348 |
parser.add_argument("--max-len", type=int, default=2000, help="Maximum generated patch count.")
|
| 349 |
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
args = parser.parse_args()
|
| 352 |
# Seed
|
| 353 |
if args.seed is not None:
|
| 354 |
random.seed(args.seed)
|
|
|
|
| 379 |
fixed_seq_len = 64 # target platform prefers fixed seq len for locenc/fsq
|
| 380 |
|
| 381 |
# Load ONNX sessions
|
| 382 |
+
vae_encode_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "audio_vae_encode.rknn"))
|
| 383 |
+
vae_decode_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "audio_vae_decode.rknn"))
|
| 384 |
+
locenc_64_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "locenc_64.rknn"))
|
| 385 |
+
locenc_1_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "locenc_1.rknn"))
|
| 386 |
+
fsq_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "fsq_layer.rknn"))
|
| 387 |
+
stop_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "stop_head.rknn"))
|
| 388 |
+
dit_step_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "dit_step.rknn"), provider_options=[{"schedule": [0,1]}])
|
| 389 |
+
lm_to_dit_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "lm_to_dit_proj.rknn"))
|
| 390 |
+
res_to_dit_sess = ort.InferenceSession(os.path.join(args.onnx_dir, "res_to_dit_proj.rknn"))
|
| 391 |
|
| 392 |
# Build text/audio features
|
| 393 |
if args.prompt_audio:
|
onnx_infer.py
CHANGED
|
@@ -153,33 +153,25 @@ def cfm_euler_with_onnx_step(
|
|
| 153 |
if not mean_mode:
|
| 154 |
dt_in = torch.zeros_like(dt_in)
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
"mu": mu,
|
| 162 |
-
"t": t_in,
|
| 163 |
-
"cond": cond,
|
| 164 |
-
"dt": dt_in,
|
| 165 |
-
},
|
| 166 |
-
name=f"dit_step_pos_{step}",
|
| 167 |
-
)
|
| 168 |
-
dphi_dt_pos = torch.from_numpy(dphi_dt_pos).to(device=device, dtype=dtype)
|
| 169 |
|
| 170 |
-
|
| 171 |
-
dphi_dt_neg = run_ort(
|
| 172 |
dit_sess,
|
| 173 |
{
|
| 174 |
-
"x":
|
| 175 |
-
"mu":
|
| 176 |
-
"t":
|
| 177 |
-
"cond":
|
| 178 |
-
"dt":
|
| 179 |
},
|
| 180 |
-
name=f"
|
| 181 |
)
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
if use_cfg_zero_star:
|
| 185 |
positive_flat = dphi_dt_pos.view(b, -1)
|
|
|
|
| 153 |
if not mean_mode:
|
| 154 |
dt_in = torch.zeros_like(dt_in)
|
| 155 |
|
| 156 |
+
x_batch = torch.cat([x, x], dim=0)
|
| 157 |
+
mu_batch = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
| 158 |
+
t_batch = torch.cat([t_in, t_in], dim=0)
|
| 159 |
+
cond_batch = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
| 160 |
+
dt_batch = torch.cat([dt_in, torch.zeros_like(dt_in)], dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
dphi_dt_batch = run_ort(
|
|
|
|
| 163 |
dit_sess,
|
| 164 |
{
|
| 165 |
+
"x": x_batch,
|
| 166 |
+
"mu": mu_batch,
|
| 167 |
+
"t": t_batch,
|
| 168 |
+
"cond": cond_batch,
|
| 169 |
+
"dt": dt_batch,
|
| 170 |
},
|
| 171 |
+
name=f"dit_step_b2_{step}",
|
| 172 |
)
|
| 173 |
+
dphi_dt_batch = torch.from_numpy(dphi_dt_batch).to(device=device, dtype=dtype)
|
| 174 |
+
dphi_dt_pos, dphi_dt_neg = torch.split(dphi_dt_batch, [b, b], dim=0)
|
| 175 |
|
| 176 |
if use_cfg_zero_star:
|
| 177 |
positive_flat = dphi_dt_pos.view(b, -1)
|