happyme531 commited on
Commit
621e4aa
·
verified ·
1 Parent(s): 5e33de2

Upload 34 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audio_vae_decode.rknn filter=lfs diff=lfs merge=lfs -text
37
+ audio_vae_encode.rknn filter=lfs diff=lfs merge=lfs -text
38
+ base_lm.rkllm filter=lfs diff=lfs merge=lfs -text
39
+ basic_ref_zh.wav filter=lfs diff=lfs merge=lfs -text
40
+ dit_step.rknn filter=lfs diff=lfs merge=lfs -text
41
+ fsq_layer.rknn filter=lfs diff=lfs merge=lfs -text
42
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
43
+ lm_to_dit_proj.rknn filter=lfs diff=lfs merge=lfs -text
44
+ locenc_1.rknn filter=lfs diff=lfs merge=lfs -text
45
+ locenc_64.rknn filter=lfs diff=lfs merge=lfs -text
46
+ model_structure.jpg filter=lfs diff=lfs merge=lfs -text
47
+ res_to_dit_proj.rknn filter=lfs diff=lfs merge=lfs -text
48
+ residual_lm.rkllm filter=lfs diff=lfs merge=lfs -text
49
+ rknn_output.wav filter=lfs diff=lfs merge=lfs -text
50
+ stop_head.rknn filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,349 @@
1
- ---
2
- license: agpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: agpl-3.0
3
+ language:
4
+ - en
5
+ - zh
6
+ base_model:
7
+ - openbmb/VoxCPM-0.5B
8
+ pipeline_tag: text-to-speech
9
+ tags:
10
+ - rknn
11
+ - rkllm
12
+ - text-to-speech
13
+ - speech
14
+ - speech generation
15
+ - voice cloning
16
+ ---
17
+
18
+ # VoxCPM-0.5B-RKNN2
19
+
20
+ ### (English README see below)
21
+
22
+ VoxCPM 是一种创新的无分词器文本转语音(TTS)系统,重新定义了语音合成的真实感。通过在连续空间中建模语音,它克服了离散标记化的局限,并实现了两项核心能力:上下文感知的语音生成和逼真的零样本语音克隆。
23
+ 不同于将语音转换为离散标记的主流方法,VoxCPM 采用端到端的扩散自回归架构,直接从文本生成连续的语音表示。它基于 MiniCPM-4 主干构建,通过分层语言建模和 FSQ 约束实现了隐式的语义-声学解耦,极大地提升了表现力和生成稳定性。
24
+
25
+ ![模型架构](model_structure.jpg)
26
+
27
+
28
+ - 推理速度(RKNN2):RK3588上RTF约8(生成10s音频需要推理80s)
29
+ - 大致内存占用(RKNN2):约3.3GB
30
+
31
+ ## 使用方法
32
+
33
+ 1. 克隆项目到本地
34
+
35
+ 2. 安装依赖
36
+
37
+ ```bash
38
+ pip install "numpy<2" scipy soundfile tqdm transformers sentencepiece rknn-toolkit-lite2
39
+ ```
40
+
41
+ 3. 运行
42
+
43
+ ```bash
44
+ python onnx_infer-rknn2.py --onnx-dir . --tokenizer-dir . --base-hf-dir . --residual-hf-dir . --text "哇, 这个模型居然在RK3588这个辣鸡SoC上也能完美运行!" --prompt-audio basic_ref_zh.wav --prompt-text "对,这就是我,万人敬仰的太乙真人。" --output rknn_output.wav --cfg-value 2.0 --inference-timesteps 10 --seed 1234
45
+ ```
46
+
47
+ 可选参数:
48
+ - `--text`: 要生成的文本
49
+ - `--prompt-audio`: 参考音频路径(用于语音克隆)
50
+ - `--prompt-text`: 参考音频对应的文本(使用参考音频时必填)
51
+ - `--cfg-value`: CFG引导强度,默认2.0
52
+ - `--inference-timesteps`: 扩散步数,默认10
53
+ - `--seed`: 随机种子
54
+ - `--output`: 输出音频路径
55
+
56
+ ## 运行效果
57
+
58
+
59
+ ```log
60
+ > python onnx_infer-rknn2.py --onnx-dir . --tokenizer-dir . --base-hf-dir . --residual-hf-dir . --text "哇, 这个模型居然在RK3588这个辣鸡SoC上也能完美运行!" --prompt-audio basic_ref_zh.wav --prompt-text "对,这就是我,万人敬仰的太乙真人。" --output rknn_output.wav --cfg-value 2.0 --inference-timesteps 10 --seed 1234
61
+
62
+ I rkllm: rkllm-runtime version: 1.2.3, rknpu driver version: 0.9.8, platform: RK3588
63
+ I rkllm: loading rkllm model from ./base_lm.rkllm
64
+ I rkllm: rkllm-toolkit version: 1.2.3, max_context_limit: 4096, npu_core_num: 1, target_platform: RK3588, model_dtype: FP16
65
+ I rkllm: Enabled cpus: [4, 5, 6, 7]
66
+ I rkllm: Enabled cpus num: 4
67
+ I rkllm: rkllm-runtime version: 1.2.3, rknpu driver version: 0.9.8, platform: RK3588
68
+ I rkllm: loading rkllm model from ./residual_lm.rkllm
69
+ I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: FP16
70
+ I rkllm: Enabled cpus: [4, 5, 6, 7]
71
+ I rkllm: Enabled cpus num: 4
72
+ W rknn-toolkit-lite2 version: 2.3.2
73
+ I RKNN: [18:58:26.264] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
74
+ I RKNN: [18:58:26.264] RKNN Driver Information, version: 0.9.8
75
+ I RKNN: [18:58:26.265] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
76
+ W RKNN: [18:58:26.404] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
77
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
78
+ W rknn-toolkit-lite2 version: 2.3.2
79
+ I RKNN: [18:58:26.537] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
80
+ I RKNN: [18:58:26.537] RKNN Driver Information, version: 0.9.8
81
+ I RKNN: [18:58:26.537] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
82
+ W RKNN: [18:58:26.616] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
83
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
84
+ W rknn-toolkit-lite2 version: 2.3.2
85
+ I RKNN: [18:58:26.795] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
86
+ I RKNN: [18:58:26.795] RKNN Driver Information, version: 0.9.8
87
+ I RKNN: [18:58:26.795] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
88
+ W RKNN: [18:58:27.020] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
89
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
90
+ W rknn-toolkit-lite2 version: 2.3.2
91
+ I RKNN: [18:58:27.194] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
92
+ I RKNN: [18:58:27.194] RKNN Driver Information, version: 0.9.8
93
+ I RKNN: [18:58:27.194] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
94
+ W RKNN: [18:58:27.317] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
95
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
96
+ W rknn-toolkit-lite2 version: 2.3.2
97
+ I RKNN: [18:58:27.431] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
98
+ I RKNN: [18:58:27.431] RKNN Driver Information, version: 0.9.8
99
+ I RKNN: [18:58:27.431] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (@2025-04-03T08:26:16)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: dynamic_shape
100
+ W rknn-toolkit-lite2 version: 2.3.2
101
+ I RKNN: [18:58:27.547] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
102
+ I RKNN: [18:58:27.547] RKNN Driver Information, version: 0.9.8
103
+ I RKNN: [18:58:27.547] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
104
+ W RKNN: [18:58:27.549] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
105
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
106
+ W rknn-toolkit-lite2 version: 2.3.2
107
+ I RKNN: [18:58:27.728] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
108
+ I RKNN: [18:58:27.728] RKNN Driver Information, version: 0.9.8
109
+ I RKNN: [18:58:27.728] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
110
+ W RKNN: [18:58:27.819] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
111
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
112
+ W rknn-toolkit-lite2 version: 2.3.2
113
+ I RKNN: [18:58:27.937] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
114
+ I RKNN: [18:58:27.937] RKNN Driver Information, version: 0.9.8
115
+ I RKNN: [18:58:27.937] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
116
+ W RKNN: [18:58:27.940] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
117
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
118
+ W rknn-toolkit-lite2 version: 2.3.2
119
+ I RKNN: [18:58:28.058] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
120
+ I RKNN: [18:58:28.058] RKNN Driver Information, version: 0.9.8
121
+ I RKNN: [18:58:28.058] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
122
+ W RKNN: [18:58:28.060] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
123
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
124
+ [time] vae_encode_0: 1601.56 ms
125
+ [time] vae_encode_38400: 1605.46 ms
126
+ [time] vae_encode_76800: 1591.07 ms
127
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
128
+ [time] locenc_0: 819.49 ms
129
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
130
+ [time] locenc_64: 818.33 ms
131
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
132
+ [time] locenc_128: 819.09 ms
133
+ [time] base_lm initial: 579.08 ms
134
+ [time] fsq_init_0: 2.54 ms
135
+ [time] fsq_init_64: 1.86 ms
136
+ [time] fsq_init_128: 1.79 ms
137
+ [time] residual_lm initial: 139.10 ms
138
+ gen_loop: 0%| | 0/2000 [00:00<?, ?it/s][time] lm_to_dit: 0.82 ms
139
+ [time] res_to_dit: 0.56 ms
140
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.32it/s]
141
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.4it/s]
142
+ [time] locenc_step: 16.32 ms
143
+ gen_loop: 0%| | 1/2000 [00:00<14:30, 2.30it/s][time] lm_to_dit: 0.57 ms
144
+ [time] res_to_dit: 0.44 ms
145
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.10it/s]
146
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.1it/s]
147
+ [time] locenc_step: 15.84 ms
148
+ gen_loop: 0%| | 2/2000 [00:00<14:27, 2.30it/s][time] lm_to_dit: 0.56 ms
149
+ [time] res_to_dit: 0.50 ms
150
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 31.93it/s]
151
+
152
+ ...
153
+
154
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.5it/s]
155
+ [time] locenc_step: 15.88 ms
156
+ gen_loop: 6%|███▉ | 123/2000 [00:53<13:35, 2.30it/s][time] lm_to_dit: 0.57 ms
157
+ [time] res_to_dit: 0.49 ms
158
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 32.94it/s]
159
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.6it/s]
160
+ [time] locenc_step: 15.84 ms
161
+ gen_loop: 6%|███▉ | 123/2000 [00:54<13:44, 2.28it/s]
162
+ [time] vae_decode_0: 1044.00 ms
163
+ [time] vae_decode_60: 1018.03 ms
164
+ [time] vae_decode_120: 1020.72 ms
165
+ [time] vae_decode_180: 1021.19 ms
166
+ [time] vae_decode_240: 1006.85 ms
167
+ Saved: rknn_output.wav
168
+ ```
169
+
170
+ ## 模型转换
171
+
172
+ #### 懒得写了,待补充
173
+
174
+ ## 已知问题
175
+
176
+ - 某些情况下语音生成可能陷入死循环,原项目似乎有检测死循环的机制,但我这里没有实现。
177
+ - 由于RKNN工具链的内部问题,locenc模型没有办法在一个模型里配置两种输入长度的两组shape,因此只能单独转换两个模型。
178
+ - 由于RKLLM工具链/运行时的内部问题,两个LLM的输出张量的数值都只有正确结果的四分之一,手动乘4之后可以得到正确结果。
179
+ - 由于RKNN工具链目前不支持非4维输入模型多batch使用多NPU核的数据并行推理,脚本中CFG是分两次单独进行的,速度较慢。
180
+
181
+
182
+ ## 参考
183
+ - [openbmb/VoxCPM-0.5B](https://huggingface.co/openbmb/VoxCPM-0.5B)
184
+ - [0seba/VoxCPMANE](https://github.com/0seba/VoxCPMANE)
185
+ - [bluryar/VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)
186
+
187
+ # English README
188
+
189
+ VoxCPM is an innovative tokenizer-free Text-to-Speech (TTS) system that redefines realism in speech synthesis. By modeling speech in continuous space, it overcomes the limitations of discrete tokenization and achieves two core capabilities: context-aware speech generation and realistic zero-shot voice cloning.
190
+
191
+ Unlike mainstream approaches that convert speech into discrete tokens, VoxCPM adopts an end-to-end diffusion autoregressive architecture that directly generates continuous speech representations from text. Built on the MiniCPM-4 backbone, it achieves implicit semantic-acoustic decoupling through hierarchical language modeling and FSQ constraints, greatly enhancing expressiveness and generation stability.
192
+
193
+ - Inference speed (RKNN2): RTF approximately 8 on RK3588 (80s inference time to generate 10s audio)
194
+ - Approximate memory usage (RKNN2): ~3.3GB
195
+
196
+ ## Usage
197
+
198
+ 1. Clone the project locally
199
+
200
+ 2. Install dependencies
201
+
202
+ ```bash
203
+ pip install "numpy<2" scipy soundfile tqdm transformers sentencepiece rknn-toolkit-lite2
204
+ ```
205
+
206
+ 3. Run
207
+
208
+ ```bash
209
+ python onnx_infer-rknn2.py --onnx-dir . --tokenizer-dir . --base-hf-dir . --residual-hf-dir . --text "Wow, this model actually runs perfectly on the RK3588 SoC!" --prompt-audio basic_ref_zh.wav --prompt-text "对,这就是我,万人敬仰的太乙真人。" --output rknn_output.wav --cfg-value 2.0 --inference-timesteps 10 --seed 1234
210
+ ```
211
+
212
+ Optional parameters:
213
+ - `--text`: Text to generate
214
+ - `--prompt-audio`: Reference audio path (for voice cloning)
215
+ - `--prompt-text`: Text corresponding to the reference audio (required when using reference audio)
216
+ - `--cfg-value`: CFG guidance strength, default 2.0
217
+ - `--inference-timesteps`: Number of diffusion steps, default 10
218
+ - `--seed`: Random seed
219
+ - `--output`: Output audio path
220
+
221
+ ## Performance
222
+
223
+
224
+ ```log
225
+ > python onnx_infer-rknn2.py --onnx-dir . --tokenizer-dir . --base-hf-dir . --residual-hf-dir . --text "哇, 这个模型居然在RK3588这个辣鸡SoC上也能完美运行!" --prompt-audio basic_ref_zh.wav --prompt-text "对,这就是我,万人敬仰的太乙真人。" --output rknn_output.wav --cfg-value 2.0 --inference-timesteps 10 --seed 1234
226
+
227
+ I rkllm: rkllm-runtime version: 1.2.3, rknpu driver version: 0.9.8, platform: RK3588
228
+ I rkllm: loading rkllm model from ./base_lm.rkllm
229
+ I rkllm: rkllm-toolkit version: 1.2.3, max_context_limit: 4096, npu_core_num: 1, target_platform: RK3588, model_dtype: FP16
230
+ I rkllm: Enabled cpus: [4, 5, 6, 7]
231
+ I rkllm: Enabled cpus num: 4
232
+ I rkllm: rkllm-runtime version: 1.2.3, rknpu driver version: 0.9.8, platform: RK3588
233
+ I rkllm: loading rkllm model from ./residual_lm.rkllm
234
+ I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: FP16
235
+ I rkllm: Enabled cpus: [4, 5, 6, 7]
236
+ I rkllm: Enabled cpus num: 4
237
+ W rknn-toolkit-lite2 version: 2.3.2
238
+ I RKNN: [18:58:26.264] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
239
+ I RKNN: [18:58:26.264] RKNN Driver Information, version: 0.9.8
240
+ I RKNN: [18:58:26.265] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
241
+ W RKNN: [18:58:26.404] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
242
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
243
+ W rknn-toolkit-lite2 version: 2.3.2
244
+ I RKNN: [18:58:26.537] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
245
+ I RKNN: [18:58:26.537] RKNN Driver Information, version: 0.9.8
246
+ I RKNN: [18:58:26.537] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
247
+ W RKNN: [18:58:26.616] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
248
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
249
+ W rknn-toolkit-lite2 version: 2.3.2
250
+ I RKNN: [18:58:26.795] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
251
+ I RKNN: [18:58:26.795] RKNN Driver Information, version: 0.9.8
252
+ I RKNN: [18:58:26.795] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
253
+ W RKNN: [18:58:27.020] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
254
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
255
+ W rknn-toolkit-lite2 version: 2.3.2
256
+ I RKNN: [18:58:27.194] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
257
+ I RKNN: [18:58:27.194] RKNN Driver Information, version: 0.9.8
258
+ I RKNN: [18:58:27.194] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
259
+ W RKNN: [18:58:27.317] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
260
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
261
+ W rknn-toolkit-lite2 version: 2.3.2
262
+ I RKNN: [18:58:27.431] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
263
+ I RKNN: [18:58:27.431] RKNN Driver Information, version: 0.9.8
264
+ I RKNN: [18:58:27.431] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (@2025-04-03T08:26:16)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: dynamic_shape
265
+ W rknn-toolkit-lite2 version: 2.3.2
266
+ I RKNN: [18:58:27.547] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
267
+ I RKNN: [18:58:27.547] RKNN Driver Information, version: 0.9.8
268
+ I RKNN: [18:58:27.547] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
269
+ W RKNN: [18:58:27.549] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
270
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
271
+ W rknn-toolkit-lite2 version: 2.3.2
272
+ I RKNN: [18:58:27.728] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
273
+ I RKNN: [18:58:27.728] RKNN Driver Information, version: 0.9.8
274
+ I RKNN: [18:58:27.728] RKNN Model Information, version: 6, toolkit version: 2.3.2(compiler version: 2.3.2 (e045de294f@2025-04-07T19:48:25)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
275
+ W RKNN: [18:58:27.819] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
276
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
277
+ W rknn-toolkit-lite2 version: 2.3.2
278
+ I RKNN: [18:58:27.937] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
279
+ I RKNN: [18:58:27.937] RKNN Driver Information, version: 0.9.8
280
+ I RKNN: [18:58:27.937] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
281
+ W RKNN: [18:58:27.940] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
282
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
283
+ W rknn-toolkit-lite2 version: 2.3.2
284
+ I RKNN: [18:58:28.058] RKNN Runtime Information, librknnrt version: 2.3.2 (429f97ae6b@2025-04-09T09:09:27)
285
+ I RKNN: [18:58:28.058] RKNN Driver Information, version: 0.9.8
286
+ I RKNN: [18:58:28.058] RKNN Model Information, version: 6, toolkit version: 2.3.0(compiler version: 2.3.0 (@2024-11-07T08:11:34)), target: RKNPU v2, target platform: rk3588, framework name: ONNX, framework layout: NCHW, model inference type: static_shape
287
+ W RKNN: [18:58:28.060] query RKNN_QUERY_INPUT_DYNAMIC_RANGE error, rknn model is static shape type, please export rknn with dynamic_shapes
288
+ W Query dynamic range failed. Ret code: RKNN_ERR_MODEL_INVALID. (If it is a static shape RKNN model, please ignore the above warning message.)
289
+ [time] vae_encode_0: 1601.56 ms
290
+ [time] vae_encode_38400: 1605.46 ms
291
+ [time] vae_encode_76800: 1591.07 ms
292
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
293
+ [time] locenc_0: 819.49 ms
294
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
295
+ [time] locenc_64: 818.33 ms
296
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.
297
+ [time] locenc_128: 819.09 ms
298
+ [time] base_lm initial: 579.08 ms
299
+ [time] fsq_init_0: 2.54 ms
300
+ [time] fsq_init_64: 1.86 ms
301
+ [time] fsq_init_128: 1.79 ms
302
+ [time] residual_lm initial: 139.10 ms
303
+ gen_loop: 0%| | 0/2000 [00:00<?, ?it/s][time] lm_to_dit: 0.82 ms
304
+ [time] res_to_dit: 0.56 ms
305
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.32it/s]
306
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.4it/s]
307
+ [time] locenc_step: 16.32 ms
308
+ gen_loop: 0%| | 1/2000 [00:00<14:30, 2.30it/s][time] lm_to_dit: 0.57 ms
309
+ [time] res_to_dit: 0.44 ms
310
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 33.10it/s]
311
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.1it/s]
312
+ [time] locenc_step: 15.84 ms
313
+ gen_loop: 0%| | 2/2000 [00:00<14:27, 2.30it/s][time] lm_to_dit: 0.56 ms
314
+ [time] res_to_dit: 0.50 ms
315
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 31.93it/s]
316
+
317
+ ...
318
+
319
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.5it/s]
320
+ [time] locenc_step: 15.88 ms
321
+ gen_loop: 6%|███▉ | 123/2000 [00:53<13:35, 2.30it/s][time] lm_to_dit: 0.57 ms
322
+ [time] res_to_dit: 0.49 ms
323
+ 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 32.94it/s]
324
+ W The input[0] need NHWC data format, but NCHW set, the data format and data buffer will be changed to NHWC.6it/s]
325
+ [time] locenc_step: 15.84 ms
326
+ gen_loop: 6%|███▉ | 123/2000 [00:54<13:44, 2.28it/s]
327
+ [time] vae_decode_0: 1044.00 ms
328
+ [time] vae_decode_60: 1018.03 ms
329
+ [time] vae_decode_120: 1020.72 ms
330
+ [time] vae_decode_180: 1021.19 ms
331
+ [time] vae_decode_240: 1006.85 ms
332
+ Saved: rknn_output.wav
333
+ ```
334
+
335
+ ## Model Conversion
336
+
337
+ #### TODO: Documentation to be added
338
+
339
+ ## Known Issues
340
+
341
+ - In some cases, speech generation may fall into an infinite loop. The original project seems to have a mechanism to detect infinite loops, but it is not implemented here.
342
+ - Due to internal issues with the RKNN toolchain, the locenc model cannot configure two sets of shapes for two different input lengths in a single model, so two separate models must be converted.
343
+ - Due to internal issues with the RKLLM toolchain/runtime, the output tensor values of both LLMs are only one-quarter of the correct result. Multiplying by 4 manually yields the correct result.
344
+ - Since the RKNN toolchain currently does not support data-parallel inference using multiple NPU cores for non-4D input models with multiple batches, CFG in the script is performed separately in two passes, which is relatively slow.
345
+
346
+ ## References
347
+ - [openbmb/VoxCPM-0.5B](https://huggingface.co/openbmb/VoxCPM-0.5B)
348
+ - [0seba/VoxCPMANE](https://github.com/0seba/VoxCPMANE)
349
+ - [bluryar/VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)
audio_vae_decode.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ec4181f54e0ec9e3de9d4380beb79855f3fdb45b1562bcb89d403e2f5f130f5
3
+ size 58597553
audio_vae_encode.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14f0cadac33b0bd5c722edd09df758eb3f4204738af112c584d4f4292ac781f5
3
+ size 102299829
base_lm.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d75d343623ac22402175e1aaa935c243338524a19e9cb1dae4955e643df15964
3
+ size 1028092444
basic_ref_zh.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96724a113240d1f82c6ded1334122f0176b96c9226ccd3c919e625bcfd2a3ede
3
+ size 324558
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "voxcpm",
3
+ "lm_config": {
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "hidden_size": 1024,
7
+ "intermediate_size": 4096,
8
+ "max_position_embeddings": 32768,
9
+ "num_attention_heads": 16,
10
+ "num_hidden_layers": 24,
11
+ "num_key_value_heads": 2,
12
+ "rms_norm_eps": 1e-05,
13
+ "rope_theta": 10000,
14
+ "rope_scaling": {
15
+ "type": "longrope",
16
+ "long_factor": [1.0004360675811768, 1.0668443441390991, 1.1631425619125366, 1.3025742769241333, 1.5040205717086792, 1.7941505908966064, 2.2101221084594727, 2.802666664123535, 3.6389970779418945, 4.804192543029785, 6.39855432510376, 8.527148246765137, 11.277542114257812, 14.684998512268066, 18.69317054748535, 23.13019371032715, 27.72362518310547, 32.1606559753418, 36.168827056884766, 39.57627868652344, 42.32667541503906, 44.45526885986328, 46.04962921142578, 47.21482849121094, 48.05115509033203, 48.64370346069336, 49.05967712402344, 49.34980392456055, 49.551246643066406, 49.69068145751953, 49.78697967529297, 49.85338592529297],
17
+ "short_factor": [1.0004360675811768, 1.0668443441390991, 1.1631425619125366, 1.3025742769241333, 1.5040205717086792, 1.7941505908966064, 2.2101221084594727, 2.802666664123535, 3.6389970779418945, 4.804192543029785, 6.39855432510376, 8.527148246765137, 11.277542114257812, 14.684998512268066, 18.69317054748535, 23.13019371032715, 27.72362518310547, 32.1606559753418, 36.168827056884766, 39.57627868652344, 42.32667541503906, 44.45526885986328, 46.04962921142578, 47.21482849121094, 48.05115509033203, 48.64370346069336, 49.05967712402344, 49.34980392456055, 49.551246643066406, 49.69068145751953, 49.78697967529297, 49.85338592529297],
18
+ "original_max_position_embeddings": 32768
19
+ },
20
+ "vocab_size": 73448,
21
+ "scale_emb": 12,
22
+ "dim_model_base": 256,
23
+ "scale_depth": 1.4,
24
+ "use_mup": false
25
+ },
26
+ "patch_size": 2,
27
+ "feat_dim": 64,
28
+ "scalar_quantization_latent_dim": 256,
29
+ "scalar_quantization_scale": 9,
30
+ "residual_lm_num_layers": 6,
31
+ "encoder_config": {
32
+ "hidden_dim": 1024,
33
+ "ffn_dim": 4096,
34
+ "num_heads": 16,
35
+ "num_layers": 4
36
+ },
37
+ "dit_config": {
38
+ "hidden_dim": 1024,
39
+ "ffn_dim": 4096,
40
+ "num_heads": 16,
41
+ "num_layers": 4,
42
+ "cfm_config": {
43
+ "sigma_min": 1e-06,
44
+ "solver": "euler",
45
+ "t_scheduler": "log-norm",
46
+ "inference_cfg_rate": 2.0
47
+ }
48
+ },
49
+ "max_length": 4096,
50
+ "device": "cuda",
51
+ "dtype": "bfloat16"
52
+ }
convert_audio_vae_decode.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: audio_vae_decode
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "audio_vae_decode.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "audio_vae_decode.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['latent'],
32
+ input_size_list=[[1, 64, 64]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_audio_vae_decode/latent.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_audio_vae_encode.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: audio_vae_encode
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "audio_vae_encode.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "audio_vae_encode.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['audio_wave'],
32
+ input_size_list=[[1, 1, 40960]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_audio_vae_encode/audio_wave.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_dit_step.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: dit_step
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "dit_step.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "dit_step.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['x', 'mu', 't', 'cond', 'dt'],
32
+ input_size_list=[[1, 64, 2], [1, 1024], [1], [1, 64, 2], [1]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False,)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_dit_step/x.npy", "dumps_dit_step/mu.npy", "dumps_dit_step/t.npy", "dumps_dit_step/cond.npy", "dumps_dit_step/dt.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_fsq_layer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: fsq_layer
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "fsq_layer.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "fsq_layer.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=[[[1, 64, 1024]], [[1, 1, 1024]]])
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=None,
32
+ input_size_list=None)
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_fsq_layer/hidden.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_lm_to_dit_proj.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: lm_to_dit_proj
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "lm_to_dit_proj.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "lm_to_dit_proj.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['input'],
32
+ input_size_list=[[1, 1024]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_lm_to_dit_proj/input.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_locenc.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: locenc
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "locenc.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "locenc_64.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['x'],
32
+ input_size_list=[[1, 64, 2, 64]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_locenc/x.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ def main2():
59
+ # 创建RKNN实例
60
+ rknn = RKNN(verbose=True)
61
+
62
+ # ONNX模型路径
63
+ ONNX_MODEL = "locenc.onnx"
64
+ # 输出RKNN模型路径
65
+ RKNN_MODEL = "locenc_1.rknn"
66
+
67
+ # 配置参数
68
+ print("--> Config model")
69
+ ret = rknn.config(target_platform="rk3588",
70
+ dynamic_input=None)
71
+ if ret != 0:
72
+ print('Config model failed!')
73
+ exit(ret)
74
+
75
+ # 加载ONNX模型
76
+ print("--> Loading model")
77
+ ret = rknn.load_onnx(model=ONNX_MODEL,
78
+ inputs=['x'],
79
+ input_size_list=[[1, 1, 2, 64]])
80
+ if ret != 0:
81
+ print('Load model failed!')
82
+ exit(ret)
83
+
84
+ # 构建模型
85
+ print("--> Building model")
86
+ ret = rknn.build(do_quantization=False)
87
+ if ret != 0:
88
+ print('Build model failed!')
89
+ exit(ret)
90
+
91
+ # 导出RKNN模型
92
+ print("--> Export RKNN model")
93
+ ret = rknn.export_rknn(RKNN_MODEL)
94
+ if ret != 0:
95
+ print('Export RKNN model failed!')
96
+ exit(ret)
97
+
98
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
99
+
100
+ # 精度分析(可选)
101
+ # rknn.accuracy_analysis(inputs=["dumps_locenc/x.npy"], target="rk3588", device_id=None)
102
+
103
+ rknn.release()
104
+
105
+ if __name__ == '__main__':
106
+ main()
107
+ main2()
convert_res_to_dit_proj.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: res_to_dit_proj
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "res_to_dit_proj.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "res_to_dit_proj.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['input'],
32
+ input_size_list=[[1, 1024]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_res_to_dit_proj/input.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_stop_head.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: stop_head
3
+
4
+ import faulthandler
5
+ faulthandler.enable()
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+ from rknn.api import RKNN
9
+ import numpy as np
10
+
11
+ def main():
12
+ # 创建RKNN实例
13
+ rknn = RKNN(verbose=True)
14
+
15
+ # ONNX模型路径
16
+ ONNX_MODEL = "stop_head.onnx"
17
+ # 输出RKNN模型路径
18
+ RKNN_MODEL = "stop_head.rknn"
19
+
20
+ # 配置参数
21
+ print("--> Config model")
22
+ ret = rknn.config(target_platform="rk3588",
23
+ dynamic_input=None)
24
+ if ret != 0:
25
+ print('Config model failed!')
26
+ exit(ret)
27
+
28
+ # 加载ONNX模型
29
+ print("--> Loading model")
30
+ ret = rknn.load_onnx(model=ONNX_MODEL,
31
+ inputs=['hidden'],
32
+ input_size_list=[[1, 1024]])
33
+ if ret != 0:
34
+ print('Load model failed!')
35
+ exit(ret)
36
+
37
+ # 构建模型
38
+ print("--> Building model")
39
+ ret = rknn.build(do_quantization=False)
40
+ if ret != 0:
41
+ print('Build model failed!')
42
+ exit(ret)
43
+
44
+ # 导出RKNN模型
45
+ print("--> Export RKNN model")
46
+ ret = rknn.export_rknn(RKNN_MODEL)
47
+ if ret != 0:
48
+ print('Export RKNN model failed!')
49
+ exit(ret)
50
+
51
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
52
+
53
+ # 精度分析(可选)
54
+ # rknn.accuracy_analysis(inputs=["dumps_stop_head/hidden.npy"], target="rk3588", device_id=None)
55
+
56
+ rknn.release()
57
+
58
+ if __name__ == '__main__':
59
+ main()
convert_vox_minicpm_to_hf.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ import torch
6
+ import math
7
+
8
+
9
+ def load_vox_configs(vox_config_path: str) -> tuple[dict, dict]:
10
+ """Return (base_lm_cfg, residual_cfg)."""
11
+ with open(vox_config_path, "r") as f:
12
+ data = json.load(f)
13
+
14
+ base = data["lm_config"]
15
+ rope = base.get("rope_scaling")
16
+ if rope:
17
+ rope = dict(rope)
18
+ # Vox config uses "type", transformers expects "rope_type"
19
+ if "type" in rope and "rope_type" not in rope:
20
+ rope["rope_type"] = rope.pop("type")
21
+ base["rope_scaling"] = rope
22
+
23
+ residual = dict(base)
24
+ residual["num_hidden_layers"] = data.get("residual_lm_num_layers", residual["num_hidden_layers"])
25
+ # keep vocab_size for easier loading; Vox sets 0 because inputs_embeds are provided
26
+ residual.setdefault("vocab_size", base.get("vocab_size"))
27
+
28
+ # Align transformers residual scaling with Vox (no scaling when use_mup=False)
29
+ if not base.get("use_mup", True):
30
+ base["scale_depth"] = math.sqrt(base["num_hidden_layers"])
31
+ residual["scale_depth"] = math.sqrt(residual["num_hidden_layers"])
32
+ return base, residual
33
+
34
+
35
+ def build_hf_config(lm_cfg: dict, minicpm_dir: str):
36
+ sys.path.insert(0, minicpm_dir)
37
+ from configuration_minicpm import MiniCPMConfig
38
+
39
+ return MiniCPMConfig(**lm_cfg)
40
+
41
+
42
+ def convert_state_dict(vox_state_path: str, lm_prefix: str) -> dict:
43
+ raw = torch.load(vox_state_path, map_location="cpu")
44
+ sd = raw["state_dict"] if isinstance(raw, dict) and "state_dict" in raw else raw
45
+
46
+ out = {}
47
+ prefix = f"{lm_prefix}."
48
+ for k, v in sd.items():
49
+ if not k.startswith(prefix):
50
+ continue
51
+ new_k = "model." + k[len(prefix) :]
52
+ out[new_k] = v
53
+
54
+ # Tie lm_head to embeddings for MiniCPMForCausalLM
55
+ if "model.embed_tokens.weight" in out:
56
+ out["lm_head.weight"] = out["model.embed_tokens.weight"]
57
+ return out
58
+
59
+
60
+ def main():
61
+ parser = argparse.ArgumentParser(description="Convert VoxCPM MiniCPM weights to transformers format")
62
+ parser.add_argument(
63
+ "--vox-config",
64
+ default="VoxCPM-0.5B/config.json",
65
+ help="Path to VoxCPM config.json (used to read lm_config)",
66
+ )
67
+ parser.add_argument(
68
+ "--vox-state",
69
+ default="VoxCPM-0.5B/pytorch_model.bin",
70
+ help="Path to VoxCPM checkpoint containing base_lm weights",
71
+ )
72
+ parser.add_argument(
73
+ "--minicpm-dir",
74
+ default="MiniCPM4-0.5B",
75
+ help="Path to local MiniCPM4-0.5B directory (provides configuration_minicpm.py)",
76
+ )
77
+ parser.add_argument(
78
+ "--out-dir",
79
+ default="converted-minicpm-hf",
80
+ help="Output directory for base LM transformers-style checkpoint",
81
+ )
82
+ parser.add_argument(
83
+ "--out-residual-dir",
84
+ default="converted-minicpm-residual-hf",
85
+ help="Output directory for residual LM checkpoint",
86
+ )
87
+ args = parser.parse_args()
88
+
89
+ os.makedirs(args.out_dir, exist_ok=True)
90
+ os.makedirs(args.out_residual_dir, exist_ok=True)
91
+
92
+ base_cfg, residual_cfg = load_vox_configs(args.vox_config)
93
+
94
+ hf_config = build_hf_config(base_cfg, args.minicpm_dir)
95
+ hf_config.save_pretrained(args.out_dir)
96
+
97
+ print("Loaded Vox lm_config and wrote transformers config to", args.out_dir)
98
+
99
+ hf_state = convert_state_dict(args.vox_state, lm_prefix="base_lm")
100
+ out_path = os.path.join(args.out_dir, "pytorch_model.bin")
101
+ torch.save(hf_state, out_path)
102
+ print("Saved base LM weights to", out_path)
103
+
104
+ residual_hf_config = build_hf_config(residual_cfg, args.minicpm_dir)
105
+ residual_hf_config.save_pretrained(args.out_residual_dir)
106
+ residual_state = convert_state_dict(args.vox_state, lm_prefix="residual_lm")
107
+ residual_out_path = os.path.join(args.out_residual_dir, "pytorch_model.bin")
108
+ torch.save(residual_state, residual_out_path)
109
+ print("Saved residual LM weights to", residual_out_path)
110
+
111
+ print("Load with MiniCPMForCausalLM.from_pretrained(...) or MiniCPMModel.from_pretrained(...).")
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
dit_step.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b7881815a088cfa90b06c804d4fd8a997930a00903facd84836592b01cef4ad
3
+ size 137852823
embed_tokens.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a13e0da6a95e7eb2562468f65e8d64dc412f628057cd9387aaa68cbd4a4cc8f
3
+ size 300843136
export_onnx.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from torch import nn
5
+
6
+ from voxcpm.model.voxcpm import VoxCPMModel
7
+
8
+
9
+ def remove_weight_norm(module: nn.Module):
10
+ """Strip weight_norm wrappers for cleaner ONNX graphs."""
11
+ for name, child in module.named_children():
12
+ remove_weight_norm(child)
13
+ if isinstance(child, (nn.Conv1d, nn.ConvTranspose1d)):
14
+ try:
15
+ torch.nn.utils.remove_weight_norm(child)
16
+ except ValueError:
17
+ # not wrapped, skip
18
+ pass
19
+
20
+
21
+ class VAEEncodeWrapper(nn.Module):
22
+ def __init__(self, audio_vae: nn.Module):
23
+ super().__init__()
24
+ self.audio_vae = audio_vae
25
+
26
+ def forward(self, audio_wave: torch.Tensor):
27
+ return self.audio_vae.encode(audio_wave, self.audio_vae.sample_rate)
28
+
29
+
30
+ class VAEDecodeWrapper(nn.Module):
31
+ def __init__(self, audio_vae: nn.Module):
32
+ super().__init__()
33
+ self.audio_vae = audio_vae
34
+
35
+ def forward(self, latent: torch.Tensor):
36
+ return self.audio_vae.decode(latent)
37
+
38
+
39
+ class LocEncWrapper(nn.Module):
40
+ def __init__(self, locenc: nn.Module):
41
+ super().__init__()
42
+ self.locenc = locenc
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ # x: [B, T, P, D]
46
+ return self.locenc(x)
47
+
48
+
49
+ class LocEncLmWrapper(nn.Module):
50
+ """LocEnc with enc_to_lm projection fused in a single graph."""
51
+
52
+ def __init__(self, locenc: nn.Module, proj: nn.Module):
53
+ super().__init__()
54
+ self.locenc = locenc
55
+ self.proj = proj
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ # x: [B, T, P, D]
59
+ hidden = self.locenc(x)
60
+ return self.proj(hidden)
61
+
62
+
63
+ class FSQWrapper(nn.Module):
64
+ def __init__(self, fsq: nn.Module):
65
+ super().__init__()
66
+ self.fsq = fsq
67
+
68
+ def forward(self, hidden: torch.Tensor):
69
+ return self.fsq(hidden)
70
+
71
+
72
+ class StopHeadWrapper(nn.Module):
73
+ def __init__(self, stop_proj: nn.Linear, stop_actn: nn.Module, stop_head: nn.Linear):
74
+ super().__init__()
75
+ self.stop_proj = stop_proj
76
+ self.stop_actn = stop_actn
77
+ self.stop_head = stop_head
78
+
79
+ def forward(self, hidden: torch.Tensor):
80
+ hidden = self.stop_proj(hidden)
81
+ hidden = self.stop_actn(hidden)
82
+ return self.stop_head(hidden)
83
+
84
+
85
+ class CFMWrapper(nn.Module):
86
+ """
87
+ Wrapper for one diffusion step block.
88
+
89
+ Note: the number of diffusion steps (n_timesteps) is fixed at export time.
90
+ """
91
+
92
+ def __init__(self, cfm: nn.Module, patch_size: int, n_timesteps: int, cfg_value: float):
93
+ super().__init__()
94
+ self.cfm = cfm
95
+ self.patch_size = patch_size
96
+ self.n_timesteps = n_timesteps
97
+ self.cfg_value = cfg_value
98
+
99
+ def forward(self, mu: torch.Tensor, cond: torch.Tensor):
100
+ # mu: [B, H_dit], cond: [B, D_feat, P]
101
+ return self.cfm(
102
+ mu=mu,
103
+ n_timesteps=self.n_timesteps,
104
+ patch_size=self.patch_size,
105
+ cond=cond,
106
+ cfg_value=self.cfg_value,
107
+ )
108
+
109
+
110
+ class DiTStepWrapper(nn.Module):
111
+ """
112
+ Wrapper for a single VoxCPMLocDiT forward (one diffusion score estimation step).
113
+ Inputs match VoxCPMLocDiT.forward: x, mu, t, cond, dt.
114
+ """
115
+
116
+ def __init__(self, dit: nn.Module):
117
+ super().__init__()
118
+ self.dit = dit
119
+
120
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, dt: torch.Tensor):
121
+ return self.dit(x, mu, t, cond, dt)
122
+
123
+
124
+ def export(model: nn.Module, inputs, path: str, dynamic_axes: dict, opset: int):
125
+ os.makedirs(os.path.dirname(path), exist_ok=True)
126
+ torch.onnx.export(
127
+ model,
128
+ inputs,
129
+ path,
130
+ opset_version=opset,
131
+ do_constant_folding=True,
132
+ input_names=list(dynamic_axes.keys()),
133
+ output_names=["output"],
134
+ dynamic_axes=dynamic_axes,
135
+ )
136
+ print(f"Saved: {path}")
137
+
138
+
139
+ def main():
140
+ parser = argparse.ArgumentParser(description="Export VoxCPM submodules to ONNX (LLM excluded).")
141
+ parser.add_argument("--model-dir", required=True, help="Path to VoxCPM model directory (config/weights).")
142
+ parser.add_argument("--out-dir", default="onnx_exports", help="Output directory for ONNX files.")
143
+ parser.add_argument("--opset", type=int, default=18, help="ONNX opset version.")
144
+ parser.add_argument("--audio-samples", type=int, default=1280, help="Dummy audio length for encoder export.")
145
+ parser.add_argument("--latent-steps", type=int, default=6, help="Dummy latent steps for decoder export.")
146
+ parser.add_argument("--seq-len", type=int, default=4, help="Dummy sequence length for LocEnc/FSQ export.")
147
+ parser.add_argument("--dit-step-t", type=float, default=0.5, help="Dummy diffusion time for DiT step export.")
148
+ parser.add_argument("--force-fp32", action="store_true", help="Force submodules to float32 for ONNX export.")
149
+ parser.add_argument("--dump-embeddings", action="store_true", help="Dump base_lm.embed_tokens weights to npy.")
150
+ args = parser.parse_args()
151
+
152
+ device = torch.device("cpu")
153
+ # Load full model once, then peel submodules; keep optimize disabled.
154
+ full_model = VoxCPMModel.from_local(args.model_dir, optimize=False).to(device).eval()
155
+ if args.force_fp32 or full_model.config.dtype != "float32":
156
+ full_model.config.dtype = "float32"
157
+ full_model = full_model.to(torch.float32)
158
+ full_model.audio_vae = full_model.audio_vae.to(torch.float32)
159
+ remove_weight_norm(full_model)
160
+
161
+ # Audio VAE encode
162
+ vae_enc = VAEEncodeWrapper(full_model.audio_vae).to(device).eval()
163
+ dummy_audio = torch.randn(1, 1, args.audio_samples, device=device)
164
+ export(
165
+ vae_enc,
166
+ dummy_audio,
167
+ os.path.join(args.out_dir, "audio_vae_encode.onnx"),
168
+ dynamic_axes={"audio_wave": {0: "batch", 2: "samples"}},
169
+ opset=args.opset,
170
+ )
171
+
172
+ # Audio VAE decode
173
+ vae_dec = VAEDecodeWrapper(full_model.audio_vae).to(device).eval()
174
+ dummy_latent = torch.randn(1, full_model.audio_vae.latent_dim, args.latent_steps, device=device)
175
+ export(
176
+ vae_dec,
177
+ dummy_latent,
178
+ os.path.join(args.out_dir, "audio_vae_decode.onnx"),
179
+ dynamic_axes={"latent": {0: "batch", 2: "latent_steps"}},
180
+ opset=args.opset,
181
+ )
182
+
183
+ # LocEnc with enc_to_lm projection fused
184
+ locenc = LocEncLmWrapper(full_model.feat_encoder, full_model.enc_to_lm_proj).to(device).eval()
185
+ dummy_seq = torch.randn(1, args.seq_len, full_model.patch_size, full_model.feat_dim, device=device)
186
+ export(
187
+ locenc,
188
+ dummy_seq,
189
+ os.path.join(args.out_dir, "locenc.onnx"),
190
+ dynamic_axes={"x": {0: "batch", 1: "seq_len"}},
191
+ opset=args.opset,
192
+ )
193
+
194
+ # FSQ layer
195
+ fsq = FSQWrapper(full_model.fsq_layer).to(device).eval()
196
+ hidden_size = full_model.config.lm_config.hidden_size
197
+ dummy_hidden = torch.randn(1, args.seq_len, hidden_size, device=device)
198
+ export(
199
+ fsq,
200
+ dummy_hidden,
201
+ os.path.join(args.out_dir, "fsq_layer.onnx"),
202
+ dynamic_axes={"hidden": {0: "batch", 1: "seq_len"}},
203
+ opset=args.opset,
204
+ )
205
+
206
+ # Stop head
207
+ stop = StopHeadWrapper(full_model.stop_proj, full_model.stop_actn, full_model.stop_head).to(device).eval()
208
+ dummy_stop_inp = torch.randn(1, hidden_size, device=device)
209
+ export(
210
+ stop,
211
+ dummy_stop_inp,
212
+ os.path.join(args.out_dir, "stop_head.onnx"),
213
+ dynamic_axes={"hidden": {0: "batch"}},
214
+ opset=args.opset,
215
+ )
216
+
217
+ # Projection layers
218
+ # export(
219
+ # full_model.enc_to_lm_proj,
220
+ # dummy_hidden,
221
+ # os.path.join(args.out_dir, "enc_to_lm_proj.onnx"),
222
+ # dynamic_axes={"input": {0: "batch", 1: "seq_len"}},
223
+ # opset=args.opset,
224
+ # )
225
+ lm_hidden = torch.randn(1, full_model.config.lm_config.hidden_size, device=device)
226
+ export(
227
+ full_model.lm_to_dit_proj,
228
+ lm_hidden,
229
+ os.path.join(args.out_dir, "lm_to_dit_proj.onnx"),
230
+ dynamic_axes={"input": {0: "batch"}},
231
+ opset=args.opset,
232
+ )
233
+ export(
234
+ full_model.res_to_dit_proj,
235
+ lm_hidden,
236
+ os.path.join(args.out_dir, "res_to_dit_proj.onnx"),
237
+ dynamic_axes={"input": {0: "batch"}},
238
+ opset=args.opset,
239
+ )
240
+
241
+ # VoxCPMLocDiT single step (score function)
242
+ dit_step = DiTStepWrapper(full_model.feat_decoder.estimator).to(device).eval()
243
+ dummy_x = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
244
+ dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device)
245
+ dummy_t = torch.full((1,), args.dit_step_t, device=device)
246
+ dummy_dt = torch.full((1,), 0.0, device=device)
247
+ dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
248
+ export(
249
+ dit_step,
250
+ (dummy_x, dummy_mu, dummy_t, dummy_cond, dummy_dt),
251
+ os.path.join(args.out_dir, "dit_step.onnx"),
252
+ dynamic_axes={
253
+ "x": {0: "batch"},
254
+ "mu": {0: "batch"},
255
+ "t": {0: "batch"},
256
+ "cond": {0: "batch"},
257
+ "dt": {0: "batch"},
258
+ },
259
+ opset=args.opset,
260
+ )
261
+
262
+ # # UnifiedCFM + VoxCPMLocDiT (single-step sampler unrolled with fixed n_timesteps)
263
+ # cfm = CFMWrapper(
264
+ # full_model.feat_decoder,
265
+ # patch_size=full_model.patch_size,
266
+ # n_timesteps=args.cfm_steps,
267
+ # cfg_value=args.cfg_value,
268
+ # ).to(device).eval()
269
+ # dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device)
270
+ # dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
271
+ # export(
272
+ # cfm,
273
+ # (dummy_mu, dummy_cond),
274
+ # os.path.join(args.out_dir, "cfm_step.onnx"),
275
+ # dynamic_axes={"mu": {0: "batch"}, "cond": {0: "batch"}},
276
+ # opset=args.opset,
277
+ # )
278
+
279
+ if args.dump_embeddings and hasattr(full_model.base_lm, "embed_tokens"):
280
+ import numpy as np
281
+ emb = full_model.base_lm.embed_tokens.weight.detach().cpu().numpy()
282
+ os.makedirs(args.out_dir, exist_ok=True)
283
+ np.save(os.path.join(args.out_dir, "embed_tokens.npy"), emb)
284
+ print(f"Saved: {os.path.join(args.out_dir, 'embed_tokens.npy')}")
285
+
286
+ print("Done.")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
fsq_layer.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfcb8153c4bdc1e9ba992302cf4339d47bdd258a620ef4048d2c002bec8b2867
3
+ size 1282675
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbcf28a8666b9fbf7361d6aad892b957920f6ea92400c074899b48f4c5b2c96f
3
+ size 7543744
lm_to_dit_proj.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0d5c01b463bd6ca197b4cb6cdd890d5ba0438ea8dda7cd4e88723f5c68d5ef5
3
+ size 2139756
locenc_1.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:482c78a7b9e31bab61886fde2d0463c6314e4fbfb26356316f7c6b2223716a29
3
+ size 131148908
locenc_64.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:780d8e0db5a009595930585c3f2c7999900e7ba941d6ee29156dc256ed1ad909
3
+ size 179347310
model_structure.jpg ADDED

Git LFS Details

  • SHA256: 5a2f12f2d322035ecfb45db948e4c2a767dd4d25cc31c932810253874f7d3f98
  • Pointer size: 131 Bytes
  • Size of remote file: 219 kB
onnx_infer-rknn2.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import math
4
+ import random
5
+ import numpy as np
6
+ import time
7
+ from tqdm import tqdm
8
+ import soundfile as sf
9
+ from scipy import signal
10
+
11
+ from rkllm_binding import *
12
+
13
+ from transformers import AutoTokenizer
14
+
15
+ import ztu_somemodelruntime_rknnlite2 as ort
16
+
17
+ def mask_multichar_chinese_tokens(tokenizer):
18
+ # Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
19
+ multichar_tokens = {
20
+ token for token in tokenizer.vocab.keys()
21
+ if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
22
+ }
23
+
24
+ class CharTokenizerWrapper:
25
+ """Wrapper class for tokenizers that handles multi-character Chinese tokens.
26
+
27
+ This wrapper automatically splits multi-character Chinese tokens into
28
+ individual characters while preserving the original tokenizer's interface.
29
+ """
30
+
31
+ def __init__(self, base_tokenizer) -> None:
32
+ """Initialize the wrapper with a base tokenizer.
33
+
34
+ Args:
35
+ base_tokenizer: The tokenizer to wrap
36
+ """
37
+ self.tokenizer = base_tokenizer
38
+ self.multichar_tokens = multichar_tokens
39
+
40
+ def tokenize(self, text: str, **kwargs):
41
+ """Tokenize text and split multi-character Chinese tokens into single characters.
42
+
43
+ Args:
44
+ text: Input text to tokenize
45
+ **kwargs: Additional arguments passed to the base tokenizer
46
+
47
+ Returns:
48
+ List of processed tokens with multi-character Chinese tokens split
49
+
50
+ Example:
51
+ >>> wrapper = CharTokenizerWrapper(tokenizer)
52
+ >>> tokens = wrapper.tokenize("你好世界")
53
+ >>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
54
+ """
55
+ if not isinstance(text, str):
56
+ raise TypeError(f"Expected string input, got {type(text)}")
57
+
58
+ tokens = self.tokenizer.tokenize(text, **kwargs)
59
+ processed = []
60
+
61
+ for token in tokens:
62
+ # Remove possible subword prefix
63
+ clean_token = token.replace("▁", "")
64
+
65
+ if clean_token in self.multichar_tokens:
66
+ # Split multi-character token into single characters
67
+ chars = list(clean_token)
68
+ processed.extend(chars)
69
+ else:
70
+ processed.append(token)
71
+
72
+ return processed
73
+
74
+ def __call__(self, text: str, **kwargs):
75
+ """Call the tokenizer and return token IDs.
76
+
77
+ This method provides the same interface as the original tokenizer
78
+ but with multi-character Chinese token handling.
79
+
80
+ Args:
81
+ text: Input text to tokenize
82
+ **kwargs: Additional arguments passed to the base tokenizer
83
+
84
+ Returns:
85
+ List of token IDs
86
+
87
+ Raises:
88
+ TypeError: If input is not a string
89
+ ValueError: If tokenization fails
90
+ """
91
+ try:
92
+ tokens = self.tokenize(text, **kwargs)
93
+ result = self.tokenizer.convert_tokens_to_ids(tokens)
94
+ return result
95
+ except Exception as e:
96
+ raise ValueError(f"Tokenization failed: {str(e)}") from e
97
+
98
+ return CharTokenizerWrapper(tokenizer)
99
+
100
+
101
+ def load_rknn(path: str, providers):
102
+ if not os.path.exists(path):
103
+ raise FileNotFoundError(f"ONNX file not found: {path}")
104
+ return ort.InferenceSession(path, providers=providers)
105
+
106
+ def ensure_numpy(arr, dtype=None):
107
+ np_arr = np.asarray(arr)
108
+ if dtype is not None:
109
+ np_arr = np_arr.astype(dtype, copy=False)
110
+ return np_arr
111
+
112
+
113
+ def run_ort(session: ort.InferenceSession, inputs: dict, name: str = None):
114
+ start = time.perf_counter()
115
+ ort_inputs = {k: ensure_numpy(v) for k, v in inputs.items()}
116
+ outputs = session.run(None, ort_inputs) # noqa: SLF001
117
+ if name:
118
+ elapsed_ms = (time.perf_counter() - start) * 1000
119
+ print(f"[time] {name}: {elapsed_ms:.2f} ms")
120
+ return np.asarray(outputs[0])
121
+
122
+
123
+ def run_locenc_fixed_len(x: np.ndarray, sess: ort.InferenceSession, max_len: int, name: str):
124
+ """Run locenc in fixed-length chunks to avoid dynamic seq len requirements."""
125
+ b, t, p, d = x.shape
126
+ outputs = []
127
+ for start in range(0, t, max_len):
128
+ end = min(start + max_len, t)
129
+ chunk = x[:, start:end, :, :]
130
+ if chunk.shape[1] < max_len:
131
+ pad = np.zeros((b, max_len - chunk.shape[1], p, d), dtype=chunk.dtype)
132
+ chunk = np.concatenate([chunk, pad], axis=1)
133
+ out = run_ort(sess, {"x": chunk.astype(np.float32)}, name=f"{name}_{start}")
134
+ outputs.append(out[:, : end - start, :])
135
+ return np.concatenate(outputs, axis=1)
136
+
137
+
138
+ def run_fsq_fixed_len(hidden: np.ndarray, sess: ort.InferenceSession, max_len: int, name: str):
139
+ """Run FSQ in fixed-length chunks to avoid dynamic seq len requirements."""
140
+ b, t, h = hidden.shape
141
+ outputs = []
142
+ for start in range(0, t, max_len):
143
+ end = min(start + max_len, t)
144
+ chunk = hidden[:, start:end, :]
145
+ if chunk.shape[1] < max_len:
146
+ pad = np.zeros((b, max_len - chunk.shape[1], h), dtype=chunk.dtype)
147
+ chunk = np.concatenate([chunk, pad], axis=1)
148
+ out = run_ort(sess, {"hidden": chunk.astype(np.float32)}, name=f"{name}_{start}")
149
+ outputs.append(out[:, : end - start, :])
150
+ return np.concatenate(outputs, axis=1)
151
+
152
+
153
+ def run_vae_encode_chunked(
154
+ audio: np.ndarray,
155
+ sess: ort.InferenceSession,
156
+ chunk_size: int,
157
+ block_latent_len: int = 64,
158
+ overlap_latent: int = 4,
159
+ name: str = "vae_encode",
160
+ ):
161
+ """Encode audio with fixed-length blocks (latent steps) and overlap to avoid dynamic shapes."""
162
+ block_samples = block_latent_len * chunk_size
163
+ overlap_samples = overlap_latent * chunk_size
164
+ stride_samples = block_samples - overlap_samples
165
+ outputs = []
166
+ t = audio.shape[-1]
167
+ for start in range(0, t, stride_samples):
168
+ end = min(start + block_samples, t)
169
+ chunk = audio[..., start:end]
170
+ chunk_len = chunk.shape[-1]
171
+ if chunk_len < block_samples:
172
+ chunk = np.pad(chunk, ((0, 0), (0, 0), (0, block_samples - chunk_len)), mode="constant")
173
+ out = run_ort(sess, {"audio_wave": chunk.astype(np.float32)}, name=f"{name}_{start}")
174
+ valid_latent = math.ceil(chunk_len / chunk_size)
175
+ start_keep = overlap_latent if start > 0 else 0
176
+ if valid_latent <= start_keep:
177
+ start_keep = max(valid_latent - 1, 0)
178
+ end_keep = min(valid_latent, block_latent_len)
179
+ if end_keep > start_keep:
180
+ outputs.append(out[:, :, start_keep:end_keep])
181
+ if not outputs:
182
+ return np.zeros((audio.shape[0], 0, 0), dtype=audio.dtype)
183
+ return np.concatenate(outputs, axis=2)
184
+
185
+
186
+ def run_vae_decode_chunked(
187
+ latent: np.ndarray,
188
+ sess: ort.InferenceSession,
189
+ chunk_size: int,
190
+ block_latent_len: int = 64,
191
+ overlap_latent: int = 4,
192
+ name: str = "vae_decode",
193
+ ):
194
+ """Decode latent in fixed-length blocks with overlap-add style stitching."""
195
+ block_len = block_latent_len
196
+ overlap = overlap_latent
197
+ stride = block_len - overlap
198
+ outputs = []
199
+ total = latent.shape[2]
200
+ for start in range(0, total, stride):
201
+ end = min(start + block_len, total)
202
+ chunk = latent[:, :, start:end]
203
+ valid_latent = end - start
204
+ if chunk.shape[2] < block_len:
205
+ pad = np.zeros(
206
+ (latent.shape[0], latent.shape[1], block_len - chunk.shape[2]),
207
+ dtype=latent.dtype,
208
+ )
209
+ chunk = np.concatenate([chunk, pad], axis=2)
210
+ out = run_ort(sess, {"latent": chunk.astype(np.float32)}, name=f"{name}_{start}")
211
+ start_keep = overlap * chunk_size if start > 0 else 0
212
+ valid_audio = valid_latent * chunk_size
213
+ if valid_audio <= start_keep:
214
+ start_keep = max(valid_audio - chunk_size, 0)
215
+ end_keep = min(out.shape[-1], start_keep + valid_audio)
216
+ if end_keep > start_keep:
217
+ outputs.append(out[..., start_keep:end_keep])
218
+ return np.concatenate(outputs, axis=-1)
219
+
220
+
221
+ def cfm_euler_with_onnx_step(
222
+ dit_sess: ort.InferenceSession,
223
+ x: np.ndarray,
224
+ mu: np.ndarray,
225
+ cond: np.ndarray,
226
+ n_timesteps: int,
227
+ cfg_value: float,
228
+ use_cfg_zero_star: bool = True,
229
+ mean_mode: bool = False,
230
+ ):
231
+ """
232
+ Re-implementation of UnifiedCFM.solve_euler using ONNX DiT single step.
233
+ Shapes:
234
+ x: [B, C, P], mu: [B, H_dit], cond: [B, C, P]
235
+ """
236
+ dtype = x.dtype
237
+ t_span = np.linspace(1.0, 0.0, n_timesteps + 1, dtype=dtype)
238
+ t_span = t_span + 1.0 * (np.cos(np.pi / 2 * t_span) - 1 + t_span) # sway sampling
239
+
240
+ t = t_span[0]
241
+ dt = t_span[0] - t_span[1]
242
+ zero_init_steps = max(1, int(len(t_span) * 0.04))
243
+
244
+ for step in tqdm(range(1, len(t_span))):
245
+ if use_cfg_zero_star and step <= zero_init_steps:
246
+ dphi_dt = np.zeros_like(x, dtype=dtype)
247
+ else:
248
+ b = x.shape[0]
249
+ t_in = np.full((b,), t, dtype=dtype)
250
+ dt_in = np.full((b,), dt if mean_mode else 0.0, dtype=dtype)
251
+
252
+ # run conditional branch (pos)
253
+ dphi_dt_pos = np.asarray(
254
+ run_ort(
255
+ dit_sess,
256
+ {
257
+ "x": x,
258
+ "mu": mu,
259
+ "t": t_in,
260
+ "cond": cond,
261
+ "dt": dt_in,
262
+ },
263
+ ),
264
+ dtype=dtype,
265
+ )
266
+
267
+ # run "negative" branch (unconditional: mu=0, cond=0)
268
+ dphi_dt_neg = np.asarray(
269
+ run_ort(
270
+ dit_sess,
271
+ {
272
+ "x": x,
273
+ "mu": np.zeros_like(mu),
274
+ "t": t_in,
275
+ "cond": np.zeros_like(cond),
276
+ "dt": np.zeros_like(dt_in),
277
+ },
278
+ ),
279
+ dtype=dtype,
280
+ )
281
+
282
+ if use_cfg_zero_star:
283
+ positive_flat = dphi_dt_pos.reshape(b, -1)
284
+ negative_flat = dphi_dt_neg.reshape(b, -1)
285
+ st_star = np.sum(positive_flat * negative_flat, axis=1, keepdims=True) / (
286
+ np.sum(negative_flat ** 2, axis=1, keepdims=True) + 1e-8
287
+ )
288
+ st_star = st_star.reshape(b, *([1] * (dphi_dt_pos.ndim - 1)))
289
+ else:
290
+ st_star = 1.0
291
+
292
+ dphi_dt = dphi_dt_neg * st_star + cfg_value * (dphi_dt_pos - dphi_dt_neg * st_star)
293
+
294
+ x = x - dt * dphi_dt
295
+ t = t - dt
296
+ if step < len(t_span) - 1:
297
+ dt = t - t_span[step + 1]
298
+
299
+ return x
300
+
301
+
302
+ def prepare_audio_features(
303
+ audio_path: str,
304
+ sample_rate: int,
305
+ patch_size: int,
306
+ chunk_size: int,
307
+ vae_encode_sess: ort.InferenceSession,
308
+ ):
309
+ audio, sr = sf.read(audio_path, always_2d=False)
310
+ if audio.ndim > 1:
311
+ audio = audio.mean(axis=1)
312
+ audio = audio.astype(np.float32)
313
+ if sr != sample_rate:
314
+ gcd = math.gcd(int(sr), int(sample_rate))
315
+ up = sample_rate // gcd
316
+ down = sr // gcd
317
+ audio = signal.resample_poly(audio, up, down)
318
+
319
+ # Expect shape [B, 1, T] for ONNX VAE encoder
320
+ if audio.ndim == 1:
321
+ audio = audio[np.newaxis, np.newaxis, :]
322
+ else:
323
+ audio = audio[np.newaxis, ...]
324
+
325
+ patch_len = patch_size * chunk_size
326
+ t = audio.shape[-1]
327
+ pad_right = (patch_len - t % patch_len) % patch_len
328
+ if pad_right > 0:
329
+ audio = np.pad(audio, ((0, 0), (0, 0), (0, pad_right)), mode="constant")
330
+
331
+ latent = run_vae_encode_chunked(
332
+ audio,
333
+ vae_encode_sess,
334
+ chunk_size=chunk_size,
335
+ block_latent_len=64,
336
+ overlap_latent=4,
337
+ name="vae_encode",
338
+ )
339
+ latent_dim = latent.shape[1]
340
+ t_latent = latent.shape[2]
341
+ if t_latent % patch_size != 0:
342
+ raise ValueError(f"Encoded latent length {t_latent} not divisible by patch_size={patch_size}")
343
+
344
+ audio_feat = latent.reshape(latent_dim, -1, patch_size).transpose(1, 2, 0)
345
+ audio_feat = audio_feat[:-1, ...] # remove last padding token
346
+ return audio_feat
347
+
348
+
349
+ def main():
350
+ parser = argparse.ArgumentParser(description="Hybrid ONNX inference for VoxCPM (non-streaming).")
351
+ parser.add_argument("--tokenizer-dir", required=True, help="Path to tokenizer (e.g., VoxCPM-0.5B).")
352
+ parser.add_argument("--base-hf-dir", required=True, help="Path to transformers-formatted base MiniCPM.")
353
+ parser.add_argument("--residual-hf-dir", required=True, help="Path to transformers-formatted residual MiniCPM.")
354
+ parser.add_argument("--onnx-dir", required=True, help="Directory containing exported ONNX files.")
355
+ parser.add_argument("--text", required=True, help="Target text to synthesize.")
356
+ parser.add_argument("--prompt-audio", default=None, help="Optional prompt audio path.")
357
+ parser.add_argument("--prompt-text", default=None, help="Text transcript of prompt audio (required if prompt-audio).")
358
+ parser.add_argument("--output", default="onnx_output.wav", help="Output wav path.")
359
+ parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG value for diffusion.")
360
+ parser.add_argument("--inference-timesteps", type=int, default=10, help="Diffusion steps.")
361
+ parser.add_argument("--min-len", type=int, default=2, help="Minimum generated patch count before stop allowed.")
362
+ parser.add_argument("--max-len", type=int, default=2000, help="Maximum generated patch count.")
363
+ parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
364
+ parser.add_argument(
365
+ "--providers",
366
+ nargs="+",
367
+ default=None,
368
+ help="ONNX Runtime providers (e.g., CUDAExecutionProvider CPUExecutionProvider).",
369
+ )
370
+ args = parser.parse_args()
371
+
372
+ providers = args.providers or ["CUDAExecutionProvider", "CPUExecutionProvider"]
373
+
374
+ # Seed
375
+ if args.seed is not None:
376
+ random.seed(args.seed)
377
+ np.random.seed(args.seed)
378
+
379
+ # Load tokenizer and HF MiniCPM models
380
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
381
+ tokenizer = mask_multichar_chinese_tokens(tokenizer)
382
+ base_model = RKLLMRuntime()
383
+ base_model_params = base_model.create_default_param()
384
+ base_model_params.model_path = (args.base_hf_dir + "/base_lm.rkllm").encode("utf-8")
385
+ base_model_params.max_context_len = 1024
386
+ base_model.init(base_model_params, lambda: print("base_model init"))
387
+ residual_model = RKLLMRuntime()
388
+ residual_model_params = residual_model.create_default_param()
389
+ residual_model_params.model_path = (args.residual_hf_dir + "/residual_lm.rkllm").encode("utf-8")
390
+ residual_model_params.max_context_len = 1024
391
+ residual_model.init(residual_model_params, lambda: print("residual_model init"))
392
+
393
+ dtype = np.float32
394
+
395
+ # constants tied to exported ONNX
396
+ patch_size = 2
397
+ latent_dim = 64
398
+ chunk_size = 640
399
+ sample_rate = 16000
400
+ audio_start_token = 101
401
+ fixed_seq_len = 64 # target platform prefers fixed seq len for locenc/fsq
402
+
403
+ # Load ONNX sessions
404
+ vae_encode_sess = load_rknn(os.path.join(args.onnx_dir, "audio_vae_encode.rknn"), providers)
405
+ vae_decode_sess = load_rknn(os.path.join(args.onnx_dir, "audio_vae_decode.rknn"), providers)
406
+ locenc_64_sess = load_rknn(os.path.join(args.onnx_dir, "locenc_64.rknn"), providers)
407
+ locenc_1_sess = load_rknn(os.path.join(args.onnx_dir, "locenc_1.rknn"), providers)
408
+ fsq_sess = load_rknn(os.path.join(args.onnx_dir, "fsq_layer.rknn"), providers)
409
+ stop_sess = load_rknn(os.path.join(args.onnx_dir, "stop_head.rknn"), providers)
410
+ dit_step_sess = load_rknn(os.path.join(args.onnx_dir, "dit_step.rknn"), providers)
411
+ lm_to_dit_sess = load_rknn(os.path.join(args.onnx_dir, "lm_to_dit_proj.rknn"), providers)
412
+ res_to_dit_sess = load_rknn(os.path.join(args.onnx_dir, "res_to_dit_proj.rknn"), providers)
413
+
414
+ # Build text/audio features
415
+ if args.prompt_audio:
416
+ if not args.prompt_text:
417
+ raise ValueError("prompt-text is required when prompt-audio is provided.")
418
+ text = args.prompt_text + args.text
419
+ else:
420
+ text = args.text
421
+
422
+ tokenized = tokenizer(text, add_special_tokens=False)
423
+ text_token = tokenized["input_ids"] if isinstance(tokenized, dict) else tokenized
424
+ text_token = np.asarray(text_token, dtype=np.int64)
425
+ text_token = np.concatenate([text_token, np.array([audio_start_token], dtype=np.int64)], axis=0)
426
+ text_length = text_token.shape[0]
427
+
428
+ if args.prompt_audio:
429
+ audio_feat = prepare_audio_features(
430
+ args.prompt_audio, sample_rate, patch_size, chunk_size, vae_encode_sess
431
+ )
432
+ audio_length = audio_feat.shape[0]
433
+
434
+ text_pad_token = np.zeros(audio_length, dtype=np.int64)
435
+ text_token = np.concatenate([text_token, text_pad_token])
436
+
437
+ audio_pad_feat = np.zeros(
438
+ (text_length, patch_size, latent_dim),
439
+ dtype=np.float32,
440
+ )
441
+ audio_feat = np.concatenate([audio_pad_feat, audio_feat], axis=0)
442
+
443
+ text_mask = np.concatenate([np.ones(text_length, dtype=np.int32), np.zeros(audio_length, dtype=np.int32)])
444
+ audio_mask = np.concatenate([np.zeros(text_length, dtype=np.int32), np.ones(audio_length, dtype=np.int32)])
445
+ else:
446
+ audio_feat = np.zeros(
447
+ (text_length, patch_size, latent_dim),
448
+ dtype=np.float32,
449
+ )
450
+ text_mask = np.ones(text_length, dtype=np.int32)
451
+ audio_mask = np.zeros(text_length, dtype=np.int32)
452
+
453
+ text_token = text_token[np.newaxis, ...]
454
+ text_mask = text_mask[np.newaxis, ...]
455
+ audio_feat = audio_feat[np.newaxis, ...]
456
+ audio_mask = audio_mask[np.newaxis, ...]
457
+
458
+ # LocEnc (ONNX) with enc_to_lm projection fused
459
+ feat_embed = run_locenc_fixed_len(
460
+ audio_feat.astype(np.float32), locenc_64_sess, fixed_seq_len, name="locenc"
461
+ )
462
+ feat_embed = np.asarray(feat_embed, dtype=dtype)
463
+
464
+ # Text embed using exported embedding npy
465
+ scale_emb = 1.0
466
+ embedding_path = os.path.join(args.onnx_dir, "embed_tokens.npy")
467
+ if not os.path.exists(embedding_path):
468
+ raise FileNotFoundError(f"Embedding npy file not found: {embedding_path}")
469
+ embedding_weight = np.load(embedding_path).astype(dtype)
470
+ text_embed = embedding_weight[text_token] * scale_emb
471
+
472
+ combined_embed = text_mask[..., None] * text_embed + audio_mask[..., None] * feat_embed
473
+
474
+ # Base LM forward
475
+ start = time.perf_counter()
476
+ enc_outputs = np.asarray(base_model.forward_embed(combined_embed.astype(np.float32), keep_history=True), dtype=dtype)
477
+ enc_outputs = enc_outputs * 4
478
+ print(f"[time] base_lm initial: {(time.perf_counter() - start)*1000:.2f} ms")
479
+
480
+ # FSQ on audio positions
481
+ enc_outputs_fsq = run_fsq_fixed_len(enc_outputs.astype(np.float32), fsq_sess, fixed_seq_len, name="fsq_init")
482
+
483
+ enc_outputs = enc_outputs_fsq * audio_mask[..., None] + enc_outputs * text_mask[..., None]
484
+
485
+ # Residual LM forward
486
+ residual_inputs = enc_outputs + audio_mask[..., None] * feat_embed
487
+ start = time.perf_counter()
488
+ residual_outputs = np.asarray(
489
+ residual_model.forward_embed(residual_inputs.astype(np.float32), keep_history=True),
490
+ dtype=dtype,
491
+ )
492
+ residual_outputs = residual_outputs * 4
493
+ print(f"[time] residual_lm initial: {(time.perf_counter() - start)*1000:.2f} ms")
494
+
495
+ lm_hidden = enc_outputs[:, -1, :].astype(dtype)
496
+ res_hidden = residual_outputs[:, -1, :].astype(dtype)
497
+
498
+ prefix_feat_cond = audio_feat[:, -1, :, :].astype(dtype) # [B, P, D]
499
+ pred_feat_seq = []
500
+
501
+ # Generation loop
502
+ for step_idx in tqdm(range(args.max_len), desc="gen_loop"):
503
+ dit_hidden_lm = run_ort(lm_to_dit_sess, {"input": lm_hidden}, name="lm_to_dit")
504
+ dit_hidden_res = run_ort(res_to_dit_sess, {"input": res_hidden}, name="res_to_dit")
505
+ dit_hidden = np.asarray(dit_hidden_lm + dit_hidden_res, dtype=dtype)
506
+ cond = np.transpose(prefix_feat_cond, (0, 2, 1)) # [B, D, P]
507
+
508
+ # Sample next patch via ONNX DiT
509
+ x0 = np.random.randn(*cond.shape).astype(dtype) # [B, D, P]
510
+ pred_feat = cfm_euler_with_onnx_step(
511
+ dit_step_sess,
512
+ x0,
513
+ dit_hidden,
514
+ cond,
515
+ n_timesteps=args.inference_timesteps,
516
+ cfg_value=args.cfg_value,
517
+ use_cfg_zero_star=True,
518
+ ).transpose(0, 2, 1) # -> [B, P, D]
519
+
520
+ pred_feat_seq.append(pred_feat[:, np.newaxis, :, :]) # keep time dimension
521
+ prefix_feat_cond = pred_feat
522
+
523
+ # Encode new patch for next step (ONNX locenc)
524
+ curr_embed = run_ort(locenc_1_sess, {"x": pred_feat[:, np.newaxis, :, :]}, name="locenc_step")
525
+ curr_embed = np.asarray(curr_embed, dtype=dtype)
526
+
527
+ # Stop check (use lm_hidden BEFORE update, consistent with original)
528
+ stop_logits_np = run_ort(stop_sess, {"hidden": lm_hidden})
529
+ stop_flag = int(np.argmax(stop_logits_np, axis=-1)[0])
530
+ if step_idx > args.min_len and stop_flag == 1:
531
+ break
532
+
533
+ lm_hidden_step = np.asarray(
534
+ base_model.forward_embed(curr_embed.astype(np.float32), keep_history=True),
535
+ dtype=dtype,
536
+ )
537
+ lm_hidden_step = lm_hidden_step * 4
538
+ lm_hidden = lm_hidden_step[:, -1, :].astype(dtype)
539
+
540
+ # FSQ expects [B, T, H]; expand step dimension then squeeze back
541
+ lm_hidden_step_for_fsq = lm_hidden[:, np.newaxis, :]
542
+ lm_hidden_fsq_np = run_ort(fsq_sess, {"hidden": lm_hidden_step_for_fsq.astype(np.float32)})
543
+ lm_hidden_fsq = np.asarray(lm_hidden_fsq_np, dtype=dtype)[:, 0, :]
544
+
545
+ res_step_inputs = (lm_hidden_fsq + curr_embed[:, 0, :]).astype(dtype)
546
+ res_step_inputs = res_step_inputs[:, np.newaxis, :]
547
+ res_step = np.asarray(
548
+ residual_model.forward_embed(res_step_inputs.astype(np.float32), keep_history=True),
549
+ dtype=dtype,
550
+ )
551
+ res_step = res_step * 4
552
+ res_hidden = res_step[:, -1, :].astype(dtype)
553
+ lm_hidden = lm_hidden_fsq
554
+
555
+ if len(pred_feat_seq) == 0:
556
+ raise RuntimeError("Generation produced zero patches.")
557
+
558
+ pred_feat_seq = np.concatenate(pred_feat_seq, axis=1) # [B, T_gen, P, D]
559
+ feat_pred = pred_feat_seq.transpose(0, 3, 1, 2).reshape(
560
+ pred_feat_seq.shape[0], pred_feat_seq.shape[3], -1
561
+ ) # [B, D, T_gen*P]
562
+
563
+ # Decode audio via ONNX
564
+ audio = run_vae_decode_chunked(
565
+ feat_pred,
566
+ vae_decode_sess,
567
+ chunk_size=chunk_size,
568
+ block_latent_len=64,
569
+ overlap_latent=4,
570
+ name="vae_decode",
571
+ )
572
+ audio = audio[..., 640:-640] # trim start/end
573
+
574
+ wav = audio.squeeze()
575
+ sf.write(args.output, wav, sample_rate)
576
+ print(f"Saved: {args.output}")
577
+
578
+
579
+ if __name__ == "__main__":
580
+ main()
onnx_infer.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+ import sys
8
+ import pathlib
9
+ import onnxruntime as ort
10
+ import time
11
+ from tqdm import tqdm
12
+
13
+ from transformers import AutoTokenizer
14
+ from transformers.cache_utils import DynamicCache
15
+
16
+ from modeling_minicpm import MiniCPMModel # noqa: E402
17
+
18
+ def mask_multichar_chinese_tokens(tokenizer):
19
+ # Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
20
+ multichar_tokens = {
21
+ token for token in tokenizer.vocab.keys()
22
+ if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
23
+ }
24
+
25
+ class CharTokenizerWrapper:
26
+ """Wrapper class for tokenizers that handles multi-character Chinese tokens.
27
+
28
+ This wrapper automatically splits multi-character Chinese tokens into
29
+ individual characters while preserving the original tokenizer's interface.
30
+ """
31
+
32
+ def __init__(self, base_tokenizer) -> None:
33
+ """Initialize the wrapper with a base tokenizer.
34
+
35
+ Args:
36
+ base_tokenizer: The tokenizer to wrap
37
+ """
38
+ self.tokenizer = base_tokenizer
39
+ self.multichar_tokens = multichar_tokens
40
+
41
+ def tokenize(self, text: str, **kwargs):
42
+ """Tokenize text and split multi-character Chinese tokens into single characters.
43
+
44
+ Args:
45
+ text: Input text to tokenize
46
+ **kwargs: Additional arguments passed to the base tokenizer
47
+
48
+ Returns:
49
+ List of processed tokens with multi-character Chinese tokens split
50
+
51
+ Example:
52
+ >>> wrapper = CharTokenizerWrapper(tokenizer)
53
+ >>> tokens = wrapper.tokenize("你好世界")
54
+ >>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
55
+ """
56
+ if not isinstance(text, str):
57
+ raise TypeError(f"Expected string input, got {type(text)}")
58
+
59
+ tokens = self.tokenizer.tokenize(text, **kwargs)
60
+ processed = []
61
+
62
+ for token in tokens:
63
+ # Remove possible subword prefix
64
+ clean_token = token.replace("▁", "")
65
+
66
+ if clean_token in self.multichar_tokens:
67
+ # Split multi-character token into single characters
68
+ chars = list(clean_token)
69
+ processed.extend(chars)
70
+ else:
71
+ processed.append(token)
72
+
73
+ return processed
74
+
75
+ def __call__(self, text: str, **kwargs):
76
+ """Call the tokenizer and return token IDs.
77
+
78
+ This method provides the same interface as the original tokenizer
79
+ but with multi-character Chinese token handling.
80
+
81
+ Args:
82
+ text: Input text to tokenize
83
+ **kwargs: Additional arguments passed to the base tokenizer
84
+
85
+ Returns:
86
+ List of token IDs
87
+
88
+ Raises:
89
+ TypeError: If input is not a string
90
+ ValueError: If tokenization fails
91
+ """
92
+ try:
93
+ tokens = self.tokenize(text, **kwargs)
94
+ result = self.tokenizer.convert_tokens_to_ids(tokens)
95
+ return result
96
+ except Exception as e:
97
+ raise ValueError(f"Tokenization failed: {str(e)}") from e
98
+
99
+ return CharTokenizerWrapper(tokenizer)
100
+
101
+
102
+ def load_onnx(path: str, providers):
103
+ if not os.path.exists(path):
104
+ raise FileNotFoundError(f"ONNX file not found: {path}")
105
+ return ort.InferenceSession(path, providers=providers)
106
+
107
+
108
+ def to_numpy(t: torch.Tensor):
109
+ return t.detach().cpu().numpy()
110
+
111
+
112
+ def run_ort(session: ort.InferenceSession, inputs: dict, name: str = None):
113
+ start = time.perf_counter()
114
+ ort_inputs = {k: (v if isinstance(v, np.ndarray) else to_numpy(v)) for k, v in inputs.items()}
115
+ outputs = session.run(None, ort_inputs)
116
+ if name:
117
+ elapsed_ms = (time.perf_counter() - start) * 1000
118
+ print(f"[time] {name}: {elapsed_ms:.2f} ms")
119
+ return outputs[0]
120
+
121
+
122
+ def cfm_euler_with_onnx_step(
123
+ dit_sess: ort.InferenceSession,
124
+ x: torch.Tensor,
125
+ mu: torch.Tensor,
126
+ cond: torch.Tensor,
127
+ n_timesteps: int,
128
+ cfg_value: float,
129
+ use_cfg_zero_star: bool = True,
130
+ mean_mode: bool = False,
131
+ ):
132
+ """
133
+ Re-implementation of UnifiedCFM.solve_euler using ONNX DiT single step.
134
+ Shapes:
135
+ x: [B, C, P], mu: [B, H_dit], cond: [B, C, P]
136
+ """
137
+ device = x.device
138
+ dtype = x.dtype
139
+ t_span = torch.linspace(1.0, 0.0, n_timesteps + 1, device=device, dtype=dtype)
140
+ t_span = t_span + 1.0 * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) # sway sampling
141
+
142
+ t = t_span[0]
143
+ dt = t_span[0] - t_span[1]
144
+ zero_init_steps = max(1, int(len(t_span) * 0.04))
145
+
146
+ for step in range(1, len(t_span)):
147
+ if use_cfg_zero_star and step <= zero_init_steps:
148
+ dphi_dt = torch.zeros_like(x)
149
+ else:
150
+ b = x.size(0)
151
+ t_in = t.expand(b)
152
+ dt_in = dt.expand(b)
153
+ if not mean_mode:
154
+ dt_in = torch.zeros_like(dt_in)
155
+
156
+ # run conditional branch (pos)
157
+ dphi_dt_pos = run_ort(
158
+ dit_sess,
159
+ {
160
+ "x": x,
161
+ "mu": mu,
162
+ "t": t_in,
163
+ "cond": cond,
164
+ "dt": dt_in,
165
+ },
166
+ name=f"dit_step_pos_{step}",
167
+ )
168
+ dphi_dt_pos = torch.from_numpy(dphi_dt_pos).to(device=device, dtype=dtype)
169
+
170
+ # run "negative" branch (unconditional: mu=0, cond=0)
171
+ dphi_dt_neg = run_ort(
172
+ dit_sess,
173
+ {
174
+ "x": x,
175
+ "mu": torch.zeros_like(mu),
176
+ "t": t_in,
177
+ "cond": torch.zeros_like(cond),
178
+ "dt": torch.zeros_like(dt_in),
179
+ },
180
+ name=f"dit_step_neg_{step}",
181
+ )
182
+ dphi_dt_neg = torch.from_numpy(dphi_dt_neg).to(device=device, dtype=dtype)
183
+
184
+ if use_cfg_zero_star:
185
+ positive_flat = dphi_dt_pos.view(b, -1)
186
+ negative_flat = dphi_dt_neg.view(b, -1)
187
+ st_star = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) / (
188
+ torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
189
+ )
190
+ st_star = st_star.view(b, *([1] * (dphi_dt_pos.ndim - 1)))
191
+ else:
192
+ st_star = 1.0
193
+
194
+ dphi_dt = dphi_dt_neg * st_star + cfg_value * (dphi_dt_pos - dphi_dt_neg * st_star)
195
+
196
+ x = x - dt * dphi_dt
197
+ t = t - dt
198
+ if step < len(t_span) - 1:
199
+ dt = t - t_span[step + 1]
200
+
201
+ return x
202
+
203
+
204
+ def prepare_audio_features(
205
+ audio_path: str,
206
+ sample_rate: int,
207
+ patch_size: int,
208
+ chunk_size: int,
209
+ vae_encode_sess: ort.InferenceSession,
210
+ ):
211
+ audio, sr = torchaudio.load(audio_path)
212
+ if audio.size(0) > 1:
213
+ audio = audio.mean(dim=0, keepdim=True)
214
+ if sr != sample_rate:
215
+ audio = torchaudio.functional.resample(audio, sr, sample_rate)
216
+
217
+ # Expect shape [B, 1, T] for ONNX VAE encoder
218
+ if audio.ndim == 2:
219
+ audio = audio.unsqueeze(0)
220
+
221
+ patch_len = patch_size * chunk_size
222
+ t = audio.size(-1)
223
+ pad_right = (patch_len - t % patch_len) % patch_len
224
+ if pad_right > 0:
225
+ audio = torch.nn.functional.pad(audio, (0, pad_right))
226
+
227
+ latent = run_ort(vae_encode_sess, {"audio_wave": audio}, name="vae_encode")
228
+ latent = torch.from_numpy(latent)
229
+ latent_dim = latent.shape[1]
230
+ t_latent = latent.shape[2]
231
+ if t_latent % patch_size != 0:
232
+ raise ValueError(f"Encoded latent length {t_latent} not divisible by patch_size={patch_size}")
233
+
234
+ audio_feat = latent.view(latent_dim, -1, patch_size).permute(1, 2, 0)
235
+ audio_feat = audio_feat[:-1, ...] # remove last padding token
236
+ return audio_feat
237
+
238
+
239
+ def main():
240
+ parser = argparse.ArgumentParser(description="Hybrid ONNX/PyTorch inference for VoxCPM (non-streaming).")
241
+ parser.add_argument("--tokenizer-dir", required=True, help="Path to tokenizer (e.g., VoxCPM-0.5B).")
242
+ parser.add_argument("--base-hf-dir", required=True, help="Path to transformers-formatted base MiniCPM.")
243
+ parser.add_argument("--residual-hf-dir", required=True, help="Path to transformers-formatted residual MiniCPM.")
244
+ parser.add_argument("--onnx-dir", required=True, help="Directory containing exported ONNX files.")
245
+ parser.add_argument("--text", required=True, help="Target text to synthesize.")
246
+ parser.add_argument("--prompt-audio", default=None, help="Optional prompt audio path.")
247
+ parser.add_argument("--prompt-text", default=None, help="Text transcript of prompt audio (required if prompt-audio).")
248
+ parser.add_argument("--output", default="onnx_output.wav", help="Output wav path.")
249
+ parser.add_argument("--device", default=None, help="torch device; default auto (cuda if available else cpu).")
250
+ parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG value for diffusion.")
251
+ parser.add_argument("--inference-timesteps", type=int, default=10, help="Diffusion steps.")
252
+ parser.add_argument("--min-len", type=int, default=2, help="Minimum generated patch count before stop allowed.")
253
+ parser.add_argument("--max-len", type=int, default=2000, help="Maximum generated patch count.")
254
+ parser.add_argument("--force-fp32", action="store_true", help="Force model dtype to float32 for consistency with ONNX.")
255
+ parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
256
+ parser.add_argument(
257
+ "--providers",
258
+ nargs="+",
259
+ default=None,
260
+ help="ONNX Runtime providers (e.g., CUDAExecutionProvider CPUExecutionProvider).",
261
+ )
262
+ args = parser.parse_args()
263
+
264
+ device = torch.device(args.device) if args.device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
+ providers = args.providers or ["CUDAExecutionProvider", "CPUExecutionProvider"]
266
+
267
+ # Seed
268
+ if args.seed is not None:
269
+ random.seed(args.seed)
270
+ np.random.seed(args.seed)
271
+ torch.manual_seed(args.seed)
272
+ if torch.cuda.is_available():
273
+ torch.cuda.manual_seed_all(args.seed)
274
+
275
+ # Inference with no grad to avoid graph retention / memory growth
276
+ with torch.inference_mode():
277
+ # Load tokenizer and HF MiniCPM models
278
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
279
+ tokenizer = mask_multichar_chinese_tokens(tokenizer)
280
+ base_model = MiniCPMModel.from_pretrained(args.base_hf_dir).to(device).eval()
281
+ residual_model = MiniCPMModel.from_pretrained(args.residual_hf_dir).to(device).eval()
282
+ if args.force_fp32:
283
+ dtype = torch.float32
284
+ else:
285
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
286
+
287
+ base_model = base_model.to(dtype)
288
+ residual_model = residual_model.to(dtype)
289
+
290
+ # constants tied to exported ONNX
291
+ patch_size = 2
292
+ feat_dim = 64
293
+ latent_dim = 64
294
+ chunk_size = 640
295
+ sample_rate = 16000
296
+ audio_start_token = 101
297
+
298
+ # Load ONNX sessions
299
+ vae_encode_sess = load_onnx(os.path.join(args.onnx_dir, "audio_vae_encode.onnx"), providers)
300
+ vae_decode_sess = load_onnx(os.path.join(args.onnx_dir, "audio_vae_decode.onnx"), providers)
301
+ locenc_sess = load_onnx(os.path.join(args.onnx_dir, "locenc.onnx"), providers)
302
+ fsq_sess = load_onnx(os.path.join(args.onnx_dir, "fsq_layer.onnx"), providers)
303
+ stop_sess = load_onnx(os.path.join(args.onnx_dir, "stop_head.onnx"), providers)
304
+ dit_step_sess = load_onnx(os.path.join(args.onnx_dir, "dit_step.onnx"), providers)
305
+ enc_to_lm_sess = load_onnx(os.path.join(args.onnx_dir, "enc_to_lm_proj.onnx"), providers)
306
+ lm_to_dit_sess = load_onnx(os.path.join(args.onnx_dir, "lm_to_dit_proj.onnx"), providers)
307
+ res_to_dit_sess = load_onnx(os.path.join(args.onnx_dir, "res_to_dit_proj.onnx"), providers)
308
+
309
+ # Build text/audio features
310
+ if args.prompt_audio:
311
+ if not args.prompt_text:
312
+ raise ValueError("prompt-text is required when prompt-audio is provided.")
313
+ text = args.prompt_text + args.text
314
+ else:
315
+ text = args.text
316
+
317
+ tokenized = tokenizer(text, add_special_tokens=False)
318
+ if isinstance(tokenized, dict):
319
+ text_token = tokenized["input_ids"]
320
+ else:
321
+ text_token = tokenized
322
+ text_token = torch.LongTensor(text_token)
323
+ text_token = torch.cat([text_token, torch.tensor([audio_start_token], dtype=torch.int64)], dim=-1)
324
+ text_length = text_token.shape[0]
325
+
326
+ if args.prompt_audio:
327
+ audio_feat = prepare_audio_features(
328
+ args.prompt_audio, sample_rate, patch_size, chunk_size, vae_encode_sess
329
+ )
330
+ audio_length = audio_feat.size(0)
331
+
332
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int64)
333
+ text_token = torch.cat([text_token, text_pad_token])
334
+
335
+ audio_pad_feat = torch.zeros(
336
+ (text_length, patch_size, latent_dim),
337
+ dtype=torch.float32,
338
+ )
339
+ audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
340
+
341
+ text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32)
342
+ audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32)
343
+ else:
344
+ audio_feat = torch.zeros(
345
+ (text_length, patch_size, latent_dim),
346
+ dtype=torch.float32,
347
+ )
348
+ text_mask = torch.ones(text_length).type(torch.int32)
349
+ audio_mask = torch.zeros(text_length).type(torch.int32)
350
+
351
+ text_token = text_token.unsqueeze(0).to(device)
352
+ text_mask = text_mask.unsqueeze(0).to(device)
353
+ audio_feat = audio_feat.unsqueeze(0).to(device)
354
+ audio_mask = audio_mask.unsqueeze(0).to(device)
355
+
356
+ # LocEnc (ONNX)
357
+ feat_embed_np = run_ort(locenc_sess, {"x": audio_feat.float()}, name="locenc")
358
+ feat_embed = torch.from_numpy(feat_embed_np).to(device=device, dtype=dtype)
359
+ feat_embed = run_ort(enc_to_lm_sess, {"input": feat_embed.float()}, name="enc_to_lm_init")
360
+ feat_embed = torch.from_numpy(feat_embed).to(device=device, dtype=dtype)
361
+
362
+ # Text embed
363
+ scale_emb = 1.0
364
+ text_embed = base_model.embed_tokens(text_token) * scale_emb
365
+ np.save("text_embed_ref.npy", text_embed.cpu().numpy())
366
+
367
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
368
+ attn_mask = torch.ones((combined_embed.size(0), combined_embed.size(1)), device=device, dtype=torch.long)
369
+ np.save("combined_embed_ref.npy", combined_embed.cpu().numpy())
370
+
371
+ # Base LM forward
372
+ start = time.perf_counter()
373
+ base_cache = DynamicCache()
374
+ base_outputs = base_model(
375
+ inputs_embeds=combined_embed,
376
+ attention_mask=attn_mask,
377
+ past_key_values=base_cache,
378
+ use_cache=True,
379
+ return_dict=True,
380
+ )
381
+ enc_outputs = base_outputs.last_hidden_state
382
+ base_cache = base_outputs.past_key_values
383
+ print(f"[time] base_lm initial: {(time.perf_counter() - start)*1000:.2f} ms")
384
+ np.save("enc_outputs_ref.npy", enc_outputs.cpu().numpy())
385
+
386
+ # FSQ on audio positions
387
+ enc_outputs_fsq_np = run_ort(fsq_sess, {"hidden": enc_outputs.float()}, name="fsq_init")
388
+ enc_outputs_fsq = torch.from_numpy(enc_outputs_fsq_np).to(device=device, dtype=dtype)
389
+ np.save("enc_outputs_fsq_ref.npy", enc_outputs_fsq.cpu().numpy())
390
+
391
+ enc_outputs = enc_outputs_fsq * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
392
+
393
+ # Residual LM forward
394
+ np.save("audio_mask_ref.npy", audio_mask.cpu().numpy())
395
+ np.save("feat_embed_ref.npy", feat_embed.cpu().numpy())
396
+ residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
397
+ start = time.perf_counter()
398
+ residual_cache = DynamicCache()
399
+ res_outputs = residual_model(
400
+ inputs_embeds=residual_inputs,
401
+ attention_mask=attn_mask,
402
+ past_key_values=residual_cache,
403
+ use_cache=True,
404
+ return_dict=True,
405
+ )
406
+ residual_outputs = res_outputs.last_hidden_state
407
+ residual_cache = res_outputs.past_key_values
408
+ print(f"[time] residual_lm initial: {(time.perf_counter() - start)*1000:.2f} ms")
409
+ np.save("residual_inputs_ref.npy", residual_inputs.cpu().numpy())
410
+ np.save("residual_outputs_ref.npy", residual_outputs.cpu().numpy())
411
+ np.save("residual_inputs_ref.npy", residual_inputs.cpu().numpy())
412
+ np.save("residual_outputs_ref.npy", residual_outputs.cpu().numpy())
413
+
414
+ lm_hidden = enc_outputs[:, -1, :].to(dtype)
415
+ res_hidden = residual_outputs[:, -1, :].to(dtype)
416
+
417
+ prefix_feat_cond = audio_feat[:, -1, :, :] # [B, P, D]
418
+ pred_feat_seq = []
419
+
420
+ # Generation loop
421
+ for step_idx in tqdm(range(args.max_len), desc="gen_loop"):
422
+ dit_hidden_lm = run_ort(lm_to_dit_sess, {"input": lm_hidden.float()}, name="lm_to_dit")
423
+ dit_hidden_res = run_ort(res_to_dit_sess, {"input": res_hidden.float()}, name="res_to_dit")
424
+ dit_hidden = torch.from_numpy(dit_hidden_lm + dit_hidden_res).to(device=device, dtype=dtype)
425
+ cond = prefix_feat_cond.transpose(1, 2).contiguous() # [B, D, P]
426
+
427
+ # Sample next patch via ONNX DiT
428
+ x0 = torch.randn_like(prefix_feat_cond.transpose(1, 2)) # [B, D, P]
429
+ pred_feat = cfm_euler_with_onnx_step(
430
+ dit_step_sess,
431
+ x0,
432
+ dit_hidden,
433
+ cond,
434
+ n_timesteps=args.inference_timesteps,
435
+ cfg_value=args.cfg_value,
436
+ use_cfg_zero_star=True,
437
+ ).transpose(1, 2) # -> [B, P, D]
438
+
439
+ pred_feat_seq.append(pred_feat.unsqueeze(1)) # keep time dimension
440
+ prefix_feat_cond = pred_feat
441
+
442
+ # Encode new patch for next step (ONNX locenc)
443
+ locenc_step_np = run_ort(locenc_sess, {"x": pred_feat.unsqueeze(1).float()}, name="locenc_step")
444
+ curr_embed = torch.from_numpy(locenc_step_np).to(device=device, dtype=dtype)
445
+ curr_embed = run_ort(enc_to_lm_sess, {"input": curr_embed.float()}, name="enc_to_lm_step")
446
+ curr_embed = torch.from_numpy(curr_embed).to(device=device, dtype=dtype)
447
+
448
+ # Stop check (use lm_hidden BEFORE update, consistent with original)
449
+ stop_logits_np = run_ort(stop_sess, {"hidden": lm_hidden.float()})
450
+ stop_logits = torch.from_numpy(stop_logits_np)
451
+ stop_flag = stop_logits.argmax(dim=-1)[0].item()
452
+ if step_idx > args.min_len and stop_flag == 1:
453
+ break
454
+
455
+ # Update LMs using transformers cache API
456
+ attn_mask = torch.cat(
457
+ [attn_mask, torch.ones((attn_mask.size(0), 1), device=device, dtype=torch.long)], dim=1
458
+ )
459
+ np.save("curr_embed_ref.npy", curr_embed.cpu().numpy())
460
+ base_step = base_model(
461
+ inputs_embeds=curr_embed,
462
+ attention_mask=attn_mask,
463
+ past_key_values=base_cache,
464
+ use_cache=True,
465
+ return_dict=True,
466
+ )
467
+ lm_hidden_step = base_step.last_hidden_state # [B, 1, H]
468
+ base_cache = base_step.past_key_values
469
+ lm_hidden = lm_hidden_step.squeeze(1).to(dtype)
470
+
471
+ # FSQ expects [B, T, H]; expand step dimension then squeeze back
472
+ lm_hidden_step = lm_hidden.unsqueeze(1)
473
+ lm_hidden_fsq_np = run_ort(fsq_sess, {"hidden": lm_hidden_step.float()})
474
+ lm_hidden_fsq = torch.from_numpy(lm_hidden_fsq_np).to(device=device, dtype=dtype).squeeze(1)
475
+
476
+ res_step_inputs = (lm_hidden_fsq + curr_embed[:, 0, :]).unsqueeze(1)
477
+ res_step = residual_model(
478
+ inputs_embeds=res_step_inputs,
479
+ attention_mask=attn_mask,
480
+ past_key_values=residual_cache,
481
+ use_cache=True,
482
+ return_dict=True,
483
+ )
484
+ residual_cache = res_step.past_key_values
485
+ res_hidden = res_step.last_hidden_state.squeeze(1).to(dtype)
486
+ lm_hidden = lm_hidden_fsq
487
+
488
+ if len(pred_feat_seq) == 0:
489
+ raise RuntimeError("Generation produced zero patches.")
490
+
491
+ pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # [B, T_gen, P, D]
492
+ feat_pred = pred_feat_seq.permute(0, 3, 1, 2).reshape(
493
+ pred_feat_seq.size(0), pred_feat_seq.size(3), -1
494
+ ) # [B, D, T_gen*P]
495
+
496
+ # Decode audio via ONNX
497
+ audio_np = run_ort(vae_decode_sess, {"latent": feat_pred.float()}, name="vae_decode")
498
+ audio = torch.from_numpy(audio_np)
499
+ audio = audio[..., 640:-640] # trim start/end
500
+
501
+ wav = audio.squeeze(0).squeeze(0).cpu().numpy()
502
+ torchaudio.save(args.output, torch.from_numpy(wav).unsqueeze(0), sample_rate)
503
+ print(f"Saved: {args.output}")
504
+
505
+
506
+ if __name__ == "__main__":
507
+ main()
res_to_dit_proj.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66e9971b7a59007bc0042430a1fdf0d021f4cf2d4875d77e7241f8a600d3a002
3
+ size 2139756
residual_lm.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3770af5abc4159e0f7c88da567b0d7cc56f7571502756492f2d991dd93deb4b6
3
+ size 497735540
rkllm_binding.py ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ctypes
3
+ import enum
4
+ import os
5
+ import threading
6
+ from typing import Optional, Sequence, Tuple
7
+
8
+ import numpy as np
9
+
10
+ # Define constants from the header
11
+ CPU0 = (1 << 0) # 0x01
12
+ CPU1 = (1 << 1) # 0x02
13
+ CPU2 = (1 << 2) # 0x04
14
+ CPU3 = (1 << 3) # 0x08
15
+ CPU4 = (1 << 4) # 0x10
16
+ CPU5 = (1 << 5) # 0x20
17
+ CPU6 = (1 << 6) # 0x40
18
+ CPU7 = (1 << 7) # 0x80
19
+
20
+ # --- Enums ---
21
+ class LLMCallState(enum.IntEnum):
22
+ RKLLM_RUN_NORMAL = 0
23
+ RKLLM_RUN_WAITING = 1
24
+ RKLLM_RUN_FINISH = 2
25
+ RKLLM_RUN_ERROR = 3
26
+
27
+ class RKLLMInputType(enum.IntEnum):
28
+ RKLLM_INPUT_PROMPT = 0
29
+ RKLLM_INPUT_TOKEN = 1
30
+ RKLLM_INPUT_EMBED = 2
31
+ RKLLM_INPUT_MULTIMODAL = 3
32
+
33
+ class RKLLMInferMode(enum.IntEnum):
34
+ RKLLM_INFER_GENERATE = 0
35
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
36
+ RKLLM_INFER_GET_LOGITS = 2
37
+
38
+ # --- Structures ---
39
+ class RKLLMExtendParam(ctypes.Structure):
40
+ base_domain_id: ctypes.c_int32
41
+ embed_flash: ctypes.c_int8
42
+ enabled_cpus_num: ctypes.c_int8
43
+ enabled_cpus_mask: ctypes.c_uint32
44
+ n_batch: ctypes.c_uint8
45
+ use_cross_attn: ctypes.c_int8
46
+ reserved: ctypes.c_uint8 * 104
47
+
48
+ _fields_ = [
49
+ ("base_domain_id", ctypes.c_int32), # 基础域ID
50
+ ("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
51
+ ("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
52
+ ("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
53
+ ("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
54
+ ("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
55
+ ("reserved", ctypes.c_uint8 * 104) # 保留字段
56
+ ]
57
+
58
+ class RKLLMParam(ctypes.Structure):
59
+ model_path: ctypes.c_char_p
60
+ max_context_len: ctypes.c_int32
61
+ max_new_tokens: ctypes.c_int32
62
+ top_k: ctypes.c_int32
63
+ n_keep: ctypes.c_int32
64
+ top_p: ctypes.c_float
65
+ temperature: ctypes.c_float
66
+ repeat_penalty: ctypes.c_float
67
+ frequency_penalty: ctypes.c_float
68
+ presence_penalty: ctypes.c_float
69
+ mirostat: ctypes.c_int32
70
+ mirostat_tau: ctypes.c_float
71
+ mirostat_eta: ctypes.c_float
72
+ skip_special_token: ctypes.c_bool
73
+ is_async: ctypes.c_bool
74
+ img_start: ctypes.c_char_p
75
+ img_end: ctypes.c_char_p
76
+ img_content: ctypes.c_char_p
77
+ extend_param: RKLLMExtendParam
78
+
79
+ _fields_ = [
80
+ ("model_path", ctypes.c_char_p), # 模型文件路径
81
+ ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
82
+ ("max_new_tokens", ctypes.c_int32), # 最大生成新token数
83
+ ("top_k", ctypes.c_int32), # Top-K采样参数
84
+ ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
85
+ ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
86
+ ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
87
+ ("repeat_penalty", ctypes.c_float), # 重复token惩罚
88
+ ("frequency_penalty", ctypes.c_float), # 频繁token惩罚
89
+ ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
90
+ ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
91
+ ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
92
+ ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
93
+ ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
94
+ ("is_async", ctypes.c_bool), # 是否异步推理
95
+ ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
96
+ ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
97
+ ("img_content", ctypes.c_char_p), # 图像内容指针
98
+ ("extend_param", RKLLMExtendParam) # 扩展参数
99
+ ]
100
+
101
+ class RKLLMLoraAdapter(ctypes.Structure):
102
+ lora_adapter_path: ctypes.c_char_p
103
+ lora_adapter_name: ctypes.c_char_p
104
+ scale: ctypes.c_float
105
+
106
+ _fields_ = [
107
+ ("lora_adapter_path", ctypes.c_char_p),
108
+ ("lora_adapter_name", ctypes.c_char_p),
109
+ ("scale", ctypes.c_float)
110
+ ]
111
+
112
+ class RKLLMEmbedInput(ctypes.Structure):
113
+ embed: ctypes.POINTER(ctypes.c_float)
114
+ n_tokens: ctypes.c_size_t
115
+
116
+ _fields_ = [
117
+ ("embed", ctypes.POINTER(ctypes.c_float)),
118
+ ("n_tokens", ctypes.c_size_t)
119
+ ]
120
+
121
+ class RKLLMTokenInput(ctypes.Structure):
122
+ input_ids: ctypes.POINTER(ctypes.c_int32)
123
+ n_tokens: ctypes.c_size_t
124
+
125
+ _fields_ = [
126
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
127
+ ("n_tokens", ctypes.c_size_t)
128
+ ]
129
+
130
+ class RKLLMMultiModelInput(ctypes.Structure):
131
+ prompt: ctypes.c_char_p
132
+ image_embed: ctypes.POINTER(ctypes.c_float)
133
+ n_image_tokens: ctypes.c_size_t
134
+ n_image: ctypes.c_size_t
135
+ image_width: ctypes.c_size_t
136
+ image_height: ctypes.c_size_t
137
+
138
+ _fields_ = [
139
+ ("prompt", ctypes.c_char_p),
140
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
141
+ ("n_image_tokens", ctypes.c_size_t),
142
+ ("n_image", ctypes.c_size_t),
143
+ ("image_width", ctypes.c_size_t),
144
+ ("image_height", ctypes.c_size_t)
145
+ ]
146
+
147
+ class RKLLMCrossAttnParam(ctypes.Structure):
148
+ """
149
+ 交叉注意力参数结构体
150
+
151
+ 该结构体用于在解码器中执行交叉注意力时使用。
152
+ 它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
153
+
154
+ - encoder_k_cache必须存储在连续内存中,布局为:
155
+ [num_layers][num_tokens][num_kv_heads][head_dim]
156
+ - encoder_v_cache必须存储在连续内存中,布局为:
157
+ [num_layers][num_kv_heads][head_dim][num_tokens]
158
+ """
159
+ encoder_k_cache: ctypes.POINTER(ctypes.c_float)
160
+ encoder_v_cache: ctypes.POINTER(ctypes.c_float)
161
+ encoder_mask: ctypes.POINTER(ctypes.c_float)
162
+ encoder_pos: ctypes.POINTER(ctypes.c_int32)
163
+ num_tokens: ctypes.c_int
164
+
165
+ _fields_ = [
166
+ ("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
167
+ ("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
168
+ ("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
169
+ ("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
170
+ ("num_tokens", ctypes.c_int) # 编码器序列中的token数量
171
+ ]
172
+
173
+ class RKLLMPerfStat(ctypes.Structure):
174
+ """
175
+ 性能统计结构体
176
+
177
+ 用于保存预填充和生成阶段的性能统计信息。
178
+ """
179
+ prefill_time_ms: ctypes.c_float
180
+ prefill_tokens: ctypes.c_int
181
+ generate_time_ms: ctypes.c_float
182
+ generate_tokens: ctypes.c_int
183
+ memory_usage_mb: ctypes.c_float
184
+
185
+ _fields_ = [
186
+ ("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
187
+ ("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
188
+ ("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
189
+ ("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
190
+ ("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
191
+ ]
192
+
193
+ class _RKLLMInputUnion(ctypes.Union):
194
+ prompt_input: ctypes.c_char_p
195
+ embed_input: RKLLMEmbedInput
196
+ token_input: RKLLMTokenInput
197
+ multimodal_input: RKLLMMultiModelInput
198
+
199
+ _fields_ = [
200
+ ("prompt_input", ctypes.c_char_p),
201
+ ("embed_input", RKLLMEmbedInput),
202
+ ("token_input", RKLLMTokenInput),
203
+ ("multimodal_input", RKLLMMultiModelInput)
204
+ ]
205
+
206
+ class RKLLMInput(ctypes.Structure):
207
+ """
208
+ LLM输入结构体
209
+
210
+ 通过联合体表示不同类型的LLM输入。
211
+ """
212
+ role: ctypes.c_char_p
213
+ enable_thinking: ctypes.c_bool
214
+ input_type: ctypes.c_int
215
+ _union_data: _RKLLMInputUnion
216
+
217
+ _fields_ = [
218
+ ("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
219
+ ("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
220
+ ("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
221
+ ("_union_data", _RKLLMInputUnion) # 联合体数据
222
+ ]
223
+ # Properties to make accessing union members easier
224
+ @property
225
+ def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
226
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
227
+ return self._union_data.prompt_input
228
+ raise AttributeError("Not a prompt input")
229
+ @prompt_input.setter
230
+ def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
231
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
232
+ self._union_data.prompt_input = value
233
+ else:
234
+ raise AttributeError("Not a prompt input")
235
+ @property
236
+ def embed_input(self) -> RKLLMEmbedInput:
237
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
238
+ return self._union_data.embed_input
239
+ raise AttributeError("Not an embed input")
240
+ @embed_input.setter
241
+ def embed_input(self, value: RKLLMEmbedInput):
242
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
243
+ self._union_data.embed_input = value
244
+ else:
245
+ raise AttributeError("Not an embed input")
246
+
247
+ @property
248
+ def token_input(self) -> RKLLMTokenInput:
249
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
250
+ return self._union_data.token_input
251
+ raise AttributeError("Not a token input")
252
+ @token_input.setter
253
+ def token_input(self, value: RKLLMTokenInput):
254
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
255
+ self._union_data.token_input = value
256
+ else:
257
+ raise AttributeError("Not a token input")
258
+
259
+ @property
260
+ def multimodal_input(self) -> RKLLMMultiModelInput:
261
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
262
+ return self._union_data.multimodal_input
263
+ raise AttributeError("Not a multimodal input")
264
+ @multimodal_input.setter
265
+ def multimodal_input(self, value: RKLLMMultiModelInput):
266
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
267
+ self._union_data.multimodal_input = value
268
+ else:
269
+ raise AttributeError("Not a multimodal input")
270
+
271
+ class RKLLMLoraParam(ctypes.Structure): # For inference
272
+ lora_adapter_name: ctypes.c_char_p
273
+
274
+ _fields_ = [
275
+ ("lora_adapter_name", ctypes.c_char_p)
276
+ ]
277
+
278
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
279
+ save_prompt_cache: ctypes.c_int # bool-like
280
+ prompt_cache_path: ctypes.c_char_p
281
+
282
+ _fields_ = [
283
+ ("save_prompt_cache", ctypes.c_int), # bool-like
284
+ ("prompt_cache_path", ctypes.c_char_p)
285
+ ]
286
+
287
+ class RKLLMInferParam(ctypes.Structure):
288
+ mode: ctypes.c_int
289
+ lora_params: ctypes.POINTER(RKLLMLoraParam)
290
+ prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
291
+ keep_history: ctypes.c_int # bool-like
292
+
293
+ _fields_ = [
294
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
295
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
296
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
297
+ ("keep_history", ctypes.c_int) # bool-like
298
+ ]
299
+
300
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
301
+ hidden_states: ctypes.POINTER(ctypes.c_float)
302
+ embd_size: ctypes.c_int
303
+ num_tokens: ctypes.c_int
304
+
305
+ _fields_ = [
306
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
307
+ ("embd_size", ctypes.c_int),
308
+ ("num_tokens", ctypes.c_int)
309
+ ]
310
+
311
+ class RKLLMResultLogits(ctypes.Structure):
312
+ logits: ctypes.POINTER(ctypes.c_float)
313
+ vocab_size: ctypes.c_int
314
+ num_tokens: ctypes.c_int
315
+
316
+ _fields_ = [
317
+ ("logits", ctypes.POINTER(ctypes.c_float)),
318
+ ("vocab_size", ctypes.c_int),
319
+ ("num_tokens", ctypes.c_int)
320
+ ]
321
+
322
+ class RKLLMResult(ctypes.Structure):
323
+ """
324
+ LLM推理结果结构体
325
+
326
+ 表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
327
+ """
328
+ text: ctypes.c_char_p
329
+ token_id: ctypes.c_int32
330
+ last_hidden_layer: RKLLMResultLastHiddenLayer
331
+ logits: RKLLMResultLogits
332
+ perf: RKLLMPerfStat
333
+
334
+ _fields_ = [
335
+ ("text", ctypes.c_char_p), # 生成的文本结果
336
+ ("token_id", ctypes.c_int32), # 生成的token ID
337
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
338
+ ("logits", RKLLMResultLogits), # 模型输出的logits
339
+ ("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
340
+ ]
341
+
342
+ # --- Typedefs ---
343
+ LLMHandle = ctypes.c_void_p
344
+
345
+ # --- Callback Function Type ---
346
+ LLMResultCallback = ctypes.CFUNCTYPE(
347
+ ctypes.c_int, # 返回类型:int,表示处理状态
348
+ ctypes.POINTER(RKLLMResult), # LLM结果指针
349
+ ctypes.c_void_p, # 用户数据指针
350
+ ctypes.c_int # LLM调用状态(LLMCallState枚举值)
351
+ )
352
+ """
353
+ 回调函数类型定义
354
+
355
+ 用于处理LLM结果的回调函数。
356
+
357
+ 参数:
358
+ - result: 指向LLM结果的指针
359
+ - userdata: 回调的用户数据指针
360
+ - state: LLM调用状态(例如:完成、错误)
361
+
362
+ 返回值:
363
+ - 0: 正常继续推理
364
+ - 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
365
+ 返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
366
+ """
367
+
368
+ class RKLLMRuntime:
369
+ def __init__(self, library_path="./librkllmrt.so"):
370
+ try:
371
+ self.lib = ctypes.CDLL(library_path)
372
+ except OSError as e:
373
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
374
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
375
+ self._setup_functions()
376
+ self.llm_handle = LLMHandle()
377
+ self._c_callback = None # To keep the callback object alive
378
+ self._user_callback = None
379
+
380
+ def _setup_functions(self):
381
+ # RKLLMParam rkllm_createDefaultParam();
382
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
383
+ self.lib.rkllm_createDefaultParam.argtypes = []
384
+
385
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
386
+ self.lib.rkllm_init.restype = ctypes.c_int
387
+ self.lib.rkllm_init.argtypes = [
388
+ ctypes.POINTER(LLMHandle),
389
+ ctypes.POINTER(RKLLMParam),
390
+ LLMResultCallback
391
+ ]
392
+
393
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
394
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
395
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
396
+
397
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
398
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
399
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
400
+
401
+ # int rkllm_release_prompt_cache(LLMHandle handle);
402
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
403
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
404
+
405
+ # int rkllm_destroy(LLMHandle handle);
406
+ self.lib.rkllm_destroy.restype = ctypes.c_int
407
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
408
+
409
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
410
+ self.lib.rkllm_run.restype = ctypes.c_int
411
+ self.lib.rkllm_run.argtypes = [
412
+ LLMHandle,
413
+ ctypes.POINTER(RKLLMInput),
414
+ ctypes.POINTER(RKLLMInferParam),
415
+ ctypes.c_void_p # userdata
416
+ ]
417
+
418
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
419
+ # Assuming async also takes userdata for the callback context
420
+ self.lib.rkllm_run_async.restype = ctypes.c_int
421
+ self.lib.rkllm_run_async.argtypes = [
422
+ LLMHandle,
423
+ ctypes.POINTER(RKLLMInput),
424
+ ctypes.POINTER(RKLLMInferParam),
425
+ ctypes.c_void_p # userdata
426
+ ]
427
+
428
+ # int rkllm_abort(LLMHandle handle);
429
+ self.lib.rkllm_abort.restype = ctypes.c_int
430
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
431
+
432
+ # int rkllm_is_running(LLMHandle handle);
433
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
434
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
435
+
436
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
437
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
438
+ self.lib.rkllm_clear_kv_cache.argtypes = [
439
+ LLMHandle,
440
+ ctypes.c_int,
441
+ ctypes.POINTER(ctypes.c_int), # start_pos
442
+ ctypes.POINTER(ctypes.c_int) # end_pos
443
+ ]
444
+
445
+ # int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
446
+ self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
447
+ self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
448
+
449
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
450
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
451
+ self.lib.rkllm_set_chat_template.argtypes = [
452
+ LLMHandle,
453
+ ctypes.c_char_p,
454
+ ctypes.c_char_p,
455
+ ctypes.c_char_p
456
+ ]
457
+
458
+ # int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
459
+ self.lib.rkllm_set_function_tools.restype = ctypes.c_int
460
+ self.lib.rkllm_set_function_tools.argtypes = [
461
+ LLMHandle,
462
+ ctypes.c_char_p, # system_prompt
463
+ ctypes.c_char_p, # tools
464
+ ctypes.c_char_p # tool_response_str
465
+ ]
466
+
467
+ # int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
468
+ self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
469
+ self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
470
+
471
+ def create_default_param(self) -> RKLLMParam:
472
+ """Creates a default RKLLMParam structure."""
473
+ return self.lib.rkllm_createDefaultParam()
474
+
475
+ def init(self, param: RKLLMParam, callback_func) -> int:
476
+ """
477
+ Initializes the LLM.
478
+ :param param: RKLLMParam structure.
479
+ :param callback_func: A Python function that matches the signature:
480
+ def my_callback(result_ptr, userdata_ptr, state_enum):
481
+ result = result_ptr.contents # RKLLMResult
482
+ # Process result
483
+ # userdata can be retrieved if passed during run, or ignored
484
+ # state = LLMCallState(state_enum)
485
+ :return: 0 for success, non-zero for failure.
486
+ """
487
+ if not callable(callback_func):
488
+ raise ValueError("callback_func must be a callable Python function.")
489
+
490
+ self._user_callback = callback_func
491
+
492
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected.
493
+ # Always register a trampoline so we can swap the Python-level handler when needed.
494
+ self._c_callback = LLMResultCallback(self._callback_trampoline)
495
+
496
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
497
+ if ret != 0:
498
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
499
+ return ret
500
+
501
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
502
+ """Loads a Lora adapter."""
503
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
504
+ if ret != 0:
505
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
506
+ return ret
507
+
508
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
509
+ """Loads a prompt cache from a file."""
510
+ c_path = prompt_cache_path.encode('utf-8')
511
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
512
+ if ret != 0:
513
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
514
+ return ret
515
+
516
+ def release_prompt_cache(self) -> int:
517
+ """Releases the prompt cache from memory."""
518
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
519
+ if ret != 0:
520
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
521
+ return ret
522
+
523
+ def destroy(self) -> int:
524
+ """Destroys the LLM instance and releases resources."""
525
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
526
+ ret = self.lib.rkllm_destroy(self.llm_handle)
527
+ self.llm_handle = LLMHandle() # Reset handle
528
+ if ret != 0:
529
+ # Don't raise here as it might be called in __del__
530
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
531
+ return ret
532
+ return 0 # Already destroyed or not initialized
533
+
534
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
535
+ """Runs an LLM inference task synchronously."""
536
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
537
+ # then cast to c_void_p. Or simply None.
538
+ if userdata is not None:
539
+ # Store the userdata object to keep it alive during the call
540
+ self._userdata_ref = userdata
541
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
542
+ else:
543
+ c_userdata = None
544
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
545
+ if ret != 0:
546
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
547
+ return ret
548
+
549
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
550
+ """Runs an LLM inference task asynchronously."""
551
+ if userdata is not None:
552
+ # Store the userdata object to keep it alive during the call
553
+ self._userdata_ref = userdata
554
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
555
+ else:
556
+ c_userdata = None
557
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
558
+ if ret != 0:
559
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
560
+ return ret
561
+
562
+ def abort(self) -> int:
563
+ """Aborts an ongoing LLM task."""
564
+ ret = self.lib.rkllm_abort(self.llm_handle)
565
+ if ret != 0:
566
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
567
+ return ret
568
+
569
+ def is_running(self) -> bool:
570
+ """Checks if an LLM task is currently running. Returns True if running."""
571
+ # The C API returns 0 if running, non-zero otherwise.
572
+ # This is a bit counter-intuitive for a boolean "is_running".
573
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
574
+
575
+ def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
576
+ """
577
+ 清除键值缓存
578
+
579
+ 此函数用于清除部分或全部KV缓存。
580
+
581
+ 参数:
582
+ - keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
583
+ 如果提供了特定范围[start_pos, end_pos),此标志将被忽略
584
+ - start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
585
+ - end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
586
+ 如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
587
+ 如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
588
+
589
+ 注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
590
+
591
+ 返回:0表示缓存清除成功,非零表示失败
592
+ """
593
+ # 准备C数组参数
594
+ c_start_pos = None
595
+ c_end_pos = None
596
+
597
+ if start_pos is not None and end_pos is not None:
598
+ if len(start_pos) != len(end_pos):
599
+ raise ValueError("start_pos和end_pos数组长度必须相同")
600
+
601
+ # 创建C数组
602
+ c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
603
+ c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
604
+
605
+ ret = self.lib.rkllm_clear_kv_cache(
606
+ self.llm_handle,
607
+ ctypes.c_int(1 if keep_system_prompt else 0),
608
+ c_start_pos,
609
+ c_end_pos
610
+ )
611
+ if ret != 0:
612
+ raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
613
+ return ret
614
+
615
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
616
+ """Sets the chat template for the LLM."""
617
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
618
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
619
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
620
+
621
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
622
+ if ret != 0:
623
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
624
+ return ret
625
+
626
+ def get_kv_cache_size(self, n_batch: int) -> list:
627
+ """
628
+ 获取给定LLM句柄的键值缓存当前大小
629
+
630
+ 此函数返回当前存储在模型KV缓存中的位置总数。
631
+
632
+ 参数:
633
+ - n_batch: 批次数量,用于确定返回数组的大小
634
+
635
+ 返回:
636
+ - list: 每个批次的缓存大小列表
637
+ """
638
+ # 预分配数组以存储每个批次的缓存大小
639
+ cache_sizes = (ctypes.c_int * n_batch)()
640
+
641
+ ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
642
+ if ret != 0:
643
+ raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
644
+
645
+ # 转换为Python列表
646
+ return [cache_sizes[i] for i in range(n_batch)]
647
+
648
+ def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
649
+ """
650
+ 为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
651
+
652
+ 参数:
653
+ - system_prompt: 定义语言模型上下文或行为的系统提示
654
+ - tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
655
+ - tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
656
+ 允许分词器将工具输出与正常对话轮次分开识别
657
+
658
+ 返回:0表示配置设置成功,非零表示错误
659
+ """
660
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
661
+ c_tools = tools.encode('utf-8') if tools else b""
662
+ c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
663
+
664
+ ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
665
+ if ret != 0:
666
+ raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
667
+ return ret
668
+
669
+ def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
670
+ """
671
+ 为LLM解码器设置交叉注意力参数
672
+
673
+ 参数:
674
+ - cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
675
+ (详见RKLLMCrossAttnParam说明)
676
+
677
+ 返回:0表示参数设置成功,非零表示错误
678
+ """
679
+ ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
680
+ if ret != 0:
681
+ raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
682
+ return ret
683
+
684
+ def __enter__(self):
685
+ return self
686
+
687
+ def __exit__(self, exc_type, exc_val, exc_tb):
688
+ self.destroy()
689
+
690
+ def __del__(self):
691
+ self.destroy() # Ensure resources are freed if object is garbage collected
692
+
693
+ def _callback_trampoline(self, result_ptr, userdata_ptr, state_enum):
694
+ """
695
+ Bridge callback that forwards to the currently active Python handler.
696
+ This keeps the C callback pointer stable while allowing per-call overrides.
697
+ """
698
+ handler = self._user_callback
699
+ if handler is None:
700
+ return 0
701
+ try:
702
+ return handler(result_ptr, userdata_ptr, state_enum)
703
+ except Exception as exc:
704
+ # Avoid propagating exceptions through the C callback boundary.
705
+ print(f"[rkllm_binding] Callback raised an exception: {exc}")
706
+ return 0
707
+
708
+ def forward_embed(
709
+ self,
710
+ embeds: np.ndarray,
711
+ *,
712
+ keep_history: bool = False,
713
+ timeout: Optional[float] = None,
714
+ return_last_only: bool = False,
715
+ ) -> np.ndarray:
716
+ """
717
+ Run a single forward pass with embedding input and return the last hidden layer.
718
+
719
+ Args:
720
+ embeds: Float32 embeddings shaped (T, H) or (1, T, H). Batch>1 is not supported.
721
+ keep_history: When False, KV cache will be cleared after the call. When True,
722
+ cache is kept; call clear_kv_cache() manually if needed.
723
+ timeout: Optional timeout (seconds) for waiting on the callback.
724
+ return_last_only: If True, return the last token vector shape (H,).
725
+
726
+ Returns:
727
+ np.ndarray containing hidden states (T, H) or the last token (H,).
728
+ """
729
+ if embeds is None:
730
+ raise ValueError("embeds must not be None.")
731
+
732
+ np_embeds = np.asarray(embeds, dtype=np.float32)
733
+ if np_embeds.ndim == 3:
734
+ if np_embeds.shape[0] != 1:
735
+ raise ValueError("Only batch size 1 is supported for forward_embed.")
736
+ num_tokens = np_embeds.shape[1]
737
+ flat = np_embeds.reshape(-1)
738
+ elif np_embeds.ndim == 2:
739
+ num_tokens = np_embeds.shape[0]
740
+ flat = np_embeds.reshape(-1)
741
+ else:
742
+ raise ValueError("embeds must have shape (T, H) or (1, T, H).")
743
+
744
+ flat = np.ascontiguousarray(flat, dtype=np.float32)
745
+ embed_buffer = (ctypes.c_float * flat.size)(*flat)
746
+
747
+ rk_input = RKLLMInput()
748
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED
749
+ embed_input = RKLLMEmbedInput()
750
+ embed_input.embed = embed_buffer
751
+ embed_input.n_tokens = num_tokens
752
+ rk_input._union_data.embed_input = embed_input
753
+
754
+ infer_params = RKLLMInferParam()
755
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER
756
+ infer_params.keep_history = 1 if keep_history else 0
757
+ infer_params.lora_params = None
758
+ infer_params.prompt_cache_params = None
759
+
760
+ done = threading.Event()
761
+ result_holder = {"hidden": None, "error": None}
762
+
763
+ def _capture_hidden(result_ptr, userdata_ptr, state_enum):
764
+ state = LLMCallState(state_enum)
765
+ if state == LLMCallState.RKLLM_RUN_ERROR:
766
+ result_holder["error"] = "RKLLM reported an error state."
767
+ done.set()
768
+ return 0
769
+
770
+ if not result_ptr:
771
+ result_holder["error"] = "Empty result pointer received."
772
+ done.set()
773
+ return 0
774
+
775
+ result = result_ptr.contents
776
+ if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
777
+ hidden = np.ctypeslib.as_array(
778
+ result.last_hidden_layer.hidden_states,
779
+ shape=(1, result.last_hidden_layer.num_tokens, result.last_hidden_layer.embd_size),
780
+ ).copy()
781
+ result_holder["hidden"] = hidden[-1].copy() if return_last_only else hidden
782
+ done.set()
783
+ return 1 # Pause further work; we already have the hidden states.
784
+
785
+ if state == LLMCallState.RKLLM_RUN_FINISH:
786
+ done.set()
787
+ return 0
788
+
789
+ previous_callback = self._user_callback
790
+ self._user_callback = _capture_hidden
791
+ try:
792
+ self.run(rk_input, infer_params)
793
+ if not done.wait(timeout):
794
+ raise TimeoutError("forward_embed timed out waiting for hidden states.")
795
+ finally:
796
+ self._user_callback = previous_callback
797
+
798
+ if result_holder["error"]:
799
+ raise RuntimeError(result_holder["error"])
800
+ if result_holder["hidden"] is None:
801
+ raise RuntimeError("forward_embed did not receive hidden states.")
802
+
803
+ try:
804
+ if not keep_history:
805
+ self.clear_kv_cache(True)
806
+ except Exception:
807
+ # Cache clearing best-effort; keep the forward result usable even if clearing fails.
808
+ pass
809
+
810
+ return result_holder["hidden"]
811
+
812
+ # --- Demo CLI ---
813
+ def _cli_parse_arguments() -> argparse.Namespace:
814
+ parser = argparse.ArgumentParser(
815
+ description="Demo application showcasing rkllm_binding usage."
816
+ )
817
+ parser.add_argument(
818
+ "model",
819
+ help="Path to the .rkllm model file used for inference."
820
+ )
821
+ parser.add_argument(
822
+ "--lib",
823
+ default="./librkllmrt.so",
824
+ help="Path to librkllmrt.so. Defaults to ./librkllmrt.so."
825
+ )
826
+
827
+ # Core generation parameters
828
+ parser.add_argument("--max-context-len", type=int, default=512, help="Maximum context length.")
829
+ parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate.")
830
+ parser.add_argument("--top-k", type=int, default=1, help="Top-K sampling parameter.")
831
+ parser.add_argument("--top-p", type=float, default=0.0, help="Top-P (nucleus) sampling parameter.")
832
+ parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
833
+ parser.add_argument("--repeat-penalty", type=float, default=1.1, help="Penalty applied to repeated tokens.")
834
+ parser.add_argument("--n-keep", type=int, default=0, help="Number of tokens to keep when context slides.")
835
+ parser.add_argument("--mirostat", type=int, default=0, help="Enable Mirostat sampling (0 disables).")
836
+ parser.add_argument("--mirostat-tau", type=float, default=5.0, help="Mirostat tau parameter.")
837
+ parser.add_argument("--mirostat-eta", type=float, default=0.1, help="Mirostat eta parameter.")
838
+ parser.add_argument(
839
+ "--skip-special-token",
840
+ action="store_true",
841
+ help="Skip special tokens when generating output."
842
+ )
843
+
844
+ # Input management
845
+ parser.add_argument(
846
+ "--input-type",
847
+ choices=("prompt", "token", "multimodal"),
848
+ default="prompt",
849
+ help="Select prompt, raw token, or multimodal (image + prompt) input."
850
+ )
851
+ parser.add_argument("--prompt", help="Prompt text to send to the model.")
852
+ parser.add_argument("--prompt-file", help="Path to a UTF-8 text file containing the prompt.")
853
+ parser.add_argument(
854
+ "--token-ids",
855
+ type=int,
856
+ nargs="+",
857
+ help="Raw token IDs (space separated). Only valid when --input-type token."
858
+ )
859
+ parser.add_argument("--role", default="user", help="Role metadata for the input message (e.g., user/system).")
860
+ parser.add_argument(
861
+ "--enable-thinking",
862
+ action="store_true",
863
+ help="Enable thinking mode for supported models."
864
+ )
865
+ parser.add_argument("--image", help="Path to an image file used when --input-type multimodal.")
866
+ parser.add_argument("--vision-encoder", help="Path to the ONNX vision encoder model.")
867
+ parser.add_argument(
868
+ "--encoder-provider",
869
+ help="Comma separated ONNX Runtime providers (e.g., 'CPUExecutionProvider')."
870
+ )
871
+ parser.add_argument(
872
+ "--encoder-threads",
873
+ type=int,
874
+ help="Thread count hint for ONNX Runtime session."
875
+ )
876
+ parser.add_argument(
877
+ "--encoder-input-shape",
878
+ help="Override encoder input spatial size as HxW or H,W (e.g., 392x392)."
879
+ )
880
+ parser.add_argument(
881
+ "--norm",
882
+ choices=("imagenet", "divide_255", "divide_128_sub_1"),
883
+ default="imagenet",
884
+ help="Image normalization preset."
885
+ )
886
+ parser.add_argument(
887
+ "--norm-mean",
888
+ type=float,
889
+ nargs=3,
890
+ metavar=("R", "G", "B"),
891
+ help="Override normalization mean (RGB order)."
892
+ )
893
+ parser.add_argument(
894
+ "--norm-std",
895
+ type=float,
896
+ nargs=3,
897
+ metavar=("R", "G", "B"),
898
+ help="Override normalization std (RGB order)."
899
+ )
900
+ parser.add_argument(
901
+ "--image-background",
902
+ type=int,
903
+ nargs=3,
904
+ metavar=("R", "G", "B"),
905
+ default=(128, 128, 128),
906
+ help="Background color used when padding image to target size."
907
+ )
908
+ parser.add_argument("--img-start-token", help="Override image start token string passed to the model.")
909
+ parser.add_argument("--img-end-token", help="Override image end token string passed to the model.")
910
+ parser.add_argument("--img-content-token", help="Override image content token string passed to the model.")
911
+
912
+ # Inference options
913
+ parser.add_argument(
914
+ "--mode",
915
+ choices=("generate", "hidden", "logits"),
916
+ default="generate",
917
+ help="Inference mode: generate tokens, return last hidden layer, or logits."
918
+ )
919
+ parser.add_argument(
920
+ "--no-keep-history",
921
+ action="store_true",
922
+ help="Do not keep dialogue history on the device."
923
+ )
924
+
925
+ # Output options
926
+ parser.add_argument(
927
+ "--stream",
928
+ action="store_true",
929
+ default=True,
930
+ help="Stream tokens to stdout as they arrive from the callback."
931
+ )
932
+ parser.add_argument(
933
+ "--hide-stats",
934
+ action="store_true",
935
+ help="Suppress performance statistics after inference."
936
+ )
937
+
938
+ args = parser.parse_args()
939
+
940
+ if args.prompt and args.prompt_file:
941
+ parser.error("Arguments --prompt and --prompt-file cannot be used together.")
942
+
943
+ if args.input_type == "prompt":
944
+ if not args.prompt and not args.prompt_file:
945
+ parser.error("Provide --prompt or --prompt-file when --input-type is prompt.")
946
+ if args.token_ids:
947
+ parser.error("--token-ids is only valid when --input-type token.")
948
+ elif args.input_type == "token":
949
+ if not args.token_ids:
950
+ parser.error("--token-ids is required when --input-type token.")
951
+ if args.prompt or args.prompt_file:
952
+ parser.error("--prompt/--prompt-file cannot be combined with --input-type token.")
953
+ else: # multimodal
954
+ if args.token_ids:
955
+ parser.error("--token-ids cannot be used with --input-type multimodal.")
956
+ if not args.prompt and not args.prompt_file:
957
+ parser.error("Provide --prompt or --prompt-file when --input-type is multimodal.")
958
+ if not args.image:
959
+ parser.error("--image is required when --input-type multimodal.")
960
+ if not args.vision_encoder:
961
+ parser.error("--vision-encoder is required when --input-type multimodal.")
962
+
963
+ if args.image_background:
964
+ for component in args.image_background:
965
+ if component < 0 or component > 255:
966
+ parser.error("--image-background values must be in the range [0, 255].")
967
+
968
+ return args
969
+
970
+
971
+ def _load_prompt_from_args(args: argparse.Namespace) -> str:
972
+ if args.prompt:
973
+ return args.prompt
974
+ if args.prompt_file:
975
+ try:
976
+ with open(args.prompt_file, "r", encoding="utf-8") as fp:
977
+ return fp.read()
978
+ except OSError as exc:
979
+ raise RuntimeError(f"Failed to read prompt file '{args.prompt_file}': {exc}") from exc
980
+ raise RuntimeError("Prompt text is required but not provided.")
981
+
982
+
983
+ def _mode_to_enum(mode: str) -> int:
984
+ mapping = {
985
+ "generate": RKLLMInferMode.RKLLM_INFER_GENERATE,
986
+ "hidden": RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER,
987
+ "logits": RKLLMInferMode.RKLLM_INFER_GET_LOGITS,
988
+ }
989
+ return mapping[mode]
990
+
991
+
992
+ def _parse_hw_string(value: str) -> Tuple[int, int]:
993
+ separators = ("x", "X", ",", " ")
994
+ token = value.strip()
995
+ for sep in separators:
996
+ if sep in token:
997
+ parts = [p for p in token.split(sep) if p]
998
+ break
999
+ else:
1000
+ parts = [token]
1001
+ if len(parts) != 2:
1002
+ raise ValueError(f"Unable to parse height/width from '{value}'. Expected format like 392x392.")
1003
+ try:
1004
+ height = int(parts[0])
1005
+ width = int(parts[1])
1006
+ except ValueError as exc:
1007
+ raise ValueError(f"Height/width must be integers, got '{value}'.") from exc
1008
+ if height <= 0 or width <= 0:
1009
+ raise ValueError("Height and width must be positive integers.")
1010
+ return height, width
1011
+
1012
+
1013
+ def _infer_hw_from_onnx_shape(shape: Sequence) -> Tuple[Optional[int], Optional[int]]:
1014
+ if shape is None or len(shape) < 4:
1015
+ return None, None
1016
+ height = shape[-2]
1017
+ width = shape[-1]
1018
+ if isinstance(height, str) or height is None:
1019
+ height = None
1020
+ if isinstance(width, str) or width is None:
1021
+ width = None
1022
+ return height, width
1023
+
1024
+
1025
+ def _parse_providers(provider_str: Optional[str]) -> Optional[list]:
1026
+ if not provider_str:
1027
+ return None
1028
+ providers = [item.strip() for item in provider_str.split(",") if item.strip()]
1029
+ return providers or None
1030
+
1031
+
1032
+ def _load_vision_encoder_session(encoder_path: str, providers: Optional[list], threads: Optional[int]):
1033
+ try:
1034
+ import onnxruntime as ort
1035
+ except ImportError as exc:
1036
+ raise RuntimeError("onnxruntime is required for multimodal inference. Please install onnxruntime.") from exc
1037
+
1038
+ sess_options = ort.SessionOptions()
1039
+ if threads and threads > 0:
1040
+ sess_options.intra_op_num_threads = threads
1041
+ try:
1042
+ if providers:
1043
+ session = ort.InferenceSession(encoder_path, sess_options=sess_options, providers=providers)
1044
+ else:
1045
+ session = ort.InferenceSession(encoder_path, sess_options=sess_options)
1046
+ except Exception as exc:
1047
+ raise RuntimeError(f"Failed to load vision encoder '{encoder_path}': {exc}") from exc
1048
+ return session
1049
+
1050
+
1051
+ def _letterbox_resize(image, target_hw: Tuple[int, int], background_color: Sequence[int]):
1052
+ try:
1053
+ import cv2
1054
+ import numpy as np
1055
+ except ImportError as exc:
1056
+ raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc
1057
+
1058
+ target_h, target_w = target_hw
1059
+ if image.ndim != 3 or image.shape[2] != 3:
1060
+ raise RuntimeError("Expected RGB image with 3 channels.")
1061
+
1062
+ src_h, src_w = image.shape[:2]
1063
+ if src_h == 0 or src_w == 0:
1064
+ raise RuntimeError("Loaded image has invalid dimensions.")
1065
+
1066
+ scale = min(target_w / src_w, target_h / src_h)
1067
+ resized_w = max(1, int(round(src_w * scale)))
1068
+ resized_h = max(1, int(round(src_h * scale)))
1069
+ resized = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)
1070
+
1071
+ canvas = np.full((target_h, target_w, 3), background_color, dtype=resized.dtype)
1072
+ top = (target_h - resized_h) // 2
1073
+ left = (target_w - resized_w) // 2
1074
+ canvas[top:top + resized_h, left:left + resized_w] = resized
1075
+ return canvas, resized_h, resized_w
1076
+
1077
+
1078
+ def _normalize_image(image, method: str, mean: Optional[Sequence[float]], std: Optional[Sequence[float]]):
1079
+ import numpy as np
1080
+
1081
+ img = image.astype(np.float32)
1082
+ mean_arr = np.array(mean, dtype=np.float32) if mean else None
1083
+ std_arr = np.array(std, dtype=np.float32) if std else None
1084
+
1085
+ if method == "imagenet":
1086
+ img = img / 255.0
1087
+ if mean_arr is None:
1088
+ mean_arr = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
1089
+ if std_arr is None:
1090
+ std_arr = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
1091
+ img = (img - mean_arr) / std_arr
1092
+ elif method == "divide_255":
1093
+ img = img / 255.0
1094
+ if mean_arr is not None:
1095
+ img = img - mean_arr
1096
+ if std_arr is not None:
1097
+ img = img / std_arr
1098
+ elif method == "divide_128_sub_1":
1099
+ img = img / 128.0 - 1.0
1100
+ if mean_arr is not None:
1101
+ img = img - mean_arr
1102
+ if std_arr is not None:
1103
+ img = img / std_arr
1104
+ else:
1105
+ raise RuntimeError(f"Unsupported normalization method '{method}'.")
1106
+
1107
+ return img
1108
+
1109
+
1110
+ def _encode_image_to_embedding(
1111
+ session,
1112
+ image_path: str,
1113
+ input_name: str,
1114
+ output_name: str,
1115
+ target_hw: Tuple[int, int],
1116
+ background_color: Sequence[int],
1117
+ norm_method: str,
1118
+ norm_mean: Optional[Sequence[float]],
1119
+ norm_std: Optional[Sequence[float]]
1120
+ ):
1121
+ try:
1122
+ import cv2
1123
+ import numpy as np
1124
+ except ImportError as exc:
1125
+ raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc
1126
+
1127
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
1128
+ if image is None:
1129
+ raise RuntimeError(f"Failed to read image from '{image_path}'.")
1130
+
1131
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1132
+ padded, resized_h, resized_w = _letterbox_resize(image, target_hw, background_color)
1133
+
1134
+ normalized = _normalize_image(padded, norm_method, norm_mean, norm_std)
1135
+ tensor = np.transpose(normalized, (2, 0, 1)) # HWC -> CHW
1136
+ tensor = np.expand_dims(tensor, axis=0) # Add batch dimension
1137
+ tensor = np.ascontiguousarray(tensor, dtype=np.float32)
1138
+
1139
+ try:
1140
+ output_list = session.run([output_name], {input_name: tensor})
1141
+ except Exception as exc:
1142
+ raise RuntimeError(f"Vision encoder inference failed: {exc}") from exc
1143
+
1144
+ if not output_list:
1145
+ raise RuntimeError("Vision encoder returned no outputs.")
1146
+
1147
+ embedding = output_list[0]
1148
+ if embedding.ndim == 3:
1149
+ if embedding.shape[0] != 1:
1150
+ raise RuntimeError("Vision encoder output batch dimension must be 1 for a single image.")
1151
+ n_tokens = embedding.shape[1]
1152
+ elif embedding.ndim == 2:
1153
+ n_tokens = embedding.shape[0]
1154
+ else:
1155
+ raise RuntimeError(f"Unsupported vision encoder output shape {embedding.shape}.")
1156
+
1157
+ flat_embedding = embedding.reshape(-1).astype(np.float32, copy=False)
1158
+ flat_embedding = np.ascontiguousarray(flat_embedding)
1159
+
1160
+ return flat_embedding, n_tokens, target_hw
1161
+
1162
+ if __name__ == "__main__":
1163
+ import os
1164
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
1165
+ args = _cli_parse_arguments()
1166
+
1167
+ prompt_text = None
1168
+ if args.input_type == "prompt":
1169
+ prompt_text = _load_prompt_from_args(args)
1170
+
1171
+ token_id_array = None
1172
+ token_input_struct = None
1173
+
1174
+ generated_chunks = []
1175
+ perf_snapshot = {
1176
+ "prefill_tokens": 0,
1177
+ "prefill_time_ms": 0.0,
1178
+ "generate_tokens": 0,
1179
+ "generate_time_ms": 0.0,
1180
+ "memory_usage_mb": 0.0,
1181
+ }
1182
+
1183
+ def demo_callback(result_ptr, userdata_ptr, state_enum):
1184
+ state = LLMCallState(state_enum)
1185
+ result = result_ptr.contents
1186
+
1187
+ current_text = ""
1188
+ if result.text:
1189
+ current_text = result.text.decode("utf-8", errors="ignore")
1190
+ generated_chunks.append(current_text)
1191
+ if args.stream and current_text:
1192
+ print(current_text, end="", flush=True)
1193
+
1194
+ perf_snapshot.update(
1195
+ prefill_tokens=result.perf.prefill_tokens,
1196
+ prefill_time_ms=result.perf.prefill_time_ms,
1197
+ generate_tokens=result.perf.generate_tokens,
1198
+ generate_time_ms=result.perf.generate_time_ms,
1199
+ memory_usage_mb=result.perf.memory_usage_mb,
1200
+ )
1201
+
1202
+ if state == LLMCallState.RKLLM_RUN_ERROR:
1203
+ print("\n[Callback] 推理过程中出现错误。")
1204
+
1205
+ return 0
1206
+
1207
+ try:
1208
+ with RKLLMRuntime(library_path=args.lib) as rk_llm:
1209
+ params = rk_llm.create_default_param()
1210
+ params.model_path = os.path.abspath(args.model).encode("utf-8")
1211
+ params.max_context_len = args.max_context_len
1212
+ params.max_new_tokens = args.max_new_tokens
1213
+ params.top_k = args.top_k
1214
+ params.top_p = float(args.top_p)
1215
+ params.temperature = float(args.temperature)
1216
+ params.repeat_penalty = float(args.repeat_penalty)
1217
+ params.n_keep = args.n_keep
1218
+ params.mirostat = args.mirostat
1219
+ params.mirostat_tau = float(args.mirostat_tau)
1220
+ params.mirostat_eta = float(args.mirostat_eta)
1221
+ params.skip_special_token = bool(args.skip_special_token)
1222
+ params.is_async = False
1223
+
1224
+ rk_llm.init(params, demo_callback)
1225
+
1226
+ rk_input = RKLLMInput()
1227
+ rk_input.role = args.role.encode("utf-8")
1228
+ rk_input.enable_thinking = bool(args.enable_thinking)
1229
+
1230
+ if args.input_type == "prompt":
1231
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
1232
+ rk_input._union_data.prompt_input = prompt_text.encode("utf-8")
1233
+ else:
1234
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_TOKEN
1235
+ token_id_array = (ctypes.c_int32 * len(args.token_ids))(*args.token_ids)
1236
+ token_input_struct = RKLLMTokenInput()
1237
+ token_input_struct.input_ids = token_id_array
1238
+ token_input_struct.n_tokens = len(args.token_ids)
1239
+ rk_input._union_data.token_input = token_input_struct
1240
+
1241
+ infer_params = RKLLMInferParam()
1242
+ infer_params.mode = _mode_to_enum(args.mode)
1243
+ infer_params.keep_history = 0 if args.no_keep_history else 1
1244
+ infer_params.lora_params = None
1245
+ infer_params.prompt_cache_params = None
1246
+
1247
+ if args.stream:
1248
+ print("=== Streaming Output ===")
1249
+
1250
+ rk_llm.run(rk_input, infer_params)
1251
+
1252
+ except OSError as exc:
1253
+ print(f"无法加载 RKLLM 运行时库:{exc}")
1254
+ except RuntimeError as exc:
1255
+ print(f"推理失败:{exc}")
1256
+ except Exception as exc:
1257
+ print(f"发生未预期的错误:{exc}")
1258
+ else:
1259
+ if args.stream:
1260
+ print() # Ensure newline after streaming output
1261
+
1262
+ final_text = "".join(generated_chunks)
1263
+ if final_text:
1264
+ print("=== 生成结果 ===")
1265
+ print(final_text)
1266
+ else:
1267
+ print("未收到生成文本。")
1268
+
1269
+ if not args.hide_stats:
1270
+ print("=== 性能统计 ===")
1271
+ print(
1272
+ f"预填充: {perf_snapshot['prefill_tokens']} tokens / {perf_snapshot['prefill_time_ms']:.2f} ms"
1273
+ )
1274
+ print(
1275
+ f"生成: {perf_snapshot['generate_tokens']} tokens / {perf_snapshot['generate_time_ms']:.2f} ms"
1276
+ )
1277
+ print(f"最大常驻内存: {perf_snapshot['memory_usage_mb']:.2f} MB")
rknn_output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d58c10f7fa051c0181f4752c14433610c9af9577ebe3306e750865c7a76a2dfa
3
+ size 320044
stop_head.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72d604bcb8531a27d8a2d678d6a51839eff0217cf0ccfd722e61f18f3be9f782
3
+ size 2156714
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "101": {
30
+ "content": "<|audio_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "102": {
38
+ "content": "<|audio_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "103": {
46
+ "content": "<|audio_prompt_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "104": {
54
+ "content": "<|audio_prompt_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "105": {
62
+ "content": "<|background|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "106": {
70
+ "content": "<|/background|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "107": {
78
+ "content": "<|characters|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "108": {
86
+ "content": "<|/characters|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "109": {
94
+ "content": "<|speaker_id|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "110": {
102
+ "content": "<|/speaker_id|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "111": {
110
+ "content": "<|span|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "112": {
118
+ "content": "<|/span|>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": true
124
+ },
125
+ "73440": {
126
+ "content": "<|im_end|>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": true
132
+ },
133
+ "73441": {
134
+ "content": "<|im_start|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": true
140
+ },
141
+ "73442": {
142
+ "content": "<|tool_call|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": true
148
+ },
149
+ "73443": {
150
+ "content": "<|execute_start|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": true
156
+ },
157
+ "73444": {
158
+ "content": "<|execute_end|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": true
164
+ },
165
+ "73445": {
166
+ "content": "<|fim_prefix|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": true
172
+ },
173
+ "73446": {
174
+ "content": "<|fim_middle|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": true
180
+ },
181
+ "73447": {
182
+ "content": "<|fim_suffix|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ }
189
+ },
190
+ "additional_special_tokens": [
191
+ "<|im_end|>",
192
+ "<|im_start|>",
193
+ "<|tool_call|>",
194
+ "<|execute_start|>",
195
+ "<|execute_end|>",
196
+ "<|fim_prefix|>",
197
+ "<|fim_middle|>",
198
+ "<|fim_suffix|>"
199
+ ],
200
+ "bos_token": "<s>",
201
+ "clean_up_tokenization_spaces": false,
202
+ "eos_token": "<|im_end|>",
203
+ "legacy": true,
204
+ "model_max_length": 1000000000000000019884624838656,
205
+ "pad_token": null,
206
+ "sp_model_kwargs": {},
207
+ "spaces_between_special_tokens": false,
208
+ "tokenizer_class": "LlamaTokenizer",
209
+ "unk_token": "<unk>",
210
+ "use_default_system_prompt": false,
211
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
212
+ }
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,1334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional, Tuple
8
+ import ctypes
9
+
10
+ try:
11
+ import onnxruntime as ort
12
+ HAS_ORT = True
13
+ except ImportError:
14
+ HAS_ORT = False
15
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
16
+
17
+ # 常量定义
18
+ RKNN_MAX_DIMS = 16
19
+ RKNN_MAX_NAME_LEN = 256
20
+ RKNN_MAX_DYNAMIC_SHAPE_NUM = 512
21
+ RKNN_QUERY_CUSTOM_STRING = 7
22
+ RKNN_QUERY_INPUT_DYNAMIC_RANGE = 13
23
+
24
+ class RKNNInputRange(ctypes.Structure):
25
+ _fields_ = [
26
+ ('index', ctypes.c_uint32),
27
+ ('shape_number', ctypes.c_uint32),
28
+ ('fmt', ctypes.c_int),
29
+ ('name', ctypes.c_char * RKNN_MAX_NAME_LEN),
30
+ ('dyn_range', (ctypes.c_uint32 * RKNN_MAX_DIMS) * RKNN_MAX_DYNAMIC_SHAPE_NUM),
31
+ ('n_dims', ctypes.c_uint32)
32
+ ]
33
+
34
+ class RKNNCustomString(ctypes.Structure):
35
+ _fields_ = [
36
+ ('string', ctypes.c_char * 1024)
37
+ ]
38
+
39
+ # 支持的平台列表
40
+ SUPPORTED_PLATFORMS = [
41
+ "rv1103", "rv1103b", "rv1106", "rv1106b",
42
+ "rk2118", "rk3562", "rk3566", "rk3568", "rk3576", "rk3588"
43
+ ]
44
+
45
+ def get_current_platform() -> Optional[str]:
46
+ """
47
+ 获取当前运行平台
48
+
49
+ Returns:
50
+ Optional[str]: 平台名称,如果文件不存在返回None
51
+
52
+ Raises:
53
+ RuntimeError: 如果平台不在支持列表中
54
+ """
55
+ platform_file = "/proc/device-tree/compatible"
56
+
57
+ if not os.path.exists(platform_file):
58
+ logger.debug(f"平台信息文件不存在: {platform_file}")
59
+ return None
60
+
61
+ try:
62
+ with open(platform_file, 'r') as f:
63
+ content = f.read()
64
+
65
+ # 使用正则匹配r[kv]\d{4}[b]?格式
66
+ match = re.search(r'r[kv]\d{4}[b]?', content.lower())
67
+ if not match:
68
+ raise RuntimeError(f"无法从{platform_file}解析平台信息")
69
+
70
+ platform = match.group()
71
+ if platform not in SUPPORTED_PLATFORMS:
72
+ raise RuntimeError(f"不支持的平台: {platform}")
73
+
74
+ logger.debug(f"当前平台: {platform}")
75
+ return platform
76
+
77
+ except Exception as e:
78
+ if not isinstance(e, RuntimeError):
79
+ logger.error(f"读取平台信息时发生错误: {str(e)}")
80
+ raise RuntimeError(f"读取平台信息失败: {str(e)}")
81
+ raise
82
+
83
+ # 配置日志
84
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
85
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
86
+ if not logger.handlers:
87
+ handler = logging.StreamHandler()
88
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
89
+ logger.addHandler(handler)
90
+
91
+ _TRUE_VALUES = {"1", "true", "yes", "y", "on", "t"}
92
+ _FALSE_VALUES = {"0", "false", "no", "n", "off", "f"}
93
+
94
+ def _get_env_bool(name: str, default: bool = False) -> bool:
95
+ """
96
+ 解析环境变量布尔值,支持1/0/true/false/yes/no/y/n/on/off.
97
+ """
98
+ raw = os.getenv(name)
99
+ if raw is None:
100
+ return default
101
+ value = raw.strip().lower()
102
+ if value in _TRUE_VALUES:
103
+ return True
104
+ if value in _FALSE_VALUES:
105
+ return False
106
+ logger.warning(f"环境变量{name}的值无效: {raw}, 应该是1/0或true/false/yes/no/y/n/on/off")
107
+ return default
108
+
109
+ def _parse_model_list(list_str: str) -> Tuple[bool, List[str]]:
110
+ """
111
+ 解析模型列表,支持以^开头表示反选.
112
+ 返回(是否反选, 过滤后的模型列表).
113
+ """
114
+ if not list_str:
115
+ return False, []
116
+ items = [item.strip() for item in list_str.split(',') if item.strip()]
117
+ if not items:
118
+ return False, []
119
+ negated = any(item.startswith('^') for item in items)
120
+ cleaned = []
121
+ for item in items:
122
+ if item.startswith('^'):
123
+ item = item[1:].strip()
124
+ if item:
125
+ cleaned.append(item)
126
+ return negated, cleaned
127
+
128
+ # ONNX Runtime日志级别到Python logging级别的映射
129
+ _LOGGING_LEVEL_MAP = {
130
+ 0: logging.DEBUG, # Verbose
131
+ 1: logging.INFO, # Info
132
+ 2: logging.WARNING, # Warning
133
+ 3: logging.ERROR, # Error
134
+ 4: logging.CRITICAL # Fatal
135
+ }
136
+
137
+ # 检查环境变量中的日志级别设置
138
+ try:
139
+ env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
140
+ if env_log_level is not None:
141
+ log_level = int(env_log_level)
142
+ if log_level in _LOGGING_LEVEL_MAP:
143
+ logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
144
+ logger.info(f"从环境变量设置日志级别: {log_level}")
145
+ else:
146
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
147
+ except ValueError:
148
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
149
+
150
+
151
+ def set_default_logger_severity(level: int) -> None:
152
+ """
153
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
154
+
155
+ Args:
156
+ level: 日志级别(0-4)
157
+ """
158
+ if level not in _LOGGING_LEVEL_MAP:
159
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
160
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
161
+
162
+ def set_default_logger_verbosity(level: int) -> None:
163
+ """
164
+ Sets the default logging verbosity level. To activate the verbose log,
165
+ you need to set the default logging severity to 0:Verbose level.
166
+
167
+ Args:
168
+ level: 日志级别(0-4)
169
+ """
170
+ set_default_logger_severity(level)
171
+
172
+ # RKNN tensor type到numpy dtype的映射
173
+ RKNN_DTYPE_MAP = {
174
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
175
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
176
+ 2: np.int8, # RKNN_TENSOR_INT8
177
+ 3: np.uint8, # RKNN_TENSOR_UINT8
178
+ 4: np.int16, # RKNN_TENSOR_INT16
179
+ 5: np.uint16, # RKNN_TENSOR_UINT16
180
+ 6: np.int32, # RKNN_TENSOR_INT32
181
+ 7: np.uint32, # RKNN_TENSOR_UINT32
182
+ 8: np.int64, # RKNN_TENSOR_INT64
183
+ 9: bool, # RKNN_TENSOR_BOOL
184
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
185
+ }
186
+
187
+ def get_available_providers() -> List[str]:
188
+ """
189
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
190
+
191
+ Returns:
192
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
193
+ """
194
+ return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
195
+
196
+
197
+ def get_device() -> str:
198
+ """
199
+ 获取当前设备
200
+
201
+ Returns:
202
+ str: 当前设备
203
+ """
204
+ return "RKNN2"
205
+
206
+ def get_version_info() -> Dict[str, str]:
207
+ """
208
+ 获取版本信息
209
+
210
+ Returns:
211
+ dict: 包含API和驱动版本信息的字典
212
+ """
213
+ runtime = RKNNLite()
214
+ version = runtime.get_sdk_version()
215
+ return {
216
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
217
+ "driver_version": version.split('\n')[3].split(': ')[1]
218
+ }
219
+
220
+ class IOTensor:
221
+ """输入/输出张量的信息封装类"""
222
+ def __init__(self, name, shape, type=None):
223
+ self.name = name.decode() if isinstance(name, bytes) else name
224
+ self.shape = shape
225
+ self.type = type
226
+
227
+ def __str__(self):
228
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
229
+
230
+ class SessionOptions:
231
+ """会话选项类"""
232
+ def __init__(self):
233
+ self.enable_profiling = False # 是否使用性能分析
234
+ self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
235
+ self.log_severity_level = -1 # 另一个设置日志级别的参数
236
+ self.log_verbosity_level = -1 # 另一个设置日志级别的参数
237
+
238
+
239
+ class InferenceSession:
240
+ """
241
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
242
+ """
243
+
244
+ def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
245
+ processed_path = InferenceSession._process_model_path(model_path, sess_options)
246
+ if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
247
+ logger.info("使用ONNX Runtime加载模型")
248
+ if not HAS_ORT:
249
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
250
+ return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
251
+ else:
252
+ # 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
253
+ instance = super().__new__(cls)
254
+ # 保存处理后的路径
255
+ instance._processed_path = processed_path
256
+ return instance
257
+
258
+ def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
259
+ """
260
+ 初始化运行时并加载模型
261
+
262
+ Args:
263
+ model_path: 模型文件路径(.rknn或.onnx)
264
+ sess_options: 会话选项
265
+ **kwargs: 其他初始化参数
266
+ """
267
+ options = sess_options or SessionOptions()
268
+
269
+ # 只在未设置环境变量时使用SessionOptions中的日志级别
270
+ if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
271
+ if options.log_severity_level != -1:
272
+ set_default_logger_severity(options.log_severity_level)
273
+ if options.log_verbosity_level != -1:
274
+ set_default_logger_verbosity(options.log_verbosity_level)
275
+
276
+ # 使用__new__中处理好的路径
277
+ model_path = getattr(self, '_processed_path', model_path)
278
+ if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
279
+ # 避免重复加载 ONNX 模型
280
+ return
281
+
282
+ # ... 现有的 RKNN 模型加载和初始化代码 ...
283
+ self.model_path = model_path
284
+ if not os.path.exists(self.model_path):
285
+ logger.error(f"模型文件不存在: {self.model_path}")
286
+ raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
287
+
288
+ self.runtime = RKNNLite(verbose=options.enable_profiling)
289
+
290
+ logger.debug(f"正在加载模型: {self.model_path}")
291
+ ret = self.runtime.load_rknn(self.model_path)
292
+ if ret != 0:
293
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
294
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
295
+ logger.debug("模型加载成功")
296
+
297
+
298
+ if options.intra_op_num_threads == 1:
299
+ core_mask = RKNNLite.NPU_CORE_AUTO
300
+ elif options.intra_op_num_threads == 2:
301
+ core_mask = RKNNLite.NPU_CORE_0_1
302
+ elif options.intra_op_num_threads == 3:
303
+ core_mask = RKNNLite.NPU_CORE_0_1_2
304
+ else:
305
+ raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
306
+
307
+ logger.debug("正在初始化运行时环境")
308
+ ret = self.runtime.init_runtime(core_mask=core_mask)
309
+ if ret != 0:
310
+ logger.error("初始化运行时环境失败")
311
+ raise RuntimeError('初始化运行时环境失败')
312
+
313
+ logger.debug("运行时环境初始化成功")
314
+
315
+ # 在 runtime 初始化后,按环境变量自动注册自定义算子插件库
316
+ try:
317
+ # 注册用户指定路径插件(逗号/分号分隔)
318
+ env_custom = os.getenv('ZTU_MODELRT_RKNN2_REG_CUSTOM_OP_LIB', '').strip()
319
+ if env_custom:
320
+ paths = [seg.strip() for seg in re.split(r"[,;:]", env_custom) if seg.strip()]
321
+ ok = 0
322
+ for p in paths:
323
+ if self.register_custom_op_lib(p):
324
+ ok += 1
325
+ if ok > 0:
326
+ logger.info(f"已注册 {ok}/{len(paths)} 个自定义算子插件")
327
+ # 注册系统目录下插件
328
+ if _get_env_bool('ZTU_MODELRT_RKNN2_REG_SYSTEM_CUSTOM_OP_LIB', True):
329
+ cnt = self.register_system_custom_op_lib()
330
+ if cnt > 0:
331
+ logger.info(f"已从系统目录注册 {cnt} 个自定义算子插件")
332
+ except Exception as e:
333
+ logger.warning(f"自动注册自定义算子插件失败: {e}")
334
+
335
+ # 可选:按环境变量注册内置(基于Python)捆绑算子
336
+ if _get_env_bool('ZTU_MODELRT_RKNN2_REG_BUNDLED_OPS', False):
337
+ logger.info("根据环境变量注册捆绑算子")
338
+ self.register_bundled_ops()
339
+
340
+ self._init_io_info()
341
+ self.options = options
342
+
343
+ def get_performance_info(self) -> Dict[str, float]:
344
+ """
345
+ 获取性能信息
346
+
347
+ Returns:
348
+ dict: 包含性能信息的字典
349
+ """
350
+ if not self.options.perf_debug:
351
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
352
+
353
+ perf = self.runtime.rknn_runtime.get_run_perf()
354
+ return {
355
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
356
+ }
357
+
358
+ def set_core_mask(self, core_mask: int) -> None:
359
+ """
360
+ 设置NPU核心使用模式
361
+
362
+ Args:
363
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
364
+ """
365
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
366
+ if ret != 0:
367
+ raise RuntimeError("设置NPU核心模式失败")
368
+
369
+ @staticmethod
370
+ def _process_model_path(model_path, sess_options):
371
+ """
372
+ 处理模型路径,支持.onnx和.rknn文件
373
+
374
+ Args:
375
+ model_path: 模型文件路径
376
+ """
377
+ # 如果是ONNX文件,检查是否需要自动加载RKNN
378
+ if model_path.lower().endswith('.onnx'):
379
+ logger.info("检测到ONNX模型文件")
380
+
381
+ # 获取需要跳过自动加载的模型列表
382
+ skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
383
+ if skip_models:
384
+ # 获取模型文件名(不含路径)用于匹配
385
+ model_name = os.path.basename(model_path)
386
+ model_name_lower = model_name.lower()
387
+ negated, skip_list = _parse_model_list(skip_models)
388
+ if skip_list:
389
+ skip_set = {m.lower() for m in skip_list}
390
+ should_skip = model_name_lower not in skip_set if negated else model_name_lower in skip_set
391
+ else:
392
+ should_skip = False
393
+ if should_skip:
394
+ logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
395
+ return model_path
396
+
397
+ # 构造RKNN文件路径
398
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
399
+ if os.path.exists(rknn_path):
400
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
401
+ return rknn_path
402
+ else:
403
+ logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
404
+ return model_path
405
+
406
+ return model_path
407
+
408
+ def _convert_nhwc_to_nchw(self, shape):
409
+ """将NHWC格式的shape转换为NCHW格式"""
410
+ if len(shape) == 4:
411
+ # NHWC -> NCHW
412
+ n, h, w, c = shape
413
+ return [n, c, h, w]
414
+ return shape
415
+
416
+ def _init_io_info(self):
417
+ """初始化模型的输入输出信息"""
418
+ runtime = self.runtime.rknn_runtime
419
+
420
+ # 获取输入输出数量
421
+ n_input, n_output = runtime.get_in_out_num()
422
+
423
+ # 获取输入信息
424
+ self.input_tensors = []
425
+ for i in range(n_input):
426
+ attr = runtime.get_tensor_attr(i)
427
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
428
+ # 对四维输入进行NHWC到NCHW的转换
429
+ shape = self._convert_nhwc_to_nchw(shape)
430
+ # 获取dtype
431
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
432
+ tensor = IOTensor(attr.name, shape, dtype)
433
+ self.input_tensors.append(tensor)
434
+
435
+ # 获取输出信息
436
+ self.output_tensors = []
437
+ for i in range(n_output):
438
+ attr = runtime.get_tensor_attr(i, is_output=True)
439
+ shape = runtime.get_output_shape(i)
440
+ # 获取dtype
441
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
442
+ tensor = IOTensor(attr.name, shape, dtype)
443
+ self.output_tensors.append(tensor)
444
+
445
+ def get_inputs(self):
446
+ """
447
+ 获取模型输入信息
448
+
449
+ Returns:
450
+ list: 包含输入信息的列表
451
+ """
452
+ return self.input_tensors
453
+
454
+ def get_outputs(self):
455
+ """
456
+ 获取模型输出信息
457
+
458
+ Returns:
459
+ list: 包含输出信息的列表
460
+ """
461
+ return self.output_tensors
462
+
463
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
464
+ """
465
+ 执行模型推理
466
+
467
+ Args:
468
+ output_names: 输出节点名称列表,指定需要返回哪些输出
469
+ input_feed: 输入数据字典或列表
470
+ data_format: 输入数据格式,"nchw"或"nhwc"
471
+ **kwargs: 其他运行时参数
472
+
473
+ Returns:
474
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
475
+ """
476
+ if input_feed is None:
477
+ logger.error("input_feed不能为None")
478
+ raise ValueError("input_feed不能为None")
479
+
480
+ # 准备输入数据
481
+ if isinstance(input_feed, dict):
482
+ # 如果是字典,按照模型输入顺序排列
483
+ inputs = []
484
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
485
+ for tensor in self.input_tensors:
486
+ if tensor.name not in input_feed:
487
+ raise ValueError(f"缺少输入: {tensor.name}")
488
+ inputs.append(input_feed[tensor.name])
489
+ elif isinstance(input_feed, (list, tuple)):
490
+ # 如果是列表,确保长度匹配
491
+ if len(input_feed) != len(self.input_tensors):
492
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
493
+ inputs = list(input_feed)
494
+ else:
495
+ logger.error("input_feed必须是字典或列表类型")
496
+ raise ValueError("input_feed必须是字典或列表类型")
497
+
498
+ # 执行推理
499
+ try:
500
+ logger.debug("开始执行推理")
501
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
502
+
503
+ # 如果没有指定output_names,返回所有输出
504
+ if output_names is None:
505
+ return all_outputs
506
+
507
+ # 获取指定的输出
508
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
509
+ selected_outputs = []
510
+ for name in output_names:
511
+ if name not in output_map:
512
+ raise ValueError(f"未找到输出节点: {name}")
513
+ selected_outputs.append(all_outputs[output_map[name]])
514
+
515
+ return selected_outputs
516
+
517
+ except Exception as e:
518
+ logger.error(f"推理执行失败: {str(e)}")
519
+ raise RuntimeError(f"推理执行失败: {str(e)}")
520
+
521
+ def close(self):
522
+ """
523
+ 关闭会话,释放资源
524
+ """
525
+ if self.runtime is not None:
526
+ logger.info("正在释放运行时资源")
527
+ self.runtime.release()
528
+ self.runtime = None
529
+
530
+ def __enter__(self):
531
+ return self
532
+
533
+ def __exit__(self, exc_type, exc_val, exc_tb):
534
+ self.close()
535
+
536
+ def end_profiling(self) -> Optional[str]:
537
+ """
538
+ 结束性能分析的存根方法
539
+
540
+ Returns:
541
+ Optional[str]: None
542
+ """
543
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
544
+ return None
545
+
546
+ def get_profiling_start_time_ns(self) -> int:
547
+ """
548
+ 获取性能分析开始时间的存根方法
549
+
550
+ Returns:
551
+ int: 0
552
+ """
553
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
554
+ return 0
555
+
556
+ def get_modelmeta(self) -> Dict[str, str]:
557
+ """
558
+ 获取模型元数据的存根方法
559
+
560
+ Returns:
561
+ Dict[str, str]: 空字典
562
+ """
563
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
564
+ return {}
565
+
566
+ def get_session_options(self) -> SessionOptions:
567
+ """
568
+ 获取会话选项
569
+
570
+ Returns:
571
+ SessionOptions: 当前会话选项
572
+ """
573
+ return self.options
574
+
575
+ def get_providers(self) -> List[str]:
576
+ """
577
+ 获取当前使用的providers的存根方法
578
+
579
+ Returns:
580
+ List[str]: ["CPUExecutionProvider"]
581
+ """
582
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
583
+ return ["CPUExecutionProvider"]
584
+
585
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
586
+ """
587
+ 获取provider选项的存根方法
588
+
589
+ Returns:
590
+ Dict[str, Dict[str, str]]: 空字典
591
+ """
592
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
593
+ return {}
594
+
595
+ def get_session_config(self) -> Dict[str, str]:
596
+ """
597
+ 获取会话配置的存根方法
598
+
599
+ Returns:
600
+ Dict[str, str]: 空字典
601
+ """
602
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
603
+ return {}
604
+
605
+ def get_session_state(self) -> Dict[str, str]:
606
+ """
607
+ 获取会话状态的存根方法
608
+
609
+ Returns:
610
+ Dict[str, str]: 空字典
611
+ """
612
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
613
+ return {}
614
+
615
+ def set_session_config(self, config: Dict[str, str]) -> None:
616
+ """
617
+ 设置会话配置的存根方法
618
+
619
+ Args:
620
+ config: 会话配置字典
621
+ """
622
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
623
+
624
+ def get_memory_info(self) -> Dict[str, int]:
625
+ """
626
+ 获取内存使用信息的存根方法
627
+
628
+ Returns:
629
+ Dict[str, int]: 空字典
630
+ """
631
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
632
+ return {}
633
+
634
+ def set_memory_pattern(self, enable: bool) -> None:
635
+ """
636
+ 设置内存模式的存根方法
637
+
638
+ Args:
639
+ enable: 是否启用内存模式
640
+ """
641
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
642
+
643
+ def disable_memory_pattern(self) -> None:
644
+ """
645
+ 禁用内存模式的存根方法
646
+ """
647
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
648
+
649
+ def get_optimization_level(self) -> int:
650
+ """
651
+ 获取优化级别的存根方法
652
+
653
+ Returns:
654
+ int: 0
655
+ """
656
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
657
+ return 0
658
+
659
+ def set_optimization_level(self, level: int) -> None:
660
+ """
661
+ 设置优化级别的存根方法
662
+
663
+ Args:
664
+ level: 优化级别
665
+ """
666
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
667
+
668
+ def get_model_metadata(self) -> Dict[str, str]:
669
+ """
670
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
671
+
672
+ Returns:
673
+ Dict[str, str]: 空字典
674
+ """
675
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
676
+ return {}
677
+
678
+ def get_model_path(self) -> str:
679
+ """
680
+ 获取模型路径
681
+
682
+ Returns:
683
+ str: 模型文件路径
684
+ """
685
+ return self.model_path
686
+
687
+ def get_input_type_info(self) -> List[Dict[str, str]]:
688
+ """
689
+ 获取输入类型信息的存根方法
690
+
691
+ Returns:
692
+ List[Dict[str, str]]: 空列表
693
+ """
694
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
695
+ return []
696
+
697
+ def get_output_type_info(self) -> List[Dict[str, str]]:
698
+ """
699
+ 获取输出类型信息的存根方法
700
+
701
+ Returns:
702
+ List[Dict[str, str]]: 空列表
703
+ """
704
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
705
+ return []
706
+
707
+ ################### 自定义算子 ###################
708
+
709
+ def _init_custom_op_types(self):
710
+ """初始化自定义算子的类型定义"""
711
+ # 常量
712
+ self._RKNN_TENSOR_FLOAT32 = 0
713
+ self._RKNN_TENSOR_UINT8 = 3
714
+ self._RKNN_TENSOR_INT64 = 8
715
+ self._RKNN_TARGET_TYPE_CPU = 1
716
+
717
+ # 结构体定义
718
+ class RKNN_TensorAttr(ctypes.Structure):
719
+ _fields_ = [
720
+ ("index", ctypes.c_uint32),
721
+ ("n_dims", ctypes.c_uint32),
722
+ ("dims", ctypes.c_uint32 * RKNN_MAX_DIMS),
723
+ ("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
724
+ ("n_elems", ctypes.c_uint32),
725
+ ("size", ctypes.c_uint32),
726
+ ("fmt", ctypes.c_int),
727
+ ("type", ctypes.c_int),
728
+ ("qnt_type", ctypes.c_int),
729
+ ("fl", ctypes.c_int8),
730
+ ("zp", ctypes.c_int32),
731
+ ("scale", ctypes.c_float),
732
+ ("w_stride", ctypes.c_uint32),
733
+ ("size_with_stride", ctypes.c_uint32),
734
+ ("pass_through", ctypes.c_uint8),
735
+ ("h_stride", ctypes.c_uint32),
736
+ ]
737
+
738
+ class RKNN_TensorMem(ctypes.Structure):
739
+ _fields_ = [
740
+ ("virt_addr", ctypes.c_void_p),
741
+ ("phys_addr", ctypes.c_uint64),
742
+ ("fd", ctypes.c_int32),
743
+ ("offset", ctypes.c_int32),
744
+ ("size", ctypes.c_uint32),
745
+ ("flags", ctypes.c_uint32),
746
+ ("priv_data", ctypes.c_void_p),
747
+ ]
748
+
749
+ class RKNN_CustomOpTensor(ctypes.Structure):
750
+ _fields_ = [
751
+ ("attr", RKNN_TensorAttr),
752
+ ("mem", RKNN_TensorMem),
753
+ ]
754
+
755
+ class RKNN_GPUOpContext(ctypes.Structure):
756
+ _fields_ = [
757
+ ("cl_context", ctypes.c_void_p),
758
+ ("cl_command_queue", ctypes.c_void_p),
759
+ ("cl_kernel", ctypes.c_void_p),
760
+ ]
761
+
762
+ InternalCtxType = (
763
+ ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
764
+ )
765
+
766
+ class RKNN_CustomOpContext(ctypes.Structure):
767
+ _fields_ = [
768
+ ("target", ctypes.c_int),
769
+ ("internal_ctx", InternalCtxType),
770
+ ("gpu_ctx", RKNN_GPUOpContext),
771
+ ("priv_data", ctypes.c_void_p),
772
+ ]
773
+
774
+ class RKNN_CustomOpAttr(ctypes.Structure):
775
+ _fields_ = [
776
+ ("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
777
+ ("dtype", ctypes.c_int),
778
+ ("n_elems", ctypes.c_uint32),
779
+ ("data", ctypes.c_void_p),
780
+ ]
781
+
782
+ CB_SIG = ctypes.CFUNCTYPE(
783
+ ctypes.c_int,
784
+ ctypes.POINTER(RKNN_CustomOpContext),
785
+ ctypes.POINTER(RKNN_CustomOpTensor),
786
+ ctypes.c_uint32,
787
+ ctypes.POINTER(RKNN_CustomOpTensor),
788
+ ctypes.c_uint32,
789
+ )
790
+
791
+ DESTROY_SIG = ctypes.CFUNCTYPE(
792
+ ctypes.c_int, ctypes.POINTER(RKNN_CustomOpContext)
793
+ )
794
+
795
+ class RKNN_CustomOp(ctypes.Structure):
796
+ _fields_ = [
797
+ ("version", ctypes.c_uint32),
798
+ ("target", ctypes.c_int),
799
+ ("op_type", ctypes.c_char * RKNN_MAX_NAME_LEN),
800
+ ("cl_kernel_name", ctypes.c_char * RKNN_MAX_NAME_LEN),
801
+ ("cl_kernel_source", ctypes.c_char_p),
802
+ ("cl_source_size", ctypes.c_uint64),
803
+ ("cl_build_options", ctypes.c_char * RKNN_MAX_NAME_LEN),
804
+ ("init", CB_SIG),
805
+ ("prepare", CB_SIG),
806
+ ("compute", CB_SIG),
807
+ ("compute_native", CB_SIG),
808
+ ("destroy", DESTROY_SIG),
809
+ ]
810
+
811
+ # 保存类型定义
812
+ self._RKNN_TensorAttr = RKNN_TensorAttr
813
+ self._RKNN_TensorMem = RKNN_TensorMem
814
+ self._RKNN_CustomOpTensor = RKNN_CustomOpTensor
815
+ self._RKNN_CustomOpContext = RKNN_CustomOpContext
816
+ self._RKNN_CustomOpAttr = RKNN_CustomOpAttr
817
+ self._RKNN_CustomOp = RKNN_CustomOp
818
+ self._CB_SIG = CB_SIG
819
+ self._DESTROY_SIG = DESTROY_SIG
820
+
821
+ def _create_attr_readers(self, get_op_attr):
822
+ """创建属性读取函数"""
823
+ def read_attr_int64(op_ctx_ptr, key: str, default: int = 0) -> int:
824
+ attr = self._RKNN_CustomOpAttr()
825
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
826
+ if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_INT64 and attr.data:
827
+ return ctypes.c_int64.from_address(attr.data).value
828
+ return default
829
+
830
+ def read_attr_float32(op_ctx_ptr, key: str, default: float = 0) -> float:
831
+ attr = self._RKNN_CustomOpAttr()
832
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
833
+ if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_FLOAT32 and attr.data:
834
+ return ctypes.c_float.from_address(attr.data).value
835
+ return default
836
+
837
+ def read_attr_str(op_ctx_ptr, key: str, default: str = "") -> str:
838
+ attr = self._RKNN_CustomOpAttr()
839
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
840
+ if attr.n_elems > 0 and attr.dtype == self._RKNN_TENSOR_UINT8 and attr.data:
841
+ buf = (ctypes.c_ubyte * attr.n_elems).from_address(attr.data)
842
+ try:
843
+ return bytes(buf).decode("utf-8", errors="ignore").strip('"')
844
+ except Exception:
845
+ return default
846
+ return default
847
+
848
+
849
+ return read_attr_int64, read_attr_str, read_attr_float32
850
+
851
+ def _build_py_custom_op(self,
852
+ op_type: str,
853
+ n_inputs: int,
854
+ n_outputs: int,
855
+ on_init,
856
+ on_compute):
857
+ """通用的Python自定义算子构造器
858
+
859
+ Args:
860
+ op_type: 算子类型名(字符串)
861
+ n_inputs: 输入个数
862
+ n_outputs: 输出个数
863
+ on_init: 回调,签名 on_init(op_ctx_p, read_attr_int64, read_attr_str) -> state
864
+ on_compute: 回调,签名 on_compute(op_ctx_p, inputs_p, outputs_p, state) -> int(0成功)
865
+ Returns:
866
+ (RKNN_CustomOp对象, 回调tuple)
867
+ """
868
+ @self._CB_SIG
869
+ def _py_init(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
870
+ try:
871
+ # 允许无需提前读取属性
872
+ runtime = self.runtime.rknn_runtime
873
+ read_attr_int64, read_attr_str, read_attr_float32 = self._create_attr_readers(runtime.lib.rknn_custom_op_get_op_attr)
874
+ user_state = on_init(op_ctx_p, read_attr_int64, read_attr_str, read_attr_float32)
875
+ # 为该实例分配唯一ID, 并写入priv_data
876
+ if not hasattr(self, "_custom_op_states"):
877
+ self._custom_op_states = {}
878
+ if not hasattr(self, "_next_custom_op_id"):
879
+ self._next_custom_op_id = 1
880
+ inst_id = int(self._next_custom_op_id)
881
+ self._next_custom_op_id += 1
882
+ # 保存Python侧状态
883
+ self._custom_op_states[inst_id] = user_state
884
+ # 将实例ID写入priv_data
885
+ try:
886
+ op_ctx_p.contents.priv_data = ctypes.c_void_p(inst_id)
887
+ except Exception:
888
+ # 回退: 直接写入整数
889
+ op_ctx_p.contents.priv_data = inst_id
890
+ return 0
891
+ except Exception as e:
892
+ logger.error(f"{op_type} init失败: {e}")
893
+ return -1
894
+
895
+ @self._CB_SIG
896
+ def _py_prepare(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
897
+ return 0
898
+
899
+ @self._CB_SIG
900
+ def _py_compute(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
901
+ try:
902
+ if n_inputs_p != n_inputs or n_outputs_p != n_outputs:
903
+ return -1
904
+ # 通过priv_data取回该实例的状态
905
+ try:
906
+ inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
907
+ except Exception:
908
+ inst_id = 0
909
+ user_state = None
910
+ if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
911
+ user_state = self._custom_op_states.get(inst_id)
912
+ else:
913
+ logger.error(f"{op_type} compute失败: 找不到实例状态, inst_id={inst_id}")
914
+ return -1
915
+ return on_compute(op_ctx_p, inputs_p, outputs_p, user_state)
916
+ except Exception as e:
917
+ logger.error(f"{op_type} compute失败: {e}")
918
+ import traceback
919
+ logger.error(f"{op_type} compute失败: {traceback.format_exc()}")
920
+ return -1
921
+
922
+ @self._DESTROY_SIG
923
+ def _py_destroy(op_ctx_p):
924
+ try:
925
+ # 清理该实例的状态
926
+ try:
927
+ inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
928
+ except Exception:
929
+ inst_id = 0
930
+ if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
931
+ del self._custom_op_states[inst_id]
932
+ # 将priv_data清空
933
+ try:
934
+ op_ctx_p.contents.priv_data = ctypes.c_void_p(0)
935
+ except Exception:
936
+ op_ctx_p.contents.priv_data = 0
937
+ return 0
938
+ except Exception:
939
+ return -1
940
+
941
+ op = self._RKNN_CustomOp()
942
+ op.version = 1
943
+ op.target = self._RKNN_TARGET_TYPE_CPU
944
+ op.op_type = op_type.encode("utf-8")
945
+ op.cl_kernel_name = b""
946
+ op.cl_kernel_source = None
947
+ op.cl_source_size = 0
948
+ op.cl_build_options = b""
949
+ op.init = _py_init
950
+ op.prepare = _py_prepare
951
+ op.compute = _py_compute
952
+ op.compute_native = self._CB_SIG() # NULL
953
+ op.destroy = _py_destroy
954
+
955
+ return op, (_py_init, _py_prepare, _py_compute, _py_destroy)
956
+
957
+
958
+ def _tensor_to_numpy(self, rknn_tensor):
959
+ """将 RKNN_CustomOpTensor 转换为 Numpy 数组视图"""
960
+ # 确定Numpy数据类型
961
+ # 您可以扩展这个映射
962
+ dtype_map = {
963
+ self._RKNN_TENSOR_FLOAT32: (ctypes.c_float, np.float32),
964
+ self._RKNN_TENSOR_UINT8: (ctypes.c_uint8, np.uint8),
965
+ self._RKNN_TENSOR_INT64: (ctypes.c_int64, np.int64),
966
+ }
967
+ c_type, np_dtype = dtype_map.get(rknn_tensor.attr.type, (None, None))
968
+ if c_type is None:
969
+ raise TypeError(f"不支持的RKNN张量类型: {rknn_tensor.attr.type}")
970
+
971
+ # 获取内存地址和形状
972
+ addr = (rknn_tensor.mem.virt_addr or 0) + int(rknn_tensor.mem.offset)
973
+ ptr = ctypes.cast(addr, ctypes.POINTER(c_type))
974
+ shape = tuple(rknn_tensor.attr.dims[i] for i in range(rknn_tensor.attr.n_dims))
975
+
976
+ # 创建Numpy数组视图
977
+ return np.ctypeslib.as_array(ptr, shape=shape)
978
+
979
+
980
+ def _create_onnxscript_op_creator(self,
981
+ op_type: str,
982
+ # 现在接收一个"函数模板构造器"
983
+ onnxscript_func_builder,
984
+ n_inputs: int,
985
+ n_outputs: int,
986
+ attributes: dict = {},
987
+ constants: dict = {}):
988
+ """
989
+ 一个高阶工厂函数,用于创建基于ONNXScript的自定义算子构造器。
990
+ 它在 on_init 阶段动态生成最终的 onnxscript 计算函数。
991
+
992
+ Args:
993
+ op_type (str): 算子类型名。
994
+ onnxscript_func_builder: 一个函数,它接收所有属性和常量作为关键字参数,
995
+ 并返回一个编译好的 onnxscript 函数。
996
+ 例如: def builder(mean, scale):
997
+ @onnxscript.script()
998
+ def compute(like):
999
+ return opset.RandomNormalLike(like, mean=mean, scale=scale)
1000
+ return compute
1001
+ attributes (dict): 从模型中读取的属性字典。
1002
+ constants (dict): 编译时常量字典。
1003
+ n_inputs (int): 输入个数。
1004
+ n_outputs (int): 输出个数。
1005
+ """
1006
+
1007
+ def creator_func():
1008
+ def on_init(op_ctx_p, read_i64, read_s, read_f32):
1009
+ # 1. 读取所有动态属性
1010
+ attr_values = {}
1011
+ for name, (attr_type, default) in attributes.items():
1012
+ if attr_type == 'int64':
1013
+ attr_values[name] = read_i64(op_ctx_p, name, default)
1014
+ elif attr_type == 'str':
1015
+ attr_values[name] = read_s(op_ctx_p, name, default)
1016
+ elif attr_type == 'float32':
1017
+ attr_values[name] = read_f32(op_ctx_p, name, default)
1018
+ else:
1019
+ raise ValueError(f"不支持的属性类型: {attr_type}")
1020
+
1021
+ # 2. 合并常量和属性
1022
+ final_kwargs = {**constants, **attr_values}
1023
+
1024
+ # 3. 动态构建 onnxscript 函数! <<<<< 核心修改
1025
+ # 这确保了所有属性值都作为常量被闭包捕获
1026
+ compute_func = onnxscript_func_builder(**final_kwargs)
1027
+
1028
+ # 4. 将最终生成的、已编译的函数存入 state
1029
+ return {"compute_func": compute_func}
1030
+
1031
+ def on_compute(op_ctx_p, inputs_p, outputs_p, state):
1032
+ compute_func = state["compute_func"]
1033
+
1034
+ input_nps = [self._tensor_to_numpy(inputs_p[i]) for i in range(n_inputs)]
1035
+ output_nps = [self._tensor_to_numpy(outputs_p[i]) for i in range(n_outputs)]
1036
+
1037
+ results = compute_func(*input_nps)
1038
+
1039
+ if n_outputs == 1:
1040
+ result_val = results[0] if isinstance(results, tuple) else results
1041
+ output_nps[0][...] = result_val
1042
+ else:
1043
+ for i in range(n_outputs):
1044
+ output_nps[i][...] = results[i]
1045
+
1046
+ return 0
1047
+
1048
+ return self._build_py_custom_op(
1049
+ op_type=op_type,
1050
+ n_inputs=n_inputs,
1051
+ n_outputs=n_outputs,
1052
+ on_init=on_init,
1053
+ on_compute=on_compute
1054
+ )
1055
+
1056
+ return creator_func
1057
+
1058
+ def _create_gridsample_op(self):
1059
+ import onnxscript
1060
+ from onnxscript import opset17 as opset
1061
+
1062
+ def grid_sample_builder(align_corners, mode, padding_mode):
1063
+ @onnxscript.script()
1064
+ def grid_sample_compute(X, G):
1065
+ return opset.GridSample(X, G, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
1066
+ return grid_sample_compute
1067
+
1068
+ grid_sample_creator = self._create_onnxscript_op_creator(
1069
+ op_type="GridSample",
1070
+ onnxscript_func_builder=grid_sample_builder, # << 传入 builder
1071
+ attributes={
1072
+ "align_corners": ("int64", 0),
1073
+ "mode": ("str", "bilinear"),
1074
+ "padding_mode": ("str", "zeros"),
1075
+ },
1076
+ n_inputs = 2,
1077
+ n_outputs = 1
1078
+ )
1079
+ return grid_sample_creator
1080
+
1081
+ def _create_scatterelements_op(self):
1082
+ import onnxscript
1083
+ from onnxscript import opset17 as opset
1084
+
1085
+ @onnxscript.script()
1086
+ def scatter_elements_compute(data, indices, updates):
1087
+ indices_i64 = opset.Cast(indices, to=onnxscript.INT64.dtype)
1088
+ return opset.ScatterElements(data, indices_i64, updates)
1089
+
1090
+ scatter_elements_creator = self._create_onnxscript_op_creator(
1091
+ op_type="ScatterElements",
1092
+ onnxscript_func_builder=lambda: scatter_elements_compute,
1093
+ n_inputs = 3,
1094
+ n_outputs = 1
1095
+ )
1096
+ return scatter_elements_creator
1097
+
1098
+ def _create_randomnormallike_op(self):
1099
+ import onnxscript
1100
+ from onnxscript import opset17 as opset
1101
+
1102
+ def random_normal_like_builder(mean, scale):
1103
+ @onnxscript.script()
1104
+ def random_normal_like_compute(like):
1105
+ return opset.RandomNormalLike(like, mean=mean, scale=scale)
1106
+
1107
+ return random_normal_like_compute
1108
+
1109
+ # 3. 使用新的工厂函数
1110
+ random_normal_like_creator = self._create_onnxscript_op_creator(
1111
+ op_type="RandomNormalLike",
1112
+ onnxscript_func_builder=random_normal_like_builder, # << 传入 builder
1113
+ attributes={
1114
+ "mean": ("float32", 0.0),
1115
+ "scale": ("float32", 1.0),
1116
+ },
1117
+ n_inputs = 1,
1118
+ n_outputs = 1
1119
+ )
1120
+ return random_normal_like_creator
1121
+
1122
+ def _create_einsum_op(self):
1123
+ import onnxscript
1124
+ from onnxscript import opset17 as opset
1125
+
1126
+ def einsum_builder(equation):
1127
+
1128
+ @onnxscript.script()
1129
+ def einsum_compute(in1, in2):
1130
+ return opset.Einsum(in1, in2, equation=equation)
1131
+
1132
+ return einsum_compute
1133
+
1134
+ # 3. 使用新的工厂函数
1135
+ einsum_creator = self._create_onnxscript_op_creator(
1136
+ op_type="Einsum",
1137
+ onnxscript_func_builder=einsum_builder, # << 传入 builder
1138
+ attributes={
1139
+ "equation": ("str", ""),
1140
+ },
1141
+ n_inputs = 2,
1142
+ n_outputs = 1
1143
+ )
1144
+ return einsum_creator
1145
+
1146
+ def _create_reducel1_op(self):
1147
+ import onnxscript
1148
+ from onnxscript import opset17 as opset
1149
+
1150
+ def reduce_l1_builder(keepdims, noop_with_empty_axes):
1151
+ @onnxscript.script()
1152
+ def reduce_l1_compute(data, axes):
1153
+ data_abs = opset.Abs(data)
1154
+ return opset.ReduceSum(
1155
+ data_abs,
1156
+ axes,
1157
+ keepdims=keepdims,
1158
+ noop_with_empty_axes=noop_with_empty_axes,
1159
+ )
1160
+ return reduce_l1_compute
1161
+
1162
+ reduce_l1_creator = self._create_onnxscript_op_creator(
1163
+ op_type="ReduceL1",
1164
+ onnxscript_func_builder=reduce_l1_builder,
1165
+ attributes={
1166
+ "keepdims": ("int64", 1),
1167
+ "noop_with_empty_axes": ("int64", 0),
1168
+ },
1169
+ n_inputs=2,
1170
+ n_outputs=1,
1171
+ )
1172
+ return reduce_l1_creator
1173
+
1174
+ def register_bundled_ops(self) -> None:
1175
+ """注册自定义操作"""
1176
+ if getattr(self, "_custom_ops_registered", False):
1177
+ return
1178
+
1179
+ runtime = self.runtime.rknn_runtime
1180
+ lib = runtime.lib
1181
+ ctx = runtime.context
1182
+
1183
+ try:
1184
+ _ = lib.rknn_register_custom_ops
1185
+ _ = lib.rknn_custom_op_get_op_attr
1186
+ except AttributeError as e:
1187
+ logger.debug(f"SDK不支持自定义算子注册: {e}")
1188
+ return
1189
+
1190
+ self._init_custom_op_types()
1191
+
1192
+ # 注意:插件库注册已在模型加载后由环境变量控制,不在此处重复触发
1193
+
1194
+ # 算子创建函数的列表现在更加清晰
1195
+ op_creator_factories = [
1196
+ self._create_gridsample_op,
1197
+ self._create_scatterelements_op,
1198
+ self._create_randomnormallike_op,
1199
+ self._create_einsum_op,
1200
+ self._create_reducel1_op,
1201
+ # self._create_my_custom_add_op, # 添加新算子非常简单
1202
+ ]
1203
+
1204
+ ops_to_register = []
1205
+ all_callbacks = []
1206
+
1207
+ for factory in op_creator_factories:
1208
+ try:
1209
+ # 调用工厂获得真正的构造器
1210
+ creator_func = factory()
1211
+ # 调用构造器生成算子实例
1212
+ op, callbacks = creator_func()
1213
+ ops_to_register.append(op)
1214
+ all_callbacks.extend(callbacks)
1215
+ logger.debug(f"成功创建自定义算子: {op.op_type.decode()}")
1216
+ except Exception as e:
1217
+ logger.warning(f"创建自定义算子失败: {e}", exc_info=True)
1218
+
1219
+ if not ops_to_register:
1220
+ logger.debug("没有可注册的自定义算子")
1221
+ return
1222
+
1223
+ # 创建一个ctypes数组以包含所有要注册的算子, 然后一次性注册
1224
+ num_ops = len(ops_to_register)
1225
+ op_array = (self._RKNN_CustomOp * num_ops)(*ops_to_register)
1226
+ ret = lib.rknn_register_custom_ops(ctx, op_array, num_ops)
1227
+ if ret != 0:
1228
+ logger.error(f"注册自定义算子失败, ret={ret} (可能是误报, 继续执行...)")
1229
+ # raise RuntimeError(f"rknn_register_custom_ops 失败, ret={ret}")
1230
+
1231
+ logger.info(f"成功注册 {len(ops_to_register)} 个自定义算子")
1232
+
1233
+ self._custom_ops_registered = True
1234
+ self._registered_ops = ops_to_register
1235
+ self._op_callbacks = all_callbacks
1236
+
1237
+ def _load_and_register_plugin_op(self, so_path: str) -> bool:
1238
+ """加载单个插件库并注册其中的自定义算子。
1239
+
1240
+ 要求插件实现 get_rknn_custom_op(),返回 rknn_custom_op*。
1241
+ 我们将该 C 指针直接传递给 rknn_register_custom_ops,避免复制。
1242
+ """
1243
+ if not os.path.isfile(so_path):
1244
+ logger.warning(f"插件库不存在: {so_path}")
1245
+ return False
1246
+
1247
+ runtime = self.runtime.rknn_runtime
1248
+ lib = runtime.lib
1249
+ ctx = runtime.context
1250
+
1251
+ # 根据平台位宽设置 rknn_context 的 ctypes 类型
1252
+ ContextCType = ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
1253
+ # 设置 rknn_register_custom_ops(ctx, op_ptr, num) 签名。第二参数按 void* 传递,避免结构体布局不一致
1254
+ try:
1255
+ lib.rknn_register_custom_ops.argtypes = [ContextCType, ctypes.c_void_p, ctypes.c_uint32]
1256
+ lib.rknn_register_custom_ops.restype = ctypes.c_int
1257
+ except Exception:
1258
+ pass
1259
+
1260
+ # 加载插件
1261
+ try:
1262
+ handle = ctypes.CDLL(so_path)
1263
+ except Exception as e:
1264
+ logger.error(f"dlopen 失败: {so_path}, err={e}")
1265
+ return False
1266
+
1267
+ # 获取 get_rknn_custom_op 符号
1268
+ try:
1269
+ get_sym = getattr(handle, "get_rknn_custom_op")
1270
+ except AttributeError:
1271
+ logger.error(f"插件缺少符号 get_rknn_custom_op: {so_path}")
1272
+ return False
1273
+
1274
+ # 返回类型直接使用 void*,避免 Python 解析第三方结构体
1275
+ try:
1276
+ get_sym.argtypes = []
1277
+ except Exception:
1278
+ pass
1279
+ get_sym.restype = ctypes.c_void_p
1280
+
1281
+ op_void_ptr = get_sym()
1282
+ if not op_void_ptr:
1283
+ logger.error(f"get_rknn_custom_op 返回空指针: {so_path}")
1284
+ return False
1285
+
1286
+ # 直接使用原生指针注册(零拷贝)
1287
+ ctx_val = ContextCType(runtime.context)
1288
+ ret = lib.rknn_register_custom_ops(ctx_val, ctypes.c_void_p(op_void_ptr), 1)
1289
+ if ret != 0:
1290
+ logger.error(f"rknn_register_custom_ops 失败, ret={ret}, so={so_path} (可能是误报, 继续执行...)")
1291
+ # return False
1292
+
1293
+ # 保留句柄,避免被垃圾回收卸载
1294
+ if not hasattr(self, "_plugin_handles"):
1295
+ self._plugin_handles = []
1296
+ self._plugin_handles.append(handle)
1297
+ logger.info(f"成功注册插件自定义算子: {so_path}")
1298
+ return True
1299
+
1300
+ def register_plugin_ops(self, plugin_paths: List[str]) -> int:
1301
+ """按给定路径列表注册插件库中的自定义算子。返回成功数量。"""
1302
+ if not plugin_paths:
1303
+ return 0
1304
+ success = 0
1305
+ for path in plugin_paths:
1306
+ try:
1307
+ if self._load_and_register_plugin_op(path):
1308
+ success += 1
1309
+ except Exception as e:
1310
+ logger.error(f"注册插件失败: {path}, err={e}")
1311
+ return success
1312
+
1313
+ # 对外API:注册单个自定义算子插件库
1314
+ def register_custom_op_lib(self, path: str) -> bool:
1315
+ return self._load_and_register_plugin_op(path)
1316
+
1317
+ # 对外API:扫描并注册 Linux 系统目录下所有插件库(Android 不处理)
1318
+ def register_system_custom_op_lib(self) -> int:
1319
+ if os.name != 'posix':
1320
+ return 0
1321
+ # 仅 Linux:RKNN 官方默认目录
1322
+ system_dir = "/usr/lib/rknpu/op_plugins/"
1323
+ if not os.path.isdir(system_dir):
1324
+ return 0
1325
+ try:
1326
+ entries = os.listdir(system_dir)
1327
+ except Exception:
1328
+ return 0
1329
+ so_list = []
1330
+ for name in entries:
1331
+ # 官方要求文件名以 librkcst_ 开头
1332
+ if name.startswith("librkcst_") and name.endswith('.so'):
1333
+ so_list.append(os.path.join(system_dir, name))
1334
+ return self.register_plugin_ops(so_list)