happyme531 commited on
Commit
7fc4eb4
·
verified ·
1 Parent(s): 21d4ecf

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
37
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
38
+ test.jpg filter=lfs diff=lfs merge=lfs -text
39
+ vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,333 @@
1
- ---
2
- license: agpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model:
3
+ - OpenGVLab/InternVL3_5-2B-HF
4
+ tags:
5
+ - rknn
6
+ - rkllm
7
+ - internvl
8
+ ---
9
+ # InternVL3_5-2B-RKLLM
10
+
11
+ ## (English README see below)
12
+
13
+ 在RK3588上运行强大的InternVL3.5-2B视觉大模型!
14
+
15
+ - 推理速度(RK3588): 视觉编码器 2.1s(三核并行) + LLM 填充 1s (265 tokens / 261 tps) + 解码 12.1 tps
16
+ - 内存占用(RK3588, 上下文长度1024): 3.9GB
17
+
18
+ ## 使用方法
19
+
20
+ 1. 克隆或者下载此仓库到本地. 模型较大, 请确保有足够的磁盘空间.
21
+
22
+ 2. 开发板的RKNPU2内核驱动版本必须>=0.9.6才能运行这么大的模型.
23
+ 使用root权限运行以下命令检查驱动版本:
24
+ ```bash
25
+ > cat /sys/kernel/debug/rknpu/version
26
+ RKNPU driver: v0.9.8
27
+ ```
28
+ 如果版本过低, 请更新驱动. 你可能需要更新内核, 或查找官方文档以获取帮助.
29
+
30
+ 3. 安装依赖
31
+
32
+ ```bash
33
+ pip install "numpy<2" opencv-python rknn-toolkit-lite2
34
+ ```
35
+
36
+ 4. 运行
37
+
38
+ ```bash
39
+ python ./run_rkllm.py ./test.jpg ./vision_encoder.rknn ./language_model_w8a8.rkllm 512 1024 3
40
+ ```
41
+
42
+ 参数说明:
43
+ - `512`: max_new_tokens, 最大生成token数.
44
+ - `1024`: max_context_len, 最大上下文长度.
45
+ - `3`: npu_core_num, 使用的NPU核心数.
46
+
47
+ 如果实测性能不理想, 可以调整CPU调度器让CPU始终运行在最高频率, 并把推理程序绑定到大核(`taskset -c 4-7 python ...`)
48
+
49
+ test.jpg:
50
+ ![test.jpg](./test.jpg)
51
+
52
+ ```
53
+ Initializing ONNX Runtime for vision encoder...
54
+ I rknn-toolkit2 version: 2.3.2
55
+ I target set by user is: rk3588
56
+ Vision encoder loaded successfully.
57
+ ONNX Input: pixel_values, ONNX Output: projected_features
58
+ Initializing RKLLM Runtime...
59
+ I rkllm: rkllm-runtime version: 1.2.2, rknpu driver version: 0.9.8, platform: RK3588
60
+ I rkllm: loading rkllm model from ./language_model_w8a8.rkllm
61
+ I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: W8A8
62
+ I rkllm: Enabled cpus: [4, 5, 6, 7]
63
+ I rkllm: Enabled cpus num: 4
64
+ RKLLM initialized successfully.
65
+ Preprocessing image...
66
+ Running vision encoder...
67
+ 视觉编码器推理耗时: 2.0876 秒
68
+ Image encoded successfully.
69
+
70
+ **********************可输入以下问题对应序号获取回答/或自定义输入********************
71
+
72
+ [0] <image>What is in the image?
73
+ [1] <image>这张图片中有什么?
74
+
75
+ *************************************************************************
76
+
77
+
78
+ user: 0
79
+ <image>What is in the image?
80
+ robot: n_image_tokens: 256
81
+
82
+
83
+ This image depicts a cozy bedroom with a large window, several pieces of furniture, and various decorative items. The room has a vintage feel due to the wallpaper pattern and the wooden furniture.
84
+
85
+ The bed occupies the left side of the image, covered with a blue comforter or quilt. Next to the bed is a dresser with a round mirror above it. On top of the dresser are several small objects, including what appears to be a water bottle and some decorative items like plants.
86
+
87
+ In front of the window on the right side of the image, there is a chair with a checkered cushion. Behind this chair, there is a bookshelf filled with books and various other items, such as baskets and possibly some knick-knacks. The bookshelf has multiple levels, each holding an assortment of books and decorative objects.
88
+
89
+ The window allows natural light to enter the room, illuminating the space and highlighting the greenery outside. There are also potted plants placed around the room, adding a touch of nature and freshness to the interior decor.
90
+
91
+ Overall, this bedroom exudes a sense of comfort and personal style, with elements that suggest it is used regularly by someone who values both aesthetics and functionality in their living space.
92
+
93
+
94
+ I rkllm: --------------------------------------------------------------------------------------
95
+ I rkllm: Model init time (ms) 4314.30
96
+ I rkllm: --------------------------------------------------------------------------------------
97
+ I rkllm: Stage Total Time (ms) Tokens Time per Token (ms) Tokens per Second
98
+ I rkllm: --------------------------------------------------------------------------------------
99
+ I rkllm: Prefill 1013.32 265 3.82 261.52
100
+ I rkllm: Generate 20155.65 244 82.61 12.11
101
+ I rkllm: --------------------------------------------------------------------------------------
102
+ I rkllm: Peak Memory Usage (GB)
103
+ I rkllm: 3.45
104
+ I rkllm: --------------------------------------------------------------------------------------
105
+
106
+ user: 1
107
+ <image>这张图片中有什么?
108
+ robot: n_image_tokens: 256
109
+
110
+
111
+ 这是一间温馨的卧室,房间内有一扇大窗户、几件家具和各种装饰物品。房间因壁纸图案和木质家具而显得复古。
112
+
113
+ 床位于图像左侧,覆盖着蓝色被套或毯子。床旁边是一个带有圆形镜子的抽屉柜。在抽屉柜上摆放着一些小物件,���括水瓶和一些装饰品,如植物。
114
+
115
+ 窗户右侧前方有一把带格子坐垫的椅子。椅子后面是一排书架,上面摆满了书籍和其他物品,如篮子和可能的一些小饰品。书架有多层,每层都放着各种书籍和装饰物。
116
+
117
+ 窗外可以看到绿树,自然光透过窗户照进房间,照亮了空间,并突出了外面的绿色植物。房间里还摆放了一些盆栽植物,为室内增添了自然的气息和清新感。
118
+
119
+ 总体而言,这间卧室给人一种舒适和个性的感觉,表明它经常被居住者使用,居住者重视生活空间中的美学和功能性。
120
+
121
+ I rkllm: --------------------------------------------------------------------------------------
122
+ I rkllm: Stage Total Time (ms) Tokens Time per Token (ms) Tokens per Second
123
+ I rkllm: --------------------------------------------------------------------------------------
124
+ I rkllm: Prefill 1287.65 264 4.88 205.03
125
+ I rkllm: Generate 19852.10 204 97.31 10.28
126
+ I rkllm: --------------------------------------------------------------------------------------
127
+ I rkllm: Peak Memory Usage (GB)
128
+ I rkllm: 3.45
129
+ I rkllm: --------------------------------------------------------------------------------------
130
+
131
+ user: ^C
132
+ Exiting...
133
+ Releasing resources...
134
+ RKLLM instance destroyed.
135
+ ```
136
+
137
+ ## 模型转换
138
+
139
+ #### 准备工作
140
+
141
+ 1. 安装rknn-toolkit2以及rkllm-toolkit:
142
+ ```bash
143
+ pip install -U rknn-toolkit2
144
+ ```
145
+ rkllm-toolkit需要在这里手动下载: https://github.com/airockchip/rknn-llm/tree/main/rkllm-toolkit
146
+
147
+ 2. 下载此仓库到本地, 但不需要下载`.rkllm`和`.rknn`结尾的模型文件.
148
+ 3. 下载InternVL3.5-2B的huggingface模型仓库到本地. ( https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF )
149
+
150
+ #### 转换LLM
151
+
152
+ 将`rkllm-convert.py`拷贝到InternVL3_5-2B-HF的模型文件夹中,执行:
153
+ ```bash
154
+ python rkllm-convert.py
155
+ ```
156
+ 默认是w8a8量化的,你可以自行打开脚本修改量化方式等。
157
+
158
+ #### 转换视觉编码器
159
+
160
+ 1. 导出ONNX
161
+
162
+ 将`export_vision_onnx.py`拷贝到InternVL3_5-2B-HF的模型文件夹根目录中,然后**在该根目录**下执行:
163
+ ```bash
164
+ python ./export_vision_onnx.py
165
+ ```
166
+ 视觉编码器会导出到`vision_encoder.onnx`.
167
+
168
+ 2. 转换rknn
169
+
170
+ ```bash
171
+ python ./convert_vision_encoder.py
172
+ ```
173
+
174
+ ## 已知问题
175
+
176
+ - 由于RKLLM的多模态输入的限制, 在整个对话中只能加载一张图片.
177
+ - 没有实现多轮对话.
178
+ - RKLLM的w8a8量化貌似存在不小的精度损失.
179
+ - 没有实现原模型中的高清图像分块输入与视频输入功能. 原因是我懒得做了,以后可以考虑加上.
180
+
181
+ ## 参考
182
+
183
+ - [OpenGVLab/InternVL3_5-2B-HF](https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF)
184
+
185
+ ----
186
+
187
+ # English README
188
+
189
+ Run the powerful InternVL3.5-2B large vision model on RK3588!
190
+
191
+ - Inference Speed (RK3588): Vision Encoder 2.1s (3-core parallel) + LLM Prefill 1s (265 tokens / 261 tps) + Decode 12.1 tps
192
+ - Memory Usage (RK3588, context length 1024): 3.9GB
193
+
194
+ ## How to Use
195
+
196
+ 1. Clone or download this repository locally. The model is large, so ensure you have enough disk space.
197
+
198
+ 2. The RKNPU2 kernel driver version on your development board must be >=0.9.6 to run this model. Run the following command with root privileges to check the driver version:
199
+ ```bash
200
+ > cat /sys/kernel/debug/rknpu/version
201
+ RKNPU driver: v0.9.8
202
+ ```
203
+ If the version is too low, please update the driver. You may need to update the kernel or refer to the official documentation for help.
204
+
205
+ 3. Install dependencies:
206
+
207
+ ```bash
208
+ pip install "numpy<2" opencv-python rknn-toolkit-lite2
209
+ ```
210
+
211
+ 4. Run:
212
+
213
+ ```bash
214
+ python ./run_rkllm.py ./test.jpg ./vision_encoder.rknn ./language_model_w8a8.rkllm 512 1024 3
215
+ ```
216
+
217
+ Parameter description:
218
+ - `512`: `max_new_tokens`, the maximum number of tokens to generate.
219
+ - `1024`: `max_context_len`, the maximum context length.
220
+ - `3`: `npu_core_num`, the number of NPU cores to use.
221
+
222
+ If the performance is not ideal, you can adjust the CPU scheduler to keep the CPU at its highest frequency and bind the inference program to the big cores (`taskset -c 4-7 python ...`).
223
+
224
+ Example with `test.jpg`:
225
+ ![test.jpg](./test.jpg)
226
+
227
+ ```
228
+ Initializing ONNX Runtime for vision encoder...
229
+ I rknn-toolkit2 version: 2.3.2
230
+ I target set by user is: rk3588
231
+ Vision encoder loaded successfully.
232
+ ONNX Input: pixel_values, ONNX Output: projected_features
233
+ Initializing RKLLM Runtime...
234
+ I rkllm: rkllm-runtime version: 1.2.2, rknpu driver version: 0.9.8, platform: RK3588
235
+ I rkllm: loading rkllm model from ./language_model_w8a8.rkllm
236
+ I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: W8A8
237
+ I rkllm: Enabled cpus: [4, 5, 6, 7]
238
+ I rkllm: Enabled cpus num: 4
239
+ RKLLM initialized successfully.
240
+ Preprocessing image...
241
+ Running vision encoder...
242
+ 视觉编码器推理耗时: 2.0876 秒
243
+ Image encoded successfully.
244
+
245
+ **********************可输入以下问题对应序号获取回答/或自定义输入********************
246
+
247
+ [0] <image>What is in the image?
248
+ [1] <image>这张图片中有什么?
249
+
250
+ *************************************************************************
251
+
252
+
253
+ user: 0
254
+ <image>What is in the image?
255
+ robot: n_image_tokens: 256
256
+
257
+
258
+ This image depicts a cozy bedroom with a large window, several pieces of furniture, and various decorative items. The room has a vintage feel due to the wallpaper pattern and the wooden furniture.
259
+
260
+ The bed occupies the left side of the image, covered with a blue comforter or quilt. Next to the bed is a dresser with a round mirror above it. On top of the dresser are several small objects, including what appears to be a water bottle and some decorative items like plants.
261
+
262
+ In front of the window on the right side of the image, there is a chair with a checkered cushion. Behind this chair, there is a bookshelf filled with books and various other items, such as baskets and possibly some knick-knacks. The bookshelf has multiple levels, each holding an assortment of books and decorative objects.
263
+
264
+ The window allows natural light to enter the room, illuminating the space and highlighting the greenery outside. There are also potted plants placed around the room, adding a touch of nature and freshness to the interior decor.
265
+
266
+ Overall, this bedroom exudes a sense of comfort and personal style, with elements that suggest it is used regularly by someone who values both aesthetics and functionality in their living space.
267
+
268
+
269
+ I rkllm: --------------------------------------------------------------------------------------
270
+ I rkllm: Model init time (ms) 4314.30
271
+ I rkllm: --------------------------------------------------------------------------------------
272
+ I rkllm: Stage Total Time (ms) Tokens Time per Token (ms) Tokens per Second
273
+ I rkllm: --------------------------------------------------------------------------------------
274
+ I rkllm: Prefill 1013.32 265 3.82 261.52
275
+ I rkllm: Generate 20155.65 244 82.61 12.11
276
+ I rkllm: --------------------------------------------------------------------------------------
277
+ I rkllm: Peak Memory Usage (GB)
278
+ I rkllm: 3.45
279
+ I rkllm: --------------------------------------------------------------------------------------
280
+
281
+ user: ^C
282
+ Exiting...
283
+ Releasing resources...
284
+ RKLLM instance destroyed.
285
+ ```
286
+
287
+ ## Model Conversion
288
+
289
+ #### Prerequisites
290
+
291
+ 1. Install `rknn-toolkit2` and `rkllm-toolkit`:
292
+ ```bash
293
+ pip install -U rknn-toolkit2
294
+ ```
295
+ `rkllm-toolkit` needs to be downloaded manually from here: https://github.com/airockchip/rknn-llm/tree/main/rkllm-toolkit
296
+
297
+ 2. Download this repository locally, but you don't need the `.rkllm` and `.rknn` model files.
298
+ 3. Download the InternVL3.5-2B huggingface model repository locally. ( https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF )
299
+
300
+ #### Convert LLM
301
+
302
+ Copy `rkllm-convert.py` to the InternVL3_5-2B-HF model folder and run:
303
+ ```bash
304
+ python rkllm-convert.py
305
+ ```
306
+ The default quantization is w8a8. You can modify the script to change quantization methods.
307
+
308
+ #### Convert Vision Encoder
309
+
310
+ 1. Export ONNX
311
+
312
+ Copy `export_vision_onnx.py` to the root directory of the InternVL3_5-2B-HF model folder, and then execute it **in that root directory**:
313
+ ```bash
314
+ python ./export_vision_onnx.py
315
+ ```
316
+ The vision encoder will be exported to `vision_encoder.onnx`.
317
+
318
+ 2. Convert to RKNN
319
+
320
+ ```bash
321
+ python ./convert_vision_encoder.py
322
+ ```
323
+
324
+ ## Known Issues
325
+
326
+ - Due to limitations in RKLLM's multimodal input, only one image can be loaded throughout the conversation.
327
+ - Multi-turn conversation is not implemented.
328
+ - RKLLM's w8a8 quantization appears to have significant precision loss.
329
+ - The high-resolution image tiling and video input features from the original model are not implemented. The reason is that I'm too lazy to do it, and it can be considered adding it later.
330
+
331
+ ## References
332
+
333
+ - [OpenGVLab/InternVL3_5-2B-HF](https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF)
convert_vision_encoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: vision_encoder
3
+
4
+ from rknn.api import RKNN
5
+ import os
6
+ import numpy as np
7
+
8
+ def main():
9
+ # 创建RKNN实例
10
+ rknn = RKNN(verbose=True)
11
+
12
+ # ONNX模型路径
13
+ ONNX_MODEL = "vision_encoder.onnx"
14
+ # 输出RKNN模型路径
15
+ RKNN_MODEL = "vision_encoder.rknn"
16
+
17
+ # 配置参数
18
+ print("--> Config model")
19
+ ret = rknn.config(target_platform="rk3588",
20
+ dynamic_input=None)
21
+ if ret != 0:
22
+ print('Config model failed!')
23
+ exit(ret)
24
+
25
+ # 加载ONNX模型
26
+ print("--> Loading model")
27
+ ret = rknn.load_onnx(model=ONNX_MODEL,
28
+ inputs=['pixel_values'],
29
+ input_size_list=[[1, 3, 448, 448]])
30
+ if ret != 0:
31
+ print('Load model failed!')
32
+ exit(ret)
33
+
34
+ # 构建模型
35
+ print("--> Building model")
36
+ ret = rknn.build(do_quantization=False)
37
+ if ret != 0:
38
+ print('Build model failed!')
39
+ exit(ret)
40
+
41
+ # 导出RKNN模型
42
+ print("--> Export RKNN model")
43
+ ret = rknn.export_rknn(RKNN_MODEL)
44
+ if ret != 0:
45
+ print('Export RKNN model failed!')
46
+ exit(ret)
47
+
48
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
49
+ rknn.release()
50
+
51
+ if __name__ == '__main__':
52
+ main()
export_vision_onnx.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+ import torchvision.transforms as T
9
+ from torchvision.transforms import InterpolationMode
10
+ from transformers.modeling_utils import PreTrainedModel
11
+
12
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
13
+ IMAGENET_STD = (0.229, 0.224, 0.225)
14
+
15
+ def build_transform(input_size):
16
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
17
+ transform = T.Compose([
18
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
19
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
20
+ T.ToTensor(),
21
+ T.Normalize(mean=MEAN, std=STD)
22
+ ])
23
+ return transform
24
+
25
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
26
+ best_ratio_diff = float('inf')
27
+ best_ratio = (1, 1)
28
+ area = width * height
29
+ for ratio in target_ratios:
30
+ target_aspect_ratio = ratio[0] / ratio[1]
31
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
32
+ if ratio_diff < best_ratio_diff:
33
+ best_ratio_diff = ratio_diff
34
+ best_ratio = ratio
35
+ elif ratio_diff == best_ratio_diff:
36
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
37
+ best_ratio = ratio
38
+ return best_ratio
39
+
40
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
41
+ orig_width, orig_height = image.size
42
+ aspect_ratio = orig_width / orig_height
43
+
44
+ # calculate the existing image aspect ratio
45
+ target_ratios = set(
46
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
47
+ i * j <= max_num and i * j >= min_num)
48
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
49
+
50
+ # find the closest aspect ratio to the target
51
+ target_aspect_ratio = find_closest_aspect_ratio(
52
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
53
+
54
+ # calculate the target width and height
55
+ target_width = image_size * target_aspect_ratio[0]
56
+ target_height = image_size * target_aspect_ratio[1]
57
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
58
+
59
+ # resize the image
60
+ resized_img = image.resize((target_width, target_height))
61
+ processed_images = []
62
+ for i in range(blocks):
63
+ box = (
64
+ (i % (target_width // image_size)) * image_size,
65
+ (i // (target_width // image_size)) * image_size,
66
+ ((i % (target_width // image_size)) + 1) * image_size,
67
+ ((i // (target_width // image_size)) + 1) * image_size
68
+ )
69
+ # split the image
70
+ split_img = resized_img.crop(box)
71
+ processed_images.append(split_img)
72
+ assert len(processed_images) == blocks
73
+ if use_thumbnail and len(processed_images) != 1:
74
+ thumbnail_img = image.resize((image_size, image_size))
75
+ processed_images.append(thumbnail_img)
76
+ return processed_images
77
+
78
+ def load_image(image_file, input_size=448, max_num=12):
79
+ image = Image.open(image_file).convert('RGB')
80
+ transform = build_transform(input_size=input_size)
81
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
82
+ pixel_values = [transform(image) for image in images]
83
+ pixel_values = torch.stack(pixel_values)
84
+ return pixel_values
85
+
86
+ # 加载本地模型
87
+ path = '.'
88
+ save_path = 'vision_encoder.onnx'
89
+ image_file = 'test.jpg'
90
+
91
+ def export_vision_InternVL(model_path: str, save_path: str):
92
+ """
93
+ Export the vision encoder and projector of Janus-Pro-1B model to ONNX format
94
+ """
95
+ # 设置默认数据类型为 float32
96
+ torch.set_default_dtype(torch.float32)
97
+
98
+ vl_gpt = AutoModel.from_pretrained(model_path,torch_dtype = torch.float32,trust_remote_code=True)
99
+
100
+ # Move model to CPU and convert to float32
101
+ vl_gpt = vl_gpt.cpu().eval().float() # 确保模型是 float32
102
+
103
+ # Create a wrapper class for vision encoder + projector
104
+ class VisionWrapper(nn.Module):
105
+ def __init__(self, model: PreTrainedModel):
106
+ super().__init__()
107
+ self.vision_model = model
108
+
109
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
110
+ # Delegate to the built-in helper so we stay consistent with Transformers' implementation.
111
+ return self.vision_model.get_image_features(pixel_values=pixel_values)
112
+
113
+ # Create wrapper instance and convert to float32
114
+ vision_wrapper = VisionWrapper(vl_gpt)
115
+ vision_wrapper.eval().float() # 确保包装器也是 float32
116
+
117
+ # Create dummy input with float32
118
+ batch_size = 1
119
+ num_channels = 3
120
+ height = 448 # InternVL2 default image size
121
+ width = 448
122
+ # dummy_input = load_image(image_file=image_file, max_num=12).to(torch.float32).cpu()
123
+ dummy_input = torch.randn(batch_size, num_channels, height, width, dtype=torch.float32)
124
+ # Export to ONNX with higher opset version
125
+ torch.onnx.export(
126
+ vision_wrapper,
127
+ dummy_input,
128
+ save_path,
129
+ export_params=True,
130
+ opset_version=17, # 使用高版本 opset 以支持 scaled_dot_product_attention
131
+ do_constant_folding=True,
132
+ input_names=['pixel_values'],
133
+ output_names=['projected_features'],
134
+ dynamic_axes={
135
+ 'pixel_values': {0: 'batch_size'},
136
+ 'projected_features': {0: 'batch_size'}
137
+ },
138
+ # 添加额外的配置
139
+ # operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
140
+ # training=torch.onnx.TrainingMode.EVAL,
141
+ dynamo=True,
142
+ verbose=False
143
+ )
144
+
145
+ print(f"Successfully exported vision components to {save_path}")
146
+
147
+ # Verify the exported model
148
+ import onnxruntime
149
+
150
+ # Create inference session
151
+ ort_session = onnxruntime.InferenceSession(save_path)
152
+
153
+ # Run inference with dummy input
154
+ ort_inputs = {
155
+ 'pixel_values': dummy_input.numpy()
156
+ }
157
+ ort_outputs = ort_session.run(None, ort_inputs)
158
+
159
+ # Compare with PyTorch output
160
+ torch_output = vision_wrapper(dummy_input)
161
+
162
+ # Check numerical accuracy with更宽松的容忍度
163
+ import numpy as np
164
+ np.testing.assert_allclose(
165
+ torch_output.detach().numpy(),
166
+ ort_outputs[0],
167
+ rtol=1e-1, # 放宽相对误差容忍度
168
+ atol=1e-2 # 放宽绝对误差容忍度
169
+ )
170
+
171
+ print("ONNX model verification successful!")
172
+
173
+ # 打印一些统计信息
174
+ torch_output_np = torch_output.detach().numpy()
175
+ onnx_output_np = ort_outputs[0]
176
+
177
+ abs_diff = np.abs(torch_output_np - onnx_output_np)
178
+ rel_diff = np.abs((torch_output_np - onnx_output_np) / (torch_output_np + 1e-7))
179
+
180
+ print(f"\nValidation Statistics:")
181
+ print(f"Max absolute difference: {np.max(abs_diff):.6f}")
182
+ print(f"Mean absolute difference: {np.mean(abs_diff):.6f}")
183
+ print(f"Max relative difference: {np.max(rel_diff):.6f}")
184
+ print(f"Mean relative difference: {np.mean(rel_diff):.6f}")
185
+
186
+ if __name__ == "__main__":
187
+ try:
188
+ import onnx
189
+ try:
190
+ onnx_version = onnx.__version__
191
+ except AttributeError:
192
+ try:
193
+ onnx_version = onnx.version.version
194
+ except AttributeError:
195
+ onnx_version = "Unknown"
196
+ print(f"ONNX version: {onnx_version}")
197
+ except ImportError:
198
+ print("ONNX not installed")
199
+
200
+ import onnxruntime
201
+ print(f"ONNX Runtime version: {onnxruntime.__version__}")
202
+
203
+ export_vision_InternVL(path, save_path)
language_model_w8a8.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ea6dbafda740b717233228a91cbb2377d0905a8b00edba9e17489da65e9834e
3
+ size 2375017292
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39d01912e67027de32c527be04684bf813e2a49c2d09ab8f6bcf47b34a43789d
3
+ size 7486400
rkllm-convert.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from transformers import AutoConfig, Qwen3ForCausalLM, AutoTokenizer
5
+
6
+ from rkllm.api import RKLLM
7
+
8
+ import argparse
9
+ import shutil
10
+ from pathlib import Path
11
+ from typing import Dict
12
+
13
+ import torch
14
+ from safetensors.torch import load_file
15
+ from transformers import AutoConfig, AutoModelForCausalLM
16
+
17
+ TOKENIZER_FILES = [
18
+ "tokenizer.json",
19
+ "tokenizer_config.json",
20
+ "special_tokens_map.json",
21
+ "added_tokens.json",
22
+ "vocab.json",
23
+ "merges.txt",
24
+ "chat_template.jinja",
25
+ ]
26
+
27
+
28
+ def parse_args() -> argparse.Namespace:
29
+ parser = argparse.ArgumentParser(description=__doc__)
30
+ parser.add_argument(
31
+ "--source",
32
+ type=Path,
33
+ default=".",
34
+ help="Path to the InternVL (HF-format) checkpoint directory, e.g. /path/to/InternVL3_5-2B-HF",
35
+ )
36
+ parser.add_argument(
37
+ "--output",
38
+ type=Path,
39
+ default="llm/",
40
+ help="Directory where the extracted Qwen3 checkpoint will be written",
41
+ )
42
+ parser.add_argument(
43
+ "--safe-serialization",
44
+ action="store_true",
45
+ default=True,
46
+ help="Save the exported model using safetensors instead of PyTorch binaries.",
47
+ )
48
+ return parser.parse_args()
49
+
50
+
51
+ def extract_text_state_dict(full_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
52
+ prefix = "language_model.model."
53
+ lm_head_prefix = "language_model.lm_head."
54
+ text_state: Dict[str, torch.Tensor] = {}
55
+
56
+ for key, tensor in full_state.items():
57
+ if key.startswith(prefix):
58
+ text_key = "model." + key[len(prefix) :]
59
+ elif key.startswith(lm_head_prefix):
60
+ text_key = "lm_head." + key[len(lm_head_prefix) :]
61
+ else:
62
+ continue
63
+ text_state[text_key] = tensor
64
+
65
+ if not text_state:
66
+ raise ValueError("Did not find any language_model weights in checkpoint; is this an InternVL model?")
67
+
68
+ return text_state
69
+
70
+
71
+ def copy_tokenizer_files(source_dir: Path, output_dir: Path) -> None:
72
+ for filename in TOKENIZER_FILES:
73
+ src = source_dir / filename
74
+ if src.exists():
75
+ dst = output_dir / filename
76
+ shutil.copyfile(src, dst)
77
+
78
+
79
+ def main() -> None:
80
+ args = parse_args()
81
+ source_dir = args.source.expanduser().resolve()
82
+ output_dir = args.output.expanduser().resolve()
83
+ output_dir.mkdir(parents=True, exist_ok=True)
84
+
85
+ config = AutoConfig.from_pretrained(source_dir, trust_remote_code=True)
86
+ text_config = config.text_config
87
+
88
+ weights_path = source_dir / "model.safetensors"
89
+ if not weights_path.exists():
90
+ raise FileNotFoundError(f"Could not find {weights_path}; expected a safetensors checkpoint")
91
+
92
+ all_weights = load_file(weights_path)
93
+ text_state = extract_text_state_dict(all_weights)
94
+
95
+ sample_tensor = next(iter(text_state.values()))
96
+ target_dtype = sample_tensor.dtype
97
+
98
+ text_model = AutoModelForCausalLM.from_config(text_config)
99
+ text_model = text_model.to(dtype=target_dtype, device=torch.device("cpu"))
100
+ missing, unexpected = text_model.load_state_dict(text_state, strict=False)
101
+ if missing or unexpected:
102
+ raise RuntimeError(
103
+ "State dict mismatch when loading text weights: "
104
+ f"missing={missing}, unexpected={unexpected}"
105
+ )
106
+
107
+ text_config.save_pretrained(output_dir)
108
+ text_model.generation_config.save_pretrained(output_dir)
109
+ text_model.save_pretrained(output_dir, safe_serialization=args.safe_serialization)
110
+
111
+ copy_tokenizer_files(source_dir, output_dir)
112
+ print(f"Exported Qwen3 model saved to {output_dir}")
113
+
114
+
115
+ modelpath = output_dir
116
+ llm = RKLLM()
117
+
118
+ ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
119
+ if ret != 0:
120
+ print('Load model failed!')
121
+ exit(ret)
122
+
123
+ qparams = None
124
+ ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
125
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
126
+
127
+ if ret != 0:
128
+ print('Build model failed!')
129
+ exit(ret)
130
+
131
+ # Export rkllm model
132
+ ret = llm.export_rkllm("./language_model_w8a8.rkllm")
133
+ if ret != 0:
134
+ print('Export model failed!')
135
+ exit(ret)
136
+
137
+
138
+
139
+ if __name__ == "__main__":
140
+ main()
141
+
rkllm_binding.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import enum
3
+ import os
4
+
5
+ # Define constants from the header
6
+ CPU0 = (1 << 0) # 0x01
7
+ CPU1 = (1 << 1) # 0x02
8
+ CPU2 = (1 << 2) # 0x04
9
+ CPU3 = (1 << 3) # 0x08
10
+ CPU4 = (1 << 4) # 0x10
11
+ CPU5 = (1 << 5) # 0x20
12
+ CPU6 = (1 << 6) # 0x40
13
+ CPU7 = (1 << 7) # 0x80
14
+
15
+ # --- Enums ---
16
+ class LLMCallState(enum.IntEnum):
17
+ RKLLM_RUN_NORMAL = 0
18
+ RKLLM_RUN_WAITING = 1
19
+ RKLLM_RUN_FINISH = 2
20
+ RKLLM_RUN_ERROR = 3
21
+
22
+ class RKLLMInputType(enum.IntEnum):
23
+ RKLLM_INPUT_PROMPT = 0
24
+ RKLLM_INPUT_TOKEN = 1
25
+ RKLLM_INPUT_EMBED = 2
26
+ RKLLM_INPUT_MULTIMODAL = 3
27
+
28
+ class RKLLMInferMode(enum.IntEnum):
29
+ RKLLM_INFER_GENERATE = 0
30
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
31
+ RKLLM_INFER_GET_LOGITS = 2
32
+
33
+ # --- Structures ---
34
+ class RKLLMExtendParam(ctypes.Structure):
35
+ base_domain_id: ctypes.c_int32
36
+ embed_flash: ctypes.c_int8
37
+ enabled_cpus_num: ctypes.c_int8
38
+ enabled_cpus_mask: ctypes.c_uint32
39
+ n_batch: ctypes.c_uint8
40
+ use_cross_attn: ctypes.c_int8
41
+ reserved: ctypes.c_uint8 * 104
42
+
43
+ _fields_ = [
44
+ ("base_domain_id", ctypes.c_int32), # 基础域ID
45
+ ("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
46
+ ("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
47
+ ("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
48
+ ("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
49
+ ("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
50
+ ("reserved", ctypes.c_uint8 * 104) # 保留字段
51
+ ]
52
+
53
+ class RKLLMParam(ctypes.Structure):
54
+ model_path: ctypes.c_char_p
55
+ max_context_len: ctypes.c_int32
56
+ max_new_tokens: ctypes.c_int32
57
+ top_k: ctypes.c_int32
58
+ n_keep: ctypes.c_int32
59
+ top_p: ctypes.c_float
60
+ temperature: ctypes.c_float
61
+ repeat_penalty: ctypes.c_float
62
+ frequency_penalty: ctypes.c_float
63
+ presence_penalty: ctypes.c_float
64
+ mirostat: ctypes.c_int32
65
+ mirostat_tau: ctypes.c_float
66
+ mirostat_eta: ctypes.c_float
67
+ skip_special_token: ctypes.c_bool
68
+ is_async: ctypes.c_bool
69
+ img_start: ctypes.c_char_p
70
+ img_end: ctypes.c_char_p
71
+ img_content: ctypes.c_char_p
72
+ extend_param: RKLLMExtendParam
73
+
74
+ _fields_ = [
75
+ ("model_path", ctypes.c_char_p), # 模型文件路径
76
+ ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
77
+ ("max_new_tokens", ctypes.c_int32), # 最大生成新token数
78
+ ("top_k", ctypes.c_int32), # Top-K采样参数
79
+ ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
80
+ ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
81
+ ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
82
+ ("repeat_penalty", ctypes.c_float), # 重复token惩罚
83
+ ("frequency_penalty", ctypes.c_float), # 频繁token惩罚
84
+ ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
85
+ ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
86
+ ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
87
+ ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
88
+ ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
89
+ ("is_async", ctypes.c_bool), # 是否异步推理
90
+ ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
91
+ ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
92
+ ("img_content", ctypes.c_char_p), # 图像内容指针
93
+ ("extend_param", RKLLMExtendParam) # 扩展参数
94
+ ]
95
+
96
+ class RKLLMLoraAdapter(ctypes.Structure):
97
+ lora_adapter_path: ctypes.c_char_p
98
+ lora_adapter_name: ctypes.c_char_p
99
+ scale: ctypes.c_float
100
+
101
+ _fields_ = [
102
+ ("lora_adapter_path", ctypes.c_char_p),
103
+ ("lora_adapter_name", ctypes.c_char_p),
104
+ ("scale", ctypes.c_float)
105
+ ]
106
+
107
+ class RKLLMEmbedInput(ctypes.Structure):
108
+ embed: ctypes.POINTER(ctypes.c_float)
109
+ n_tokens: ctypes.c_size_t
110
+
111
+ _fields_ = [
112
+ ("embed", ctypes.POINTER(ctypes.c_float)),
113
+ ("n_tokens", ctypes.c_size_t)
114
+ ]
115
+
116
+ class RKLLMTokenInput(ctypes.Structure):
117
+ input_ids: ctypes.POINTER(ctypes.c_int32)
118
+ n_tokens: ctypes.c_size_t
119
+
120
+ _fields_ = [
121
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
122
+ ("n_tokens", ctypes.c_size_t)
123
+ ]
124
+
125
+ class RKLLMMultiModelInput(ctypes.Structure):
126
+ prompt: ctypes.c_char_p
127
+ image_embed: ctypes.POINTER(ctypes.c_float)
128
+ n_image_tokens: ctypes.c_size_t
129
+ n_image: ctypes.c_size_t
130
+ image_width: ctypes.c_size_t
131
+ image_height: ctypes.c_size_t
132
+
133
+ _fields_ = [
134
+ ("prompt", ctypes.c_char_p),
135
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
136
+ ("n_image_tokens", ctypes.c_size_t),
137
+ ("n_image", ctypes.c_size_t),
138
+ ("image_width", ctypes.c_size_t),
139
+ ("image_height", ctypes.c_size_t)
140
+ ]
141
+
142
+ class RKLLMCrossAttnParam(ctypes.Structure):
143
+ """
144
+ 交叉注意力参数结构体
145
+
146
+ 该结构体用于在解码器中执行交叉注意力时使用。
147
+ 它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
148
+
149
+ - encoder_k_cache必须存储在连续内存中,布局为:
150
+ [num_layers][num_tokens][num_kv_heads][head_dim]
151
+ - encoder_v_cache必须存储在连续内存中,布局为:
152
+ [num_layers][num_kv_heads][head_dim][num_tokens]
153
+ """
154
+ encoder_k_cache: ctypes.POINTER(ctypes.c_float)
155
+ encoder_v_cache: ctypes.POINTER(ctypes.c_float)
156
+ encoder_mask: ctypes.POINTER(ctypes.c_float)
157
+ encoder_pos: ctypes.POINTER(ctypes.c_int32)
158
+ num_tokens: ctypes.c_int
159
+
160
+ _fields_ = [
161
+ ("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
162
+ ("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
163
+ ("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
164
+ ("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
165
+ ("num_tokens", ctypes.c_int) # 编码器序列中的token数量
166
+ ]
167
+
168
+ class RKLLMPerfStat(ctypes.Structure):
169
+ """
170
+ 性能统计结构体
171
+
172
+ 用于保存预填充和生成阶段的性能统计信息。
173
+ """
174
+ prefill_time_ms: ctypes.c_float
175
+ prefill_tokens: ctypes.c_int
176
+ generate_time_ms: ctypes.c_float
177
+ generate_tokens: ctypes.c_int
178
+ memory_usage_mb: ctypes.c_float
179
+
180
+ _fields_ = [
181
+ ("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
182
+ ("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
183
+ ("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
184
+ ("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
185
+ ("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
186
+ ]
187
+
188
+ class _RKLLMInputUnion(ctypes.Union):
189
+ prompt_input: ctypes.c_char_p
190
+ embed_input: RKLLMEmbedInput
191
+ token_input: RKLLMTokenInput
192
+ multimodal_input: RKLLMMultiModelInput
193
+
194
+ _fields_ = [
195
+ ("prompt_input", ctypes.c_char_p),
196
+ ("embed_input", RKLLMEmbedInput),
197
+ ("token_input", RKLLMTokenInput),
198
+ ("multimodal_input", RKLLMMultiModelInput)
199
+ ]
200
+
201
+ class RKLLMInput(ctypes.Structure):
202
+ """
203
+ LLM输入结构体
204
+
205
+ 通过联合体表示不同类型的LLM输入。
206
+ """
207
+ role: ctypes.c_char_p
208
+ enable_thinking: ctypes.c_bool
209
+ input_type: ctypes.c_int
210
+ _union_data: _RKLLMInputUnion
211
+
212
+ _fields_ = [
213
+ ("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
214
+ ("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
215
+ ("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
216
+ ("_union_data", _RKLLMInputUnion) # 联合体数据
217
+ ]
218
+ # Properties to make accessing union members easier
219
+ @property
220
+ def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
221
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
222
+ return self._union_data.prompt_input
223
+ raise AttributeError("Not a prompt input")
224
+ @prompt_input.setter
225
+ def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
226
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
227
+ self._union_data.prompt_input = value
228
+ else:
229
+ raise AttributeError("Not a prompt input")
230
+ @property
231
+ def embed_input(self) -> RKLLMEmbedInput:
232
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
233
+ return self._union_data.embed_input
234
+ raise AttributeError("Not an embed input")
235
+ @embed_input.setter
236
+ def embed_input(self, value: RKLLMEmbedInput):
237
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
238
+ self._union_data.embed_input = value
239
+ else:
240
+ raise AttributeError("Not an embed input")
241
+
242
+ @property
243
+ def token_input(self) -> RKLLMTokenInput:
244
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
245
+ return self._union_data.token_input
246
+ raise AttributeError("Not a token input")
247
+ @token_input.setter
248
+ def token_input(self, value: RKLLMTokenInput):
249
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
250
+ self._union_data.token_input = value
251
+ else:
252
+ raise AttributeError("Not a token input")
253
+
254
+ @property
255
+ def multimodal_input(self) -> RKLLMMultiModelInput:
256
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
257
+ return self._union_data.multimodal_input
258
+ raise AttributeError("Not a multimodal input")
259
+ @multimodal_input.setter
260
+ def multimodal_input(self, value: RKLLMMultiModelInput):
261
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
262
+ self._union_data.multimodal_input = value
263
+ else:
264
+ raise AttributeError("Not a multimodal input")
265
+
266
+ class RKLLMLoraParam(ctypes.Structure): # For inference
267
+ lora_adapter_name: ctypes.c_char_p
268
+
269
+ _fields_ = [
270
+ ("lora_adapter_name", ctypes.c_char_p)
271
+ ]
272
+
273
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
274
+ save_prompt_cache: ctypes.c_int # bool-like
275
+ prompt_cache_path: ctypes.c_char_p
276
+
277
+ _fields_ = [
278
+ ("save_prompt_cache", ctypes.c_int), # bool-like
279
+ ("prompt_cache_path", ctypes.c_char_p)
280
+ ]
281
+
282
+ class RKLLMInferParam(ctypes.Structure):
283
+ mode: ctypes.c_int
284
+ lora_params: ctypes.POINTER(RKLLMLoraParam)
285
+ prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
286
+ keep_history: ctypes.c_int # bool-like
287
+
288
+ _fields_ = [
289
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
290
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
291
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
292
+ ("keep_history", ctypes.c_int) # bool-like
293
+ ]
294
+
295
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
296
+ hidden_states: ctypes.POINTER(ctypes.c_float)
297
+ embd_size: ctypes.c_int
298
+ num_tokens: ctypes.c_int
299
+
300
+ _fields_ = [
301
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
302
+ ("embd_size", ctypes.c_int),
303
+ ("num_tokens", ctypes.c_int)
304
+ ]
305
+
306
+ class RKLLMResultLogits(ctypes.Structure):
307
+ logits: ctypes.POINTER(ctypes.c_float)
308
+ vocab_size: ctypes.c_int
309
+ num_tokens: ctypes.c_int
310
+
311
+ _fields_ = [
312
+ ("logits", ctypes.POINTER(ctypes.c_float)),
313
+ ("vocab_size", ctypes.c_int),
314
+ ("num_tokens", ctypes.c_int)
315
+ ]
316
+
317
+ class RKLLMResult(ctypes.Structure):
318
+ """
319
+ LLM推理结果结构体
320
+
321
+ 表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
322
+ """
323
+ text: ctypes.c_char_p
324
+ token_id: ctypes.c_int32
325
+ last_hidden_layer: RKLLMResultLastHiddenLayer
326
+ logits: RKLLMResultLogits
327
+ perf: RKLLMPerfStat
328
+
329
+ _fields_ = [
330
+ ("text", ctypes.c_char_p), # 生成的文本结果
331
+ ("token_id", ctypes.c_int32), # 生成的token ID
332
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
333
+ ("logits", RKLLMResultLogits), # 模型输出的logits
334
+ ("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
335
+ ]
336
+
337
+ # --- Typedefs ---
338
+ LLMHandle = ctypes.c_void_p
339
+
340
+ # --- Callback Function Type ---
341
+ LLMResultCallback = ctypes.CFUNCTYPE(
342
+ ctypes.c_int, # 返回类型:int,表示处理状态
343
+ ctypes.POINTER(RKLLMResult), # LLM结果指针
344
+ ctypes.c_void_p, # 用户数据指针
345
+ ctypes.c_int # LLM调用状态(LLMCallState枚举值)
346
+ )
347
+ """
348
+ 回调函数类型定义
349
+
350
+ 用于处理LLM结果的回调函数。
351
+
352
+ 参数:
353
+ - result: 指向LLM结果的指针
354
+ - userdata: 回调的用户数据指针
355
+ - state: LLM调用状态(例如:完成、错误)
356
+
357
+ 返回值:
358
+ - 0: 正常继续推理
359
+ - 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
360
+ 返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
361
+ """
362
+
363
+ class RKLLMRuntime:
364
+ def __init__(self, library_path="./librkllmrt.so"):
365
+ try:
366
+ self.lib = ctypes.CDLL(library_path)
367
+ except OSError as e:
368
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
369
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
370
+ self._setup_functions()
371
+ self.llm_handle = LLMHandle()
372
+ self._c_callback = None # To keep the callback object alive
373
+
374
+ def _setup_functions(self):
375
+ # RKLLMParam rkllm_createDefaultParam();
376
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
377
+ self.lib.rkllm_createDefaultParam.argtypes = []
378
+
379
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
380
+ self.lib.rkllm_init.restype = ctypes.c_int
381
+ self.lib.rkllm_init.argtypes = [
382
+ ctypes.POINTER(LLMHandle),
383
+ ctypes.POINTER(RKLLMParam),
384
+ LLMResultCallback
385
+ ]
386
+
387
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
388
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
389
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
390
+
391
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
392
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
393
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
394
+
395
+ # int rkllm_release_prompt_cache(LLMHandle handle);
396
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
397
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
398
+
399
+ # int rkllm_destroy(LLMHandle handle);
400
+ self.lib.rkllm_destroy.restype = ctypes.c_int
401
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
402
+
403
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
404
+ self.lib.rkllm_run.restype = ctypes.c_int
405
+ self.lib.rkllm_run.argtypes = [
406
+ LLMHandle,
407
+ ctypes.POINTER(RKLLMInput),
408
+ ctypes.POINTER(RKLLMInferParam),
409
+ ctypes.c_void_p # userdata
410
+ ]
411
+
412
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
413
+ # Assuming async also takes userdata for the callback context
414
+ self.lib.rkllm_run_async.restype = ctypes.c_int
415
+ self.lib.rkllm_run_async.argtypes = [
416
+ LLMHandle,
417
+ ctypes.POINTER(RKLLMInput),
418
+ ctypes.POINTER(RKLLMInferParam),
419
+ ctypes.c_void_p # userdata
420
+ ]
421
+
422
+ # int rkllm_abort(LLMHandle handle);
423
+ self.lib.rkllm_abort.restype = ctypes.c_int
424
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
425
+
426
+ # int rkllm_is_running(LLMHandle handle);
427
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
428
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
429
+
430
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
431
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
432
+ self.lib.rkllm_clear_kv_cache.argtypes = [
433
+ LLMHandle,
434
+ ctypes.c_int,
435
+ ctypes.POINTER(ctypes.c_int), # start_pos
436
+ ctypes.POINTER(ctypes.c_int) # end_pos
437
+ ]
438
+
439
+ # int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
440
+ self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
441
+ self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
442
+
443
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
444
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
445
+ self.lib.rkllm_set_chat_template.argtypes = [
446
+ LLMHandle,
447
+ ctypes.c_char_p,
448
+ ctypes.c_char_p,
449
+ ctypes.c_char_p
450
+ ]
451
+
452
+ # int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
453
+ self.lib.rkllm_set_function_tools.restype = ctypes.c_int
454
+ self.lib.rkllm_set_function_tools.argtypes = [
455
+ LLMHandle,
456
+ ctypes.c_char_p, # system_prompt
457
+ ctypes.c_char_p, # tools
458
+ ctypes.c_char_p # tool_response_str
459
+ ]
460
+
461
+ # int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
462
+ self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
463
+ self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
464
+
465
+ def create_default_param(self) -> RKLLMParam:
466
+ """Creates a default RKLLMParam structure."""
467
+ return self.lib.rkllm_createDefaultParam()
468
+
469
+ def init(self, param: RKLLMParam, callback_func) -> int:
470
+ """
471
+ Initializes the LLM.
472
+ :param param: RKLLMParam structure.
473
+ :param callback_func: A Python function that matches the signature:
474
+ def my_callback(result_ptr, userdata_ptr, state_enum):
475
+ result = result_ptr.contents # RKLLMResult
476
+ # Process result
477
+ # userdata can be retrieved if passed during run, or ignored
478
+ # state = LLMCallState(state_enum)
479
+ :return: 0 for success, non-zero for failure.
480
+ """
481
+ if not callable(callback_func):
482
+ raise ValueError("callback_func must be a callable Python function.")
483
+
484
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected
485
+ self._c_callback = LLMResultCallback(callback_func)
486
+
487
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
488
+ if ret != 0:
489
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
490
+ return ret
491
+
492
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
493
+ """Loads a Lora adapter."""
494
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
495
+ if ret != 0:
496
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
497
+ return ret
498
+
499
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
500
+ """Loads a prompt cache from a file."""
501
+ c_path = prompt_cache_path.encode('utf-8')
502
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
503
+ if ret != 0:
504
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
505
+ return ret
506
+
507
+ def release_prompt_cache(self) -> int:
508
+ """Releases the prompt cache from memory."""
509
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
510
+ if ret != 0:
511
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
512
+ return ret
513
+
514
+ def destroy(self) -> int:
515
+ """Destroys the LLM instance and releases resources."""
516
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
517
+ ret = self.lib.rkllm_destroy(self.llm_handle)
518
+ self.llm_handle = LLMHandle() # Reset handle
519
+ if ret != 0:
520
+ # Don't raise here as it might be called in __del__
521
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
522
+ return ret
523
+ return 0 # Already destroyed or not initialized
524
+
525
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
526
+ """Runs an LLM inference task synchronously."""
527
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
528
+ # then cast to c_void_p. Or simply None.
529
+ if userdata is not None:
530
+ # Store the userdata object to keep it alive during the call
531
+ self._userdata_ref = userdata
532
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
533
+ else:
534
+ c_userdata = None
535
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
536
+ if ret != 0:
537
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
538
+ return ret
539
+
540
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
541
+ """Runs an LLM inference task asynchronously."""
542
+ if userdata is not None:
543
+ # Store the userdata object to keep it alive during the call
544
+ self._userdata_ref = userdata
545
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
546
+ else:
547
+ c_userdata = None
548
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
549
+ if ret != 0:
550
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
551
+ return ret
552
+
553
+ def abort(self) -> int:
554
+ """Aborts an ongoing LLM task."""
555
+ ret = self.lib.rkllm_abort(self.llm_handle)
556
+ if ret != 0:
557
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
558
+ return ret
559
+
560
+ def is_running(self) -> bool:
561
+ """Checks if an LLM task is currently running. Returns True if running."""
562
+ # The C API returns 0 if running, non-zero otherwise.
563
+ # This is a bit counter-intuitive for a boolean "is_running".
564
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
565
+
566
+ def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
567
+ """
568
+ 清除键值缓存
569
+
570
+ 此函数用于清除部分或全部KV缓存。
571
+
572
+ 参数:
573
+ - keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
574
+ 如果提供了特定范围[start_pos, end_pos),此标志将被忽略
575
+ - start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
576
+ - end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
577
+ 如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
578
+ 如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
579
+
580
+ 注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
581
+
582
+ 返回:0表示缓存清除成功,非零表示失败
583
+ """
584
+ # 准备C数组参数
585
+ c_start_pos = None
586
+ c_end_pos = None
587
+
588
+ if start_pos is not None and end_pos is not None:
589
+ if len(start_pos) != len(end_pos):
590
+ raise ValueError("start_pos和end_pos数组长度必须相同")
591
+
592
+ # 创建C数组
593
+ c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
594
+ c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
595
+
596
+ ret = self.lib.rkllm_clear_kv_cache(
597
+ self.llm_handle,
598
+ ctypes.c_int(1 if keep_system_prompt else 0),
599
+ c_start_pos,
600
+ c_end_pos
601
+ )
602
+ if ret != 0:
603
+ raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
604
+ return ret
605
+
606
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
607
+ """Sets the chat template for the LLM."""
608
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
609
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
610
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
611
+
612
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
613
+ if ret != 0:
614
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
615
+ return ret
616
+
617
+ def get_kv_cache_size(self, n_batch: int) -> list:
618
+ """
619
+ 获取给定LLM句柄的键值缓存当前大小
620
+
621
+ 此函数返回当前存储在模型KV缓存中的位置总数。
622
+
623
+ 参数:
624
+ - n_batch: 批次数量,用于确定返回数组的大小
625
+
626
+ 返回:
627
+ - list: 每个批次的缓存大小列表
628
+ """
629
+ # 预分配数组以存储每个批次的缓存大小
630
+ cache_sizes = (ctypes.c_int * n_batch)()
631
+
632
+ ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
633
+ if ret != 0:
634
+ raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
635
+
636
+ # 转换为Python列表
637
+ return [cache_sizes[i] for i in range(n_batch)]
638
+
639
+ def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
640
+ """
641
+ 为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
642
+
643
+ 参数:
644
+ - system_prompt: 定义语言模型上下文或行为的系统提示
645
+ - tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
646
+ - tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
647
+ 允许分词器将工具输出与正常对话轮次分开识别
648
+
649
+ 返回:0表示配置设置成功,非零表示错误
650
+ """
651
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
652
+ c_tools = tools.encode('utf-8') if tools else b""
653
+ c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
654
+
655
+ ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
656
+ if ret != 0:
657
+ raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
658
+ return ret
659
+
660
+ def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
661
+ """
662
+ 为LLM解码器设置交叉注意力参数
663
+
664
+ 参数:
665
+ - cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
666
+ (详见RKLLMCrossAttnParam说明)
667
+
668
+ 返回:0表示参数设置成功,非零表示错误
669
+ """
670
+ ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
671
+ if ret != 0:
672
+ raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
673
+ return ret
674
+
675
+ def __enter__(self):
676
+ return self
677
+
678
+ def __exit__(self, exc_type, exc_val, exc_tb):
679
+ self.destroy()
680
+
681
+ def __del__(self):
682
+ self.destroy() # Ensure resources are freed if object is garbage collected
683
+
684
+ # --- Example Usage (Illustrative) ---
685
+ if __name__ == "__main__":
686
+ # This is a placeholder for how you might use it.
687
+ # You'll need a valid .rkllm model and librkllmrt.so in your path.
688
+
689
+ # Global list to store results from callback for demonstration
690
+ results_buffer = []
691
+
692
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
693
+ """
694
+ 回调函数,由C库调用来处理LLM结果
695
+
696
+ 参数:
697
+ - result_ptr: 指向LLM结果的指针
698
+ - userdata_ptr: 用户数据指针
699
+ - state_enum: LLM调用状态枚举值
700
+
701
+ 返回:
702
+ - 0: 继续推理
703
+ - 1: 暂停推理
704
+ """
705
+ global results_buffer
706
+ state = LLMCallState(state_enum)
707
+ result = result_ptr.contents
708
+
709
+ current_text = ""
710
+ if result.text: # 检查char_p是否不为NULL
711
+ current_text = result.text.decode('utf-8', errors='ignore')
712
+
713
+ print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
714
+
715
+ # 显示性能统计信息
716
+ if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0:
717
+ print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, "
718
+ f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, "
719
+ f"内存={result.perf.memory_usage_mb:.1f}MB")
720
+
721
+ results_buffer.append(current_text)
722
+
723
+ if state == LLMCallState.RKLLM_RUN_FINISH:
724
+ print("推理完成。")
725
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
726
+ print("推理错误。")
727
+
728
+ # 返回0继续推理,返回1暂停推理
729
+ return 0
730
+
731
+ # --- Attempt to use the wrapper ---
732
+ try:
733
+ print("Initializing RKLLMRuntime...")
734
+ # Adjust library_path if librkllmrt.so is not in default search paths
735
+ # e.g., library_path="./path/to/librkllmrt.so"
736
+ rk_llm = RKLLMRuntime()
737
+
738
+ print("Creating default parameters...")
739
+ params = rk_llm.create_default_param()
740
+
741
+ # --- Configure parameters ---
742
+ # THIS IS CRITICAL: model_path must point to an actual .rkllm file
743
+ # For this example to run, you need a model file.
744
+ # Let's assume a dummy path for now, this will fail at init if not valid.
745
+ model_file = "dummy_model.rkllm"
746
+ if not os.path.exists(model_file):
747
+ print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
748
+ # Create a dummy file for the example to proceed further, though init will still fail
749
+ # with a real library unless it's a valid model.
750
+ with open(model_file, "w") as f:
751
+ f.write("dummy content")
752
+
753
+ params.model_path = model_file.encode('utf-8')
754
+ params.max_context_len = 512
755
+ params.max_new_tokens = 128
756
+ params.top_k = 1 # Greedy
757
+ params.temperature = 0.7
758
+ params.repeat_penalty = 1.1
759
+ # ... set other params as needed
760
+
761
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
762
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
763
+ try:
764
+ rk_llm.init(params, my_python_callback)
765
+ print("LLM Initialized.")
766
+ except RuntimeError as e:
767
+ print(f"Error during LLM initialization: {e}")
768
+ print("This is expected if 'dummy_model.rkllm' is not a valid model.")
769
+ print("Replace 'dummy_model.rkllm' with a real model path to test further.")
770
+ exit()
771
+
772
+
773
+ # --- Prepare input ---
774
+ print("准备输入...")
775
+ rk_input = RKLLMInput()
776
+ rk_input.role = b"user" # 设置角色为用户输入
777
+ rk_input.enable_thinking = False # 禁用思考模式(适用于Qwen3模型)
778
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
779
+
780
+ prompt_text = "将以下英文文本翻译成中文:'Hello, world!'"
781
+ c_prompt = prompt_text.encode('utf-8')
782
+ rk_input._union_data.prompt_input = c_prompt # 直接访问联合体成员
783
+
784
+ # --- Prepare inference parameters ---
785
+ print("Preparing inference parameters...")
786
+ infer_params = RKLLMInferParam()
787
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
788
+ infer_params.keep_history = 1 # True
789
+ # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
790
+ # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
791
+
792
+ # --- Run inference ---
793
+ print(f"Running inference with prompt: '{prompt_text}'")
794
+ results_buffer.clear()
795
+ try:
796
+ rk_llm.run(rk_input, infer_params) # Userdata is None by default
797
+ print("\n--- Full Response ---")
798
+ print("".join(results_buffer))
799
+ print("---------------------\n")
800
+ except RuntimeError as e:
801
+ print(f"Error during LLM run: {e}")
802
+
803
+
804
+ # --- Example: Set chat template (if model supports it) ---
805
+ # print("Setting chat template...")
806
+ # try:
807
+ # rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
808
+ # print("Chat template set.")
809
+ # except RuntimeError as e:
810
+ # print(f"Error setting chat template: {e}")
811
+
812
+ # --- Example: Clear KV Cache ---
813
+ # print("Clearing KV cache (keeping system prompt if any)...")
814
+ # try:
815
+ # rk_llm.clear_kv_cache(keep_system_prompt=True)
816
+ # print("KV cache cleared.")
817
+ # except RuntimeError as e:
818
+ # print(f"Error clearing KV cache: {e}")
819
+
820
+ # --- 示例:获取KV缓存大小 ---
821
+ # print("获取KV缓存大小...")
822
+ # try:
823
+ # cache_sizes = rk_llm.get_kv_cache_size(n_batch=1) # 假设批次大小为1
824
+ # print(f"当前KV缓存大小: {cache_sizes}")
825
+ # except RuntimeError as e:
826
+ # print(f"获取KV缓存大小错误: {e}")
827
+
828
+ # --- 示例:设置函数工具 ---
829
+ # print("设置函数调用工具...")
830
+ # try:
831
+ # system_prompt = "你是一个有用的助手,可以调用提供的函��来帮助用户。"
832
+ # tools = '''[{
833
+ # "name": "get_weather",
834
+ # "description": "获取指定城市的天气信息",
835
+ # "parameters": {
836
+ # "type": "object",
837
+ # "properties": {
838
+ # "city": {"type": "string", "description": "城市名称"}
839
+ # },
840
+ # "required": ["city"]
841
+ # }
842
+ # }]'''
843
+ # tool_response_str = "<tool_response>"
844
+ # rk_llm.set_function_tools(system_prompt, tools, tool_response_str)
845
+ # print("函数工具设置成功。")
846
+ # except RuntimeError as e:
847
+ # print(f"设置函数工具错误: {e}")
848
+
849
+ # --- 示例:清除KV缓存(带范围参数) ---
850
+ # print("使用范围参数清除KV缓存...")
851
+ # try:
852
+ # # 清除位置10到20的缓存
853
+ # start_positions = [10] # 批次0的起始位置
854
+ # end_positions = [20] # 批次0的结束位置
855
+ # rk_llm.clear_kv_cache(keep_system_prompt=True, start_pos=start_positions, end_pos=end_positions)
856
+ # print("范围KV缓存清除完成。")
857
+ # except RuntimeError as e:
858
+ # print(f"清除范围KV缓存错误: {e}")
859
+
860
+ except OSError as e:
861
+ print(f"OSError: {e}. Could not load the RKLLM library.")
862
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
863
+ except Exception as e:
864
+ print(f"An unexpected error occurred: {e}")
865
+ finally:
866
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
867
+ print("Destroying LLM instance...")
868
+ rk_llm.destroy()
869
+ print("LLM instance destroyed.")
870
+ if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
871
+ os.remove(model_file) # Clean up dummy file
872
+
873
+ print("Example finished.")
run_rkllm.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faulthandler
2
+ faulthandler.enable()
3
+ import sys
4
+ import os
5
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
6
+ import ctypes
7
+ import argparse
8
+ import cv2
9
+ import numpy as np
10
+ import ztu_somemodelruntime_rknnlite2 as ort
11
+ from rkllm_binding import (
12
+ RKLLMRuntime,
13
+ RKLLMParam,
14
+ RKLLMInput,
15
+ RKLLMInferParam,
16
+ LLMCallState,
17
+ RKLLMInputType,
18
+ RKLLMInferMode,
19
+ RKLLMResult
20
+ )
21
+
22
+ # Constants aligned with InternVL config
23
+ IMAGE_HEIGHT = 448
24
+ IMAGE_WIDTH = 448
25
+ IMAGE_SEQ_LENGTH = 256
26
+ MULTIMODAL_HIDDEN_DIM = 2048
27
+ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
28
+ IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
29
+
30
+ def expand2square(img, background_color):
31
+ """
32
+ Expand the image into a square and fill it with the specified background color.
33
+ """
34
+ height, width, _ = img.shape
35
+ if width == height:
36
+ return img.copy()
37
+
38
+ size = max(width, height)
39
+ square_img = np.full((size, size, 3), background_color, dtype=np.uint8)
40
+
41
+ x_offset = (size - width) // 2
42
+ y_offset = (size - height) // 2
43
+
44
+ square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img
45
+ return square_img
46
+
47
+ def llm_callback(result_ptr, userdata_ptr, state_enum):
48
+ """
49
+ Callback function to handle LLM results.
50
+ """
51
+ state = LLMCallState(state_enum)
52
+ result = result_ptr.contents
53
+
54
+ if state == LLMCallState.RKLLM_RUN_NORMAL:
55
+ if result.text:
56
+ print(result.text.decode('utf-8', errors='ignore'), end='', flush=True)
57
+ elif state == LLMCallState.RKLLM_RUN_FINISH:
58
+ print("\n", flush=True)
59
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
60
+ print("\nrun error", flush=True)
61
+
62
+ return 0
63
+
64
+ def main():
65
+ parser = argparse.ArgumentParser(
66
+ description="Run RKLLM visual language model inference based on the C++ example."
67
+ )
68
+ parser.add_argument("image_path", type=str, help="Path to the input image.")
69
+ parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.")
70
+ parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.")
71
+ parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.")
72
+ parser.add_argument("max_context_len", type=int, help="Maximum context length.")
73
+ # The rknn_core_num is not directly used by onnxruntime in the same way,
74
+ # but we keep it for API consistency with the C++ example.
75
+ # ONNX Runtime will manage its own threading and execution providers.
76
+ parser.add_argument("rknn_core_num", type=int, help="Sets the number of npu cores used in vision encoder.")
77
+
78
+ args = parser.parse_args()
79
+
80
+ # --- 1. Initialize Image Encoder (ONNX Runtime) ---
81
+ print("Initializing ONNX Runtime for vision encoder...")
82
+ try:
83
+ sess_options = ort.SessionOptions()
84
+ sess_options.intra_op_num_threads = args.rknn_core_num
85
+ ort_session = ort.InferenceSession(args.encoder_model_path, sess_options=sess_options)
86
+ except Exception as e:
87
+ print(f"Failed to load ONNX model: {e}")
88
+ sys.exit(1)
89
+ print("Vision encoder loaded successfully.")
90
+
91
+ input_name = ort_session.get_inputs()[0].name
92
+ output_name = ort_session.get_outputs()[0].name
93
+ print(f"ONNX Input: {input_name}, ONNX Output: {output_name}")
94
+
95
+ # --- 2. Initialize LLM ---
96
+ print("Initializing RKLLM Runtime...")
97
+ rk_llm = RKLLMRuntime()
98
+ param = rk_llm.create_default_param()
99
+
100
+ param.model_path = args.llm_model_path.encode('utf-8')
101
+ param.top_k = 1
102
+ param.max_new_tokens = args.max_new_tokens
103
+ param.max_context_len = args.max_context_len
104
+ param.skip_special_token = True
105
+ param.img_start = b"<img>"
106
+ param.img_end = b"</img>\n"
107
+ param.img_content = b""
108
+ param.extend_param.base_domain_id = 1
109
+
110
+ try:
111
+ rk_llm.init(param, llm_callback)
112
+ print("RKLLM initialized successfully.")
113
+ except RuntimeError as e:
114
+ print(f"RKLLM init failed: {e}")
115
+ sys.exit(1)
116
+
117
+ # --- 3. Image Preprocessing ---
118
+ print("Preprocessing image...")
119
+ img = cv2.imread(args.image_path)
120
+ if img is None:
121
+ print(f"Failed to read image from {args.image_path}")
122
+ sys.exit(1)
123
+
124
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
125
+
126
+ background_color = (127.5, 127.5, 127.5) # Keep close to official preprocessing
127
+ square_img = expand2square(img, background_color)
128
+ resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR)
129
+
130
+ # Normalize and prepare for ONNX model
131
+ input_tensor = resized_img.astype(np.float32)
132
+ # Normalize using InternVL vision config statistics
133
+ input_tensor = (input_tensor / 255.0 - IMAGENET_MEAN) / IMAGENET_STD
134
+ # Convert to NCHW format
135
+ input_tensor = np.transpose(input_tensor, (2, 0, 1)) # HWC -> CHW
136
+ input_tensor = np.expand_dims(input_tensor, axis=0) # Add batch dimension -> (1, 3, 448, 448)
137
+
138
+ # --- 4. Run Image Encoder ---
139
+ print("Running vision encoder...")
140
+ import time
141
+ start_time = time.time()
142
+ try:
143
+ img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0]
144
+ if img_vec_output.ndim != 3:
145
+ raise RuntimeError(f"Unexpected encoder output shape {img_vec_output.shape}, expected (batch, tokens, hidden)")
146
+ if img_vec_output.shape[-1] != MULTIMODAL_HIDDEN_DIM:
147
+ print(f"Warning: hidden dim {img_vec_output.shape[-1]} differs from expected {MULTIMODAL_HIDDEN_DIM}")
148
+ if img_vec_output.shape[1] != IMAGE_SEQ_LENGTH:
149
+ print(f"Warning: token count {img_vec_output.shape[1]} differs from expected {IMAGE_SEQ_LENGTH}")
150
+ elapsed_time = time.time() - start_time
151
+ print(f"视觉编码器推理耗时: {elapsed_time:.4f} 秒")
152
+ # The output from C++ is a flat float array. Let's flatten the ONNX output.
153
+ img_vec = img_vec_output.flatten().astype(np.float32)
154
+
155
+ except Exception as e:
156
+ print(f"Failed to run vision encoder inference: {e}")
157
+ rk_llm.destroy()
158
+ sys.exit(1)
159
+
160
+ print("Image encoded successfully.")
161
+
162
+ # --- 5. Interactive Chat Loop ---
163
+ rkllm_infer_params = RKLLMInferParam()
164
+ rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
165
+ rkllm_infer_params.keep_history = 1
166
+
167
+ # Set chat template
168
+ # Looks the default template parsed by RKLLM gives better result than this one, don't know why.
169
+
170
+ # rk_llm.set_chat_template(
171
+ # system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
172
+ # prompt_prefix="<|im_start|>user\n",
173
+ # prompt_postfix="<|im_end|>\n<|im_start|>assistant\n"
174
+ # )
175
+
176
+ pre_input = [
177
+ "<image>What is in the image?",
178
+ "<image>这张图片中有什么?"
179
+ ]
180
+ print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n")
181
+ for i, p in enumerate(pre_input):
182
+ print(f"[{i}] {p}")
183
+ print("\n*************************************************************************\n")
184
+
185
+ try:
186
+ while True:
187
+ print("\nuser: ", end="", flush=True)
188
+ input_str = sys.stdin.readline().strip()
189
+
190
+ if not input_str:
191
+ continue
192
+ if input_str == "exit":
193
+ break
194
+ if input_str == "clear":
195
+ try:
196
+ rk_llm.clear_kv_cache(keep_system_prompt=True)
197
+ print("KV cache cleared.")
198
+ except RuntimeError as e:
199
+ print(f"Failed to clear KV cache: {e}")
200
+ continue
201
+
202
+ try:
203
+ idx = int(input_str)
204
+ if 0 <= idx < len(pre_input):
205
+ input_str = pre_input[idx]
206
+ print(input_str)
207
+ except (ValueError, IndexError):
208
+ pass # Use the raw string if not a valid index
209
+
210
+ rkllm_input = RKLLMInput()
211
+ rkllm_input.role = b"user"
212
+
213
+ print("robot: ", end="", flush=True)
214
+
215
+ if "<image>" in input_str:
216
+ rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
217
+
218
+ # Setup multimodal input
219
+ rkllm_input.multimodal_input.prompt = input_str.encode('utf-8')
220
+ rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
221
+ rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[1]
222
+ print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens)
223
+ rkllm_input.multimodal_input.n_image = 1
224
+ rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT
225
+ rkllm_input.multimodal_input.image_width = IMAGE_WIDTH
226
+ else:
227
+ rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
228
+ rkllm_input.prompt_input = input_str.encode('utf-8')
229
+
230
+ try:
231
+ rk_llm.run(rkllm_input, rkllm_infer_params)
232
+ except RuntimeError as e:
233
+ print(f"\nError during rkllm_run: {e}")
234
+
235
+ except KeyboardInterrupt:
236
+ print("\nExiting...")
237
+ finally:
238
+ print("Releasing resources...")
239
+ rk_llm.destroy()
240
+ print("RKLLM instance destroyed.")
241
+
242
+ if __name__ == "__main__":
243
+ main()
test.jpg ADDED

Git LFS Details

  • SHA256: a4cd7f45ac1ce27eaafb254b23af7c0b18a064be08870ceaaf03b2147f2ce550
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
vision_encoder.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69670e8e48938fcd2543c5cae22e7789d5783b525a34a7013edeba724744c461
3
+ size 674706120
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ try:
10
+ import onnxruntime as ort
11
+ HAS_ORT = True
12
+ except ImportError:
13
+ HAS_ORT = False
14
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
18
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
19
+ if not logger.handlers:
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
22
+ logger.addHandler(handler)
23
+
24
+ # ONNX Runtime日志级别到Python logging级别的映射
25
+ _LOGGING_LEVEL_MAP = {
26
+ 0: logging.DEBUG, # Verbose
27
+ 1: logging.INFO, # Info
28
+ 2: logging.WARNING, # Warning
29
+ 3: logging.ERROR, # Error
30
+ 4: logging.CRITICAL # Fatal
31
+ }
32
+
33
+ # 检查环境变量中的日志级别设置
34
+ try:
35
+ env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
36
+ if env_log_level is not None:
37
+ log_level = int(env_log_level)
38
+ if log_level in _LOGGING_LEVEL_MAP:
39
+ logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
40
+ logger.info(f"从环境变量设置日志级别: {log_level}")
41
+ else:
42
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
43
+ except ValueError:
44
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
45
+
46
+
47
+ def set_default_logger_severity(level: int) -> None:
48
+ """
49
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
50
+
51
+ Args:
52
+ level: 日志级别(0-4)
53
+ """
54
+ if level not in _LOGGING_LEVEL_MAP:
55
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
56
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
57
+
58
+ def set_default_logger_verbosity(level: int) -> None:
59
+ """
60
+ Sets the default logging verbosity level. To activate the verbose log,
61
+ you need to set the default logging severity to 0:Verbose level.
62
+
63
+ Args:
64
+ level: 日志级别(0-4)
65
+ """
66
+ set_default_logger_severity(level)
67
+
68
+ # RKNN tensor type到numpy dtype的映射
69
+ RKNN_DTYPE_MAP = {
70
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
71
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
72
+ 2: np.int8, # RKNN_TENSOR_INT8
73
+ 3: np.uint8, # RKNN_TENSOR_UINT8
74
+ 4: np.int16, # RKNN_TENSOR_INT16
75
+ 5: np.uint16, # RKNN_TENSOR_UINT16
76
+ 6: np.int32, # RKNN_TENSOR_INT32
77
+ 7: np.uint32, # RKNN_TENSOR_UINT32
78
+ 8: np.int64, # RKNN_TENSOR_INT64
79
+ 9: bool, # RKNN_TENSOR_BOOL
80
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
81
+ }
82
+
83
+ def get_available_providers() -> List[str]:
84
+ """
85
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
86
+
87
+ Returns:
88
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
89
+ """
90
+ return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
91
+
92
+
93
+ def get_device() -> str:
94
+ """
95
+ 获取当前设备
96
+
97
+ Returns:
98
+ str: 当前设备
99
+ """
100
+ return "RKNN2"
101
+
102
+ def get_version_info() -> Dict[str, str]:
103
+ """
104
+ 获取版本信息
105
+
106
+ Returns:
107
+ dict: 包含API和驱动版本信息的字典
108
+ """
109
+ runtime = RKNNLite()
110
+ version = runtime.get_sdk_version()
111
+ return {
112
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
113
+ "driver_version": version.split('\n')[3].split(': ')[1]
114
+ }
115
+
116
+ class IOTensor:
117
+ """输入/输出张量的信息封装类"""
118
+ def __init__(self, name, shape, type=None):
119
+ self.name = name.decode() if isinstance(name, bytes) else name
120
+ self.shape = shape
121
+ self.type = type
122
+
123
+ def __str__(self):
124
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
125
+
126
+ class SessionOptions:
127
+ """会话选项类"""
128
+ def __init__(self):
129
+ self.enable_profiling = False # 是否使用性能分析
130
+ self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
131
+ self.log_severity_level = -1 # 另一个设置日志级别的参数
132
+ self.log_verbosity_level = -1 # 另一个设置日志级别的参数
133
+
134
+
135
+ class InferenceSession:
136
+ """
137
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
138
+ """
139
+
140
+ def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
141
+ processed_path = InferenceSession._process_model_path(model_path, sess_options)
142
+ if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
143
+ logger.info("使用ONNX Runtime加载模型")
144
+ if not HAS_ORT:
145
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
146
+ return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
147
+ else:
148
+ # 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
149
+ instance = super().__new__(cls)
150
+ # 保存处理后的路径
151
+ instance._processed_path = processed_path
152
+ return instance
153
+
154
+ def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
155
+ """
156
+ 初始化运行时并加载模型
157
+
158
+ Args:
159
+ model_path: 模型文件路径(.rknn或.onnx)
160
+ sess_options: 会话选项
161
+ **kwargs: 其他初始化参数
162
+ """
163
+ options = sess_options or SessionOptions()
164
+
165
+ # 只在未设置环境变量时使用SessionOptions中的日志级别
166
+ if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
167
+ if options.log_severity_level != -1:
168
+ set_default_logger_severity(options.log_severity_level)
169
+ if options.log_verbosity_level != -1:
170
+ set_default_logger_verbosity(options.log_verbosity_level)
171
+
172
+ # 使用__new__中处理好的路径
173
+ model_path = getattr(self, '_processed_path', model_path)
174
+ if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
175
+ # 避免重复加载 ONNX 模型
176
+ return
177
+
178
+ # ... 现有的 RKNN 模型加载和初始化代码 ...
179
+ self.model_path = model_path
180
+ if not os.path.exists(self.model_path):
181
+ logger.error(f"模型文件不存在: {self.model_path}")
182
+ raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
183
+
184
+ self.runtime = RKNNLite(verbose=options.enable_profiling)
185
+
186
+ logger.debug(f"正在加载模型: {self.model_path}")
187
+ ret = self.runtime.load_rknn(self.model_path)
188
+ if ret != 0:
189
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
190
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
191
+ logger.debug("模型加载成功")
192
+
193
+
194
+ if options.intra_op_num_threads == 1:
195
+ core_mask = RKNNLite.NPU_CORE_AUTO
196
+ elif options.intra_op_num_threads == 2:
197
+ core_mask = RKNNLite.NPU_CORE_0_1
198
+ elif options.intra_op_num_threads == 3:
199
+ core_mask = RKNNLite.NPU_CORE_0_1_2
200
+ else:
201
+ raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
202
+
203
+ logger.debug("正在初始化运行时环境")
204
+ ret = self.runtime.init_runtime(core_mask=core_mask)
205
+ if ret != 0:
206
+ logger.error("初始化运行时环境失败")
207
+ raise RuntimeError('初始化运行时环境失败')
208
+
209
+ logger.debug("运行时环境初始化成功")
210
+
211
+ # 在 runtime 初始化后,按环境变量自动注册自定义算子插件库
212
+ try:
213
+ # 注册用户指定路径插件(逗号/分号分隔)
214
+ env_custom = os.getenv('ZTU_MODELRT_RKNN2_REG_CUSTOM_OP_LIB', '').strip()
215
+ if env_custom:
216
+ paths = [seg.strip() for seg in re.split(r"[,;:]", env_custom) if seg.strip()]
217
+ ok = 0
218
+ for p in paths:
219
+ if self.register_custom_op_lib(p):
220
+ ok += 1
221
+ if ok > 0:
222
+ logger.info(f"已注册 {ok}/{len(paths)} 个自定义算子插件")
223
+ # 注册系统目录下插件
224
+ if os.getenv('ZTU_MODELRT_RKNN2_REG_SYSTEM_CUSTOM_OP_LIB', '1') == '1':
225
+ cnt = self.register_system_custom_op_lib()
226
+ if cnt > 0:
227
+ logger.info(f"已从系统目录注册 {cnt} 个自定义算子插件")
228
+ except Exception as e:
229
+ logger.warning(f"自动注册自定义算子插件失败: {e}")
230
+
231
+ # 可选:按环境变量注册内置(基于Python)捆绑算子
232
+ if os.getenv('ZTU_MODELRT_RKNN2_REG_BUNDLED_OPS', '0') == '1':
233
+ logger.info("根据环境变量注册捆绑算子")
234
+ self.register_bundled_ops()
235
+
236
+ self._init_io_info()
237
+ self.options = options
238
+
239
+ def get_performance_info(self) -> Dict[str, float]:
240
+ """
241
+ 获取性能信息
242
+
243
+ Returns:
244
+ dict: 包含性能信息的字典
245
+ """
246
+ if not self.options.perf_debug:
247
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
248
+
249
+ perf = self.runtime.rknn_runtime.get_run_perf()
250
+ return {
251
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
252
+ }
253
+
254
+ def set_core_mask(self, core_mask: int) -> None:
255
+ """
256
+ 设置NPU核心使用模式
257
+
258
+ Args:
259
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
260
+ """
261
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
262
+ if ret != 0:
263
+ raise RuntimeError("设置NPU核心��式失败")
264
+
265
+ @staticmethod
266
+ def _process_model_path(model_path, sess_options):
267
+ """
268
+ 处理模型路径,支持.onnx和.rknn文件
269
+
270
+ Args:
271
+ model_path: 模型文件路径
272
+ """
273
+ # 如果是ONNX文件,检查是否需要自动加载RKNN
274
+ if model_path.lower().endswith('.onnx'):
275
+ logger.info("检测到ONNX模型文件")
276
+
277
+ # 获取需要跳过自动加载的模型列表
278
+ skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
279
+ if skip_models:
280
+ skip_list = [m.strip() for m in skip_models.split(',')]
281
+ # 获取模型文件名(不含路径)用于匹配
282
+ model_name = os.path.basename(model_path)
283
+ if model_name.lower() in [m.lower() for m in skip_list]:
284
+ logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
285
+ return model_path
286
+
287
+ # 构造RKNN文件路径
288
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
289
+ if os.path.exists(rknn_path):
290
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
291
+ return rknn_path
292
+ else:
293
+ logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
294
+ return model_path
295
+
296
+ return model_path
297
+
298
+ def _convert_nhwc_to_nchw(self, shape):
299
+ """将NHWC格式的shape转换为NCHW格式"""
300
+ if len(shape) == 4:
301
+ # NHWC -> NCHW
302
+ n, h, w, c = shape
303
+ return [n, c, h, w]
304
+ return shape
305
+
306
+ def _init_io_info(self):
307
+ """初始化模型的输入输出信息"""
308
+ runtime = self.runtime.rknn_runtime
309
+
310
+ # 获取输入输出数量
311
+ n_input, n_output = runtime.get_in_out_num()
312
+
313
+ # 获取输入信息
314
+ self.input_tensors = []
315
+ for i in range(n_input):
316
+ attr = runtime.get_tensor_attr(i)
317
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
318
+ # 对四维输入进行NHWC到NCHW的转换
319
+ shape = self._convert_nhwc_to_nchw(shape)
320
+ # 获取dtype
321
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
322
+ tensor = IOTensor(attr.name, shape, dtype)
323
+ self.input_tensors.append(tensor)
324
+
325
+ # 获取输出信息
326
+ self.output_tensors = []
327
+ for i in range(n_output):
328
+ attr = runtime.get_tensor_attr(i, is_output=True)
329
+ shape = runtime.get_output_shape(i)
330
+ # 获取dtype
331
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
332
+ tensor = IOTensor(attr.name, shape, dtype)
333
+ self.output_tensors.append(tensor)
334
+
335
+ def get_inputs(self):
336
+ """
337
+ 获取模型输入信息
338
+
339
+ Returns:
340
+ list: 包含输入信息的列表
341
+ """
342
+ return self.input_tensors
343
+
344
+ def get_outputs(self):
345
+ """
346
+ 获取模型输出信息
347
+
348
+ Returns:
349
+ list: 包含输出信息的列表
350
+ """
351
+ return self.output_tensors
352
+
353
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
354
+ """
355
+ 执行模型推理
356
+
357
+ Args:
358
+ output_names: 输出节点名称列表,指定需要返回哪些输出
359
+ input_feed: 输入数据字典或列表
360
+ data_format: 输入数据格式,"nchw"或"nhwc"
361
+ **kwargs: 其他运行时参数
362
+
363
+ Returns:
364
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
365
+ """
366
+ if input_feed is None:
367
+ logger.error("input_feed不能为None")
368
+ raise ValueError("input_feed不能为None")
369
+
370
+ # 准备输入数据
371
+ if isinstance(input_feed, dict):
372
+ # 如果是字典,按照模型输入顺序排列
373
+ inputs = []
374
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
375
+ for tensor in self.input_tensors:
376
+ if tensor.name not in input_feed:
377
+ raise ValueError(f"缺少输入: {tensor.name}")
378
+ inputs.append(input_feed[tensor.name])
379
+ elif isinstance(input_feed, (list, tuple)):
380
+ # 如果是列表,确保长度匹配
381
+ if len(input_feed) != len(self.input_tensors):
382
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
383
+ inputs = list(input_feed)
384
+ else:
385
+ logger.error("input_feed必须是字典或列表类型")
386
+ raise ValueError("input_feed必须是字典或列表类型")
387
+
388
+ # 执行推理
389
+ try:
390
+ logger.debug("开始执行推理")
391
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
392
+
393
+ # 如果没有指定output_names,返回所有输出
394
+ if output_names is None:
395
+ return all_outputs
396
+
397
+ # 获取指定的输出
398
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
399
+ selected_outputs = []
400
+ for name in output_names:
401
+ if name not in output_map:
402
+ raise ValueError(f"未找到输出节点: {name}")
403
+ selected_outputs.append(all_outputs[output_map[name]])
404
+
405
+ return selected_outputs
406
+
407
+ except Exception as e:
408
+ logger.error(f"推理执行失败: {str(e)}")
409
+ raise RuntimeError(f"推理执行失败: {str(e)}")
410
+
411
+ def close(self):
412
+ """
413
+ 关闭会话,释放资源
414
+ """
415
+ if self.runtime is not None:
416
+ logger.info("正在释放运行时资源")
417
+ self.runtime.release()
418
+ self.runtime = None
419
+
420
+ def __enter__(self):
421
+ return self
422
+
423
+ def __exit__(self, exc_type, exc_val, exc_tb):
424
+ self.close()
425
+
426
+ def end_profiling(self) -> Optional[str]:
427
+ """
428
+ 结束性能分析的存根方法
429
+
430
+ Returns:
431
+ Optional[str]: None
432
+ """
433
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
434
+ return None
435
+
436
+ def get_profiling_start_time_ns(self) -> int:
437
+ """
438
+ 获取性能分析开始时间的存根方法
439
+
440
+ Returns:
441
+ int: 0
442
+ """
443
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
444
+ return 0
445
+
446
+ def get_modelmeta(self) -> Dict[str, str]:
447
+ """
448
+ 获取模型元数据的存根方法
449
+
450
+ Returns:
451
+ Dict[str, str]: 空字典
452
+ """
453
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
454
+ return {}
455
+
456
+ def get_session_options(self) -> SessionOptions:
457
+ """
458
+ 获取会话选项
459
+
460
+ Returns:
461
+ SessionOptions: 当前会话选项
462
+ """
463
+ return self.options
464
+
465
+ def get_providers(self) -> List[str]:
466
+ """
467
+ 获取当前使用的providers的存根方法
468
+
469
+ Returns:
470
+ List[str]: ["CPUExecutionProvider"]
471
+ """
472
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
473
+ return ["CPUExecutionProvider"]
474
+
475
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
476
+ """
477
+ 获取provider选项的存根方法
478
+
479
+ Returns:
480
+ Dict[str, Dict[str, str]]: 空字典
481
+ """
482
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
483
+ return {}
484
+
485
+ def get_session_config(self) -> Dict[str, str]:
486
+ """
487
+ 获取会话配置的存根方法
488
+
489
+ Returns:
490
+ Dict[str, str]: 空字典
491
+ """
492
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
493
+ return {}
494
+
495
+ def get_session_state(self) -> Dict[str, str]:
496
+ """
497
+ 获取会话状态的存根方法
498
+
499
+ Returns:
500
+ Dict[str, str]: 空字典
501
+ """
502
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
503
+ return {}
504
+
505
+ def set_session_config(self, config: Dict[str, str]) -> None:
506
+ """
507
+ 设置会话配置的存根方法
508
+
509
+ Args:
510
+ config: 会话配置字典
511
+ """
512
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
513
+
514
+ def get_memory_info(self) -> Dict[str, int]:
515
+ """
516
+ 获取内存使用信息的存根方法
517
+
518
+ Returns:
519
+ Dict[str, int]: 空字典
520
+ """
521
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
522
+ return {}
523
+
524
+ def set_memory_pattern(self, enable: bool) -> None:
525
+ """
526
+ 设置内存模式的存根方法
527
+
528
+ Args:
529
+ enable: 是否启用内存模式
530
+ """
531
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
532
+
533
+ def disable_memory_pattern(self) -> None:
534
+ """
535
+ 禁用内存模式的存根方法
536
+ """
537
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
538
+
539
+ def get_optimization_level(self) -> int:
540
+ """
541
+ 获取优化级别的存根方法
542
+
543
+ Returns:
544
+ int: 0
545
+ """
546
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
547
+ return 0
548
+
549
+ def set_optimization_level(self, level: int) -> None:
550
+ """
551
+ 设置优化级别的存根方法
552
+
553
+ Args:
554
+ level: 优化级别
555
+ """
556
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
557
+
558
+ def get_model_metadata(self) -> Dict[str, str]:
559
+ """
560
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
561
+
562
+ Returns:
563
+ Dict[str, str]: 空字典
564
+ """
565
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
566
+ return {}
567
+
568
+ def get_model_path(self) -> str:
569
+ """
570
+ 获取模型路径
571
+
572
+ Returns:
573
+ str: 模型文件路径
574
+ """
575
+ return self.model_path
576
+
577
+ def get_input_type_info(self) -> List[Dict[str, str]]:
578
+ """
579
+ 获取输入类型信息的存根方法
580
+
581
+ Returns:
582
+ List[Dict[str, str]]: 空列表
583
+ """
584
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
585
+ return []
586
+
587
+ def get_output_type_info(self) -> List[Dict[str, str]]:
588
+ """
589
+ 获取输出类型信息的存根方法
590
+
591
+ Returns:
592
+ List[Dict[str, str]]: 空列表
593
+ """
594
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
595
+ return []
596
+
597
+ ################### 自定义算子 ###################
598
+
599
+ def _init_custom_op_types(self):
600
+ """初始化自定义算子的类型定义"""
601
+ # 常量
602
+ self._RKNN_TENSOR_FLOAT32 = 0
603
+ self._RKNN_TENSOR_UINT8 = 3
604
+ self._RKNN_TENSOR_INT64 = 8
605
+ self._RKNN_TARGET_TYPE_CPU = 1
606
+
607
+ # 结构体定义
608
+ class RKNN_TensorAttr(ctypes.Structure):
609
+ _fields_ = [
610
+ ("index", ctypes.c_uint32),
611
+ ("n_dims", ctypes.c_uint32),
612
+ ("dims", ctypes.c_uint32 * RKNN_MAX_DIMS),
613
+ ("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
614
+ ("n_elems", ctypes.c_uint32),
615
+ ("size", ctypes.c_uint32),
616
+ ("fmt", ctypes.c_int),
617
+ ("type", ctypes.c_int),
618
+ ("qnt_type", ctypes.c_int),
619
+ ("fl", ctypes.c_int8),
620
+ ("zp", ctypes.c_int32),
621
+ ("scale", ctypes.c_float),
622
+ ("w_stride", ctypes.c_uint32),
623
+ ("size_with_stride", ctypes.c_uint32),
624
+ ("pass_through", ctypes.c_uint8),
625
+ ("h_stride", ctypes.c_uint32),
626
+ ]
627
+
628
+ class RKNN_TensorMem(ctypes.Structure):
629
+ _fields_ = [
630
+ ("virt_addr", ctypes.c_void_p),
631
+ ("phys_addr", ctypes.c_uint64),
632
+ ("fd", ctypes.c_int32),
633
+ ("offset", ctypes.c_int32),
634
+ ("size", ctypes.c_uint32),
635
+ ("flags", ctypes.c_uint32),
636
+ ("priv_data", ctypes.c_void_p),
637
+ ]
638
+
639
+ class RKNN_CustomOpTensor(ctypes.Structure):
640
+ _fields_ = [
641
+ ("attr", RKNN_TensorAttr),
642
+ ("mem", RKNN_TensorMem),
643
+ ]
644
+
645
+ class RKNN_GPUOpContext(ctypes.Structure):
646
+ _fields_ = [
647
+ ("cl_context", ctypes.c_void_p),
648
+ ("cl_command_queue", ctypes.c_void_p),
649
+ ("cl_kernel", ctypes.c_void_p),
650
+ ]
651
+
652
+ InternalCtxType = (
653
+ ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
654
+ )
655
+
656
+ class RKNN_CustomOpContext(ctypes.Structure):
657
+ _fields_ = [
658
+ ("target", ctypes.c_int),
659
+ ("internal_ctx", InternalCtxType),
660
+ ("gpu_ctx", RKNN_GPUOpContext),
661
+ ("priv_data", ctypes.c_void_p),
662
+ ]
663
+
664
+ class RKNN_CustomOpAttr(ctypes.Structure):
665
+ _fields_ = [
666
+ ("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
667
+ ("dtype", ctypes.c_int),
668
+ ("n_elems", ctypes.c_uint32),
669
+ ("data", ctypes.c_void_p),
670
+ ]
671
+
672
+ CB_SIG = ctypes.CFUNCTYPE(
673
+ ctypes.c_int,
674
+ ctypes.POINTER(RKNN_CustomOpContext),
675
+ ctypes.POINTER(RKNN_CustomOpTensor),
676
+ ctypes.c_uint32,
677
+ ctypes.POINTER(RKNN_CustomOpTensor),
678
+ ctypes.c_uint32,
679
+ )
680
+
681
+ DESTROY_SIG = ctypes.CFUNCTYPE(
682
+ ctypes.c_int, ctypes.POINTER(RKNN_CustomOpContext)
683
+ )
684
+
685
+ class RKNN_CustomOp(ctypes.Structure):
686
+ _fields_ = [
687
+ ("version", ctypes.c_uint32),
688
+ ("target", ctypes.c_int),
689
+ ("op_type", ctypes.c_char * RKNN_MAX_NAME_LEN),
690
+ ("cl_kernel_name", ctypes.c_char * RKNN_MAX_NAME_LEN),
691
+ ("cl_kernel_source", ctypes.c_char_p),
692
+ ("cl_source_size", ctypes.c_uint64),
693
+ ("cl_build_options", ctypes.c_char * RKNN_MAX_NAME_LEN),
694
+ ("init", CB_SIG),
695
+ ("prepare", CB_SIG),
696
+ ("compute", CB_SIG),
697
+ ("compute_native", CB_SIG),
698
+ ("destroy", DESTROY_SIG),
699
+ ]
700
+
701
+ # 保存类型定义
702
+ self._RKNN_TensorAttr = RKNN_TensorAttr
703
+ self._RKNN_TensorMem = RKNN_TensorMem
704
+ self._RKNN_CustomOpTensor = RKNN_CustomOpTensor
705
+ self._RKNN_CustomOpContext = RKNN_CustomOpContext
706
+ self._RKNN_CustomOpAttr = RKNN_CustomOpAttr
707
+ self._RKNN_CustomOp = RKNN_CustomOp
708
+ self._CB_SIG = CB_SIG
709
+ self._DESTROY_SIG = DESTROY_SIG
710
+
711
+ def _create_attr_readers(self, get_op_attr):
712
+ """创建属性读取函数"""
713
+ def read_attr_int64(op_ctx_ptr, key: str, default: int = 0) -> int:
714
+ attr = self._RKNN_CustomOpAttr()
715
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
716
+ if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_INT64 and attr.data:
717
+ return ctypes.c_int64.from_address(attr.data).value
718
+ return default
719
+
720
+ def read_attr_float32(op_ctx_ptr, key: str, default: float = 0) -> float:
721
+ attr = self._RKNN_CustomOpAttr()
722
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
723
+ if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_FLOAT32 and attr.data:
724
+ return ctypes.c_float.from_address(attr.data).value
725
+ return default
726
+
727
+ def read_attr_str(op_ctx_ptr, key: str, default: str = "") -> str:
728
+ attr = self._RKNN_CustomOpAttr()
729
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
730
+ if attr.n_elems > 0 and attr.dtype == self._RKNN_TENSOR_UINT8 and attr.data:
731
+ buf = (ctypes.c_ubyte * attr.n_elems).from_address(attr.data)
732
+ try:
733
+ return bytes(buf).decode("utf-8", errors="ignore").strip('"')
734
+ except Exception:
735
+ return default
736
+ return default
737
+
738
+
739
+ return read_attr_int64, read_attr_str, read_attr_float32
740
+
741
+ def _build_py_custom_op(self,
742
+ op_type: str,
743
+ n_inputs: int,
744
+ n_outputs: int,
745
+ on_init,
746
+ on_compute):
747
+ """通用的Python自定义算子构造器
748
+
749
+ Args:
750
+ op_type: 算子类型名(字符串)
751
+ n_inputs: 输入个数
752
+ n_outputs: 输出个数
753
+ on_init: 回调,签名 on_init(op_ctx_p, read_attr_int64, read_attr_str) -> state
754
+ on_compute: 回调,签名 on_compute(op_ctx_p, inputs_p, outputs_p, state) -> int(0成功)
755
+ Returns:
756
+ (RKNN_CustomOp对象, 回调tuple)
757
+ """
758
+ @self._CB_SIG
759
+ def _py_init(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
760
+ try:
761
+ # 允许无需提前读取属性
762
+ runtime = self.runtime.rknn_base.rknn_runtime
763
+ read_attr_int64, read_attr_str, read_attr_float32 = self._create_attr_readers(runtime.lib.rknn_custom_op_get_op_attr)
764
+ user_state = on_init(op_ctx_p, read_attr_int64, read_attr_str, read_attr_float32)
765
+ # 为该实例分配唯一ID, 并写入priv_data
766
+ if not hasattr(self, "_custom_op_states"):
767
+ self._custom_op_states = {}
768
+ if not hasattr(self, "_next_custom_op_id"):
769
+ self._next_custom_op_id = 1
770
+ inst_id = int(self._next_custom_op_id)
771
+ self._next_custom_op_id += 1
772
+ # 保存Python侧状态
773
+ self._custom_op_states[inst_id] = user_state
774
+ # 将实例ID写入priv_data
775
+ try:
776
+ op_ctx_p.contents.priv_data = ctypes.c_void_p(inst_id)
777
+ except Exception:
778
+ # 回退: 直接写入整数
779
+ op_ctx_p.contents.priv_data = inst_id
780
+ return 0
781
+ except Exception as e:
782
+ logger.error(f"{op_type} init失败: {e}")
783
+ return -1
784
+
785
+ @self._CB_SIG
786
+ def _py_prepare(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
787
+ return 0
788
+
789
+ @self._CB_SIG
790
+ def _py_compute(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
791
+ try:
792
+ if n_inputs_p != n_inputs or n_outputs_p != n_outputs:
793
+ return -1
794
+ # 通过priv_data取回该实例的状态
795
+ try:
796
+ inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
797
+ except Exception:
798
+ inst_id = 0
799
+ user_state = None
800
+ if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
801
+ user_state = self._custom_op_states.get(inst_id)
802
+ else:
803
+ logger.error(f"{op_type} compute失败: 找不到实例状态, inst_id={inst_id}")
804
+ return -1
805
+ return on_compute(op_ctx_p, inputs_p, outputs_p, user_state)
806
+ except Exception as e:
807
+ logger.error(f"{op_type} compute失败: {e}")
808
+ import traceback
809
+ logger.error(f"{op_type} compute失败: {traceback.format_exc()}")
810
+ return -1
811
+
812
+ @self._DESTROY_SIG
813
+ def _py_destroy(op_ctx_p):
814
+ try:
815
+ # 清理该实例的状态
816
+ try:
817
+ inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
818
+ except Exception:
819
+ inst_id = 0
820
+ if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
821
+ del self._custom_op_states[inst_id]
822
+ # 将priv_data清空
823
+ try:
824
+ op_ctx_p.contents.priv_data = ctypes.c_void_p(0)
825
+ except Exception:
826
+ op_ctx_p.contents.priv_data = 0
827
+ return 0
828
+ except Exception:
829
+ return -1
830
+
831
+ op = self._RKNN_CustomOp()
832
+ op.version = 1
833
+ op.target = self._RKNN_TARGET_TYPE_CPU
834
+ op.op_type = op_type.encode("utf-8")
835
+ op.cl_kernel_name = b""
836
+ op.cl_kernel_source = None
837
+ op.cl_source_size = 0
838
+ op.cl_build_options = b""
839
+ op.init = _py_init
840
+ op.prepare = _py_prepare
841
+ op.compute = _py_compute
842
+ op.compute_native = self._CB_SIG() # NULL
843
+ op.destroy = _py_destroy
844
+
845
+ return op, (_py_init, _py_prepare, _py_compute, _py_destroy)
846
+
847
+
848
+ def _tensor_to_numpy(self, rknn_tensor):
849
+ """将 RKNN_CustomOpTensor 转换为 Numpy 数组视图"""
850
+ # 确定Numpy数据类型
851
+ # 您可以扩展这个映射
852
+ dtype_map = {
853
+ self._RKNN_TENSOR_FLOAT32: (ctypes.c_float, np.float32),
854
+ self._RKNN_TENSOR_UINT8: (ctypes.c_uint8, np.uint8),
855
+ self._RKNN_TENSOR_INT64: (ctypes.c_int64, np.int64),
856
+ }
857
+ c_type, np_dtype = dtype_map.get(rknn_tensor.attr.type, (None, None))
858
+ if c_type is None:
859
+ raise TypeError(f"不支持的RKNN张量类型: {rknn_tensor.attr.type}")
860
+
861
+ # 获取内存地址和形状
862
+ addr = (rknn_tensor.mem.virt_addr or 0) + int(rknn_tensor.mem.offset)
863
+ ptr = ctypes.cast(addr, ctypes.POINTER(c_type))
864
+ shape = tuple(rknn_tensor.attr.dims[i] for i in range(rknn_tensor.attr.n_dims))
865
+
866
+ # 创建Numpy数组视图
867
+ return np.ctypeslib.as_array(ptr, shape=shape)
868
+
869
+
870
+ def _create_onnxscript_op_creator(self,
871
+ op_type: str,
872
+ # 现在接收一个"函数模板构造器"
873
+ onnxscript_func_builder,
874
+ n_inputs: int,
875
+ n_outputs: int,
876
+ attributes: dict = {},
877
+ constants: dict = {}):
878
+ """
879
+ 一个高阶工厂函数,用于创建基于ONNXScript的自定义算子构造器。
880
+ 它在 on_init 阶段动态生成最终的 onnxscript 计算函数。
881
+
882
+ Args:
883
+ op_type (str): 算子类型名。
884
+ onnxscript_func_builder: 一个函数,它接收所有属性和常量作为关键字参数,
885
+ 并返回一个编译好的 onnxscript 函数。
886
+ 例如: def builder(mean, scale):
887
+ @onnxscript.script()
888
+ def compute(like):
889
+ return opset.RandomNormalLike(like, mean=mean, scale=scale)
890
+ return compute
891
+ attributes (dict): 从模型中读取的属性字典。
892
+ constants (dict): 编译时常量字典。
893
+ n_inputs (int): 输入个数。
894
+ n_outputs (int): 输出个数。
895
+ """
896
+
897
+ def creator_func():
898
+ def on_init(op_ctx_p, read_i64, read_s, read_f32):
899
+ # 1. 读取所有动态属性
900
+ attr_values = {}
901
+ for name, (attr_type, default) in attributes.items():
902
+ if attr_type == 'int64':
903
+ attr_values[name] = read_i64(op_ctx_p, name, default)
904
+ elif attr_type == 'str':
905
+ attr_values[name] = read_s(op_ctx_p, name, default)
906
+ elif attr_type == 'float32':
907
+ attr_values[name] = read_f32(op_ctx_p, name, default)
908
+ else:
909
+ raise ValueError(f"不支持的属性类型: {attr_type}")
910
+
911
+ # 2. 合并常量和属性
912
+ final_kwargs = {**constants, **attr_values}
913
+
914
+ # 3. 动态构建 onnxscript 函数! <<<<< 核心修改
915
+ # 这确保了所有属性值都作为常量被闭包捕获
916
+ compute_func = onnxscript_func_builder(**final_kwargs)
917
+
918
+ # 4. 将最终生成的、已编译的函数存入 state
919
+ return {"compute_func": compute_func}
920
+
921
+ def on_compute(op_ctx_p, inputs_p, outputs_p, state):
922
+ compute_func = state["compute_func"]
923
+
924
+ input_nps = [self._tensor_to_numpy(inputs_p[i]) for i in range(n_inputs)]
925
+ output_nps = [self._tensor_to_numpy(outputs_p[i]) for i in range(n_outputs)]
926
+
927
+ results = compute_func(*input_nps)
928
+
929
+ if n_outputs == 1:
930
+ result_val = results[0] if isinstance(results, tuple) else results
931
+ output_nps[0][...] = result_val
932
+ else:
933
+ for i in range(n_outputs):
934
+ output_nps[i][...] = results[i]
935
+
936
+ return 0
937
+
938
+ return self._build_py_custom_op(
939
+ op_type=op_type,
940
+ n_inputs=n_inputs,
941
+ n_outputs=n_outputs,
942
+ on_init=on_init,
943
+ on_compute=on_compute
944
+ )
945
+
946
+ return creator_func
947
+
948
+ def _create_gridsample_op(self):
949
+ import onnxscript
950
+ from onnxscript import opset17 as opset
951
+
952
+ def grid_sample_builder(align_corners, mode, padding_mode):
953
+ @onnxscript.script()
954
+ def grid_sample_compute(X, G):
955
+ return opset.GridSample(X, G, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
956
+ return grid_sample_compute
957
+
958
+ grid_sample_creator = self._create_onnxscript_op_creator(
959
+ op_type="GridSample",
960
+ onnxscript_func_builder=grid_sample_builder, # << 传入 builder
961
+ attributes={
962
+ "align_corners": ("int64", 0),
963
+ "mode": ("str", "bilinear"),
964
+ "padding_mode": ("str", "zeros"),
965
+ },
966
+ n_inputs = 2,
967
+ n_outputs = 1
968
+ )
969
+ return grid_sample_creator
970
+
971
+ def _create_scatterelements_op(self):
972
+ import onnxscript
973
+ from onnxscript import opset17 as opset
974
+
975
+ @onnxscript.script()
976
+ def scatter_elements_compute(data, indices, updates):
977
+ indices_i64 = opset.Cast(indices, to=onnxscript.INT64.dtype)
978
+ return opset.ScatterElements(data, indices_i64, updates)
979
+
980
+ scatter_elements_creator = self._create_onnxscript_op_creator(
981
+ op_type="ScatterElements",
982
+ onnxscript_func_builder=lambda: scatter_elements_compute,
983
+ n_inputs = 3,
984
+ n_outputs = 1
985
+ )
986
+ return scatter_elements_creator
987
+
988
+ def _create_randomnormallike_op(self):
989
+ import onnxscript
990
+ from onnxscript import opset17 as opset
991
+
992
+ def random_normal_like_builder(mean, scale):
993
+ @onnxscript.script()
994
+ def random_normal_like_compute(like):
995
+ return opset.RandomNormalLike(like, mean=mean, scale=scale)
996
+
997
+ return random_normal_like_compute
998
+
999
+ # 3. 使用新的工厂函数
1000
+ random_normal_like_creator = self._create_onnxscript_op_creator(
1001
+ op_type="RandomNormalLike",
1002
+ onnxscript_func_builder=random_normal_like_builder, # << 传入 builder
1003
+ attributes={
1004
+ "mean": ("float32", 0.0),
1005
+ "scale": ("float32", 1.0),
1006
+ },
1007
+ n_inputs = 1,
1008
+ n_outputs = 1
1009
+ )
1010
+ return random_normal_like_creator
1011
+
1012
+ def _create_einsum_op(self):
1013
+ import onnxscript
1014
+ from onnxscript import opset17 as opset
1015
+
1016
+ def einsum_builder(equation):
1017
+
1018
+ @onnxscript.script()
1019
+ def einsum_compute(in1, in2):
1020
+ return opset.Einsum(in1, in2, equation=equation)
1021
+
1022
+ return einsum_compute
1023
+
1024
+ # 3. 使用新的工厂函数
1025
+ einsum_creator = self._create_onnxscript_op_creator(
1026
+ op_type="Einsum",
1027
+ onnxscript_func_builder=einsum_builder, # << 传入 builder
1028
+ attributes={
1029
+ "equation": ("str", ""),
1030
+ },
1031
+ n_inputs = 2,
1032
+ n_outputs = 1
1033
+ )
1034
+ return einsum_creator
1035
+
1036
+ def register_bundled_ops(self) -> None:
1037
+ """注册自定义操作"""
1038
+ if getattr(self, "_custom_ops_registered", False):
1039
+ return
1040
+
1041
+ runtime = self.runtime.rknn_base.rknn_runtime
1042
+ lib = runtime.lib
1043
+ ctx = runtime.context
1044
+
1045
+ try:
1046
+ _ = lib.rknn_register_custom_ops
1047
+ _ = lib.rknn_custom_op_get_op_attr
1048
+ except AttributeError as e:
1049
+ logger.debug(f"SDK不支持自定义算子注册: {e}")
1050
+ return
1051
+
1052
+ self._init_custom_op_types()
1053
+
1054
+ # 注意:插件库注册已在模型加载后由环境变量控制,不在此处重复触发
1055
+
1056
+ # 算子创建函数的列表现在更加清晰
1057
+ op_creator_factories = [
1058
+ self._create_gridsample_op,
1059
+ self._create_scatterelements_op,
1060
+ self._create_randomnormallike_op,
1061
+ self._create_einsum_op,
1062
+ # self._create_my_custom_add_op, # 添加新算子非常简单
1063
+ ]
1064
+
1065
+ ops_to_register = []
1066
+ all_callbacks = []
1067
+
1068
+ for factory in op_creator_factories:
1069
+ try:
1070
+ # 调用工厂获得真正的构造器
1071
+ creator_func = factory()
1072
+ # 调用构造器生成算子实例
1073
+ op, callbacks = creator_func()
1074
+ ops_to_register.append(op)
1075
+ all_callbacks.extend(callbacks)
1076
+ logger.debug(f"成功创建自定义算子: {op.op_type.decode()}")
1077
+ except Exception as e:
1078
+ logger.warning(f"创建自定义算子失败: {e}", exc_info=True)
1079
+
1080
+ if not ops_to_register:
1081
+ logger.debug("没有可注册的自定义算子")
1082
+ return
1083
+
1084
+ # 创建一个ctypes数组以包含所有要注册的算子, 然后一次性注册
1085
+ num_ops = len(ops_to_register)
1086
+ op_array = (self._RKNN_CustomOp * num_ops)(*ops_to_register)
1087
+ ret = lib.rknn_register_custom_ops(ctx, op_array, num_ops)
1088
+ if ret != 0:
1089
+ logger.error(f"注册自定义算子失败, ret={ret} (可能是误报, 继续执行...)")
1090
+ # raise RuntimeError(f"rknn_register_custom_ops 失败, ret={ret}")
1091
+
1092
+ logger.info(f"成功注册 {len(ops_to_register)} 个自定义算子")
1093
+
1094
+ self._custom_ops_registered = True
1095
+ self._registered_ops = ops_to_register
1096
+ self._op_callbacks = all_callbacks
1097
+
1098
+ def _load_and_register_plugin_op(self, so_path: str) -> bool:
1099
+ """加载单个插件库并注册其中的自定义算子。
1100
+
1101
+ 要求插件实现 get_rknn_custom_op(),返回 rknn_custom_op*。
1102
+ 我们将该 C 指针直接传递给 rknn_register_custom_ops,避免复制。
1103
+ """
1104
+ if not os.path.isfile(so_path):
1105
+ logger.warning(f"插件库不存在: {so_path}")
1106
+ return False
1107
+
1108
+ runtime = self.runtime.rknn_base.rknn_runtime
1109
+ lib = runtime.lib
1110
+ ctx = runtime.context
1111
+
1112
+ # 根据平台位宽设置 rknn_context 的 ctypes 类型
1113
+ ContextCType = ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
1114
+ # 设置 rknn_register_custom_ops(ctx, op_ptr, num) 签名。第二参数按 void* 传递,避免结构体布局不一致
1115
+ try:
1116
+ lib.rknn_register_custom_ops.argtypes = [ContextCType, ctypes.c_void_p, ctypes.c_uint32]
1117
+ lib.rknn_register_custom_ops.restype = ctypes.c_int
1118
+ except Exception:
1119
+ pass
1120
+
1121
+ # 加载插件
1122
+ try:
1123
+ handle = ctypes.CDLL(so_path)
1124
+ except Exception as e:
1125
+ logger.error(f"dlopen 失败: {so_path}, err={e}")
1126
+ return False
1127
+
1128
+ # 获取 get_rknn_custom_op 符号
1129
+ try:
1130
+ get_sym = getattr(handle, "get_rknn_custom_op")
1131
+ except AttributeError:
1132
+ logger.error(f"插件缺少符号 get_rknn_custom_op: {so_path}")
1133
+ return False
1134
+
1135
+ # 返回类型直接使用 void*,避免 Python 解析第三方结构体
1136
+ try:
1137
+ get_sym.argtypes = []
1138
+ except Exception:
1139
+ pass
1140
+ get_sym.restype = ctypes.c_void_p
1141
+
1142
+ op_void_ptr = get_sym()
1143
+ if not op_void_ptr:
1144
+ logger.error(f"get_rknn_custom_op 返回空指针: {so_path}")
1145
+ return False
1146
+
1147
+ # 直接使用原生指针注册(零拷贝)
1148
+ ctx_val = ContextCType(runtime.context)
1149
+ ret = lib.rknn_register_custom_ops(ctx_val, ctypes.c_void_p(op_void_ptr), 1)
1150
+ if ret != 0:
1151
+ logger.error(f"rknn_register_custom_ops 失败, ret={ret}, so={so_path} (可能是误报, 继续执行...)")
1152
+ # return False
1153
+
1154
+ # 保留句柄,避免被垃圾回收卸载
1155
+ if not hasattr(self, "_plugin_handles"):
1156
+ self._plugin_handles = []
1157
+ self._plugin_handles.append(handle)
1158
+ logger.info(f"成功注册插件自定义算子: {so_path}")
1159
+ return True
1160
+
1161
+ def register_plugin_ops(self, plugin_paths: List[str]) -> int:
1162
+ """按给定路径列表注册插件库中的自定义算子。返回成功数量。"""
1163
+ if not plugin_paths:
1164
+ return 0
1165
+ success = 0
1166
+ for path in plugin_paths:
1167
+ try:
1168
+ if self._load_and_register_plugin_op(path):
1169
+ success += 1
1170
+ except Exception as e:
1171
+ logger.error(f"注册插件失败: {path}, err={e}")
1172
+ return success
1173
+
1174
+ # 对外API:注册单个自定义算子插件库
1175
+ def register_custom_op_lib(self, path: str) -> bool:
1176
+ return self._load_and_register_plugin_op(path)
1177
+
1178
+ # 对外API:扫描并注册 Linux 系统目录下所有插件库(Android 不处理)
1179
+ def register_system_custom_op_lib(self) -> int:
1180
+ if os.name != 'posix':
1181
+ return 0
1182
+ # 仅 Linux:RKNN 官方默认目录
1183
+ system_dir = "/usr/lib/rknpu/op_plugins/"
1184
+ if not os.path.isdir(system_dir):
1185
+ return 0
1186
+ try:
1187
+ entries = os.listdir(system_dir)
1188
+ except Exception:
1189
+ return 0
1190
+ so_list = []
1191
+ for name in entries:
1192
+ # 官方要求文件名以 librkcst_ 开头
1193
+ if name.startswith("librkcst_") and name.endswith('.so'):
1194
+ so_list.append(os.path.join(system_dir, name))
1195
+ return self.register_plugin_ops(so_list)