Upload 11 files
Browse files- .gitattributes +4 -0
- README.md +333 -3
- convert_vision_encoder.py +52 -0
- export_vision_onnx.py +203 -0
- language_model_w8a8.rkllm +3 -0
- librkllmrt.so +3 -0
- rkllm-convert.py +141 -0
- rkllm_binding.py +873 -0
- run_rkllm.py +243 -0
- test.jpg +3 -0
- vision_encoder.rknn +3 -0
- ztu_somemodelruntime_rknnlite2.py +1195 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
test.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,333 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model:
|
| 3 |
+
- OpenGVLab/InternVL3_5-2B-HF
|
| 4 |
+
tags:
|
| 5 |
+
- rknn
|
| 6 |
+
- rkllm
|
| 7 |
+
- internvl
|
| 8 |
+
---
|
| 9 |
+
# InternVL3_5-2B-RKLLM
|
| 10 |
+
|
| 11 |
+
## (English README see below)
|
| 12 |
+
|
| 13 |
+
在RK3588上运行强大的InternVL3.5-2B视觉大模型!
|
| 14 |
+
|
| 15 |
+
- 推理速度(RK3588): 视觉编码器 2.1s(三核并行) + LLM 填充 1s (265 tokens / 261 tps) + 解码 12.1 tps
|
| 16 |
+
- 内存占用(RK3588, 上下文长度1024): 3.9GB
|
| 17 |
+
|
| 18 |
+
## 使用方法
|
| 19 |
+
|
| 20 |
+
1. 克隆或者下载此仓库到本地. 模型较大, 请确保有足够的磁盘空间.
|
| 21 |
+
|
| 22 |
+
2. 开发板的RKNPU2内核驱动版本必须>=0.9.6才能运行这么大的模型.
|
| 23 |
+
使用root权限运行以下命令检查驱动版本:
|
| 24 |
+
```bash
|
| 25 |
+
> cat /sys/kernel/debug/rknpu/version
|
| 26 |
+
RKNPU driver: v0.9.8
|
| 27 |
+
```
|
| 28 |
+
如果版本过低, 请更新驱动. 你可能需要更新内核, 或查找官方文档以获取帮助.
|
| 29 |
+
|
| 30 |
+
3. 安装依赖
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
pip install "numpy<2" opencv-python rknn-toolkit-lite2
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
4. 运行
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python ./run_rkllm.py ./test.jpg ./vision_encoder.rknn ./language_model_w8a8.rkllm 512 1024 3
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
参数说明:
|
| 43 |
+
- `512`: max_new_tokens, 最大生成token数.
|
| 44 |
+
- `1024`: max_context_len, 最大上下文长度.
|
| 45 |
+
- `3`: npu_core_num, 使用的NPU核心数.
|
| 46 |
+
|
| 47 |
+
如果实测性能不理想, 可以调整CPU调度器让CPU始终运行在最高频率, 并把推理程序绑定到大核(`taskset -c 4-7 python ...`)
|
| 48 |
+
|
| 49 |
+
test.jpg:
|
| 50 |
+

|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
Initializing ONNX Runtime for vision encoder...
|
| 54 |
+
I rknn-toolkit2 version: 2.3.2
|
| 55 |
+
I target set by user is: rk3588
|
| 56 |
+
Vision encoder loaded successfully.
|
| 57 |
+
ONNX Input: pixel_values, ONNX Output: projected_features
|
| 58 |
+
Initializing RKLLM Runtime...
|
| 59 |
+
I rkllm: rkllm-runtime version: 1.2.2, rknpu driver version: 0.9.8, platform: RK3588
|
| 60 |
+
I rkllm: loading rkllm model from ./language_model_w8a8.rkllm
|
| 61 |
+
I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: W8A8
|
| 62 |
+
I rkllm: Enabled cpus: [4, 5, 6, 7]
|
| 63 |
+
I rkllm: Enabled cpus num: 4
|
| 64 |
+
RKLLM initialized successfully.
|
| 65 |
+
Preprocessing image...
|
| 66 |
+
Running vision encoder...
|
| 67 |
+
视觉编码器推理耗时: 2.0876 秒
|
| 68 |
+
Image encoded successfully.
|
| 69 |
+
|
| 70 |
+
**********************可输入以下问题对应序号获取回答/或自定义输入********************
|
| 71 |
+
|
| 72 |
+
[0] <image>What is in the image?
|
| 73 |
+
[1] <image>这张图片中有什么?
|
| 74 |
+
|
| 75 |
+
*************************************************************************
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
user: 0
|
| 79 |
+
<image>What is in the image?
|
| 80 |
+
robot: n_image_tokens: 256
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
This image depicts a cozy bedroom with a large window, several pieces of furniture, and various decorative items. The room has a vintage feel due to the wallpaper pattern and the wooden furniture.
|
| 84 |
+
|
| 85 |
+
The bed occupies the left side of the image, covered with a blue comforter or quilt. Next to the bed is a dresser with a round mirror above it. On top of the dresser are several small objects, including what appears to be a water bottle and some decorative items like plants.
|
| 86 |
+
|
| 87 |
+
In front of the window on the right side of the image, there is a chair with a checkered cushion. Behind this chair, there is a bookshelf filled with books and various other items, such as baskets and possibly some knick-knacks. The bookshelf has multiple levels, each holding an assortment of books and decorative objects.
|
| 88 |
+
|
| 89 |
+
The window allows natural light to enter the room, illuminating the space and highlighting the greenery outside. There are also potted plants placed around the room, adding a touch of nature and freshness to the interior decor.
|
| 90 |
+
|
| 91 |
+
Overall, this bedroom exudes a sense of comfort and personal style, with elements that suggest it is used regularly by someone who values both aesthetics and functionality in their living space.
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 95 |
+
I rkllm: Model init time (ms) 4314.30
|
| 96 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 97 |
+
I rkllm: Stage Total Time (ms) Tokens Time per Token (ms) Tokens per Second
|
| 98 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 99 |
+
I rkllm: Prefill 1013.32 265 3.82 261.52
|
| 100 |
+
I rkllm: Generate 20155.65 244 82.61 12.11
|
| 101 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 102 |
+
I rkllm: Peak Memory Usage (GB)
|
| 103 |
+
I rkllm: 3.45
|
| 104 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
user: 1
|
| 107 |
+
<image>这张图片中有什么?
|
| 108 |
+
robot: n_image_tokens: 256
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
这是一间温馨的卧室,房间内有一扇大窗户、几件家具和各种装饰物品。房间因壁纸图案和木质家具而显得复古。
|
| 112 |
+
|
| 113 |
+
床位于图像左侧,覆盖着蓝色被套或毯子。床旁边是一个带有圆形镜子的抽屉柜。在抽屉柜上摆放着一些小物件,���括水瓶和一些装饰品,如植物。
|
| 114 |
+
|
| 115 |
+
窗户右侧前方有一把带格子坐垫的椅子。椅子后面是一排书架,上面摆满了书籍和其他物品,如篮子和可能的一些小饰品。书架有多层,每层都放着各种书籍和装饰物。
|
| 116 |
+
|
| 117 |
+
窗外可以看到绿树,自然光透过窗户照进房间,照亮了空间,并突出了外面的绿色植物。房间里还摆放了一些盆栽植物,为室内增添了自然的气息和清新感。
|
| 118 |
+
|
| 119 |
+
总体而言,这间卧室给人一种舒适和个性的感觉,表明它经常被居住者使用,居住者重视生活空间中的美学和功能性。
|
| 120 |
+
|
| 121 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 122 |
+
I rkllm: Stage Total Time (ms) Tokens Time per Token (ms) Tokens per Second
|
| 123 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 124 |
+
I rkllm: Prefill 1287.65 264 4.88 205.03
|
| 125 |
+
I rkllm: Generate 19852.10 204 97.31 10.28
|
| 126 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 127 |
+
I rkllm: Peak Memory Usage (GB)
|
| 128 |
+
I rkllm: 3.45
|
| 129 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
user: ^C
|
| 132 |
+
Exiting...
|
| 133 |
+
Releasing resources...
|
| 134 |
+
RKLLM instance destroyed.
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## 模型转换
|
| 138 |
+
|
| 139 |
+
#### 准备工作
|
| 140 |
+
|
| 141 |
+
1. 安装rknn-toolkit2以及rkllm-toolkit:
|
| 142 |
+
```bash
|
| 143 |
+
pip install -U rknn-toolkit2
|
| 144 |
+
```
|
| 145 |
+
rkllm-toolkit需要在这里手动下载: https://github.com/airockchip/rknn-llm/tree/main/rkllm-toolkit
|
| 146 |
+
|
| 147 |
+
2. 下载此仓库到本地, 但不需要下载`.rkllm`和`.rknn`结尾的模型文件.
|
| 148 |
+
3. 下载InternVL3.5-2B的huggingface模型仓库到本地. ( https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF )
|
| 149 |
+
|
| 150 |
+
#### 转换LLM
|
| 151 |
+
|
| 152 |
+
将`rkllm-convert.py`拷贝到InternVL3_5-2B-HF的模型文件夹中,执行:
|
| 153 |
+
```bash
|
| 154 |
+
python rkllm-convert.py
|
| 155 |
+
```
|
| 156 |
+
默认是w8a8量化的,你可以自行打开脚本修改量化方式等。
|
| 157 |
+
|
| 158 |
+
#### 转换视觉编码器
|
| 159 |
+
|
| 160 |
+
1. 导出ONNX
|
| 161 |
+
|
| 162 |
+
将`export_vision_onnx.py`拷贝到InternVL3_5-2B-HF的模型文件夹根目录中,然后**在该根目录**下执行:
|
| 163 |
+
```bash
|
| 164 |
+
python ./export_vision_onnx.py
|
| 165 |
+
```
|
| 166 |
+
视觉编码器会导出到`vision_encoder.onnx`.
|
| 167 |
+
|
| 168 |
+
2. 转换rknn
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
python ./convert_vision_encoder.py
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## 已知问题
|
| 175 |
+
|
| 176 |
+
- 由于RKLLM的多模态输入的限制, 在整个对话中只能加载一张图片.
|
| 177 |
+
- 没有实现多轮对话.
|
| 178 |
+
- RKLLM的w8a8量化貌似存在不小的精度损失.
|
| 179 |
+
- 没有实现原模型中的高清图像分块输入与视频输入功能. 原因是我懒得做了,以后可以考虑加上.
|
| 180 |
+
|
| 181 |
+
## 参考
|
| 182 |
+
|
| 183 |
+
- [OpenGVLab/InternVL3_5-2B-HF](https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF)
|
| 184 |
+
|
| 185 |
+
----
|
| 186 |
+
|
| 187 |
+
# English README
|
| 188 |
+
|
| 189 |
+
Run the powerful InternVL3.5-2B large vision model on RK3588!
|
| 190 |
+
|
| 191 |
+
- Inference Speed (RK3588): Vision Encoder 2.1s (3-core parallel) + LLM Prefill 1s (265 tokens / 261 tps) + Decode 12.1 tps
|
| 192 |
+
- Memory Usage (RK3588, context length 1024): 3.9GB
|
| 193 |
+
|
| 194 |
+
## How to Use
|
| 195 |
+
|
| 196 |
+
1. Clone or download this repository locally. The model is large, so ensure you have enough disk space.
|
| 197 |
+
|
| 198 |
+
2. The RKNPU2 kernel driver version on your development board must be >=0.9.6 to run this model. Run the following command with root privileges to check the driver version:
|
| 199 |
+
```bash
|
| 200 |
+
> cat /sys/kernel/debug/rknpu/version
|
| 201 |
+
RKNPU driver: v0.9.8
|
| 202 |
+
```
|
| 203 |
+
If the version is too low, please update the driver. You may need to update the kernel or refer to the official documentation for help.
|
| 204 |
+
|
| 205 |
+
3. Install dependencies:
|
| 206 |
+
|
| 207 |
+
```bash
|
| 208 |
+
pip install "numpy<2" opencv-python rknn-toolkit-lite2
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
4. Run:
|
| 212 |
+
|
| 213 |
+
```bash
|
| 214 |
+
python ./run_rkllm.py ./test.jpg ./vision_encoder.rknn ./language_model_w8a8.rkllm 512 1024 3
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
Parameter description:
|
| 218 |
+
- `512`: `max_new_tokens`, the maximum number of tokens to generate.
|
| 219 |
+
- `1024`: `max_context_len`, the maximum context length.
|
| 220 |
+
- `3`: `npu_core_num`, the number of NPU cores to use.
|
| 221 |
+
|
| 222 |
+
If the performance is not ideal, you can adjust the CPU scheduler to keep the CPU at its highest frequency and bind the inference program to the big cores (`taskset -c 4-7 python ...`).
|
| 223 |
+
|
| 224 |
+
Example with `test.jpg`:
|
| 225 |
+

|
| 226 |
+
|
| 227 |
+
```
|
| 228 |
+
Initializing ONNX Runtime for vision encoder...
|
| 229 |
+
I rknn-toolkit2 version: 2.3.2
|
| 230 |
+
I target set by user is: rk3588
|
| 231 |
+
Vision encoder loaded successfully.
|
| 232 |
+
ONNX Input: pixel_values, ONNX Output: projected_features
|
| 233 |
+
Initializing RKLLM Runtime...
|
| 234 |
+
I rkllm: rkllm-runtime version: 1.2.2, rknpu driver version: 0.9.8, platform: RK3588
|
| 235 |
+
I rkllm: loading rkllm model from ./language_model_w8a8.rkllm
|
| 236 |
+
I rkllm: rkllm-toolkit version: 1.2.2, max_context_limit: 4096, npu_core_num: 3, target_platform: RK3588, model_dtype: W8A8
|
| 237 |
+
I rkllm: Enabled cpus: [4, 5, 6, 7]
|
| 238 |
+
I rkllm: Enabled cpus num: 4
|
| 239 |
+
RKLLM initialized successfully.
|
| 240 |
+
Preprocessing image...
|
| 241 |
+
Running vision encoder...
|
| 242 |
+
视觉编码器推理耗时: 2.0876 秒
|
| 243 |
+
Image encoded successfully.
|
| 244 |
+
|
| 245 |
+
**********************可输入以下问题对应序号获取回答/或自定义输入********************
|
| 246 |
+
|
| 247 |
+
[0] <image>What is in the image?
|
| 248 |
+
[1] <image>这张图片中有什么?
|
| 249 |
+
|
| 250 |
+
*************************************************************************
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
user: 0
|
| 254 |
+
<image>What is in the image?
|
| 255 |
+
robot: n_image_tokens: 256
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
This image depicts a cozy bedroom with a large window, several pieces of furniture, and various decorative items. The room has a vintage feel due to the wallpaper pattern and the wooden furniture.
|
| 259 |
+
|
| 260 |
+
The bed occupies the left side of the image, covered with a blue comforter or quilt. Next to the bed is a dresser with a round mirror above it. On top of the dresser are several small objects, including what appears to be a water bottle and some decorative items like plants.
|
| 261 |
+
|
| 262 |
+
In front of the window on the right side of the image, there is a chair with a checkered cushion. Behind this chair, there is a bookshelf filled with books and various other items, such as baskets and possibly some knick-knacks. The bookshelf has multiple levels, each holding an assortment of books and decorative objects.
|
| 263 |
+
|
| 264 |
+
The window allows natural light to enter the room, illuminating the space and highlighting the greenery outside. There are also potted plants placed around the room, adding a touch of nature and freshness to the interior decor.
|
| 265 |
+
|
| 266 |
+
Overall, this bedroom exudes a sense of comfort and personal style, with elements that suggest it is used regularly by someone who values both aesthetics and functionality in their living space.
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 270 |
+
I rkllm: Model init time (ms) 4314.30
|
| 271 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 272 |
+
I rkllm: Stage Total Time (ms) Tokens Time per Token (ms) Tokens per Second
|
| 273 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 274 |
+
I rkllm: Prefill 1013.32 265 3.82 261.52
|
| 275 |
+
I rkllm: Generate 20155.65 244 82.61 12.11
|
| 276 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 277 |
+
I rkllm: Peak Memory Usage (GB)
|
| 278 |
+
I rkllm: 3.45
|
| 279 |
+
I rkllm: --------------------------------------------------------------------------------------
|
| 280 |
+
|
| 281 |
+
user: ^C
|
| 282 |
+
Exiting...
|
| 283 |
+
Releasing resources...
|
| 284 |
+
RKLLM instance destroyed.
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
## Model Conversion
|
| 288 |
+
|
| 289 |
+
#### Prerequisites
|
| 290 |
+
|
| 291 |
+
1. Install `rknn-toolkit2` and `rkllm-toolkit`:
|
| 292 |
+
```bash
|
| 293 |
+
pip install -U rknn-toolkit2
|
| 294 |
+
```
|
| 295 |
+
`rkllm-toolkit` needs to be downloaded manually from here: https://github.com/airockchip/rknn-llm/tree/main/rkllm-toolkit
|
| 296 |
+
|
| 297 |
+
2. Download this repository locally, but you don't need the `.rkllm` and `.rknn` model files.
|
| 298 |
+
3. Download the InternVL3.5-2B huggingface model repository locally. ( https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF )
|
| 299 |
+
|
| 300 |
+
#### Convert LLM
|
| 301 |
+
|
| 302 |
+
Copy `rkllm-convert.py` to the InternVL3_5-2B-HF model folder and run:
|
| 303 |
+
```bash
|
| 304 |
+
python rkllm-convert.py
|
| 305 |
+
```
|
| 306 |
+
The default quantization is w8a8. You can modify the script to change quantization methods.
|
| 307 |
+
|
| 308 |
+
#### Convert Vision Encoder
|
| 309 |
+
|
| 310 |
+
1. Export ONNX
|
| 311 |
+
|
| 312 |
+
Copy `export_vision_onnx.py` to the root directory of the InternVL3_5-2B-HF model folder, and then execute it **in that root directory**:
|
| 313 |
+
```bash
|
| 314 |
+
python ./export_vision_onnx.py
|
| 315 |
+
```
|
| 316 |
+
The vision encoder will be exported to `vision_encoder.onnx`.
|
| 317 |
+
|
| 318 |
+
2. Convert to RKNN
|
| 319 |
+
|
| 320 |
+
```bash
|
| 321 |
+
python ./convert_vision_encoder.py
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
## Known Issues
|
| 325 |
+
|
| 326 |
+
- Due to limitations in RKLLM's multimodal input, only one image can be loaded throughout the conversation.
|
| 327 |
+
- Multi-turn conversation is not implemented.
|
| 328 |
+
- RKLLM's w8a8 quantization appears to have significant precision loss.
|
| 329 |
+
- The high-resolution image tiling and video input features from the original model are not implemented. The reason is that I'm too lazy to do it, and it can be considered adding it later.
|
| 330 |
+
|
| 331 |
+
## References
|
| 332 |
+
|
| 333 |
+
- [OpenGVLab/InternVL3_5-2B-HF](https://huggingface.co/OpenGVLab/InternVL3_5-2B-HF)
|
convert_vision_encoder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# ztu_somemodelruntime_rknn2: vision_encoder
|
| 3 |
+
|
| 4 |
+
from rknn.api import RKNN
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
# 创建RKNN实例
|
| 10 |
+
rknn = RKNN(verbose=True)
|
| 11 |
+
|
| 12 |
+
# ONNX模型路径
|
| 13 |
+
ONNX_MODEL = "vision_encoder.onnx"
|
| 14 |
+
# 输出RKNN模型路径
|
| 15 |
+
RKNN_MODEL = "vision_encoder.rknn"
|
| 16 |
+
|
| 17 |
+
# 配置参数
|
| 18 |
+
print("--> Config model")
|
| 19 |
+
ret = rknn.config(target_platform="rk3588",
|
| 20 |
+
dynamic_input=None)
|
| 21 |
+
if ret != 0:
|
| 22 |
+
print('Config model failed!')
|
| 23 |
+
exit(ret)
|
| 24 |
+
|
| 25 |
+
# 加载ONNX模型
|
| 26 |
+
print("--> Loading model")
|
| 27 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 28 |
+
inputs=['pixel_values'],
|
| 29 |
+
input_size_list=[[1, 3, 448, 448]])
|
| 30 |
+
if ret != 0:
|
| 31 |
+
print('Load model failed!')
|
| 32 |
+
exit(ret)
|
| 33 |
+
|
| 34 |
+
# 构建模型
|
| 35 |
+
print("--> Building model")
|
| 36 |
+
ret = rknn.build(do_quantization=False)
|
| 37 |
+
if ret != 0:
|
| 38 |
+
print('Build model failed!')
|
| 39 |
+
exit(ret)
|
| 40 |
+
|
| 41 |
+
# 导出RKNN模型
|
| 42 |
+
print("--> Export RKNN model")
|
| 43 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 44 |
+
if ret != 0:
|
| 45 |
+
print('Export RKNN model failed!')
|
| 46 |
+
exit(ret)
|
| 47 |
+
|
| 48 |
+
print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
|
| 49 |
+
rknn.release()
|
| 50 |
+
|
| 51 |
+
if __name__ == '__main__':
|
| 52 |
+
main()
|
export_vision_onnx.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import AutoTokenizer, AutoModel
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
from torchvision.transforms import InterpolationMode
|
| 10 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 11 |
+
|
| 12 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 13 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 14 |
+
|
| 15 |
+
def build_transform(input_size):
|
| 16 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
| 17 |
+
transform = T.Compose([
|
| 18 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 19 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 20 |
+
T.ToTensor(),
|
| 21 |
+
T.Normalize(mean=MEAN, std=STD)
|
| 22 |
+
])
|
| 23 |
+
return transform
|
| 24 |
+
|
| 25 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 26 |
+
best_ratio_diff = float('inf')
|
| 27 |
+
best_ratio = (1, 1)
|
| 28 |
+
area = width * height
|
| 29 |
+
for ratio in target_ratios:
|
| 30 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 31 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 32 |
+
if ratio_diff < best_ratio_diff:
|
| 33 |
+
best_ratio_diff = ratio_diff
|
| 34 |
+
best_ratio = ratio
|
| 35 |
+
elif ratio_diff == best_ratio_diff:
|
| 36 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 37 |
+
best_ratio = ratio
|
| 38 |
+
return best_ratio
|
| 39 |
+
|
| 40 |
+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
| 41 |
+
orig_width, orig_height = image.size
|
| 42 |
+
aspect_ratio = orig_width / orig_height
|
| 43 |
+
|
| 44 |
+
# calculate the existing image aspect ratio
|
| 45 |
+
target_ratios = set(
|
| 46 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
| 47 |
+
i * j <= max_num and i * j >= min_num)
|
| 48 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 49 |
+
|
| 50 |
+
# find the closest aspect ratio to the target
|
| 51 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 52 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 53 |
+
|
| 54 |
+
# calculate the target width and height
|
| 55 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 56 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 57 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 58 |
+
|
| 59 |
+
# resize the image
|
| 60 |
+
resized_img = image.resize((target_width, target_height))
|
| 61 |
+
processed_images = []
|
| 62 |
+
for i in range(blocks):
|
| 63 |
+
box = (
|
| 64 |
+
(i % (target_width // image_size)) * image_size,
|
| 65 |
+
(i // (target_width // image_size)) * image_size,
|
| 66 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 67 |
+
((i // (target_width // image_size)) + 1) * image_size
|
| 68 |
+
)
|
| 69 |
+
# split the image
|
| 70 |
+
split_img = resized_img.crop(box)
|
| 71 |
+
processed_images.append(split_img)
|
| 72 |
+
assert len(processed_images) == blocks
|
| 73 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 74 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 75 |
+
processed_images.append(thumbnail_img)
|
| 76 |
+
return processed_images
|
| 77 |
+
|
| 78 |
+
def load_image(image_file, input_size=448, max_num=12):
|
| 79 |
+
image = Image.open(image_file).convert('RGB')
|
| 80 |
+
transform = build_transform(input_size=input_size)
|
| 81 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
| 82 |
+
pixel_values = [transform(image) for image in images]
|
| 83 |
+
pixel_values = torch.stack(pixel_values)
|
| 84 |
+
return pixel_values
|
| 85 |
+
|
| 86 |
+
# 加载本地模型
|
| 87 |
+
path = '.'
|
| 88 |
+
save_path = 'vision_encoder.onnx'
|
| 89 |
+
image_file = 'test.jpg'
|
| 90 |
+
|
| 91 |
+
def export_vision_InternVL(model_path: str, save_path: str):
|
| 92 |
+
"""
|
| 93 |
+
Export the vision encoder and projector of Janus-Pro-1B model to ONNX format
|
| 94 |
+
"""
|
| 95 |
+
# 设置默认数据类型为 float32
|
| 96 |
+
torch.set_default_dtype(torch.float32)
|
| 97 |
+
|
| 98 |
+
vl_gpt = AutoModel.from_pretrained(model_path,torch_dtype = torch.float32,trust_remote_code=True)
|
| 99 |
+
|
| 100 |
+
# Move model to CPU and convert to float32
|
| 101 |
+
vl_gpt = vl_gpt.cpu().eval().float() # 确保模型是 float32
|
| 102 |
+
|
| 103 |
+
# Create a wrapper class for vision encoder + projector
|
| 104 |
+
class VisionWrapper(nn.Module):
|
| 105 |
+
def __init__(self, model: PreTrainedModel):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.vision_model = model
|
| 108 |
+
|
| 109 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
| 110 |
+
# Delegate to the built-in helper so we stay consistent with Transformers' implementation.
|
| 111 |
+
return self.vision_model.get_image_features(pixel_values=pixel_values)
|
| 112 |
+
|
| 113 |
+
# Create wrapper instance and convert to float32
|
| 114 |
+
vision_wrapper = VisionWrapper(vl_gpt)
|
| 115 |
+
vision_wrapper.eval().float() # 确保包装器也是 float32
|
| 116 |
+
|
| 117 |
+
# Create dummy input with float32
|
| 118 |
+
batch_size = 1
|
| 119 |
+
num_channels = 3
|
| 120 |
+
height = 448 # InternVL2 default image size
|
| 121 |
+
width = 448
|
| 122 |
+
# dummy_input = load_image(image_file=image_file, max_num=12).to(torch.float32).cpu()
|
| 123 |
+
dummy_input = torch.randn(batch_size, num_channels, height, width, dtype=torch.float32)
|
| 124 |
+
# Export to ONNX with higher opset version
|
| 125 |
+
torch.onnx.export(
|
| 126 |
+
vision_wrapper,
|
| 127 |
+
dummy_input,
|
| 128 |
+
save_path,
|
| 129 |
+
export_params=True,
|
| 130 |
+
opset_version=17, # 使用高版本 opset 以支持 scaled_dot_product_attention
|
| 131 |
+
do_constant_folding=True,
|
| 132 |
+
input_names=['pixel_values'],
|
| 133 |
+
output_names=['projected_features'],
|
| 134 |
+
dynamic_axes={
|
| 135 |
+
'pixel_values': {0: 'batch_size'},
|
| 136 |
+
'projected_features': {0: 'batch_size'}
|
| 137 |
+
},
|
| 138 |
+
# 添加额外的配置
|
| 139 |
+
# operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
|
| 140 |
+
# training=torch.onnx.TrainingMode.EVAL,
|
| 141 |
+
dynamo=True,
|
| 142 |
+
verbose=False
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
print(f"Successfully exported vision components to {save_path}")
|
| 146 |
+
|
| 147 |
+
# Verify the exported model
|
| 148 |
+
import onnxruntime
|
| 149 |
+
|
| 150 |
+
# Create inference session
|
| 151 |
+
ort_session = onnxruntime.InferenceSession(save_path)
|
| 152 |
+
|
| 153 |
+
# Run inference with dummy input
|
| 154 |
+
ort_inputs = {
|
| 155 |
+
'pixel_values': dummy_input.numpy()
|
| 156 |
+
}
|
| 157 |
+
ort_outputs = ort_session.run(None, ort_inputs)
|
| 158 |
+
|
| 159 |
+
# Compare with PyTorch output
|
| 160 |
+
torch_output = vision_wrapper(dummy_input)
|
| 161 |
+
|
| 162 |
+
# Check numerical accuracy with更宽松的容忍度
|
| 163 |
+
import numpy as np
|
| 164 |
+
np.testing.assert_allclose(
|
| 165 |
+
torch_output.detach().numpy(),
|
| 166 |
+
ort_outputs[0],
|
| 167 |
+
rtol=1e-1, # 放宽相对误差容忍度
|
| 168 |
+
atol=1e-2 # 放宽绝对误差容忍度
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
print("ONNX model verification successful!")
|
| 172 |
+
|
| 173 |
+
# 打印一些统计信息
|
| 174 |
+
torch_output_np = torch_output.detach().numpy()
|
| 175 |
+
onnx_output_np = ort_outputs[0]
|
| 176 |
+
|
| 177 |
+
abs_diff = np.abs(torch_output_np - onnx_output_np)
|
| 178 |
+
rel_diff = np.abs((torch_output_np - onnx_output_np) / (torch_output_np + 1e-7))
|
| 179 |
+
|
| 180 |
+
print(f"\nValidation Statistics:")
|
| 181 |
+
print(f"Max absolute difference: {np.max(abs_diff):.6f}")
|
| 182 |
+
print(f"Mean absolute difference: {np.mean(abs_diff):.6f}")
|
| 183 |
+
print(f"Max relative difference: {np.max(rel_diff):.6f}")
|
| 184 |
+
print(f"Mean relative difference: {np.mean(rel_diff):.6f}")
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
try:
|
| 188 |
+
import onnx
|
| 189 |
+
try:
|
| 190 |
+
onnx_version = onnx.__version__
|
| 191 |
+
except AttributeError:
|
| 192 |
+
try:
|
| 193 |
+
onnx_version = onnx.version.version
|
| 194 |
+
except AttributeError:
|
| 195 |
+
onnx_version = "Unknown"
|
| 196 |
+
print(f"ONNX version: {onnx_version}")
|
| 197 |
+
except ImportError:
|
| 198 |
+
print("ONNX not installed")
|
| 199 |
+
|
| 200 |
+
import onnxruntime
|
| 201 |
+
print(f"ONNX Runtime version: {onnxruntime.__version__}")
|
| 202 |
+
|
| 203 |
+
export_vision_InternVL(path, save_path)
|
language_model_w8a8.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ea6dbafda740b717233228a91cbb2377d0905a8b00edba9e17489da65e9834e
|
| 3 |
+
size 2375017292
|
librkllmrt.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39d01912e67027de32c527be04684bf813e2a49c2d09ab8f6bcf47b34a43789d
|
| 3 |
+
size 7486400
|
rkllm-convert.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from transformers import AutoConfig, Qwen3ForCausalLM, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
from rkllm.api import RKLLM
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import shutil
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from safetensors.torch import load_file
|
| 15 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 16 |
+
|
| 17 |
+
TOKENIZER_FILES = [
|
| 18 |
+
"tokenizer.json",
|
| 19 |
+
"tokenizer_config.json",
|
| 20 |
+
"special_tokens_map.json",
|
| 21 |
+
"added_tokens.json",
|
| 22 |
+
"vocab.json",
|
| 23 |
+
"merges.txt",
|
| 24 |
+
"chat_template.jinja",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args() -> argparse.Namespace:
|
| 29 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--source",
|
| 32 |
+
type=Path,
|
| 33 |
+
default=".",
|
| 34 |
+
help="Path to the InternVL (HF-format) checkpoint directory, e.g. /path/to/InternVL3_5-2B-HF",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--output",
|
| 38 |
+
type=Path,
|
| 39 |
+
default="llm/",
|
| 40 |
+
help="Directory where the extracted Qwen3 checkpoint will be written",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--safe-serialization",
|
| 44 |
+
action="store_true",
|
| 45 |
+
default=True,
|
| 46 |
+
help="Save the exported model using safetensors instead of PyTorch binaries.",
|
| 47 |
+
)
|
| 48 |
+
return parser.parse_args()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def extract_text_state_dict(full_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 52 |
+
prefix = "language_model.model."
|
| 53 |
+
lm_head_prefix = "language_model.lm_head."
|
| 54 |
+
text_state: Dict[str, torch.Tensor] = {}
|
| 55 |
+
|
| 56 |
+
for key, tensor in full_state.items():
|
| 57 |
+
if key.startswith(prefix):
|
| 58 |
+
text_key = "model." + key[len(prefix) :]
|
| 59 |
+
elif key.startswith(lm_head_prefix):
|
| 60 |
+
text_key = "lm_head." + key[len(lm_head_prefix) :]
|
| 61 |
+
else:
|
| 62 |
+
continue
|
| 63 |
+
text_state[text_key] = tensor
|
| 64 |
+
|
| 65 |
+
if not text_state:
|
| 66 |
+
raise ValueError("Did not find any language_model weights in checkpoint; is this an InternVL model?")
|
| 67 |
+
|
| 68 |
+
return text_state
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def copy_tokenizer_files(source_dir: Path, output_dir: Path) -> None:
|
| 72 |
+
for filename in TOKENIZER_FILES:
|
| 73 |
+
src = source_dir / filename
|
| 74 |
+
if src.exists():
|
| 75 |
+
dst = output_dir / filename
|
| 76 |
+
shutil.copyfile(src, dst)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
args = parse_args()
|
| 81 |
+
source_dir = args.source.expanduser().resolve()
|
| 82 |
+
output_dir = args.output.expanduser().resolve()
|
| 83 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 84 |
+
|
| 85 |
+
config = AutoConfig.from_pretrained(source_dir, trust_remote_code=True)
|
| 86 |
+
text_config = config.text_config
|
| 87 |
+
|
| 88 |
+
weights_path = source_dir / "model.safetensors"
|
| 89 |
+
if not weights_path.exists():
|
| 90 |
+
raise FileNotFoundError(f"Could not find {weights_path}; expected a safetensors checkpoint")
|
| 91 |
+
|
| 92 |
+
all_weights = load_file(weights_path)
|
| 93 |
+
text_state = extract_text_state_dict(all_weights)
|
| 94 |
+
|
| 95 |
+
sample_tensor = next(iter(text_state.values()))
|
| 96 |
+
target_dtype = sample_tensor.dtype
|
| 97 |
+
|
| 98 |
+
text_model = AutoModelForCausalLM.from_config(text_config)
|
| 99 |
+
text_model = text_model.to(dtype=target_dtype, device=torch.device("cpu"))
|
| 100 |
+
missing, unexpected = text_model.load_state_dict(text_state, strict=False)
|
| 101 |
+
if missing or unexpected:
|
| 102 |
+
raise RuntimeError(
|
| 103 |
+
"State dict mismatch when loading text weights: "
|
| 104 |
+
f"missing={missing}, unexpected={unexpected}"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
text_config.save_pretrained(output_dir)
|
| 108 |
+
text_model.generation_config.save_pretrained(output_dir)
|
| 109 |
+
text_model.save_pretrained(output_dir, safe_serialization=args.safe_serialization)
|
| 110 |
+
|
| 111 |
+
copy_tokenizer_files(source_dir, output_dir)
|
| 112 |
+
print(f"Exported Qwen3 model saved to {output_dir}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
modelpath = output_dir
|
| 116 |
+
llm = RKLLM()
|
| 117 |
+
|
| 118 |
+
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
|
| 119 |
+
if ret != 0:
|
| 120 |
+
print('Load model failed!')
|
| 121 |
+
exit(ret)
|
| 122 |
+
|
| 123 |
+
qparams = None
|
| 124 |
+
ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
|
| 125 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
| 126 |
+
|
| 127 |
+
if ret != 0:
|
| 128 |
+
print('Build model failed!')
|
| 129 |
+
exit(ret)
|
| 130 |
+
|
| 131 |
+
# Export rkllm model
|
| 132 |
+
ret = llm.export_rkllm("./language_model_w8a8.rkllm")
|
| 133 |
+
if ret != 0:
|
| 134 |
+
print('Export model failed!')
|
| 135 |
+
exit(ret)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
main()
|
| 141 |
+
|
rkllm_binding.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import enum
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define constants from the header
|
| 6 |
+
CPU0 = (1 << 0) # 0x01
|
| 7 |
+
CPU1 = (1 << 1) # 0x02
|
| 8 |
+
CPU2 = (1 << 2) # 0x04
|
| 9 |
+
CPU3 = (1 << 3) # 0x08
|
| 10 |
+
CPU4 = (1 << 4) # 0x10
|
| 11 |
+
CPU5 = (1 << 5) # 0x20
|
| 12 |
+
CPU6 = (1 << 6) # 0x40
|
| 13 |
+
CPU7 = (1 << 7) # 0x80
|
| 14 |
+
|
| 15 |
+
# --- Enums ---
|
| 16 |
+
class LLMCallState(enum.IntEnum):
|
| 17 |
+
RKLLM_RUN_NORMAL = 0
|
| 18 |
+
RKLLM_RUN_WAITING = 1
|
| 19 |
+
RKLLM_RUN_FINISH = 2
|
| 20 |
+
RKLLM_RUN_ERROR = 3
|
| 21 |
+
|
| 22 |
+
class RKLLMInputType(enum.IntEnum):
|
| 23 |
+
RKLLM_INPUT_PROMPT = 0
|
| 24 |
+
RKLLM_INPUT_TOKEN = 1
|
| 25 |
+
RKLLM_INPUT_EMBED = 2
|
| 26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
| 27 |
+
|
| 28 |
+
class RKLLMInferMode(enum.IntEnum):
|
| 29 |
+
RKLLM_INFER_GENERATE = 0
|
| 30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
| 31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
| 32 |
+
|
| 33 |
+
# --- Structures ---
|
| 34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
| 35 |
+
base_domain_id: ctypes.c_int32
|
| 36 |
+
embed_flash: ctypes.c_int8
|
| 37 |
+
enabled_cpus_num: ctypes.c_int8
|
| 38 |
+
enabled_cpus_mask: ctypes.c_uint32
|
| 39 |
+
n_batch: ctypes.c_uint8
|
| 40 |
+
use_cross_attn: ctypes.c_int8
|
| 41 |
+
reserved: ctypes.c_uint8 * 104
|
| 42 |
+
|
| 43 |
+
_fields_ = [
|
| 44 |
+
("base_domain_id", ctypes.c_int32), # 基础域ID
|
| 45 |
+
("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
|
| 46 |
+
("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
|
| 47 |
+
("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
|
| 48 |
+
("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
|
| 49 |
+
("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
|
| 50 |
+
("reserved", ctypes.c_uint8 * 104) # 保留字段
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
class RKLLMParam(ctypes.Structure):
|
| 54 |
+
model_path: ctypes.c_char_p
|
| 55 |
+
max_context_len: ctypes.c_int32
|
| 56 |
+
max_new_tokens: ctypes.c_int32
|
| 57 |
+
top_k: ctypes.c_int32
|
| 58 |
+
n_keep: ctypes.c_int32
|
| 59 |
+
top_p: ctypes.c_float
|
| 60 |
+
temperature: ctypes.c_float
|
| 61 |
+
repeat_penalty: ctypes.c_float
|
| 62 |
+
frequency_penalty: ctypes.c_float
|
| 63 |
+
presence_penalty: ctypes.c_float
|
| 64 |
+
mirostat: ctypes.c_int32
|
| 65 |
+
mirostat_tau: ctypes.c_float
|
| 66 |
+
mirostat_eta: ctypes.c_float
|
| 67 |
+
skip_special_token: ctypes.c_bool
|
| 68 |
+
is_async: ctypes.c_bool
|
| 69 |
+
img_start: ctypes.c_char_p
|
| 70 |
+
img_end: ctypes.c_char_p
|
| 71 |
+
img_content: ctypes.c_char_p
|
| 72 |
+
extend_param: RKLLMExtendParam
|
| 73 |
+
|
| 74 |
+
_fields_ = [
|
| 75 |
+
("model_path", ctypes.c_char_p), # 模型文件路径
|
| 76 |
+
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
|
| 77 |
+
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
|
| 78 |
+
("top_k", ctypes.c_int32), # Top-K采样参数
|
| 79 |
+
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
|
| 80 |
+
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
|
| 81 |
+
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
|
| 82 |
+
("repeat_penalty", ctypes.c_float), # 重复token惩罚
|
| 83 |
+
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
|
| 84 |
+
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
|
| 85 |
+
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
|
| 86 |
+
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
|
| 87 |
+
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
|
| 88 |
+
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
|
| 89 |
+
("is_async", ctypes.c_bool), # 是否异步推理
|
| 90 |
+
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
|
| 91 |
+
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
|
| 92 |
+
("img_content", ctypes.c_char_p), # 图像内容指针
|
| 93 |
+
("extend_param", RKLLMExtendParam) # 扩展参数
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
| 97 |
+
lora_adapter_path: ctypes.c_char_p
|
| 98 |
+
lora_adapter_name: ctypes.c_char_p
|
| 99 |
+
scale: ctypes.c_float
|
| 100 |
+
|
| 101 |
+
_fields_ = [
|
| 102 |
+
("lora_adapter_path", ctypes.c_char_p),
|
| 103 |
+
("lora_adapter_name", ctypes.c_char_p),
|
| 104 |
+
("scale", ctypes.c_float)
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
| 108 |
+
embed: ctypes.POINTER(ctypes.c_float)
|
| 109 |
+
n_tokens: ctypes.c_size_t
|
| 110 |
+
|
| 111 |
+
_fields_ = [
|
| 112 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
| 113 |
+
("n_tokens", ctypes.c_size_t)
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
class RKLLMTokenInput(ctypes.Structure):
|
| 117 |
+
input_ids: ctypes.POINTER(ctypes.c_int32)
|
| 118 |
+
n_tokens: ctypes.c_size_t
|
| 119 |
+
|
| 120 |
+
_fields_ = [
|
| 121 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
| 122 |
+
("n_tokens", ctypes.c_size_t)
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
| 126 |
+
prompt: ctypes.c_char_p
|
| 127 |
+
image_embed: ctypes.POINTER(ctypes.c_float)
|
| 128 |
+
n_image_tokens: ctypes.c_size_t
|
| 129 |
+
n_image: ctypes.c_size_t
|
| 130 |
+
image_width: ctypes.c_size_t
|
| 131 |
+
image_height: ctypes.c_size_t
|
| 132 |
+
|
| 133 |
+
_fields_ = [
|
| 134 |
+
("prompt", ctypes.c_char_p),
|
| 135 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
| 136 |
+
("n_image_tokens", ctypes.c_size_t),
|
| 137 |
+
("n_image", ctypes.c_size_t),
|
| 138 |
+
("image_width", ctypes.c_size_t),
|
| 139 |
+
("image_height", ctypes.c_size_t)
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
class RKLLMCrossAttnParam(ctypes.Structure):
|
| 143 |
+
"""
|
| 144 |
+
交叉注意力参数结构体
|
| 145 |
+
|
| 146 |
+
该结构体用于在解码器中执行交叉注意力时使用。
|
| 147 |
+
它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
|
| 148 |
+
|
| 149 |
+
- encoder_k_cache必须存储在连续内存中,布局为:
|
| 150 |
+
[num_layers][num_tokens][num_kv_heads][head_dim]
|
| 151 |
+
- encoder_v_cache必须存储在连续内存中,布局为:
|
| 152 |
+
[num_layers][num_kv_heads][head_dim][num_tokens]
|
| 153 |
+
"""
|
| 154 |
+
encoder_k_cache: ctypes.POINTER(ctypes.c_float)
|
| 155 |
+
encoder_v_cache: ctypes.POINTER(ctypes.c_float)
|
| 156 |
+
encoder_mask: ctypes.POINTER(ctypes.c_float)
|
| 157 |
+
encoder_pos: ctypes.POINTER(ctypes.c_int32)
|
| 158 |
+
num_tokens: ctypes.c_int
|
| 159 |
+
|
| 160 |
+
_fields_ = [
|
| 161 |
+
("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
|
| 162 |
+
("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
|
| 163 |
+
("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
|
| 164 |
+
("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
|
| 165 |
+
("num_tokens", ctypes.c_int) # 编码器序列中的token数量
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
class RKLLMPerfStat(ctypes.Structure):
|
| 169 |
+
"""
|
| 170 |
+
性能统计结构体
|
| 171 |
+
|
| 172 |
+
用于保存预填充和生成阶段的性能统计信息。
|
| 173 |
+
"""
|
| 174 |
+
prefill_time_ms: ctypes.c_float
|
| 175 |
+
prefill_tokens: ctypes.c_int
|
| 176 |
+
generate_time_ms: ctypes.c_float
|
| 177 |
+
generate_tokens: ctypes.c_int
|
| 178 |
+
memory_usage_mb: ctypes.c_float
|
| 179 |
+
|
| 180 |
+
_fields_ = [
|
| 181 |
+
("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
|
| 182 |
+
("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
|
| 183 |
+
("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
|
| 184 |
+
("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
|
| 185 |
+
("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
class _RKLLMInputUnion(ctypes.Union):
|
| 189 |
+
prompt_input: ctypes.c_char_p
|
| 190 |
+
embed_input: RKLLMEmbedInput
|
| 191 |
+
token_input: RKLLMTokenInput
|
| 192 |
+
multimodal_input: RKLLMMultiModelInput
|
| 193 |
+
|
| 194 |
+
_fields_ = [
|
| 195 |
+
("prompt_input", ctypes.c_char_p),
|
| 196 |
+
("embed_input", RKLLMEmbedInput),
|
| 197 |
+
("token_input", RKLLMTokenInput),
|
| 198 |
+
("multimodal_input", RKLLMMultiModelInput)
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
class RKLLMInput(ctypes.Structure):
|
| 202 |
+
"""
|
| 203 |
+
LLM输入结构体
|
| 204 |
+
|
| 205 |
+
通过联合体表示不同类型的LLM输入。
|
| 206 |
+
"""
|
| 207 |
+
role: ctypes.c_char_p
|
| 208 |
+
enable_thinking: ctypes.c_bool
|
| 209 |
+
input_type: ctypes.c_int
|
| 210 |
+
_union_data: _RKLLMInputUnion
|
| 211 |
+
|
| 212 |
+
_fields_ = [
|
| 213 |
+
("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
|
| 214 |
+
("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
|
| 215 |
+
("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
|
| 216 |
+
("_union_data", _RKLLMInputUnion) # 联合体数据
|
| 217 |
+
]
|
| 218 |
+
# Properties to make accessing union members easier
|
| 219 |
+
@property
|
| 220 |
+
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
|
| 221 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 222 |
+
return self._union_data.prompt_input
|
| 223 |
+
raise AttributeError("Not a prompt input")
|
| 224 |
+
@prompt_input.setter
|
| 225 |
+
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
|
| 226 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 227 |
+
self._union_data.prompt_input = value
|
| 228 |
+
else:
|
| 229 |
+
raise AttributeError("Not a prompt input")
|
| 230 |
+
@property
|
| 231 |
+
def embed_input(self) -> RKLLMEmbedInput:
|
| 232 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 233 |
+
return self._union_data.embed_input
|
| 234 |
+
raise AttributeError("Not an embed input")
|
| 235 |
+
@embed_input.setter
|
| 236 |
+
def embed_input(self, value: RKLLMEmbedInput):
|
| 237 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 238 |
+
self._union_data.embed_input = value
|
| 239 |
+
else:
|
| 240 |
+
raise AttributeError("Not an embed input")
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def token_input(self) -> RKLLMTokenInput:
|
| 244 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 245 |
+
return self._union_data.token_input
|
| 246 |
+
raise AttributeError("Not a token input")
|
| 247 |
+
@token_input.setter
|
| 248 |
+
def token_input(self, value: RKLLMTokenInput):
|
| 249 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 250 |
+
self._union_data.token_input = value
|
| 251 |
+
else:
|
| 252 |
+
raise AttributeError("Not a token input")
|
| 253 |
+
|
| 254 |
+
@property
|
| 255 |
+
def multimodal_input(self) -> RKLLMMultiModelInput:
|
| 256 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 257 |
+
return self._union_data.multimodal_input
|
| 258 |
+
raise AttributeError("Not a multimodal input")
|
| 259 |
+
@multimodal_input.setter
|
| 260 |
+
def multimodal_input(self, value: RKLLMMultiModelInput):
|
| 261 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 262 |
+
self._union_data.multimodal_input = value
|
| 263 |
+
else:
|
| 264 |
+
raise AttributeError("Not a multimodal input")
|
| 265 |
+
|
| 266 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
| 267 |
+
lora_adapter_name: ctypes.c_char_p
|
| 268 |
+
|
| 269 |
+
_fields_ = [
|
| 270 |
+
("lora_adapter_name", ctypes.c_char_p)
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
| 274 |
+
save_prompt_cache: ctypes.c_int # bool-like
|
| 275 |
+
prompt_cache_path: ctypes.c_char_p
|
| 276 |
+
|
| 277 |
+
_fields_ = [
|
| 278 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
| 279 |
+
("prompt_cache_path", ctypes.c_char_p)
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
class RKLLMInferParam(ctypes.Structure):
|
| 283 |
+
mode: ctypes.c_int
|
| 284 |
+
lora_params: ctypes.POINTER(RKLLMLoraParam)
|
| 285 |
+
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
|
| 286 |
+
keep_history: ctypes.c_int # bool-like
|
| 287 |
+
|
| 288 |
+
_fields_ = [
|
| 289 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
| 290 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
| 291 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
| 292 |
+
("keep_history", ctypes.c_int) # bool-like
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
| 296 |
+
hidden_states: ctypes.POINTER(ctypes.c_float)
|
| 297 |
+
embd_size: ctypes.c_int
|
| 298 |
+
num_tokens: ctypes.c_int
|
| 299 |
+
|
| 300 |
+
_fields_ = [
|
| 301 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
| 302 |
+
("embd_size", ctypes.c_int),
|
| 303 |
+
("num_tokens", ctypes.c_int)
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
class RKLLMResultLogits(ctypes.Structure):
|
| 307 |
+
logits: ctypes.POINTER(ctypes.c_float)
|
| 308 |
+
vocab_size: ctypes.c_int
|
| 309 |
+
num_tokens: ctypes.c_int
|
| 310 |
+
|
| 311 |
+
_fields_ = [
|
| 312 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
| 313 |
+
("vocab_size", ctypes.c_int),
|
| 314 |
+
("num_tokens", ctypes.c_int)
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
class RKLLMResult(ctypes.Structure):
|
| 318 |
+
"""
|
| 319 |
+
LLM推理结果结构体
|
| 320 |
+
|
| 321 |
+
表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
|
| 322 |
+
"""
|
| 323 |
+
text: ctypes.c_char_p
|
| 324 |
+
token_id: ctypes.c_int32
|
| 325 |
+
last_hidden_layer: RKLLMResultLastHiddenLayer
|
| 326 |
+
logits: RKLLMResultLogits
|
| 327 |
+
perf: RKLLMPerfStat
|
| 328 |
+
|
| 329 |
+
_fields_ = [
|
| 330 |
+
("text", ctypes.c_char_p), # 生成的文本结果
|
| 331 |
+
("token_id", ctypes.c_int32), # 生成的token ID
|
| 332 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
|
| 333 |
+
("logits", RKLLMResultLogits), # 模型输出的logits
|
| 334 |
+
("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
# --- Typedefs ---
|
| 338 |
+
LLMHandle = ctypes.c_void_p
|
| 339 |
+
|
| 340 |
+
# --- Callback Function Type ---
|
| 341 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
| 342 |
+
ctypes.c_int, # 返回类型:int,表示处理状态
|
| 343 |
+
ctypes.POINTER(RKLLMResult), # LLM结果指针
|
| 344 |
+
ctypes.c_void_p, # 用户数据指针
|
| 345 |
+
ctypes.c_int # LLM调用状态(LLMCallState枚举值)
|
| 346 |
+
)
|
| 347 |
+
"""
|
| 348 |
+
回调函数类型定义
|
| 349 |
+
|
| 350 |
+
用于处理LLM结果的回调函数。
|
| 351 |
+
|
| 352 |
+
参数:
|
| 353 |
+
- result: 指向LLM结果的指针
|
| 354 |
+
- userdata: 回调的用户数据指针
|
| 355 |
+
- state: LLM调用状态(例如:完成、错误)
|
| 356 |
+
|
| 357 |
+
返回值:
|
| 358 |
+
- 0: 正常继续推理
|
| 359 |
+
- 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
|
| 360 |
+
返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
class RKLLMRuntime:
|
| 364 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
| 365 |
+
try:
|
| 366 |
+
self.lib = ctypes.CDLL(library_path)
|
| 367 |
+
except OSError as e:
|
| 368 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
| 369 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
| 370 |
+
self._setup_functions()
|
| 371 |
+
self.llm_handle = LLMHandle()
|
| 372 |
+
self._c_callback = None # To keep the callback object alive
|
| 373 |
+
|
| 374 |
+
def _setup_functions(self):
|
| 375 |
+
# RKLLMParam rkllm_createDefaultParam();
|
| 376 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
| 377 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
| 378 |
+
|
| 379 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
| 380 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
| 381 |
+
self.lib.rkllm_init.argtypes = [
|
| 382 |
+
ctypes.POINTER(LLMHandle),
|
| 383 |
+
ctypes.POINTER(RKLLMParam),
|
| 384 |
+
LLMResultCallback
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
| 388 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
| 389 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
| 390 |
+
|
| 391 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
| 392 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
| 393 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
| 394 |
+
|
| 395 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
| 396 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
| 397 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
| 398 |
+
|
| 399 |
+
# int rkllm_destroy(LLMHandle handle);
|
| 400 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
| 401 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
| 402 |
+
|
| 403 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 404 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
| 405 |
+
self.lib.rkllm_run.argtypes = [
|
| 406 |
+
LLMHandle,
|
| 407 |
+
ctypes.POINTER(RKLLMInput),
|
| 408 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 409 |
+
ctypes.c_void_p # userdata
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 413 |
+
# Assuming async also takes userdata for the callback context
|
| 414 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
| 415 |
+
self.lib.rkllm_run_async.argtypes = [
|
| 416 |
+
LLMHandle,
|
| 417 |
+
ctypes.POINTER(RKLLMInput),
|
| 418 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 419 |
+
ctypes.c_void_p # userdata
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
# int rkllm_abort(LLMHandle handle);
|
| 423 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
| 424 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
| 425 |
+
|
| 426 |
+
# int rkllm_is_running(LLMHandle handle);
|
| 427 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
| 428 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
| 429 |
+
|
| 430 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
|
| 431 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
| 432 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [
|
| 433 |
+
LLMHandle,
|
| 434 |
+
ctypes.c_int,
|
| 435 |
+
ctypes.POINTER(ctypes.c_int), # start_pos
|
| 436 |
+
ctypes.POINTER(ctypes.c_int) # end_pos
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
# int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
|
| 440 |
+
self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
|
| 441 |
+
self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
|
| 442 |
+
|
| 443 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
| 444 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
| 445 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
| 446 |
+
LLMHandle,
|
| 447 |
+
ctypes.c_char_p,
|
| 448 |
+
ctypes.c_char_p,
|
| 449 |
+
ctypes.c_char_p
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
# int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
|
| 453 |
+
self.lib.rkllm_set_function_tools.restype = ctypes.c_int
|
| 454 |
+
self.lib.rkllm_set_function_tools.argtypes = [
|
| 455 |
+
LLMHandle,
|
| 456 |
+
ctypes.c_char_p, # system_prompt
|
| 457 |
+
ctypes.c_char_p, # tools
|
| 458 |
+
ctypes.c_char_p # tool_response_str
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
# int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
|
| 462 |
+
self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
|
| 463 |
+
self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
|
| 464 |
+
|
| 465 |
+
def create_default_param(self) -> RKLLMParam:
|
| 466 |
+
"""Creates a default RKLLMParam structure."""
|
| 467 |
+
return self.lib.rkllm_createDefaultParam()
|
| 468 |
+
|
| 469 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
| 470 |
+
"""
|
| 471 |
+
Initializes the LLM.
|
| 472 |
+
:param param: RKLLMParam structure.
|
| 473 |
+
:param callback_func: A Python function that matches the signature:
|
| 474 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
| 475 |
+
result = result_ptr.contents # RKLLMResult
|
| 476 |
+
# Process result
|
| 477 |
+
# userdata can be retrieved if passed during run, or ignored
|
| 478 |
+
# state = LLMCallState(state_enum)
|
| 479 |
+
:return: 0 for success, non-zero for failure.
|
| 480 |
+
"""
|
| 481 |
+
if not callable(callback_func):
|
| 482 |
+
raise ValueError("callback_func must be a callable Python function.")
|
| 483 |
+
|
| 484 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
| 485 |
+
self._c_callback = LLMResultCallback(callback_func)
|
| 486 |
+
|
| 487 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
| 488 |
+
if ret != 0:
|
| 489 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
| 490 |
+
return ret
|
| 491 |
+
|
| 492 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
| 493 |
+
"""Loads a Lora adapter."""
|
| 494 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
| 495 |
+
if ret != 0:
|
| 496 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
| 497 |
+
return ret
|
| 498 |
+
|
| 499 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
| 500 |
+
"""Loads a prompt cache from a file."""
|
| 501 |
+
c_path = prompt_cache_path.encode('utf-8')
|
| 502 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
| 503 |
+
if ret != 0:
|
| 504 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
| 505 |
+
return ret
|
| 506 |
+
|
| 507 |
+
def release_prompt_cache(self) -> int:
|
| 508 |
+
"""Releases the prompt cache from memory."""
|
| 509 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
| 510 |
+
if ret != 0:
|
| 511 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
| 512 |
+
return ret
|
| 513 |
+
|
| 514 |
+
def destroy(self) -> int:
|
| 515 |
+
"""Destroys the LLM instance and releases resources."""
|
| 516 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
| 517 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
| 518 |
+
self.llm_handle = LLMHandle() # Reset handle
|
| 519 |
+
if ret != 0:
|
| 520 |
+
# Don't raise here as it might be called in __del__
|
| 521 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
| 522 |
+
return ret
|
| 523 |
+
return 0 # Already destroyed or not initialized
|
| 524 |
+
|
| 525 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 526 |
+
"""Runs an LLM inference task synchronously."""
|
| 527 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
| 528 |
+
# then cast to c_void_p. Or simply None.
|
| 529 |
+
if userdata is not None:
|
| 530 |
+
# Store the userdata object to keep it alive during the call
|
| 531 |
+
self._userdata_ref = userdata
|
| 532 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 533 |
+
else:
|
| 534 |
+
c_userdata = None
|
| 535 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 536 |
+
if ret != 0:
|
| 537 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
| 538 |
+
return ret
|
| 539 |
+
|
| 540 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 541 |
+
"""Runs an LLM inference task asynchronously."""
|
| 542 |
+
if userdata is not None:
|
| 543 |
+
# Store the userdata object to keep it alive during the call
|
| 544 |
+
self._userdata_ref = userdata
|
| 545 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 546 |
+
else:
|
| 547 |
+
c_userdata = None
|
| 548 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 549 |
+
if ret != 0:
|
| 550 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
| 551 |
+
return ret
|
| 552 |
+
|
| 553 |
+
def abort(self) -> int:
|
| 554 |
+
"""Aborts an ongoing LLM task."""
|
| 555 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
| 556 |
+
if ret != 0:
|
| 557 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
| 558 |
+
return ret
|
| 559 |
+
|
| 560 |
+
def is_running(self) -> bool:
|
| 561 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
| 562 |
+
# The C API returns 0 if running, non-zero otherwise.
|
| 563 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
| 564 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
| 565 |
+
|
| 566 |
+
def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
|
| 567 |
+
"""
|
| 568 |
+
清除键值缓存
|
| 569 |
+
|
| 570 |
+
此函数用于清除部分或全部KV缓存。
|
| 571 |
+
|
| 572 |
+
参数:
|
| 573 |
+
- keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
|
| 574 |
+
如果提供了特定范围[start_pos, end_pos),此标志将被忽略
|
| 575 |
+
- start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
|
| 576 |
+
- end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
|
| 577 |
+
如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
|
| 578 |
+
如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
|
| 579 |
+
|
| 580 |
+
注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
|
| 581 |
+
|
| 582 |
+
返回:0表示缓存清除成功,非零表示失败
|
| 583 |
+
"""
|
| 584 |
+
# 准备C数组参数
|
| 585 |
+
c_start_pos = None
|
| 586 |
+
c_end_pos = None
|
| 587 |
+
|
| 588 |
+
if start_pos is not None and end_pos is not None:
|
| 589 |
+
if len(start_pos) != len(end_pos):
|
| 590 |
+
raise ValueError("start_pos和end_pos数组长度必须相同")
|
| 591 |
+
|
| 592 |
+
# 创建C数组
|
| 593 |
+
c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
|
| 594 |
+
c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
|
| 595 |
+
|
| 596 |
+
ret = self.lib.rkllm_clear_kv_cache(
|
| 597 |
+
self.llm_handle,
|
| 598 |
+
ctypes.c_int(1 if keep_system_prompt else 0),
|
| 599 |
+
c_start_pos,
|
| 600 |
+
c_end_pos
|
| 601 |
+
)
|
| 602 |
+
if ret != 0:
|
| 603 |
+
raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
|
| 604 |
+
return ret
|
| 605 |
+
|
| 606 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
| 607 |
+
"""Sets the chat template for the LLM."""
|
| 608 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 609 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
|
| 610 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
|
| 611 |
+
|
| 612 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
| 613 |
+
if ret != 0:
|
| 614 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
| 615 |
+
return ret
|
| 616 |
+
|
| 617 |
+
def get_kv_cache_size(self, n_batch: int) -> list:
|
| 618 |
+
"""
|
| 619 |
+
获取给定LLM句柄的键值缓存当前大小
|
| 620 |
+
|
| 621 |
+
此函数返回当前存储在模型KV缓存中的位置总数。
|
| 622 |
+
|
| 623 |
+
参数:
|
| 624 |
+
- n_batch: 批次数量,用于确定返回数组的大小
|
| 625 |
+
|
| 626 |
+
返回:
|
| 627 |
+
- list: 每个批次的缓存大小列表
|
| 628 |
+
"""
|
| 629 |
+
# 预分配数组以存储每个批次的缓存大小
|
| 630 |
+
cache_sizes = (ctypes.c_int * n_batch)()
|
| 631 |
+
|
| 632 |
+
ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
|
| 633 |
+
if ret != 0:
|
| 634 |
+
raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
|
| 635 |
+
|
| 636 |
+
# 转换为Python列表
|
| 637 |
+
return [cache_sizes[i] for i in range(n_batch)]
|
| 638 |
+
|
| 639 |
+
def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
|
| 640 |
+
"""
|
| 641 |
+
为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
|
| 642 |
+
|
| 643 |
+
参数:
|
| 644 |
+
- system_prompt: 定义语言模型上下文或行为的系统提示
|
| 645 |
+
- tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
|
| 646 |
+
- tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
|
| 647 |
+
允许分词器将工具输出与正常对话轮次分开识别
|
| 648 |
+
|
| 649 |
+
返回:0表示配置设置成功,非零表示错误
|
| 650 |
+
"""
|
| 651 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 652 |
+
c_tools = tools.encode('utf-8') if tools else b""
|
| 653 |
+
c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
|
| 654 |
+
|
| 655 |
+
ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
|
| 656 |
+
if ret != 0:
|
| 657 |
+
raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
|
| 658 |
+
return ret
|
| 659 |
+
|
| 660 |
+
def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
|
| 661 |
+
"""
|
| 662 |
+
为LLM解码器设置交叉注意力参数
|
| 663 |
+
|
| 664 |
+
参数:
|
| 665 |
+
- cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
|
| 666 |
+
(详见RKLLMCrossAttnParam说明)
|
| 667 |
+
|
| 668 |
+
返回:0表示参数设置成功,非零表示错误
|
| 669 |
+
"""
|
| 670 |
+
ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
|
| 671 |
+
if ret != 0:
|
| 672 |
+
raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
|
| 673 |
+
return ret
|
| 674 |
+
|
| 675 |
+
def __enter__(self):
|
| 676 |
+
return self
|
| 677 |
+
|
| 678 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 679 |
+
self.destroy()
|
| 680 |
+
|
| 681 |
+
def __del__(self):
|
| 682 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
| 683 |
+
|
| 684 |
+
# --- Example Usage (Illustrative) ---
|
| 685 |
+
if __name__ == "__main__":
|
| 686 |
+
# This is a placeholder for how you might use it.
|
| 687 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
| 688 |
+
|
| 689 |
+
# Global list to store results from callback for demonstration
|
| 690 |
+
results_buffer = []
|
| 691 |
+
|
| 692 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 693 |
+
"""
|
| 694 |
+
回调函数,由C库调用来处理LLM结果
|
| 695 |
+
|
| 696 |
+
参数:
|
| 697 |
+
- result_ptr: 指向LLM结果的指针
|
| 698 |
+
- userdata_ptr: 用户数据指针
|
| 699 |
+
- state_enum: LLM调用状态枚举值
|
| 700 |
+
|
| 701 |
+
返回:
|
| 702 |
+
- 0: 继续推理
|
| 703 |
+
- 1: 暂停推理
|
| 704 |
+
"""
|
| 705 |
+
global results_buffer
|
| 706 |
+
state = LLMCallState(state_enum)
|
| 707 |
+
result = result_ptr.contents
|
| 708 |
+
|
| 709 |
+
current_text = ""
|
| 710 |
+
if result.text: # 检查char_p是否不为NULL
|
| 711 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 712 |
+
|
| 713 |
+
print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
| 714 |
+
|
| 715 |
+
# 显示性能统计信息
|
| 716 |
+
if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0:
|
| 717 |
+
print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, "
|
| 718 |
+
f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, "
|
| 719 |
+
f"内存={result.perf.memory_usage_mb:.1f}MB")
|
| 720 |
+
|
| 721 |
+
results_buffer.append(current_text)
|
| 722 |
+
|
| 723 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 724 |
+
print("推理完成。")
|
| 725 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 726 |
+
print("推理错误。")
|
| 727 |
+
|
| 728 |
+
# 返回0继续推理,返回1暂停推理
|
| 729 |
+
return 0
|
| 730 |
+
|
| 731 |
+
# --- Attempt to use the wrapper ---
|
| 732 |
+
try:
|
| 733 |
+
print("Initializing RKLLMRuntime...")
|
| 734 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 735 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 736 |
+
rk_llm = RKLLMRuntime()
|
| 737 |
+
|
| 738 |
+
print("Creating default parameters...")
|
| 739 |
+
params = rk_llm.create_default_param()
|
| 740 |
+
|
| 741 |
+
# --- Configure parameters ---
|
| 742 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
| 743 |
+
# For this example to run, you need a model file.
|
| 744 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
| 745 |
+
model_file = "dummy_model.rkllm"
|
| 746 |
+
if not os.path.exists(model_file):
|
| 747 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
| 748 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
| 749 |
+
# with a real library unless it's a valid model.
|
| 750 |
+
with open(model_file, "w") as f:
|
| 751 |
+
f.write("dummy content")
|
| 752 |
+
|
| 753 |
+
params.model_path = model_file.encode('utf-8')
|
| 754 |
+
params.max_context_len = 512
|
| 755 |
+
params.max_new_tokens = 128
|
| 756 |
+
params.top_k = 1 # Greedy
|
| 757 |
+
params.temperature = 0.7
|
| 758 |
+
params.repeat_penalty = 1.1
|
| 759 |
+
# ... set other params as needed
|
| 760 |
+
|
| 761 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 762 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 763 |
+
try:
|
| 764 |
+
rk_llm.init(params, my_python_callback)
|
| 765 |
+
print("LLM Initialized.")
|
| 766 |
+
except RuntimeError as e:
|
| 767 |
+
print(f"Error during LLM initialization: {e}")
|
| 768 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
| 769 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
| 770 |
+
exit()
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
# --- Prepare input ---
|
| 774 |
+
print("准备输入...")
|
| 775 |
+
rk_input = RKLLMInput()
|
| 776 |
+
rk_input.role = b"user" # 设置角色为用户输入
|
| 777 |
+
rk_input.enable_thinking = False # 禁用思考模式(适用于Qwen3模型)
|
| 778 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 779 |
+
|
| 780 |
+
prompt_text = "将以下英文文本翻译成中文:'Hello, world!'"
|
| 781 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 782 |
+
rk_input._union_data.prompt_input = c_prompt # 直接访问联合体成员
|
| 783 |
+
|
| 784 |
+
# --- Prepare inference parameters ---
|
| 785 |
+
print("Preparing inference parameters...")
|
| 786 |
+
infer_params = RKLLMInferParam()
|
| 787 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 788 |
+
infer_params.keep_history = 1 # True
|
| 789 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
| 790 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
| 791 |
+
|
| 792 |
+
# --- Run inference ---
|
| 793 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
| 794 |
+
results_buffer.clear()
|
| 795 |
+
try:
|
| 796 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
| 797 |
+
print("\n--- Full Response ---")
|
| 798 |
+
print("".join(results_buffer))
|
| 799 |
+
print("---------------------\n")
|
| 800 |
+
except RuntimeError as e:
|
| 801 |
+
print(f"Error during LLM run: {e}")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
# --- Example: Set chat template (if model supports it) ---
|
| 805 |
+
# print("Setting chat template...")
|
| 806 |
+
# try:
|
| 807 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
| 808 |
+
# print("Chat template set.")
|
| 809 |
+
# except RuntimeError as e:
|
| 810 |
+
# print(f"Error setting chat template: {e}")
|
| 811 |
+
|
| 812 |
+
# --- Example: Clear KV Cache ---
|
| 813 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
| 814 |
+
# try:
|
| 815 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 816 |
+
# print("KV cache cleared.")
|
| 817 |
+
# except RuntimeError as e:
|
| 818 |
+
# print(f"Error clearing KV cache: {e}")
|
| 819 |
+
|
| 820 |
+
# --- 示例:获取KV缓存大小 ---
|
| 821 |
+
# print("获取KV缓存大小...")
|
| 822 |
+
# try:
|
| 823 |
+
# cache_sizes = rk_llm.get_kv_cache_size(n_batch=1) # 假设批次大小为1
|
| 824 |
+
# print(f"当前KV缓存大小: {cache_sizes}")
|
| 825 |
+
# except RuntimeError as e:
|
| 826 |
+
# print(f"获取KV缓存大小错误: {e}")
|
| 827 |
+
|
| 828 |
+
# --- 示例:设置函数工具 ---
|
| 829 |
+
# print("设置函数调用工具...")
|
| 830 |
+
# try:
|
| 831 |
+
# system_prompt = "你是一个有用的助手,可以调用提供的函��来帮助用户。"
|
| 832 |
+
# tools = '''[{
|
| 833 |
+
# "name": "get_weather",
|
| 834 |
+
# "description": "获取指定城市的天气信息",
|
| 835 |
+
# "parameters": {
|
| 836 |
+
# "type": "object",
|
| 837 |
+
# "properties": {
|
| 838 |
+
# "city": {"type": "string", "description": "城市名称"}
|
| 839 |
+
# },
|
| 840 |
+
# "required": ["city"]
|
| 841 |
+
# }
|
| 842 |
+
# }]'''
|
| 843 |
+
# tool_response_str = "<tool_response>"
|
| 844 |
+
# rk_llm.set_function_tools(system_prompt, tools, tool_response_str)
|
| 845 |
+
# print("函数工具设置成功。")
|
| 846 |
+
# except RuntimeError as e:
|
| 847 |
+
# print(f"设置函数工具错误: {e}")
|
| 848 |
+
|
| 849 |
+
# --- 示例:清除KV缓存(带范围参数) ---
|
| 850 |
+
# print("使用范围参数清除KV缓存...")
|
| 851 |
+
# try:
|
| 852 |
+
# # 清除位置10到20的缓存
|
| 853 |
+
# start_positions = [10] # 批次0的起始位置
|
| 854 |
+
# end_positions = [20] # 批次0的结束位置
|
| 855 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True, start_pos=start_positions, end_pos=end_positions)
|
| 856 |
+
# print("范围KV缓存清除完成。")
|
| 857 |
+
# except RuntimeError as e:
|
| 858 |
+
# print(f"清除范围KV缓存错误: {e}")
|
| 859 |
+
|
| 860 |
+
except OSError as e:
|
| 861 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 862 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 863 |
+
except Exception as e:
|
| 864 |
+
print(f"An unexpected error occurred: {e}")
|
| 865 |
+
finally:
|
| 866 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 867 |
+
print("Destroying LLM instance...")
|
| 868 |
+
rk_llm.destroy()
|
| 869 |
+
print("LLM instance destroyed.")
|
| 870 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
| 871 |
+
os.remove(model_file) # Clean up dummy file
|
| 872 |
+
|
| 873 |
+
print("Example finished.")
|
run_rkllm.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faulthandler
|
| 2 |
+
faulthandler.enable()
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
| 6 |
+
import ctypes
|
| 7 |
+
import argparse
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import ztu_somemodelruntime_rknnlite2 as ort
|
| 11 |
+
from rkllm_binding import (
|
| 12 |
+
RKLLMRuntime,
|
| 13 |
+
RKLLMParam,
|
| 14 |
+
RKLLMInput,
|
| 15 |
+
RKLLMInferParam,
|
| 16 |
+
LLMCallState,
|
| 17 |
+
RKLLMInputType,
|
| 18 |
+
RKLLMInferMode,
|
| 19 |
+
RKLLMResult
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Constants aligned with InternVL config
|
| 23 |
+
IMAGE_HEIGHT = 448
|
| 24 |
+
IMAGE_WIDTH = 448
|
| 25 |
+
IMAGE_SEQ_LENGTH = 256
|
| 26 |
+
MULTIMODAL_HIDDEN_DIM = 2048
|
| 27 |
+
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 28 |
+
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 29 |
+
|
| 30 |
+
def expand2square(img, background_color):
|
| 31 |
+
"""
|
| 32 |
+
Expand the image into a square and fill it with the specified background color.
|
| 33 |
+
"""
|
| 34 |
+
height, width, _ = img.shape
|
| 35 |
+
if width == height:
|
| 36 |
+
return img.copy()
|
| 37 |
+
|
| 38 |
+
size = max(width, height)
|
| 39 |
+
square_img = np.full((size, size, 3), background_color, dtype=np.uint8)
|
| 40 |
+
|
| 41 |
+
x_offset = (size - width) // 2
|
| 42 |
+
y_offset = (size - height) // 2
|
| 43 |
+
|
| 44 |
+
square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img
|
| 45 |
+
return square_img
|
| 46 |
+
|
| 47 |
+
def llm_callback(result_ptr, userdata_ptr, state_enum):
|
| 48 |
+
"""
|
| 49 |
+
Callback function to handle LLM results.
|
| 50 |
+
"""
|
| 51 |
+
state = LLMCallState(state_enum)
|
| 52 |
+
result = result_ptr.contents
|
| 53 |
+
|
| 54 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
| 55 |
+
if result.text:
|
| 56 |
+
print(result.text.decode('utf-8', errors='ignore'), end='', flush=True)
|
| 57 |
+
elif state == LLMCallState.RKLLM_RUN_FINISH:
|
| 58 |
+
print("\n", flush=True)
|
| 59 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 60 |
+
print("\nrun error", flush=True)
|
| 61 |
+
|
| 62 |
+
return 0
|
| 63 |
+
|
| 64 |
+
def main():
|
| 65 |
+
parser = argparse.ArgumentParser(
|
| 66 |
+
description="Run RKLLM visual language model inference based on the C++ example."
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument("image_path", type=str, help="Path to the input image.")
|
| 69 |
+
parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.")
|
| 70 |
+
parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.")
|
| 71 |
+
parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.")
|
| 72 |
+
parser.add_argument("max_context_len", type=int, help="Maximum context length.")
|
| 73 |
+
# The rknn_core_num is not directly used by onnxruntime in the same way,
|
| 74 |
+
# but we keep it for API consistency with the C++ example.
|
| 75 |
+
# ONNX Runtime will manage its own threading and execution providers.
|
| 76 |
+
parser.add_argument("rknn_core_num", type=int, help="Sets the number of npu cores used in vision encoder.")
|
| 77 |
+
|
| 78 |
+
args = parser.parse_args()
|
| 79 |
+
|
| 80 |
+
# --- 1. Initialize Image Encoder (ONNX Runtime) ---
|
| 81 |
+
print("Initializing ONNX Runtime for vision encoder...")
|
| 82 |
+
try:
|
| 83 |
+
sess_options = ort.SessionOptions()
|
| 84 |
+
sess_options.intra_op_num_threads = args.rknn_core_num
|
| 85 |
+
ort_session = ort.InferenceSession(args.encoder_model_path, sess_options=sess_options)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Failed to load ONNX model: {e}")
|
| 88 |
+
sys.exit(1)
|
| 89 |
+
print("Vision encoder loaded successfully.")
|
| 90 |
+
|
| 91 |
+
input_name = ort_session.get_inputs()[0].name
|
| 92 |
+
output_name = ort_session.get_outputs()[0].name
|
| 93 |
+
print(f"ONNX Input: {input_name}, ONNX Output: {output_name}")
|
| 94 |
+
|
| 95 |
+
# --- 2. Initialize LLM ---
|
| 96 |
+
print("Initializing RKLLM Runtime...")
|
| 97 |
+
rk_llm = RKLLMRuntime()
|
| 98 |
+
param = rk_llm.create_default_param()
|
| 99 |
+
|
| 100 |
+
param.model_path = args.llm_model_path.encode('utf-8')
|
| 101 |
+
param.top_k = 1
|
| 102 |
+
param.max_new_tokens = args.max_new_tokens
|
| 103 |
+
param.max_context_len = args.max_context_len
|
| 104 |
+
param.skip_special_token = True
|
| 105 |
+
param.img_start = b"<img>"
|
| 106 |
+
param.img_end = b"</img>\n"
|
| 107 |
+
param.img_content = b""
|
| 108 |
+
param.extend_param.base_domain_id = 1
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
rk_llm.init(param, llm_callback)
|
| 112 |
+
print("RKLLM initialized successfully.")
|
| 113 |
+
except RuntimeError as e:
|
| 114 |
+
print(f"RKLLM init failed: {e}")
|
| 115 |
+
sys.exit(1)
|
| 116 |
+
|
| 117 |
+
# --- 3. Image Preprocessing ---
|
| 118 |
+
print("Preprocessing image...")
|
| 119 |
+
img = cv2.imread(args.image_path)
|
| 120 |
+
if img is None:
|
| 121 |
+
print(f"Failed to read image from {args.image_path}")
|
| 122 |
+
sys.exit(1)
|
| 123 |
+
|
| 124 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 125 |
+
|
| 126 |
+
background_color = (127.5, 127.5, 127.5) # Keep close to official preprocessing
|
| 127 |
+
square_img = expand2square(img, background_color)
|
| 128 |
+
resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR)
|
| 129 |
+
|
| 130 |
+
# Normalize and prepare for ONNX model
|
| 131 |
+
input_tensor = resized_img.astype(np.float32)
|
| 132 |
+
# Normalize using InternVL vision config statistics
|
| 133 |
+
input_tensor = (input_tensor / 255.0 - IMAGENET_MEAN) / IMAGENET_STD
|
| 134 |
+
# Convert to NCHW format
|
| 135 |
+
input_tensor = np.transpose(input_tensor, (2, 0, 1)) # HWC -> CHW
|
| 136 |
+
input_tensor = np.expand_dims(input_tensor, axis=0) # Add batch dimension -> (1, 3, 448, 448)
|
| 137 |
+
|
| 138 |
+
# --- 4. Run Image Encoder ---
|
| 139 |
+
print("Running vision encoder...")
|
| 140 |
+
import time
|
| 141 |
+
start_time = time.time()
|
| 142 |
+
try:
|
| 143 |
+
img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0]
|
| 144 |
+
if img_vec_output.ndim != 3:
|
| 145 |
+
raise RuntimeError(f"Unexpected encoder output shape {img_vec_output.shape}, expected (batch, tokens, hidden)")
|
| 146 |
+
if img_vec_output.shape[-1] != MULTIMODAL_HIDDEN_DIM:
|
| 147 |
+
print(f"Warning: hidden dim {img_vec_output.shape[-1]} differs from expected {MULTIMODAL_HIDDEN_DIM}")
|
| 148 |
+
if img_vec_output.shape[1] != IMAGE_SEQ_LENGTH:
|
| 149 |
+
print(f"Warning: token count {img_vec_output.shape[1]} differs from expected {IMAGE_SEQ_LENGTH}")
|
| 150 |
+
elapsed_time = time.time() - start_time
|
| 151 |
+
print(f"视觉编码器推理耗时: {elapsed_time:.4f} 秒")
|
| 152 |
+
# The output from C++ is a flat float array. Let's flatten the ONNX output.
|
| 153 |
+
img_vec = img_vec_output.flatten().astype(np.float32)
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Failed to run vision encoder inference: {e}")
|
| 157 |
+
rk_llm.destroy()
|
| 158 |
+
sys.exit(1)
|
| 159 |
+
|
| 160 |
+
print("Image encoded successfully.")
|
| 161 |
+
|
| 162 |
+
# --- 5. Interactive Chat Loop ---
|
| 163 |
+
rkllm_infer_params = RKLLMInferParam()
|
| 164 |
+
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 165 |
+
rkllm_infer_params.keep_history = 1
|
| 166 |
+
|
| 167 |
+
# Set chat template
|
| 168 |
+
# Looks the default template parsed by RKLLM gives better result than this one, don't know why.
|
| 169 |
+
|
| 170 |
+
# rk_llm.set_chat_template(
|
| 171 |
+
# system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
|
| 172 |
+
# prompt_prefix="<|im_start|>user\n",
|
| 173 |
+
# prompt_postfix="<|im_end|>\n<|im_start|>assistant\n"
|
| 174 |
+
# )
|
| 175 |
+
|
| 176 |
+
pre_input = [
|
| 177 |
+
"<image>What is in the image?",
|
| 178 |
+
"<image>这张图片中有什么?"
|
| 179 |
+
]
|
| 180 |
+
print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n")
|
| 181 |
+
for i, p in enumerate(pre_input):
|
| 182 |
+
print(f"[{i}] {p}")
|
| 183 |
+
print("\n*************************************************************************\n")
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
while True:
|
| 187 |
+
print("\nuser: ", end="", flush=True)
|
| 188 |
+
input_str = sys.stdin.readline().strip()
|
| 189 |
+
|
| 190 |
+
if not input_str:
|
| 191 |
+
continue
|
| 192 |
+
if input_str == "exit":
|
| 193 |
+
break
|
| 194 |
+
if input_str == "clear":
|
| 195 |
+
try:
|
| 196 |
+
rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 197 |
+
print("KV cache cleared.")
|
| 198 |
+
except RuntimeError as e:
|
| 199 |
+
print(f"Failed to clear KV cache: {e}")
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
idx = int(input_str)
|
| 204 |
+
if 0 <= idx < len(pre_input):
|
| 205 |
+
input_str = pre_input[idx]
|
| 206 |
+
print(input_str)
|
| 207 |
+
except (ValueError, IndexError):
|
| 208 |
+
pass # Use the raw string if not a valid index
|
| 209 |
+
|
| 210 |
+
rkllm_input = RKLLMInput()
|
| 211 |
+
rkllm_input.role = b"user"
|
| 212 |
+
|
| 213 |
+
print("robot: ", end="", flush=True)
|
| 214 |
+
|
| 215 |
+
if "<image>" in input_str:
|
| 216 |
+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
|
| 217 |
+
|
| 218 |
+
# Setup multimodal input
|
| 219 |
+
rkllm_input.multimodal_input.prompt = input_str.encode('utf-8')
|
| 220 |
+
rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
| 221 |
+
rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[1]
|
| 222 |
+
print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens)
|
| 223 |
+
rkllm_input.multimodal_input.n_image = 1
|
| 224 |
+
rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT
|
| 225 |
+
rkllm_input.multimodal_input.image_width = IMAGE_WIDTH
|
| 226 |
+
else:
|
| 227 |
+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 228 |
+
rkllm_input.prompt_input = input_str.encode('utf-8')
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
rk_llm.run(rkllm_input, rkllm_infer_params)
|
| 232 |
+
except RuntimeError as e:
|
| 233 |
+
print(f"\nError during rkllm_run: {e}")
|
| 234 |
+
|
| 235 |
+
except KeyboardInterrupt:
|
| 236 |
+
print("\nExiting...")
|
| 237 |
+
finally:
|
| 238 |
+
print("Releasing resources...")
|
| 239 |
+
rk_llm.destroy()
|
| 240 |
+
print("RKLLM instance destroyed.")
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|
test.jpg
ADDED
|
Git LFS Details
|
vision_encoder.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69670e8e48938fcd2543c5cae22e7789d5783b525a34a7013edeba724744c461
|
| 3 |
+
size 674706120
|
ztu_somemodelruntime_rknnlite2.py
ADDED
|
@@ -0,0 +1,1195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 模块级常量和函数
|
| 2 |
+
from rknnlite.api import RKNNLite
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Union, Optional
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
HAS_ORT = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
HAS_ORT = False
|
| 14 |
+
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
|
| 15 |
+
|
| 16 |
+
# 配置日志
|
| 17 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
| 18 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
| 19 |
+
if not logger.handlers:
|
| 20 |
+
handler = logging.StreamHandler()
|
| 21 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 22 |
+
logger.addHandler(handler)
|
| 23 |
+
|
| 24 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
| 25 |
+
_LOGGING_LEVEL_MAP = {
|
| 26 |
+
0: logging.DEBUG, # Verbose
|
| 27 |
+
1: logging.INFO, # Info
|
| 28 |
+
2: logging.WARNING, # Warning
|
| 29 |
+
3: logging.ERROR, # Error
|
| 30 |
+
4: logging.CRITICAL # Fatal
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# 检查环境变量中的日志级别设置
|
| 34 |
+
try:
|
| 35 |
+
env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
|
| 36 |
+
if env_log_level is not None:
|
| 37 |
+
log_level = int(env_log_level)
|
| 38 |
+
if log_level in _LOGGING_LEVEL_MAP:
|
| 39 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
|
| 40 |
+
logger.info(f"从环境变量设置日志级别: {log_level}")
|
| 41 |
+
else:
|
| 42 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
|
| 43 |
+
except ValueError:
|
| 44 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def set_default_logger_severity(level: int) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
level: 日志级别(0-4)
|
| 53 |
+
"""
|
| 54 |
+
if level not in _LOGGING_LEVEL_MAP:
|
| 55 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
| 56 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
| 57 |
+
|
| 58 |
+
def set_default_logger_verbosity(level: int) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
| 61 |
+
you need to set the default logging severity to 0:Verbose level.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
level: 日志级别(0-4)
|
| 65 |
+
"""
|
| 66 |
+
set_default_logger_severity(level)
|
| 67 |
+
|
| 68 |
+
# RKNN tensor type到numpy dtype的映射
|
| 69 |
+
RKNN_DTYPE_MAP = {
|
| 70 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
| 71 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
| 72 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
| 73 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
| 74 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
| 75 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
| 76 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
| 77 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
| 78 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
| 79 |
+
9: bool, # RKNN_TENSOR_BOOL
|
| 80 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def get_available_providers() -> List[str]:
|
| 84 |
+
"""
|
| 85 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 89 |
+
"""
|
| 90 |
+
return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_device() -> str:
|
| 94 |
+
"""
|
| 95 |
+
获取当前设备
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
str: 当前设备
|
| 99 |
+
"""
|
| 100 |
+
return "RKNN2"
|
| 101 |
+
|
| 102 |
+
def get_version_info() -> Dict[str, str]:
|
| 103 |
+
"""
|
| 104 |
+
获取版本信息
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
dict: 包含API和驱动版本信息的字典
|
| 108 |
+
"""
|
| 109 |
+
runtime = RKNNLite()
|
| 110 |
+
version = runtime.get_sdk_version()
|
| 111 |
+
return {
|
| 112 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
| 113 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
class IOTensor:
|
| 117 |
+
"""输入/输出张量的信息封装类"""
|
| 118 |
+
def __init__(self, name, shape, type=None):
|
| 119 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
| 120 |
+
self.shape = shape
|
| 121 |
+
self.type = type
|
| 122 |
+
|
| 123 |
+
def __str__(self):
|
| 124 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
| 125 |
+
|
| 126 |
+
class SessionOptions:
|
| 127 |
+
"""会话选项类"""
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.enable_profiling = False # 是否使用性能分析
|
| 130 |
+
self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
|
| 131 |
+
self.log_severity_level = -1 # 另一个设置日志级别的参数
|
| 132 |
+
self.log_verbosity_level = -1 # 另一个设置日志级别的参数
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class InferenceSession:
|
| 136 |
+
"""
|
| 137 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 141 |
+
processed_path = InferenceSession._process_model_path(model_path, sess_options)
|
| 142 |
+
if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
|
| 143 |
+
logger.info("使用ONNX Runtime加载模型")
|
| 144 |
+
if not HAS_ORT:
|
| 145 |
+
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
|
| 146 |
+
return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
|
| 147 |
+
else:
|
| 148 |
+
# 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
|
| 149 |
+
instance = super().__new__(cls)
|
| 150 |
+
# 保存处理后的路径
|
| 151 |
+
instance._processed_path = processed_path
|
| 152 |
+
return instance
|
| 153 |
+
|
| 154 |
+
def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 155 |
+
"""
|
| 156 |
+
初始化运行时并加载模型
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
| 160 |
+
sess_options: 会话选项
|
| 161 |
+
**kwargs: 其他初始化参数
|
| 162 |
+
"""
|
| 163 |
+
options = sess_options or SessionOptions()
|
| 164 |
+
|
| 165 |
+
# 只在未设置环境变量时使用SessionOptions中的日志级别
|
| 166 |
+
if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
|
| 167 |
+
if options.log_severity_level != -1:
|
| 168 |
+
set_default_logger_severity(options.log_severity_level)
|
| 169 |
+
if options.log_verbosity_level != -1:
|
| 170 |
+
set_default_logger_verbosity(options.log_verbosity_level)
|
| 171 |
+
|
| 172 |
+
# 使用__new__中处理好的路径
|
| 173 |
+
model_path = getattr(self, '_processed_path', model_path)
|
| 174 |
+
if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
|
| 175 |
+
# 避免重复加载 ONNX 模型
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
# ... 现有的 RKNN 模型加载和初始化代码 ...
|
| 179 |
+
self.model_path = model_path
|
| 180 |
+
if not os.path.exists(self.model_path):
|
| 181 |
+
logger.error(f"模型文件不存在: {self.model_path}")
|
| 182 |
+
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
| 183 |
+
|
| 184 |
+
self.runtime = RKNNLite(verbose=options.enable_profiling)
|
| 185 |
+
|
| 186 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
| 187 |
+
ret = self.runtime.load_rknn(self.model_path)
|
| 188 |
+
if ret != 0:
|
| 189 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
| 190 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
| 191 |
+
logger.debug("模型加载成功")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if options.intra_op_num_threads == 1:
|
| 195 |
+
core_mask = RKNNLite.NPU_CORE_AUTO
|
| 196 |
+
elif options.intra_op_num_threads == 2:
|
| 197 |
+
core_mask = RKNNLite.NPU_CORE_0_1
|
| 198 |
+
elif options.intra_op_num_threads == 3:
|
| 199 |
+
core_mask = RKNNLite.NPU_CORE_0_1_2
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
|
| 202 |
+
|
| 203 |
+
logger.debug("正在初始化运行时环境")
|
| 204 |
+
ret = self.runtime.init_runtime(core_mask=core_mask)
|
| 205 |
+
if ret != 0:
|
| 206 |
+
logger.error("初始化运行时环境失败")
|
| 207 |
+
raise RuntimeError('初始化运行时环境失败')
|
| 208 |
+
|
| 209 |
+
logger.debug("运行时环境初始化成功")
|
| 210 |
+
|
| 211 |
+
# 在 runtime 初始化后,按环境变量自动注册自定义算子插件库
|
| 212 |
+
try:
|
| 213 |
+
# 注册用户指定路径插件(逗号/分号分隔)
|
| 214 |
+
env_custom = os.getenv('ZTU_MODELRT_RKNN2_REG_CUSTOM_OP_LIB', '').strip()
|
| 215 |
+
if env_custom:
|
| 216 |
+
paths = [seg.strip() for seg in re.split(r"[,;:]", env_custom) if seg.strip()]
|
| 217 |
+
ok = 0
|
| 218 |
+
for p in paths:
|
| 219 |
+
if self.register_custom_op_lib(p):
|
| 220 |
+
ok += 1
|
| 221 |
+
if ok > 0:
|
| 222 |
+
logger.info(f"已注册 {ok}/{len(paths)} 个自定义算子插件")
|
| 223 |
+
# 注册系统目录下插件
|
| 224 |
+
if os.getenv('ZTU_MODELRT_RKNN2_REG_SYSTEM_CUSTOM_OP_LIB', '1') == '1':
|
| 225 |
+
cnt = self.register_system_custom_op_lib()
|
| 226 |
+
if cnt > 0:
|
| 227 |
+
logger.info(f"已从系统目录注册 {cnt} 个自定义算子插件")
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.warning(f"自动注册自定义算子插件失败: {e}")
|
| 230 |
+
|
| 231 |
+
# 可选:按环境变量注册内置(基于Python)捆绑算子
|
| 232 |
+
if os.getenv('ZTU_MODELRT_RKNN2_REG_BUNDLED_OPS', '0') == '1':
|
| 233 |
+
logger.info("根据环境变量注册捆绑算子")
|
| 234 |
+
self.register_bundled_ops()
|
| 235 |
+
|
| 236 |
+
self._init_io_info()
|
| 237 |
+
self.options = options
|
| 238 |
+
|
| 239 |
+
def get_performance_info(self) -> Dict[str, float]:
|
| 240 |
+
"""
|
| 241 |
+
获取性能信息
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
dict: 包含性能信息的字典
|
| 245 |
+
"""
|
| 246 |
+
if not self.options.perf_debug:
|
| 247 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
| 248 |
+
|
| 249 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
| 250 |
+
return {
|
| 251 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
def set_core_mask(self, core_mask: int) -> None:
|
| 255 |
+
"""
|
| 256 |
+
设置NPU核心使用模式
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
| 260 |
+
"""
|
| 261 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
| 262 |
+
if ret != 0:
|
| 263 |
+
raise RuntimeError("设置NPU核心��式失败")
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def _process_model_path(model_path, sess_options):
|
| 267 |
+
"""
|
| 268 |
+
处理模型路径,支持.onnx和.rknn文件
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
model_path: 模型文件路径
|
| 272 |
+
"""
|
| 273 |
+
# 如果是ONNX文件,检查是否需要自动加载RKNN
|
| 274 |
+
if model_path.lower().endswith('.onnx'):
|
| 275 |
+
logger.info("检测到ONNX模型文件")
|
| 276 |
+
|
| 277 |
+
# 获取需要跳过自动加载的模型列表
|
| 278 |
+
skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
|
| 279 |
+
if skip_models:
|
| 280 |
+
skip_list = [m.strip() for m in skip_models.split(',')]
|
| 281 |
+
# 获取模型文件名(不含路径)用于匹配
|
| 282 |
+
model_name = os.path.basename(model_path)
|
| 283 |
+
if model_name.lower() in [m.lower() for m in skip_list]:
|
| 284 |
+
logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
|
| 285 |
+
return model_path
|
| 286 |
+
|
| 287 |
+
# 构造RKNN文件路径
|
| 288 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
| 289 |
+
if os.path.exists(rknn_path):
|
| 290 |
+
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
|
| 291 |
+
return rknn_path
|
| 292 |
+
else:
|
| 293 |
+
logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
|
| 294 |
+
return model_path
|
| 295 |
+
|
| 296 |
+
return model_path
|
| 297 |
+
|
| 298 |
+
def _convert_nhwc_to_nchw(self, shape):
|
| 299 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
| 300 |
+
if len(shape) == 4:
|
| 301 |
+
# NHWC -> NCHW
|
| 302 |
+
n, h, w, c = shape
|
| 303 |
+
return [n, c, h, w]
|
| 304 |
+
return shape
|
| 305 |
+
|
| 306 |
+
def _init_io_info(self):
|
| 307 |
+
"""初始化模型的输入输出信息"""
|
| 308 |
+
runtime = self.runtime.rknn_runtime
|
| 309 |
+
|
| 310 |
+
# 获取输入输出数量
|
| 311 |
+
n_input, n_output = runtime.get_in_out_num()
|
| 312 |
+
|
| 313 |
+
# 获取输入信息
|
| 314 |
+
self.input_tensors = []
|
| 315 |
+
for i in range(n_input):
|
| 316 |
+
attr = runtime.get_tensor_attr(i)
|
| 317 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
| 318 |
+
# 对四维输入进行NHWC到NCHW的转换
|
| 319 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
| 320 |
+
# 获取dtype
|
| 321 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 322 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 323 |
+
self.input_tensors.append(tensor)
|
| 324 |
+
|
| 325 |
+
# 获取输出信息
|
| 326 |
+
self.output_tensors = []
|
| 327 |
+
for i in range(n_output):
|
| 328 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
| 329 |
+
shape = runtime.get_output_shape(i)
|
| 330 |
+
# 获取dtype
|
| 331 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 332 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 333 |
+
self.output_tensors.append(tensor)
|
| 334 |
+
|
| 335 |
+
def get_inputs(self):
|
| 336 |
+
"""
|
| 337 |
+
获取模型输入信息
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
list: 包含输入信息的列表
|
| 341 |
+
"""
|
| 342 |
+
return self.input_tensors
|
| 343 |
+
|
| 344 |
+
def get_outputs(self):
|
| 345 |
+
"""
|
| 346 |
+
获取模型输出信息
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
list: 包含输出信息的列表
|
| 350 |
+
"""
|
| 351 |
+
return self.output_tensors
|
| 352 |
+
|
| 353 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
| 354 |
+
"""
|
| 355 |
+
执行模型推理
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
| 359 |
+
input_feed: 输入数据字典或列表
|
| 360 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
| 361 |
+
**kwargs: 其他运行时参数
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
| 365 |
+
"""
|
| 366 |
+
if input_feed is None:
|
| 367 |
+
logger.error("input_feed不能为None")
|
| 368 |
+
raise ValueError("input_feed不能为None")
|
| 369 |
+
|
| 370 |
+
# 准备输入数据
|
| 371 |
+
if isinstance(input_feed, dict):
|
| 372 |
+
# 如果是字典,按照模型输入顺序排列
|
| 373 |
+
inputs = []
|
| 374 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
| 375 |
+
for tensor in self.input_tensors:
|
| 376 |
+
if tensor.name not in input_feed:
|
| 377 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
| 378 |
+
inputs.append(input_feed[tensor.name])
|
| 379 |
+
elif isinstance(input_feed, (list, tuple)):
|
| 380 |
+
# 如果是列表,确保长度匹配
|
| 381 |
+
if len(input_feed) != len(self.input_tensors):
|
| 382 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
| 383 |
+
inputs = list(input_feed)
|
| 384 |
+
else:
|
| 385 |
+
logger.error("input_feed必须是字典或列表类型")
|
| 386 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
| 387 |
+
|
| 388 |
+
# 执行推理
|
| 389 |
+
try:
|
| 390 |
+
logger.debug("开始执行推理")
|
| 391 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
| 392 |
+
|
| 393 |
+
# 如果没有指定output_names,返回所有输出
|
| 394 |
+
if output_names is None:
|
| 395 |
+
return all_outputs
|
| 396 |
+
|
| 397 |
+
# 获取指定的输出
|
| 398 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
| 399 |
+
selected_outputs = []
|
| 400 |
+
for name in output_names:
|
| 401 |
+
if name not in output_map:
|
| 402 |
+
raise ValueError(f"未找到输出节点: {name}")
|
| 403 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
| 404 |
+
|
| 405 |
+
return selected_outputs
|
| 406 |
+
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.error(f"推理执行失败: {str(e)}")
|
| 409 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
| 410 |
+
|
| 411 |
+
def close(self):
|
| 412 |
+
"""
|
| 413 |
+
关闭会话,释放资源
|
| 414 |
+
"""
|
| 415 |
+
if self.runtime is not None:
|
| 416 |
+
logger.info("正在释放运行时资源")
|
| 417 |
+
self.runtime.release()
|
| 418 |
+
self.runtime = None
|
| 419 |
+
|
| 420 |
+
def __enter__(self):
|
| 421 |
+
return self
|
| 422 |
+
|
| 423 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 424 |
+
self.close()
|
| 425 |
+
|
| 426 |
+
def end_profiling(self) -> Optional[str]:
|
| 427 |
+
"""
|
| 428 |
+
结束性能分析的存根方法
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
Optional[str]: None
|
| 432 |
+
"""
|
| 433 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
def get_profiling_start_time_ns(self) -> int:
|
| 437 |
+
"""
|
| 438 |
+
获取性能分析开始时间的存根方法
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
int: 0
|
| 442 |
+
"""
|
| 443 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 444 |
+
return 0
|
| 445 |
+
|
| 446 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
| 447 |
+
"""
|
| 448 |
+
获取模型元数据的存根方法
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
Dict[str, str]: 空字典
|
| 452 |
+
"""
|
| 453 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 454 |
+
return {}
|
| 455 |
+
|
| 456 |
+
def get_session_options(self) -> SessionOptions:
|
| 457 |
+
"""
|
| 458 |
+
获取会话选项
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
SessionOptions: 当前会话选项
|
| 462 |
+
"""
|
| 463 |
+
return self.options
|
| 464 |
+
|
| 465 |
+
def get_providers(self) -> List[str]:
|
| 466 |
+
"""
|
| 467 |
+
获取当前使用的providers的存根方法
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
List[str]: ["CPUExecutionProvider"]
|
| 471 |
+
"""
|
| 472 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
| 473 |
+
return ["CPUExecutionProvider"]
|
| 474 |
+
|
| 475 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
| 476 |
+
"""
|
| 477 |
+
获取provider选项的存根方法
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
Dict[str, Dict[str, str]]: 空字典
|
| 481 |
+
"""
|
| 482 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 483 |
+
return {}
|
| 484 |
+
|
| 485 |
+
def get_session_config(self) -> Dict[str, str]:
|
| 486 |
+
"""
|
| 487 |
+
获取会话配置的存根方法
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
Dict[str, str]: 空字典
|
| 491 |
+
"""
|
| 492 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 493 |
+
return {}
|
| 494 |
+
|
| 495 |
+
def get_session_state(self) -> Dict[str, str]:
|
| 496 |
+
"""
|
| 497 |
+
获取会话状态的存根方法
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
Dict[str, str]: 空字典
|
| 501 |
+
"""
|
| 502 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 503 |
+
return {}
|
| 504 |
+
|
| 505 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
| 506 |
+
"""
|
| 507 |
+
设置会话配置的存根方法
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
config: 会话配置字典
|
| 511 |
+
"""
|
| 512 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 513 |
+
|
| 514 |
+
def get_memory_info(self) -> Dict[str, int]:
|
| 515 |
+
"""
|
| 516 |
+
获取内存使用信息的存根方法
|
| 517 |
+
|
| 518 |
+
Returns:
|
| 519 |
+
Dict[str, int]: 空字典
|
| 520 |
+
"""
|
| 521 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 522 |
+
return {}
|
| 523 |
+
|
| 524 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
| 525 |
+
"""
|
| 526 |
+
设置内存模式的存根方法
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
enable: 是否启用内存模式
|
| 530 |
+
"""
|
| 531 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 532 |
+
|
| 533 |
+
def disable_memory_pattern(self) -> None:
|
| 534 |
+
"""
|
| 535 |
+
禁用内存模式的存根方法
|
| 536 |
+
"""
|
| 537 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 538 |
+
|
| 539 |
+
def get_optimization_level(self) -> int:
|
| 540 |
+
"""
|
| 541 |
+
获取优化级别的存根方法
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
int: 0
|
| 545 |
+
"""
|
| 546 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 547 |
+
return 0
|
| 548 |
+
|
| 549 |
+
def set_optimization_level(self, level: int) -> None:
|
| 550 |
+
"""
|
| 551 |
+
设置优化级别的存根方法
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
level: 优化级别
|
| 555 |
+
"""
|
| 556 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 557 |
+
|
| 558 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
| 559 |
+
"""
|
| 560 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
Dict[str, str]: 空字典
|
| 564 |
+
"""
|
| 565 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 566 |
+
return {}
|
| 567 |
+
|
| 568 |
+
def get_model_path(self) -> str:
|
| 569 |
+
"""
|
| 570 |
+
获取模型路径
|
| 571 |
+
|
| 572 |
+
Returns:
|
| 573 |
+
str: 模型文件路径
|
| 574 |
+
"""
|
| 575 |
+
return self.model_path
|
| 576 |
+
|
| 577 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
| 578 |
+
"""
|
| 579 |
+
获取输入类型信息的存根方法
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
List[Dict[str, str]]: 空列表
|
| 583 |
+
"""
|
| 584 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 585 |
+
return []
|
| 586 |
+
|
| 587 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
| 588 |
+
"""
|
| 589 |
+
获取输出类型信息的存根方法
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
List[Dict[str, str]]: 空列表
|
| 593 |
+
"""
|
| 594 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 595 |
+
return []
|
| 596 |
+
|
| 597 |
+
################### 自定义算子 ###################
|
| 598 |
+
|
| 599 |
+
def _init_custom_op_types(self):
|
| 600 |
+
"""初始化自定义算子的类型定义"""
|
| 601 |
+
# 常量
|
| 602 |
+
self._RKNN_TENSOR_FLOAT32 = 0
|
| 603 |
+
self._RKNN_TENSOR_UINT8 = 3
|
| 604 |
+
self._RKNN_TENSOR_INT64 = 8
|
| 605 |
+
self._RKNN_TARGET_TYPE_CPU = 1
|
| 606 |
+
|
| 607 |
+
# 结构体定义
|
| 608 |
+
class RKNN_TensorAttr(ctypes.Structure):
|
| 609 |
+
_fields_ = [
|
| 610 |
+
("index", ctypes.c_uint32),
|
| 611 |
+
("n_dims", ctypes.c_uint32),
|
| 612 |
+
("dims", ctypes.c_uint32 * RKNN_MAX_DIMS),
|
| 613 |
+
("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
|
| 614 |
+
("n_elems", ctypes.c_uint32),
|
| 615 |
+
("size", ctypes.c_uint32),
|
| 616 |
+
("fmt", ctypes.c_int),
|
| 617 |
+
("type", ctypes.c_int),
|
| 618 |
+
("qnt_type", ctypes.c_int),
|
| 619 |
+
("fl", ctypes.c_int8),
|
| 620 |
+
("zp", ctypes.c_int32),
|
| 621 |
+
("scale", ctypes.c_float),
|
| 622 |
+
("w_stride", ctypes.c_uint32),
|
| 623 |
+
("size_with_stride", ctypes.c_uint32),
|
| 624 |
+
("pass_through", ctypes.c_uint8),
|
| 625 |
+
("h_stride", ctypes.c_uint32),
|
| 626 |
+
]
|
| 627 |
+
|
| 628 |
+
class RKNN_TensorMem(ctypes.Structure):
|
| 629 |
+
_fields_ = [
|
| 630 |
+
("virt_addr", ctypes.c_void_p),
|
| 631 |
+
("phys_addr", ctypes.c_uint64),
|
| 632 |
+
("fd", ctypes.c_int32),
|
| 633 |
+
("offset", ctypes.c_int32),
|
| 634 |
+
("size", ctypes.c_uint32),
|
| 635 |
+
("flags", ctypes.c_uint32),
|
| 636 |
+
("priv_data", ctypes.c_void_p),
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
class RKNN_CustomOpTensor(ctypes.Structure):
|
| 640 |
+
_fields_ = [
|
| 641 |
+
("attr", RKNN_TensorAttr),
|
| 642 |
+
("mem", RKNN_TensorMem),
|
| 643 |
+
]
|
| 644 |
+
|
| 645 |
+
class RKNN_GPUOpContext(ctypes.Structure):
|
| 646 |
+
_fields_ = [
|
| 647 |
+
("cl_context", ctypes.c_void_p),
|
| 648 |
+
("cl_command_queue", ctypes.c_void_p),
|
| 649 |
+
("cl_kernel", ctypes.c_void_p),
|
| 650 |
+
]
|
| 651 |
+
|
| 652 |
+
InternalCtxType = (
|
| 653 |
+
ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
class RKNN_CustomOpContext(ctypes.Structure):
|
| 657 |
+
_fields_ = [
|
| 658 |
+
("target", ctypes.c_int),
|
| 659 |
+
("internal_ctx", InternalCtxType),
|
| 660 |
+
("gpu_ctx", RKNN_GPUOpContext),
|
| 661 |
+
("priv_data", ctypes.c_void_p),
|
| 662 |
+
]
|
| 663 |
+
|
| 664 |
+
class RKNN_CustomOpAttr(ctypes.Structure):
|
| 665 |
+
_fields_ = [
|
| 666 |
+
("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
|
| 667 |
+
("dtype", ctypes.c_int),
|
| 668 |
+
("n_elems", ctypes.c_uint32),
|
| 669 |
+
("data", ctypes.c_void_p),
|
| 670 |
+
]
|
| 671 |
+
|
| 672 |
+
CB_SIG = ctypes.CFUNCTYPE(
|
| 673 |
+
ctypes.c_int,
|
| 674 |
+
ctypes.POINTER(RKNN_CustomOpContext),
|
| 675 |
+
ctypes.POINTER(RKNN_CustomOpTensor),
|
| 676 |
+
ctypes.c_uint32,
|
| 677 |
+
ctypes.POINTER(RKNN_CustomOpTensor),
|
| 678 |
+
ctypes.c_uint32,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
DESTROY_SIG = ctypes.CFUNCTYPE(
|
| 682 |
+
ctypes.c_int, ctypes.POINTER(RKNN_CustomOpContext)
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
class RKNN_CustomOp(ctypes.Structure):
|
| 686 |
+
_fields_ = [
|
| 687 |
+
("version", ctypes.c_uint32),
|
| 688 |
+
("target", ctypes.c_int),
|
| 689 |
+
("op_type", ctypes.c_char * RKNN_MAX_NAME_LEN),
|
| 690 |
+
("cl_kernel_name", ctypes.c_char * RKNN_MAX_NAME_LEN),
|
| 691 |
+
("cl_kernel_source", ctypes.c_char_p),
|
| 692 |
+
("cl_source_size", ctypes.c_uint64),
|
| 693 |
+
("cl_build_options", ctypes.c_char * RKNN_MAX_NAME_LEN),
|
| 694 |
+
("init", CB_SIG),
|
| 695 |
+
("prepare", CB_SIG),
|
| 696 |
+
("compute", CB_SIG),
|
| 697 |
+
("compute_native", CB_SIG),
|
| 698 |
+
("destroy", DESTROY_SIG),
|
| 699 |
+
]
|
| 700 |
+
|
| 701 |
+
# 保存类型定义
|
| 702 |
+
self._RKNN_TensorAttr = RKNN_TensorAttr
|
| 703 |
+
self._RKNN_TensorMem = RKNN_TensorMem
|
| 704 |
+
self._RKNN_CustomOpTensor = RKNN_CustomOpTensor
|
| 705 |
+
self._RKNN_CustomOpContext = RKNN_CustomOpContext
|
| 706 |
+
self._RKNN_CustomOpAttr = RKNN_CustomOpAttr
|
| 707 |
+
self._RKNN_CustomOp = RKNN_CustomOp
|
| 708 |
+
self._CB_SIG = CB_SIG
|
| 709 |
+
self._DESTROY_SIG = DESTROY_SIG
|
| 710 |
+
|
| 711 |
+
def _create_attr_readers(self, get_op_attr):
|
| 712 |
+
"""创建属性读取函数"""
|
| 713 |
+
def read_attr_int64(op_ctx_ptr, key: str, default: int = 0) -> int:
|
| 714 |
+
attr = self._RKNN_CustomOpAttr()
|
| 715 |
+
get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
|
| 716 |
+
if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_INT64 and attr.data:
|
| 717 |
+
return ctypes.c_int64.from_address(attr.data).value
|
| 718 |
+
return default
|
| 719 |
+
|
| 720 |
+
def read_attr_float32(op_ctx_ptr, key: str, default: float = 0) -> float:
|
| 721 |
+
attr = self._RKNN_CustomOpAttr()
|
| 722 |
+
get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
|
| 723 |
+
if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_FLOAT32 and attr.data:
|
| 724 |
+
return ctypes.c_float.from_address(attr.data).value
|
| 725 |
+
return default
|
| 726 |
+
|
| 727 |
+
def read_attr_str(op_ctx_ptr, key: str, default: str = "") -> str:
|
| 728 |
+
attr = self._RKNN_CustomOpAttr()
|
| 729 |
+
get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
|
| 730 |
+
if attr.n_elems > 0 and attr.dtype == self._RKNN_TENSOR_UINT8 and attr.data:
|
| 731 |
+
buf = (ctypes.c_ubyte * attr.n_elems).from_address(attr.data)
|
| 732 |
+
try:
|
| 733 |
+
return bytes(buf).decode("utf-8", errors="ignore").strip('"')
|
| 734 |
+
except Exception:
|
| 735 |
+
return default
|
| 736 |
+
return default
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
return read_attr_int64, read_attr_str, read_attr_float32
|
| 740 |
+
|
| 741 |
+
def _build_py_custom_op(self,
|
| 742 |
+
op_type: str,
|
| 743 |
+
n_inputs: int,
|
| 744 |
+
n_outputs: int,
|
| 745 |
+
on_init,
|
| 746 |
+
on_compute):
|
| 747 |
+
"""通用的Python自定义算子构造器
|
| 748 |
+
|
| 749 |
+
Args:
|
| 750 |
+
op_type: 算子类型名(字符串)
|
| 751 |
+
n_inputs: 输入个数
|
| 752 |
+
n_outputs: 输出个数
|
| 753 |
+
on_init: 回调,签名 on_init(op_ctx_p, read_attr_int64, read_attr_str) -> state
|
| 754 |
+
on_compute: 回调,签名 on_compute(op_ctx_p, inputs_p, outputs_p, state) -> int(0成功)
|
| 755 |
+
Returns:
|
| 756 |
+
(RKNN_CustomOp对象, 回调tuple)
|
| 757 |
+
"""
|
| 758 |
+
@self._CB_SIG
|
| 759 |
+
def _py_init(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
|
| 760 |
+
try:
|
| 761 |
+
# 允许无需提前读取属性
|
| 762 |
+
runtime = self.runtime.rknn_base.rknn_runtime
|
| 763 |
+
read_attr_int64, read_attr_str, read_attr_float32 = self._create_attr_readers(runtime.lib.rknn_custom_op_get_op_attr)
|
| 764 |
+
user_state = on_init(op_ctx_p, read_attr_int64, read_attr_str, read_attr_float32)
|
| 765 |
+
# 为该实例分配唯一ID, 并写入priv_data
|
| 766 |
+
if not hasattr(self, "_custom_op_states"):
|
| 767 |
+
self._custom_op_states = {}
|
| 768 |
+
if not hasattr(self, "_next_custom_op_id"):
|
| 769 |
+
self._next_custom_op_id = 1
|
| 770 |
+
inst_id = int(self._next_custom_op_id)
|
| 771 |
+
self._next_custom_op_id += 1
|
| 772 |
+
# 保存Python侧状态
|
| 773 |
+
self._custom_op_states[inst_id] = user_state
|
| 774 |
+
# 将实例ID写入priv_data
|
| 775 |
+
try:
|
| 776 |
+
op_ctx_p.contents.priv_data = ctypes.c_void_p(inst_id)
|
| 777 |
+
except Exception:
|
| 778 |
+
# 回退: 直接写入整数
|
| 779 |
+
op_ctx_p.contents.priv_data = inst_id
|
| 780 |
+
return 0
|
| 781 |
+
except Exception as e:
|
| 782 |
+
logger.error(f"{op_type} init失败: {e}")
|
| 783 |
+
return -1
|
| 784 |
+
|
| 785 |
+
@self._CB_SIG
|
| 786 |
+
def _py_prepare(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
|
| 787 |
+
return 0
|
| 788 |
+
|
| 789 |
+
@self._CB_SIG
|
| 790 |
+
def _py_compute(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
|
| 791 |
+
try:
|
| 792 |
+
if n_inputs_p != n_inputs or n_outputs_p != n_outputs:
|
| 793 |
+
return -1
|
| 794 |
+
# 通过priv_data取回该实例的状态
|
| 795 |
+
try:
|
| 796 |
+
inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
|
| 797 |
+
except Exception:
|
| 798 |
+
inst_id = 0
|
| 799 |
+
user_state = None
|
| 800 |
+
if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
|
| 801 |
+
user_state = self._custom_op_states.get(inst_id)
|
| 802 |
+
else:
|
| 803 |
+
logger.error(f"{op_type} compute失败: 找不到实例状态, inst_id={inst_id}")
|
| 804 |
+
return -1
|
| 805 |
+
return on_compute(op_ctx_p, inputs_p, outputs_p, user_state)
|
| 806 |
+
except Exception as e:
|
| 807 |
+
logger.error(f"{op_type} compute失败: {e}")
|
| 808 |
+
import traceback
|
| 809 |
+
logger.error(f"{op_type} compute失败: {traceback.format_exc()}")
|
| 810 |
+
return -1
|
| 811 |
+
|
| 812 |
+
@self._DESTROY_SIG
|
| 813 |
+
def _py_destroy(op_ctx_p):
|
| 814 |
+
try:
|
| 815 |
+
# 清理该实例的状态
|
| 816 |
+
try:
|
| 817 |
+
inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
|
| 818 |
+
except Exception:
|
| 819 |
+
inst_id = 0
|
| 820 |
+
if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
|
| 821 |
+
del self._custom_op_states[inst_id]
|
| 822 |
+
# 将priv_data清空
|
| 823 |
+
try:
|
| 824 |
+
op_ctx_p.contents.priv_data = ctypes.c_void_p(0)
|
| 825 |
+
except Exception:
|
| 826 |
+
op_ctx_p.contents.priv_data = 0
|
| 827 |
+
return 0
|
| 828 |
+
except Exception:
|
| 829 |
+
return -1
|
| 830 |
+
|
| 831 |
+
op = self._RKNN_CustomOp()
|
| 832 |
+
op.version = 1
|
| 833 |
+
op.target = self._RKNN_TARGET_TYPE_CPU
|
| 834 |
+
op.op_type = op_type.encode("utf-8")
|
| 835 |
+
op.cl_kernel_name = b""
|
| 836 |
+
op.cl_kernel_source = None
|
| 837 |
+
op.cl_source_size = 0
|
| 838 |
+
op.cl_build_options = b""
|
| 839 |
+
op.init = _py_init
|
| 840 |
+
op.prepare = _py_prepare
|
| 841 |
+
op.compute = _py_compute
|
| 842 |
+
op.compute_native = self._CB_SIG() # NULL
|
| 843 |
+
op.destroy = _py_destroy
|
| 844 |
+
|
| 845 |
+
return op, (_py_init, _py_prepare, _py_compute, _py_destroy)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def _tensor_to_numpy(self, rknn_tensor):
|
| 849 |
+
"""将 RKNN_CustomOpTensor 转换为 Numpy 数组视图"""
|
| 850 |
+
# 确定Numpy数据类型
|
| 851 |
+
# 您可以扩展这个映射
|
| 852 |
+
dtype_map = {
|
| 853 |
+
self._RKNN_TENSOR_FLOAT32: (ctypes.c_float, np.float32),
|
| 854 |
+
self._RKNN_TENSOR_UINT8: (ctypes.c_uint8, np.uint8),
|
| 855 |
+
self._RKNN_TENSOR_INT64: (ctypes.c_int64, np.int64),
|
| 856 |
+
}
|
| 857 |
+
c_type, np_dtype = dtype_map.get(rknn_tensor.attr.type, (None, None))
|
| 858 |
+
if c_type is None:
|
| 859 |
+
raise TypeError(f"不支持的RKNN张量类型: {rknn_tensor.attr.type}")
|
| 860 |
+
|
| 861 |
+
# 获取内存地址和形状
|
| 862 |
+
addr = (rknn_tensor.mem.virt_addr or 0) + int(rknn_tensor.mem.offset)
|
| 863 |
+
ptr = ctypes.cast(addr, ctypes.POINTER(c_type))
|
| 864 |
+
shape = tuple(rknn_tensor.attr.dims[i] for i in range(rknn_tensor.attr.n_dims))
|
| 865 |
+
|
| 866 |
+
# 创建Numpy数组视图
|
| 867 |
+
return np.ctypeslib.as_array(ptr, shape=shape)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def _create_onnxscript_op_creator(self,
|
| 871 |
+
op_type: str,
|
| 872 |
+
# 现在接收一个"函数模板构造器"
|
| 873 |
+
onnxscript_func_builder,
|
| 874 |
+
n_inputs: int,
|
| 875 |
+
n_outputs: int,
|
| 876 |
+
attributes: dict = {},
|
| 877 |
+
constants: dict = {}):
|
| 878 |
+
"""
|
| 879 |
+
一个高阶工厂函数,用于创建基于ONNXScript的自定义算子构造器。
|
| 880 |
+
它在 on_init 阶段动态生成最终的 onnxscript 计算函数。
|
| 881 |
+
|
| 882 |
+
Args:
|
| 883 |
+
op_type (str): 算子类型名。
|
| 884 |
+
onnxscript_func_builder: 一个函数,它接收所有属性和常量作为关键字参数,
|
| 885 |
+
并返回一个编译好的 onnxscript 函数。
|
| 886 |
+
例如: def builder(mean, scale):
|
| 887 |
+
@onnxscript.script()
|
| 888 |
+
def compute(like):
|
| 889 |
+
return opset.RandomNormalLike(like, mean=mean, scale=scale)
|
| 890 |
+
return compute
|
| 891 |
+
attributes (dict): 从模型中读取的属性字典。
|
| 892 |
+
constants (dict): 编译时常量字典。
|
| 893 |
+
n_inputs (int): 输入个数。
|
| 894 |
+
n_outputs (int): 输出个数。
|
| 895 |
+
"""
|
| 896 |
+
|
| 897 |
+
def creator_func():
|
| 898 |
+
def on_init(op_ctx_p, read_i64, read_s, read_f32):
|
| 899 |
+
# 1. 读取所有动态属性
|
| 900 |
+
attr_values = {}
|
| 901 |
+
for name, (attr_type, default) in attributes.items():
|
| 902 |
+
if attr_type == 'int64':
|
| 903 |
+
attr_values[name] = read_i64(op_ctx_p, name, default)
|
| 904 |
+
elif attr_type == 'str':
|
| 905 |
+
attr_values[name] = read_s(op_ctx_p, name, default)
|
| 906 |
+
elif attr_type == 'float32':
|
| 907 |
+
attr_values[name] = read_f32(op_ctx_p, name, default)
|
| 908 |
+
else:
|
| 909 |
+
raise ValueError(f"不支持的属性类型: {attr_type}")
|
| 910 |
+
|
| 911 |
+
# 2. 合并常量和属性
|
| 912 |
+
final_kwargs = {**constants, **attr_values}
|
| 913 |
+
|
| 914 |
+
# 3. 动态构建 onnxscript 函数! <<<<< 核心修改
|
| 915 |
+
# 这确保了所有属性值都作为常量被闭包捕获
|
| 916 |
+
compute_func = onnxscript_func_builder(**final_kwargs)
|
| 917 |
+
|
| 918 |
+
# 4. 将最终生成的、已编译的函数存入 state
|
| 919 |
+
return {"compute_func": compute_func}
|
| 920 |
+
|
| 921 |
+
def on_compute(op_ctx_p, inputs_p, outputs_p, state):
|
| 922 |
+
compute_func = state["compute_func"]
|
| 923 |
+
|
| 924 |
+
input_nps = [self._tensor_to_numpy(inputs_p[i]) for i in range(n_inputs)]
|
| 925 |
+
output_nps = [self._tensor_to_numpy(outputs_p[i]) for i in range(n_outputs)]
|
| 926 |
+
|
| 927 |
+
results = compute_func(*input_nps)
|
| 928 |
+
|
| 929 |
+
if n_outputs == 1:
|
| 930 |
+
result_val = results[0] if isinstance(results, tuple) else results
|
| 931 |
+
output_nps[0][...] = result_val
|
| 932 |
+
else:
|
| 933 |
+
for i in range(n_outputs):
|
| 934 |
+
output_nps[i][...] = results[i]
|
| 935 |
+
|
| 936 |
+
return 0
|
| 937 |
+
|
| 938 |
+
return self._build_py_custom_op(
|
| 939 |
+
op_type=op_type,
|
| 940 |
+
n_inputs=n_inputs,
|
| 941 |
+
n_outputs=n_outputs,
|
| 942 |
+
on_init=on_init,
|
| 943 |
+
on_compute=on_compute
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
return creator_func
|
| 947 |
+
|
| 948 |
+
def _create_gridsample_op(self):
|
| 949 |
+
import onnxscript
|
| 950 |
+
from onnxscript import opset17 as opset
|
| 951 |
+
|
| 952 |
+
def grid_sample_builder(align_corners, mode, padding_mode):
|
| 953 |
+
@onnxscript.script()
|
| 954 |
+
def grid_sample_compute(X, G):
|
| 955 |
+
return opset.GridSample(X, G, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
|
| 956 |
+
return grid_sample_compute
|
| 957 |
+
|
| 958 |
+
grid_sample_creator = self._create_onnxscript_op_creator(
|
| 959 |
+
op_type="GridSample",
|
| 960 |
+
onnxscript_func_builder=grid_sample_builder, # << 传入 builder
|
| 961 |
+
attributes={
|
| 962 |
+
"align_corners": ("int64", 0),
|
| 963 |
+
"mode": ("str", "bilinear"),
|
| 964 |
+
"padding_mode": ("str", "zeros"),
|
| 965 |
+
},
|
| 966 |
+
n_inputs = 2,
|
| 967 |
+
n_outputs = 1
|
| 968 |
+
)
|
| 969 |
+
return grid_sample_creator
|
| 970 |
+
|
| 971 |
+
def _create_scatterelements_op(self):
|
| 972 |
+
import onnxscript
|
| 973 |
+
from onnxscript import opset17 as opset
|
| 974 |
+
|
| 975 |
+
@onnxscript.script()
|
| 976 |
+
def scatter_elements_compute(data, indices, updates):
|
| 977 |
+
indices_i64 = opset.Cast(indices, to=onnxscript.INT64.dtype)
|
| 978 |
+
return opset.ScatterElements(data, indices_i64, updates)
|
| 979 |
+
|
| 980 |
+
scatter_elements_creator = self._create_onnxscript_op_creator(
|
| 981 |
+
op_type="ScatterElements",
|
| 982 |
+
onnxscript_func_builder=lambda: scatter_elements_compute,
|
| 983 |
+
n_inputs = 3,
|
| 984 |
+
n_outputs = 1
|
| 985 |
+
)
|
| 986 |
+
return scatter_elements_creator
|
| 987 |
+
|
| 988 |
+
def _create_randomnormallike_op(self):
|
| 989 |
+
import onnxscript
|
| 990 |
+
from onnxscript import opset17 as opset
|
| 991 |
+
|
| 992 |
+
def random_normal_like_builder(mean, scale):
|
| 993 |
+
@onnxscript.script()
|
| 994 |
+
def random_normal_like_compute(like):
|
| 995 |
+
return opset.RandomNormalLike(like, mean=mean, scale=scale)
|
| 996 |
+
|
| 997 |
+
return random_normal_like_compute
|
| 998 |
+
|
| 999 |
+
# 3. 使用新的工厂函数
|
| 1000 |
+
random_normal_like_creator = self._create_onnxscript_op_creator(
|
| 1001 |
+
op_type="RandomNormalLike",
|
| 1002 |
+
onnxscript_func_builder=random_normal_like_builder, # << 传入 builder
|
| 1003 |
+
attributes={
|
| 1004 |
+
"mean": ("float32", 0.0),
|
| 1005 |
+
"scale": ("float32", 1.0),
|
| 1006 |
+
},
|
| 1007 |
+
n_inputs = 1,
|
| 1008 |
+
n_outputs = 1
|
| 1009 |
+
)
|
| 1010 |
+
return random_normal_like_creator
|
| 1011 |
+
|
| 1012 |
+
def _create_einsum_op(self):
|
| 1013 |
+
import onnxscript
|
| 1014 |
+
from onnxscript import opset17 as opset
|
| 1015 |
+
|
| 1016 |
+
def einsum_builder(equation):
|
| 1017 |
+
|
| 1018 |
+
@onnxscript.script()
|
| 1019 |
+
def einsum_compute(in1, in2):
|
| 1020 |
+
return opset.Einsum(in1, in2, equation=equation)
|
| 1021 |
+
|
| 1022 |
+
return einsum_compute
|
| 1023 |
+
|
| 1024 |
+
# 3. 使用新的工厂函数
|
| 1025 |
+
einsum_creator = self._create_onnxscript_op_creator(
|
| 1026 |
+
op_type="Einsum",
|
| 1027 |
+
onnxscript_func_builder=einsum_builder, # << 传入 builder
|
| 1028 |
+
attributes={
|
| 1029 |
+
"equation": ("str", ""),
|
| 1030 |
+
},
|
| 1031 |
+
n_inputs = 2,
|
| 1032 |
+
n_outputs = 1
|
| 1033 |
+
)
|
| 1034 |
+
return einsum_creator
|
| 1035 |
+
|
| 1036 |
+
def register_bundled_ops(self) -> None:
|
| 1037 |
+
"""注册自定义操作"""
|
| 1038 |
+
if getattr(self, "_custom_ops_registered", False):
|
| 1039 |
+
return
|
| 1040 |
+
|
| 1041 |
+
runtime = self.runtime.rknn_base.rknn_runtime
|
| 1042 |
+
lib = runtime.lib
|
| 1043 |
+
ctx = runtime.context
|
| 1044 |
+
|
| 1045 |
+
try:
|
| 1046 |
+
_ = lib.rknn_register_custom_ops
|
| 1047 |
+
_ = lib.rknn_custom_op_get_op_attr
|
| 1048 |
+
except AttributeError as e:
|
| 1049 |
+
logger.debug(f"SDK不支持自定义算子注册: {e}")
|
| 1050 |
+
return
|
| 1051 |
+
|
| 1052 |
+
self._init_custom_op_types()
|
| 1053 |
+
|
| 1054 |
+
# 注意:插件库注册已在模型加载后由环境变量控制,不在此处重复触发
|
| 1055 |
+
|
| 1056 |
+
# 算子创建函数的列表现在更加清晰
|
| 1057 |
+
op_creator_factories = [
|
| 1058 |
+
self._create_gridsample_op,
|
| 1059 |
+
self._create_scatterelements_op,
|
| 1060 |
+
self._create_randomnormallike_op,
|
| 1061 |
+
self._create_einsum_op,
|
| 1062 |
+
# self._create_my_custom_add_op, # 添加新算子非常简单
|
| 1063 |
+
]
|
| 1064 |
+
|
| 1065 |
+
ops_to_register = []
|
| 1066 |
+
all_callbacks = []
|
| 1067 |
+
|
| 1068 |
+
for factory in op_creator_factories:
|
| 1069 |
+
try:
|
| 1070 |
+
# 调用工厂获得真正的构造器
|
| 1071 |
+
creator_func = factory()
|
| 1072 |
+
# 调用构造器生成算子实例
|
| 1073 |
+
op, callbacks = creator_func()
|
| 1074 |
+
ops_to_register.append(op)
|
| 1075 |
+
all_callbacks.extend(callbacks)
|
| 1076 |
+
logger.debug(f"成功创建自定义算子: {op.op_type.decode()}")
|
| 1077 |
+
except Exception as e:
|
| 1078 |
+
logger.warning(f"创建自定义算子失败: {e}", exc_info=True)
|
| 1079 |
+
|
| 1080 |
+
if not ops_to_register:
|
| 1081 |
+
logger.debug("没有可注册的自定义算子")
|
| 1082 |
+
return
|
| 1083 |
+
|
| 1084 |
+
# 创建一个ctypes数组以包含所有要注册的算子, 然后一次性注册
|
| 1085 |
+
num_ops = len(ops_to_register)
|
| 1086 |
+
op_array = (self._RKNN_CustomOp * num_ops)(*ops_to_register)
|
| 1087 |
+
ret = lib.rknn_register_custom_ops(ctx, op_array, num_ops)
|
| 1088 |
+
if ret != 0:
|
| 1089 |
+
logger.error(f"注册自定义算子失败, ret={ret} (可能是误报, 继续执行...)")
|
| 1090 |
+
# raise RuntimeError(f"rknn_register_custom_ops 失败, ret={ret}")
|
| 1091 |
+
|
| 1092 |
+
logger.info(f"成功注册 {len(ops_to_register)} 个自定义算子")
|
| 1093 |
+
|
| 1094 |
+
self._custom_ops_registered = True
|
| 1095 |
+
self._registered_ops = ops_to_register
|
| 1096 |
+
self._op_callbacks = all_callbacks
|
| 1097 |
+
|
| 1098 |
+
def _load_and_register_plugin_op(self, so_path: str) -> bool:
|
| 1099 |
+
"""加载单个插件库并注册其中的自定义算子。
|
| 1100 |
+
|
| 1101 |
+
要求插件实现 get_rknn_custom_op(),返回 rknn_custom_op*。
|
| 1102 |
+
我们将该 C 指针直接传递给 rknn_register_custom_ops,避免复制。
|
| 1103 |
+
"""
|
| 1104 |
+
if not os.path.isfile(so_path):
|
| 1105 |
+
logger.warning(f"插件库不存在: {so_path}")
|
| 1106 |
+
return False
|
| 1107 |
+
|
| 1108 |
+
runtime = self.runtime.rknn_base.rknn_runtime
|
| 1109 |
+
lib = runtime.lib
|
| 1110 |
+
ctx = runtime.context
|
| 1111 |
+
|
| 1112 |
+
# 根据平台位宽设置 rknn_context 的 ctypes 类型
|
| 1113 |
+
ContextCType = ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
|
| 1114 |
+
# 设置 rknn_register_custom_ops(ctx, op_ptr, num) 签名。第二参数按 void* 传递,避免结构体布局不一致
|
| 1115 |
+
try:
|
| 1116 |
+
lib.rknn_register_custom_ops.argtypes = [ContextCType, ctypes.c_void_p, ctypes.c_uint32]
|
| 1117 |
+
lib.rknn_register_custom_ops.restype = ctypes.c_int
|
| 1118 |
+
except Exception:
|
| 1119 |
+
pass
|
| 1120 |
+
|
| 1121 |
+
# 加载插件
|
| 1122 |
+
try:
|
| 1123 |
+
handle = ctypes.CDLL(so_path)
|
| 1124 |
+
except Exception as e:
|
| 1125 |
+
logger.error(f"dlopen 失败: {so_path}, err={e}")
|
| 1126 |
+
return False
|
| 1127 |
+
|
| 1128 |
+
# 获取 get_rknn_custom_op 符号
|
| 1129 |
+
try:
|
| 1130 |
+
get_sym = getattr(handle, "get_rknn_custom_op")
|
| 1131 |
+
except AttributeError:
|
| 1132 |
+
logger.error(f"插件缺少符号 get_rknn_custom_op: {so_path}")
|
| 1133 |
+
return False
|
| 1134 |
+
|
| 1135 |
+
# 返回类型直接使用 void*,避免 Python 解析第三方结构体
|
| 1136 |
+
try:
|
| 1137 |
+
get_sym.argtypes = []
|
| 1138 |
+
except Exception:
|
| 1139 |
+
pass
|
| 1140 |
+
get_sym.restype = ctypes.c_void_p
|
| 1141 |
+
|
| 1142 |
+
op_void_ptr = get_sym()
|
| 1143 |
+
if not op_void_ptr:
|
| 1144 |
+
logger.error(f"get_rknn_custom_op 返回空指针: {so_path}")
|
| 1145 |
+
return False
|
| 1146 |
+
|
| 1147 |
+
# 直接使用原生指针注册(零拷贝)
|
| 1148 |
+
ctx_val = ContextCType(runtime.context)
|
| 1149 |
+
ret = lib.rknn_register_custom_ops(ctx_val, ctypes.c_void_p(op_void_ptr), 1)
|
| 1150 |
+
if ret != 0:
|
| 1151 |
+
logger.error(f"rknn_register_custom_ops 失败, ret={ret}, so={so_path} (可能是误报, 继续执行...)")
|
| 1152 |
+
# return False
|
| 1153 |
+
|
| 1154 |
+
# 保留句柄,避免被垃圾回收卸载
|
| 1155 |
+
if not hasattr(self, "_plugin_handles"):
|
| 1156 |
+
self._plugin_handles = []
|
| 1157 |
+
self._plugin_handles.append(handle)
|
| 1158 |
+
logger.info(f"成功注册插件自定义算子: {so_path}")
|
| 1159 |
+
return True
|
| 1160 |
+
|
| 1161 |
+
def register_plugin_ops(self, plugin_paths: List[str]) -> int:
|
| 1162 |
+
"""按给定路径列表注册插件库中的自定义算子。返回成功数量。"""
|
| 1163 |
+
if not plugin_paths:
|
| 1164 |
+
return 0
|
| 1165 |
+
success = 0
|
| 1166 |
+
for path in plugin_paths:
|
| 1167 |
+
try:
|
| 1168 |
+
if self._load_and_register_plugin_op(path):
|
| 1169 |
+
success += 1
|
| 1170 |
+
except Exception as e:
|
| 1171 |
+
logger.error(f"注册插件失败: {path}, err={e}")
|
| 1172 |
+
return success
|
| 1173 |
+
|
| 1174 |
+
# 对外API:注册单个自定义算子插件库
|
| 1175 |
+
def register_custom_op_lib(self, path: str) -> bool:
|
| 1176 |
+
return self._load_and_register_plugin_op(path)
|
| 1177 |
+
|
| 1178 |
+
# 对外API:扫描并注册 Linux 系统目录下所有插件库(Android 不处理)
|
| 1179 |
+
def register_system_custom_op_lib(self) -> int:
|
| 1180 |
+
if os.name != 'posix':
|
| 1181 |
+
return 0
|
| 1182 |
+
# 仅 Linux:RKNN 官方默认目录
|
| 1183 |
+
system_dir = "/usr/lib/rknpu/op_plugins/"
|
| 1184 |
+
if not os.path.isdir(system_dir):
|
| 1185 |
+
return 0
|
| 1186 |
+
try:
|
| 1187 |
+
entries = os.listdir(system_dir)
|
| 1188 |
+
except Exception:
|
| 1189 |
+
return 0
|
| 1190 |
+
so_list = []
|
| 1191 |
+
for name in entries:
|
| 1192 |
+
# 官方要求文件名以 librkcst_ 开头
|
| 1193 |
+
if name.startswith("librkcst_") and name.endswith('.so'):
|
| 1194 |
+
so_list.append(os.path.join(system_dir, name))
|
| 1195 |
+
return self.register_plugin_ops(so_list)
|