Upload 7 files
Browse files- .gitattributes +3 -0
- language_model.rkllm +3 -0
- language_model_w8a8.rkllm +3 -0
- librkllmrt.so +3 -0
- rkllm-convert-w8a8.py +25 -0
- rkllm-convert.py +23 -0
- rkllm_binding.py +867 -0
- run.py +106 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
language_model.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
language_model.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e458a6dee8ea66c8e166596027adb4e1b1cf30b5e150747f7a56630df1139c5
|
| 3 |
+
size 893228148
|
language_model_w8a8.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28df7af77cd515f63f07faad96be587618979b2a5f46541ae64f5fdbb080499e
|
| 3 |
+
size 627635884
|
librkllmrt.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7e6f87f07bbb08058cad4871cc74e8069a054fe4f6259b43c29a4738b0affdd
|
| 3 |
+
size 7461896
|
rkllm-convert-w8a8.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install -U transformers==4.50.0 -i https://mirrors.aliyun.com/pypi/simple
|
| 2 |
+
|
| 3 |
+
from rkllm.api import RKLLM
|
| 4 |
+
|
| 5 |
+
modelpath = '.'
|
| 6 |
+
llm = RKLLM()
|
| 7 |
+
|
| 8 |
+
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
|
| 9 |
+
if ret != 0:
|
| 10 |
+
print('Load model failed!')
|
| 11 |
+
exit(ret)
|
| 12 |
+
|
| 13 |
+
qparams = None
|
| 14 |
+
ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
|
| 15 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
| 16 |
+
|
| 17 |
+
if ret != 0:
|
| 18 |
+
print('Build model failed!')
|
| 19 |
+
exit(ret)
|
| 20 |
+
|
| 21 |
+
# Export rkllm model
|
| 22 |
+
ret = llm.export_rkllm("./language_model_w8a8.rkllm")
|
| 23 |
+
if ret != 0:
|
| 24 |
+
print('Export model failed!')
|
| 25 |
+
exit(ret)
|
rkllm-convert.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rkllm.api import RKLLM
|
| 2 |
+
|
| 3 |
+
modelpath = '.'
|
| 4 |
+
llm = RKLLM()
|
| 5 |
+
|
| 6 |
+
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
|
| 7 |
+
if ret != 0:
|
| 8 |
+
print('Load model failed!')
|
| 9 |
+
exit(ret)
|
| 10 |
+
|
| 11 |
+
qparams = None
|
| 12 |
+
ret = llm.build(do_quantization=False, optimization_level=1, quantized_dtype='w8a8',
|
| 13 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
| 14 |
+
|
| 15 |
+
if ret != 0:
|
| 16 |
+
print('Build model failed!')
|
| 17 |
+
exit(ret)
|
| 18 |
+
|
| 19 |
+
# Export rkllm model
|
| 20 |
+
ret = llm.export_rkllm("./language_model.rkllm")
|
| 21 |
+
if ret != 0:
|
| 22 |
+
print('Export model failed!')
|
| 23 |
+
exit(ret)
|
rkllm_binding.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import enum
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define constants from the header
|
| 6 |
+
CPU0 = (1 << 0) # 0x01
|
| 7 |
+
CPU1 = (1 << 1) # 0x02
|
| 8 |
+
CPU2 = (1 << 2) # 0x04
|
| 9 |
+
CPU3 = (1 << 3) # 0x08
|
| 10 |
+
CPU4 = (1 << 4) # 0x10
|
| 11 |
+
CPU5 = (1 << 5) # 0x20
|
| 12 |
+
CPU6 = (1 << 6) # 0x40
|
| 13 |
+
CPU7 = (1 << 7) # 0x80
|
| 14 |
+
|
| 15 |
+
# --- Enums ---
|
| 16 |
+
class LLMCallState(enum.IntEnum):
|
| 17 |
+
RKLLM_RUN_NORMAL = 0
|
| 18 |
+
RKLLM_RUN_WAITING = 1
|
| 19 |
+
RKLLM_RUN_FINISH = 2
|
| 20 |
+
RKLLM_RUN_ERROR = 3
|
| 21 |
+
|
| 22 |
+
class RKLLMInputType(enum.IntEnum):
|
| 23 |
+
RKLLM_INPUT_PROMPT = 0
|
| 24 |
+
RKLLM_INPUT_TOKEN = 1
|
| 25 |
+
RKLLM_INPUT_EMBED = 2
|
| 26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
| 27 |
+
|
| 28 |
+
class RKLLMInferMode(enum.IntEnum):
|
| 29 |
+
RKLLM_INFER_GENERATE = 0
|
| 30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
| 31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
| 32 |
+
|
| 33 |
+
# --- Structures ---
|
| 34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
| 35 |
+
base_domain_id: ctypes.c_int32
|
| 36 |
+
embed_flash: ctypes.c_int8
|
| 37 |
+
enabled_cpus_num: ctypes.c_int8
|
| 38 |
+
enabled_cpus_mask: ctypes.c_uint32
|
| 39 |
+
n_batch: ctypes.c_uint8
|
| 40 |
+
use_cross_attn: ctypes.c_int8
|
| 41 |
+
reserved: ctypes.c_uint8 * 104
|
| 42 |
+
|
| 43 |
+
_fields_ = [
|
| 44 |
+
("base_domain_id", ctypes.c_int32), # 基础域ID
|
| 45 |
+
("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
|
| 46 |
+
("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
|
| 47 |
+
("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
|
| 48 |
+
("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
|
| 49 |
+
("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
|
| 50 |
+
("reserved", ctypes.c_uint8 * 104) # 保留字段
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
class RKLLMParam(ctypes.Structure):
|
| 54 |
+
model_path: ctypes.c_char_p
|
| 55 |
+
max_context_len: ctypes.c_int32
|
| 56 |
+
max_new_tokens: ctypes.c_int32
|
| 57 |
+
top_k: ctypes.c_int32
|
| 58 |
+
n_keep: ctypes.c_int32
|
| 59 |
+
top_p: ctypes.c_float
|
| 60 |
+
temperature: ctypes.c_float
|
| 61 |
+
repeat_penalty: ctypes.c_float
|
| 62 |
+
frequency_penalty: ctypes.c_float
|
| 63 |
+
presence_penalty: ctypes.c_float
|
| 64 |
+
mirostat: ctypes.c_int32
|
| 65 |
+
mirostat_tau: ctypes.c_float
|
| 66 |
+
mirostat_eta: ctypes.c_float
|
| 67 |
+
skip_special_token: ctypes.c_bool
|
| 68 |
+
is_async: ctypes.c_bool
|
| 69 |
+
img_start: ctypes.c_char_p
|
| 70 |
+
img_end: ctypes.c_char_p
|
| 71 |
+
img_content: ctypes.c_char_p
|
| 72 |
+
extend_param: RKLLMExtendParam
|
| 73 |
+
|
| 74 |
+
_fields_ = [
|
| 75 |
+
("model_path", ctypes.c_char_p), # 模型文件路径
|
| 76 |
+
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
|
| 77 |
+
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
|
| 78 |
+
("top_k", ctypes.c_int32), # Top-K采样参数
|
| 79 |
+
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
|
| 80 |
+
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
|
| 81 |
+
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
|
| 82 |
+
("repeat_penalty", ctypes.c_float), # 重复token惩罚
|
| 83 |
+
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
|
| 84 |
+
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
|
| 85 |
+
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
|
| 86 |
+
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
|
| 87 |
+
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
|
| 88 |
+
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
|
| 89 |
+
("is_async", ctypes.c_bool), # 是否异步推理
|
| 90 |
+
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
|
| 91 |
+
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
|
| 92 |
+
("img_content", ctypes.c_char_p), # 图像内容指针
|
| 93 |
+
("extend_param", RKLLMExtendParam) # 扩展参数
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
| 97 |
+
lora_adapter_path: ctypes.c_char_p
|
| 98 |
+
lora_adapter_name: ctypes.c_char_p
|
| 99 |
+
scale: ctypes.c_float
|
| 100 |
+
|
| 101 |
+
_fields_ = [
|
| 102 |
+
("lora_adapter_path", ctypes.c_char_p),
|
| 103 |
+
("lora_adapter_name", ctypes.c_char_p),
|
| 104 |
+
("scale", ctypes.c_float)
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
| 108 |
+
embed: ctypes.POINTER(ctypes.c_float)
|
| 109 |
+
n_tokens: ctypes.c_size_t
|
| 110 |
+
|
| 111 |
+
_fields_ = [
|
| 112 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
| 113 |
+
("n_tokens", ctypes.c_size_t)
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
class RKLLMTokenInput(ctypes.Structure):
|
| 117 |
+
input_ids: ctypes.POINTER(ctypes.c_int32)
|
| 118 |
+
n_tokens: ctypes.c_size_t
|
| 119 |
+
|
| 120 |
+
_fields_ = [
|
| 121 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
| 122 |
+
("n_tokens", ctypes.c_size_t)
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
| 126 |
+
prompt: ctypes.c_char_p
|
| 127 |
+
image_embed: ctypes.POINTER(ctypes.c_float)
|
| 128 |
+
n_image_tokens: ctypes.c_size_t
|
| 129 |
+
n_image: ctypes.c_size_t
|
| 130 |
+
image_width: ctypes.c_size_t
|
| 131 |
+
image_height: ctypes.c_size_t
|
| 132 |
+
|
| 133 |
+
_fields_ = [
|
| 134 |
+
("prompt", ctypes.c_char_p),
|
| 135 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
| 136 |
+
("n_image_tokens", ctypes.c_size_t),
|
| 137 |
+
("n_image", ctypes.c_size_t),
|
| 138 |
+
("image_width", ctypes.c_size_t),
|
| 139 |
+
("image_height", ctypes.c_size_t)
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
class RKLLMCrossAttnParam(ctypes.Structure):
|
| 143 |
+
"""
|
| 144 |
+
交叉注意力参数结构体
|
| 145 |
+
|
| 146 |
+
该结构体用于在解码器中执行交叉注意力时使用。
|
| 147 |
+
它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
|
| 148 |
+
|
| 149 |
+
- encoder_k_cache必须存储在连续内存中,布局为:
|
| 150 |
+
[num_layers][num_tokens][num_kv_heads][head_dim]
|
| 151 |
+
- encoder_v_cache必须存储在连续内存中,布局为:
|
| 152 |
+
[num_layers][num_kv_heads][head_dim][num_tokens]
|
| 153 |
+
"""
|
| 154 |
+
encoder_k_cache: ctypes.POINTER(ctypes.c_float)
|
| 155 |
+
encoder_v_cache: ctypes.POINTER(ctypes.c_float)
|
| 156 |
+
encoder_mask: ctypes.POINTER(ctypes.c_float)
|
| 157 |
+
encoder_pos: ctypes.POINTER(ctypes.c_int32)
|
| 158 |
+
num_tokens: ctypes.c_int
|
| 159 |
+
|
| 160 |
+
_fields_ = [
|
| 161 |
+
("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
|
| 162 |
+
("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
|
| 163 |
+
("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
|
| 164 |
+
("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
|
| 165 |
+
("num_tokens", ctypes.c_int) # 编码器序列中的token数量
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
class RKLLMPerfStat(ctypes.Structure):
|
| 169 |
+
"""
|
| 170 |
+
性能统计结构体
|
| 171 |
+
|
| 172 |
+
用于保存预填充和生成阶段的性能统计信息。
|
| 173 |
+
"""
|
| 174 |
+
prefill_time_ms: ctypes.c_float
|
| 175 |
+
prefill_tokens: ctypes.c_int
|
| 176 |
+
generate_time_ms: ctypes.c_float
|
| 177 |
+
generate_tokens: ctypes.c_int
|
| 178 |
+
memory_usage_mb: ctypes.c_float
|
| 179 |
+
|
| 180 |
+
_fields_ = [
|
| 181 |
+
("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
|
| 182 |
+
("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
|
| 183 |
+
("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
|
| 184 |
+
("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
|
| 185 |
+
("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
class _RKLLMInputUnion(ctypes.Union):
|
| 189 |
+
prompt_input: ctypes.c_char_p
|
| 190 |
+
embed_input: RKLLMEmbedInput
|
| 191 |
+
token_input: RKLLMTokenInput
|
| 192 |
+
multimodal_input: RKLLMMultiModelInput
|
| 193 |
+
|
| 194 |
+
_fields_ = [
|
| 195 |
+
("prompt_input", ctypes.c_char_p),
|
| 196 |
+
("embed_input", RKLLMEmbedInput),
|
| 197 |
+
("token_input", RKLLMTokenInput),
|
| 198 |
+
("multimodal_input", RKLLMMultiModelInput)
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
class RKLLMInput(ctypes.Structure):
|
| 202 |
+
"""
|
| 203 |
+
LLM输入结构体
|
| 204 |
+
|
| 205 |
+
通过联合体表示不同类型的LLM输入。
|
| 206 |
+
"""
|
| 207 |
+
role: ctypes.c_char_p
|
| 208 |
+
enable_thinking: ctypes.c_bool
|
| 209 |
+
input_type: ctypes.c_int
|
| 210 |
+
_union_data: _RKLLMInputUnion
|
| 211 |
+
|
| 212 |
+
_fields_ = [
|
| 213 |
+
("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
|
| 214 |
+
("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
|
| 215 |
+
("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
|
| 216 |
+
("_union_data", _RKLLMInputUnion) # 联合体数据
|
| 217 |
+
]
|
| 218 |
+
# Properties to make accessing union members easier
|
| 219 |
+
@property
|
| 220 |
+
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
|
| 221 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 222 |
+
return self._union_data.prompt_input
|
| 223 |
+
raise AttributeError("Not a prompt input")
|
| 224 |
+
@prompt_input.setter
|
| 225 |
+
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
|
| 226 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 227 |
+
self._union_data.prompt_input = value
|
| 228 |
+
else:
|
| 229 |
+
raise AttributeError("Not a prompt input")
|
| 230 |
+
@property
|
| 231 |
+
def embed_input(self) -> RKLLMEmbedInput:
|
| 232 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 233 |
+
return self._union_data.embed_input
|
| 234 |
+
raise AttributeError("Not an embed input")
|
| 235 |
+
@embed_input.setter
|
| 236 |
+
def embed_input(self, value: RKLLMEmbedInput):
|
| 237 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 238 |
+
self._union_data.embed_input = value
|
| 239 |
+
else:
|
| 240 |
+
raise AttributeError("Not an embed input")
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def token_input(self) -> RKLLMTokenInput:
|
| 244 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 245 |
+
return self._union_data.token_input
|
| 246 |
+
raise AttributeError("Not a token input")
|
| 247 |
+
@token_input.setter
|
| 248 |
+
def token_input(self, value: RKLLMTokenInput):
|
| 249 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 250 |
+
self._union_data.token_input = value
|
| 251 |
+
else:
|
| 252 |
+
raise AttributeError("Not a token input")
|
| 253 |
+
|
| 254 |
+
@property
|
| 255 |
+
def multimodal_input(self) -> RKLLMMultiModelInput:
|
| 256 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 257 |
+
return self._union_data.multimodal_input
|
| 258 |
+
raise AttributeError("Not a multimodal input")
|
| 259 |
+
@multimodal_input.setter
|
| 260 |
+
def multimodal_input(self, value: RKLLMMultiModelInput):
|
| 261 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 262 |
+
self._union_data.multimodal_input = value
|
| 263 |
+
else:
|
| 264 |
+
raise AttributeError("Not a multimodal input")
|
| 265 |
+
|
| 266 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
| 267 |
+
lora_adapter_name: ctypes.c_char_p
|
| 268 |
+
|
| 269 |
+
_fields_ = [
|
| 270 |
+
("lora_adapter_name", ctypes.c_char_p)
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
| 274 |
+
save_prompt_cache: ctypes.c_int # bool-like
|
| 275 |
+
prompt_cache_path: ctypes.c_char_p
|
| 276 |
+
|
| 277 |
+
_fields_ = [
|
| 278 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
| 279 |
+
("prompt_cache_path", ctypes.c_char_p)
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
class RKLLMInferParam(ctypes.Structure):
|
| 283 |
+
mode: ctypes.c_int
|
| 284 |
+
lora_params: ctypes.POINTER(RKLLMLoraParam)
|
| 285 |
+
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
|
| 286 |
+
keep_history: ctypes.c_int # bool-like
|
| 287 |
+
|
| 288 |
+
_fields_ = [
|
| 289 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
| 290 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
| 291 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
| 292 |
+
("keep_history", ctypes.c_int) # bool-like
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
| 296 |
+
hidden_states: ctypes.POINTER(ctypes.c_float)
|
| 297 |
+
embd_size: ctypes.c_int
|
| 298 |
+
num_tokens: ctypes.c_int
|
| 299 |
+
|
| 300 |
+
_fields_ = [
|
| 301 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
| 302 |
+
("embd_size", ctypes.c_int),
|
| 303 |
+
("num_tokens", ctypes.c_int)
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
class RKLLMResultLogits(ctypes.Structure):
|
| 307 |
+
logits: ctypes.POINTER(ctypes.c_float)
|
| 308 |
+
vocab_size: ctypes.c_int
|
| 309 |
+
num_tokens: ctypes.c_int
|
| 310 |
+
|
| 311 |
+
_fields_ = [
|
| 312 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
| 313 |
+
("vocab_size", ctypes.c_int),
|
| 314 |
+
("num_tokens", ctypes.c_int)
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
class RKLLMResult(ctypes.Structure):
|
| 318 |
+
"""
|
| 319 |
+
LLM推理结果结构体
|
| 320 |
+
|
| 321 |
+
表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
|
| 322 |
+
"""
|
| 323 |
+
text: ctypes.c_char_p
|
| 324 |
+
token_id: ctypes.c_int32
|
| 325 |
+
last_hidden_layer: RKLLMResultLastHiddenLayer
|
| 326 |
+
logits: RKLLMResultLogits
|
| 327 |
+
perf: RKLLMPerfStat
|
| 328 |
+
|
| 329 |
+
_fields_ = [
|
| 330 |
+
("text", ctypes.c_char_p), # 生成的文本结果
|
| 331 |
+
("token_id", ctypes.c_int32), # 生成的token ID
|
| 332 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
|
| 333 |
+
("logits", RKLLMResultLogits), # 模型输出的logits
|
| 334 |
+
("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
# --- Typedefs ---
|
| 338 |
+
LLMHandle = ctypes.c_void_p
|
| 339 |
+
|
| 340 |
+
# --- Callback Function Type ---
|
| 341 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
| 342 |
+
ctypes.c_int, # 返回类型:int,表示处理状态
|
| 343 |
+
ctypes.POINTER(RKLLMResult), # LLM结果指针
|
| 344 |
+
ctypes.c_void_p, # 用户数据指针
|
| 345 |
+
ctypes.c_int # LLM调用状态(LLMCallState枚举值)
|
| 346 |
+
)
|
| 347 |
+
"""
|
| 348 |
+
回调函数类型定义
|
| 349 |
+
|
| 350 |
+
用于处理LLM结果的回调函数。
|
| 351 |
+
|
| 352 |
+
参数:
|
| 353 |
+
- result: 指向LLM结果的指针
|
| 354 |
+
- userdata: 回调的用户数据指针
|
| 355 |
+
- state: LLM调用状态(例如:完成、错误)
|
| 356 |
+
|
| 357 |
+
返回值:
|
| 358 |
+
- 0: 正常继续推理
|
| 359 |
+
- 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
|
| 360 |
+
返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
class RKLLMRuntime:
|
| 364 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
| 365 |
+
try:
|
| 366 |
+
self.lib = ctypes.CDLL(library_path)
|
| 367 |
+
except OSError as e:
|
| 368 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
| 369 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
| 370 |
+
self._setup_functions()
|
| 371 |
+
self.llm_handle = LLMHandle()
|
| 372 |
+
self._c_callback = None # To keep the callback object alive
|
| 373 |
+
|
| 374 |
+
def _setup_functions(self):
|
| 375 |
+
# RKLLMParam rkllm_createDefaultParam();
|
| 376 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
| 377 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
| 378 |
+
|
| 379 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
| 380 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
| 381 |
+
self.lib.rkllm_init.argtypes = [
|
| 382 |
+
ctypes.POINTER(LLMHandle),
|
| 383 |
+
ctypes.POINTER(RKLLMParam),
|
| 384 |
+
LLMResultCallback
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
| 388 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
| 389 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
| 390 |
+
|
| 391 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
| 392 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
| 393 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
| 394 |
+
|
| 395 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
| 396 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
| 397 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
| 398 |
+
|
| 399 |
+
# int rkllm_destroy(LLMHandle handle);
|
| 400 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
| 401 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
| 402 |
+
|
| 403 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 404 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
| 405 |
+
self.lib.rkllm_run.argtypes = [
|
| 406 |
+
LLMHandle,
|
| 407 |
+
ctypes.POINTER(RKLLMInput),
|
| 408 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 409 |
+
ctypes.c_void_p # userdata
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 413 |
+
# Assuming async also takes userdata for the callback context
|
| 414 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
| 415 |
+
self.lib.rkllm_run_async.argtypes = [
|
| 416 |
+
LLMHandle,
|
| 417 |
+
ctypes.POINTER(RKLLMInput),
|
| 418 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 419 |
+
ctypes.c_void_p # userdata
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
# int rkllm_abort(LLMHandle handle);
|
| 423 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
| 424 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
| 425 |
+
|
| 426 |
+
# int rkllm_is_running(LLMHandle handle);
|
| 427 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
| 428 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
| 429 |
+
|
| 430 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
|
| 431 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
| 432 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [
|
| 433 |
+
LLMHandle,
|
| 434 |
+
ctypes.c_int,
|
| 435 |
+
ctypes.POINTER(ctypes.c_int), # start_pos
|
| 436 |
+
ctypes.POINTER(ctypes.c_int) # end_pos
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
# int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
|
| 440 |
+
self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
|
| 441 |
+
self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
|
| 442 |
+
|
| 443 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
| 444 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
| 445 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
| 446 |
+
LLMHandle,
|
| 447 |
+
ctypes.c_char_p,
|
| 448 |
+
ctypes.c_char_p,
|
| 449 |
+
ctypes.c_char_p
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
# int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
|
| 453 |
+
self.lib.rkllm_set_function_tools.restype = ctypes.c_int
|
| 454 |
+
self.lib.rkllm_set_function_tools.argtypes = [
|
| 455 |
+
LLMHandle,
|
| 456 |
+
ctypes.c_char_p, # system_prompt
|
| 457 |
+
ctypes.c_char_p, # tools
|
| 458 |
+
ctypes.c_char_p # tool_response_str
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
# int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
|
| 462 |
+
self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
|
| 463 |
+
self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
|
| 464 |
+
|
| 465 |
+
def create_default_param(self) -> RKLLMParam:
|
| 466 |
+
"""Creates a default RKLLMParam structure."""
|
| 467 |
+
return self.lib.rkllm_createDefaultParam()
|
| 468 |
+
|
| 469 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
| 470 |
+
"""
|
| 471 |
+
Initializes the LLM.
|
| 472 |
+
:param param: RKLLMParam structure.
|
| 473 |
+
:param callback_func: A Python function that matches the signature:
|
| 474 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
| 475 |
+
result = result_ptr.contents # RKLLMResult
|
| 476 |
+
# Process result
|
| 477 |
+
# userdata can be retrieved if passed during run, or ignored
|
| 478 |
+
# state = LLMCallState(state_enum)
|
| 479 |
+
:return: 0 for success, non-zero for failure.
|
| 480 |
+
"""
|
| 481 |
+
if not callable(callback_func):
|
| 482 |
+
raise ValueError("callback_func must be a callable Python function.")
|
| 483 |
+
|
| 484 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
| 485 |
+
self._c_callback = LLMResultCallback(callback_func)
|
| 486 |
+
|
| 487 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
| 488 |
+
if ret != 0:
|
| 489 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
| 490 |
+
return ret
|
| 491 |
+
|
| 492 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
| 493 |
+
"""Loads a Lora adapter."""
|
| 494 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
| 495 |
+
if ret != 0:
|
| 496 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
| 497 |
+
return ret
|
| 498 |
+
|
| 499 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
| 500 |
+
"""Loads a prompt cache from a file."""
|
| 501 |
+
c_path = prompt_cache_path.encode('utf-8')
|
| 502 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
| 503 |
+
if ret != 0:
|
| 504 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
| 505 |
+
return ret
|
| 506 |
+
|
| 507 |
+
def release_prompt_cache(self) -> int:
|
| 508 |
+
"""Releases the prompt cache from memory."""
|
| 509 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
| 510 |
+
if ret != 0:
|
| 511 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
| 512 |
+
return ret
|
| 513 |
+
|
| 514 |
+
def destroy(self) -> int:
|
| 515 |
+
"""Destroys the LLM instance and releases resources."""
|
| 516 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
| 517 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
| 518 |
+
self.llm_handle = LLMHandle() # Reset handle
|
| 519 |
+
if ret != 0:
|
| 520 |
+
# Don't raise here as it might be called in __del__
|
| 521 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
| 522 |
+
return ret
|
| 523 |
+
return 0 # Already destroyed or not initialized
|
| 524 |
+
|
| 525 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 526 |
+
"""Runs an LLM inference task synchronously."""
|
| 527 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
| 528 |
+
# then cast to c_void_p. Or simply None.
|
| 529 |
+
if userdata is not None:
|
| 530 |
+
# Store the userdata object to keep it alive during the call
|
| 531 |
+
self._userdata_ref = userdata
|
| 532 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 533 |
+
else:
|
| 534 |
+
c_userdata = None
|
| 535 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 536 |
+
if ret != 0:
|
| 537 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
| 538 |
+
return ret
|
| 539 |
+
|
| 540 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 541 |
+
"""Runs an LLM inference task asynchronously."""
|
| 542 |
+
if userdata is not None:
|
| 543 |
+
# Store the userdata object to keep it alive during the call
|
| 544 |
+
self._userdata_ref = userdata
|
| 545 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 546 |
+
else:
|
| 547 |
+
c_userdata = None
|
| 548 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 549 |
+
if ret != 0:
|
| 550 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
| 551 |
+
return ret
|
| 552 |
+
|
| 553 |
+
def abort(self) -> int:
|
| 554 |
+
"""Aborts an ongoing LLM task."""
|
| 555 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
| 556 |
+
if ret != 0:
|
| 557 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
| 558 |
+
return ret
|
| 559 |
+
|
| 560 |
+
def is_running(self) -> bool:
|
| 561 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
| 562 |
+
# The C API returns 0 if running, non-zero otherwise.
|
| 563 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
| 564 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
| 565 |
+
|
| 566 |
+
def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
|
| 567 |
+
"""
|
| 568 |
+
清除键值缓存
|
| 569 |
+
|
| 570 |
+
此函数用于清除部分或全部KV缓存。
|
| 571 |
+
|
| 572 |
+
参数:
|
| 573 |
+
- keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
|
| 574 |
+
如果提供了特定范围[start_pos, end_pos),此标志将被忽略
|
| 575 |
+
- start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
|
| 576 |
+
- end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
|
| 577 |
+
如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
|
| 578 |
+
如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
|
| 579 |
+
|
| 580 |
+
注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
|
| 581 |
+
|
| 582 |
+
返回:0表示缓存清除成功,非零表示失败
|
| 583 |
+
"""
|
| 584 |
+
# 准备C数组参数
|
| 585 |
+
c_start_pos = None
|
| 586 |
+
c_end_pos = None
|
| 587 |
+
|
| 588 |
+
if start_pos is not None and end_pos is not None:
|
| 589 |
+
if len(start_pos) != len(end_pos):
|
| 590 |
+
raise ValueError("start_pos和end_pos数组长度必须相同")
|
| 591 |
+
|
| 592 |
+
# 创建C数组
|
| 593 |
+
c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
|
| 594 |
+
c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
|
| 595 |
+
|
| 596 |
+
ret = self.lib.rkllm_clear_kv_cache(
|
| 597 |
+
self.llm_handle,
|
| 598 |
+
ctypes.c_int(1 if keep_system_prompt else 0),
|
| 599 |
+
c_start_pos,
|
| 600 |
+
c_end_pos
|
| 601 |
+
)
|
| 602 |
+
if ret != 0:
|
| 603 |
+
raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
|
| 604 |
+
return ret
|
| 605 |
+
|
| 606 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
| 607 |
+
"""Sets the chat template for the LLM."""
|
| 608 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 609 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
|
| 610 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
|
| 611 |
+
|
| 612 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
| 613 |
+
if ret != 0:
|
| 614 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
| 615 |
+
return ret
|
| 616 |
+
|
| 617 |
+
def get_kv_cache_size(self, n_batch: int) -> list:
|
| 618 |
+
"""
|
| 619 |
+
获取给定LLM句柄的键值缓存当前大小
|
| 620 |
+
|
| 621 |
+
此函数返回当前存储在模型KV缓存中的位置总数。
|
| 622 |
+
|
| 623 |
+
参数:
|
| 624 |
+
- n_batch: 批次数量,用于确定返回数组的大小
|
| 625 |
+
|
| 626 |
+
返回:
|
| 627 |
+
- list: 每个批次的缓存大小列表
|
| 628 |
+
"""
|
| 629 |
+
# 预分配数组以存储每个批次的缓存大小
|
| 630 |
+
cache_sizes = (ctypes.c_int * n_batch)()
|
| 631 |
+
|
| 632 |
+
ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
|
| 633 |
+
if ret != 0:
|
| 634 |
+
raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
|
| 635 |
+
|
| 636 |
+
# 转换为Python列表
|
| 637 |
+
return [cache_sizes[i] for i in range(n_batch)]
|
| 638 |
+
|
| 639 |
+
def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
|
| 640 |
+
"""
|
| 641 |
+
为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
|
| 642 |
+
|
| 643 |
+
参数:
|
| 644 |
+
- system_prompt: 定义语言模型上下文或行为的系统提示
|
| 645 |
+
- tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
|
| 646 |
+
- tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
|
| 647 |
+
允许分词器将工具输出与正常对话轮次分开识别
|
| 648 |
+
|
| 649 |
+
返回:0表示配置设置成功,非零表示错误
|
| 650 |
+
"""
|
| 651 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 652 |
+
c_tools = tools.encode('utf-8') if tools else b""
|
| 653 |
+
c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
|
| 654 |
+
|
| 655 |
+
ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
|
| 656 |
+
if ret != 0:
|
| 657 |
+
raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
|
| 658 |
+
return ret
|
| 659 |
+
|
| 660 |
+
def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
|
| 661 |
+
"""
|
| 662 |
+
为LLM解码器设置交叉注意力参数
|
| 663 |
+
|
| 664 |
+
参数:
|
| 665 |
+
- cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
|
| 666 |
+
(详见RKLLMCrossAttnParam说明)
|
| 667 |
+
|
| 668 |
+
返回:0表示参数设置成功,非零表示错误
|
| 669 |
+
"""
|
| 670 |
+
ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
|
| 671 |
+
if ret != 0:
|
| 672 |
+
raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
|
| 673 |
+
return ret
|
| 674 |
+
|
| 675 |
+
def __enter__(self):
|
| 676 |
+
return self
|
| 677 |
+
|
| 678 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 679 |
+
self.destroy()
|
| 680 |
+
|
| 681 |
+
def __del__(self):
|
| 682 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
| 683 |
+
|
| 684 |
+
# --- Example Usage (Illustrative) ---
|
| 685 |
+
if __name__ == "__main__":
|
| 686 |
+
# This is a placeholder for how you might use it.
|
| 687 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
| 688 |
+
|
| 689 |
+
# Global list to store results from callback for demonstration
|
| 690 |
+
results_buffer = []
|
| 691 |
+
|
| 692 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 693 |
+
"""
|
| 694 |
+
回调函数,由C库调用来处理LLM结果
|
| 695 |
+
|
| 696 |
+
参数:
|
| 697 |
+
- result_ptr: 指向LLM结果的指针
|
| 698 |
+
- userdata_ptr: 用户数据指针
|
| 699 |
+
- state_enum: LLM调用状态枚举值
|
| 700 |
+
|
| 701 |
+
返回:
|
| 702 |
+
- 0: 继续推理
|
| 703 |
+
- 1: 暂停推理
|
| 704 |
+
"""
|
| 705 |
+
global results_buffer
|
| 706 |
+
state = LLMCallState(state_enum)
|
| 707 |
+
result = result_ptr.contents
|
| 708 |
+
|
| 709 |
+
current_text = ""
|
| 710 |
+
if result.text: # 检查char_p是否不为NULL
|
| 711 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 712 |
+
|
| 713 |
+
print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
| 714 |
+
|
| 715 |
+
# 显示性能统计信息
|
| 716 |
+
if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0:
|
| 717 |
+
print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, "
|
| 718 |
+
f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, "
|
| 719 |
+
f"内存={result.perf.memory_usage_mb:.1f}MB")
|
| 720 |
+
|
| 721 |
+
results_buffer.append(current_text)
|
| 722 |
+
|
| 723 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 724 |
+
print("推理完成。")
|
| 725 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 726 |
+
print("推理错误。")
|
| 727 |
+
|
| 728 |
+
# 返回0继续推理,返回1暂停推理
|
| 729 |
+
return 0
|
| 730 |
+
|
| 731 |
+
# --- Attempt to use the wrapper ---
|
| 732 |
+
try:
|
| 733 |
+
print("Initializing RKLLMRuntime...")
|
| 734 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 735 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 736 |
+
rk_llm = RKLLMRuntime()
|
| 737 |
+
|
| 738 |
+
print("Creating default parameters...")
|
| 739 |
+
params = rk_llm.create_default_param()
|
| 740 |
+
|
| 741 |
+
# --- Configure parameters ---
|
| 742 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
| 743 |
+
# For this example to run, you need a model file.
|
| 744 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
| 745 |
+
model_file = "language_model.rkllm"
|
| 746 |
+
if not os.path.exists(model_file):
|
| 747 |
+
raise FileNotFoundError(f"Model file '{model_file}' does not exist.")
|
| 748 |
+
|
| 749 |
+
params.model_path = model_file.encode('utf-8')
|
| 750 |
+
params.max_context_len = 512
|
| 751 |
+
params.max_new_tokens = 128
|
| 752 |
+
# params.top_k = 1 # Greedy
|
| 753 |
+
params.temperature = 0.7
|
| 754 |
+
params.repeat_penalty = 1.1
|
| 755 |
+
# ... set other params as needed
|
| 756 |
+
|
| 757 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 758 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 759 |
+
try:
|
| 760 |
+
rk_llm.init(params, my_python_callback)
|
| 761 |
+
print("LLM Initialized.")
|
| 762 |
+
except RuntimeError as e:
|
| 763 |
+
print(f"Error during LLM initialization: {e}")
|
| 764 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
| 765 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
| 766 |
+
exit()
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
# --- Prepare input ---
|
| 770 |
+
print("准备输入...")
|
| 771 |
+
rk_input = RKLLMInput()
|
| 772 |
+
rk_input.role = b"user" # 设置角色为用户输入
|
| 773 |
+
rk_input.enable_thinking = False # 禁用思考模式(适用于Qwen3模型)
|
| 774 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 775 |
+
|
| 776 |
+
prompt_text = "将以下英文文本翻译成中文:'Hello, world!'"
|
| 777 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 778 |
+
rk_input._union_data.prompt_input = c_prompt # 直接访问联合体成员
|
| 779 |
+
|
| 780 |
+
# --- Prepare inference parameters ---
|
| 781 |
+
print("Preparing inference parameters...")
|
| 782 |
+
infer_params = RKLLMInferParam()
|
| 783 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 784 |
+
infer_params.keep_history = 1 # True
|
| 785 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
| 786 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
| 787 |
+
|
| 788 |
+
# --- Run inference ---
|
| 789 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
| 790 |
+
results_buffer.clear()
|
| 791 |
+
try:
|
| 792 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
| 793 |
+
print("\n--- Full Response ---")
|
| 794 |
+
print("".join(results_buffer))
|
| 795 |
+
print("---------------------\n")
|
| 796 |
+
except RuntimeError as e:
|
| 797 |
+
print(f"Error during LLM run: {e}")
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
# --- Example: Set chat template (if model supports it) ---
|
| 801 |
+
# print("Setting chat template...")
|
| 802 |
+
# try:
|
| 803 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
| 804 |
+
# print("Chat template set.")
|
| 805 |
+
# except RuntimeError as e:
|
| 806 |
+
# print(f"Error setting chat template: {e}")
|
| 807 |
+
|
| 808 |
+
# --- Example: Clear KV Cache ---
|
| 809 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
| 810 |
+
# try:
|
| 811 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 812 |
+
# print("KV cache cleared.")
|
| 813 |
+
# except RuntimeError as e:
|
| 814 |
+
# print(f"Error clearing KV cache: {e}")
|
| 815 |
+
|
| 816 |
+
# --- 示例:获取KV缓存大小 ---
|
| 817 |
+
# print("获取KV缓存大小...")
|
| 818 |
+
# try:
|
| 819 |
+
# cache_sizes = rk_llm.get_kv_cache_size(n_batch=1) # 假设批次大小为1
|
| 820 |
+
# print(f"当前KV缓存大小: {cache_sizes}")
|
| 821 |
+
# except RuntimeError as e:
|
| 822 |
+
# print(f"获取KV缓存大小错误: {e}")
|
| 823 |
+
|
| 824 |
+
# --- 示例:设置函数工具 ---
|
| 825 |
+
# print("设置函数调用工具...")
|
| 826 |
+
# try:
|
| 827 |
+
# system_prompt = "你是一个有用的助手,可以调用提供的函数来帮助用户。"
|
| 828 |
+
# tools = '''[{
|
| 829 |
+
# "name": "get_weather",
|
| 830 |
+
# "description": "获取指定城市的天气信息",
|
| 831 |
+
# "parameters": {
|
| 832 |
+
# "type": "object",
|
| 833 |
+
# "properties": {
|
| 834 |
+
# "city": {"type": "string", "description": "城市名称"}
|
| 835 |
+
# },
|
| 836 |
+
# "required": ["city"]
|
| 837 |
+
# }
|
| 838 |
+
# }]'''
|
| 839 |
+
# tool_response_str = "<tool_response>"
|
| 840 |
+
# rk_llm.set_function_tools(system_prompt, tools, tool_response_str)
|
| 841 |
+
# print("函数工具设置成功。")
|
| 842 |
+
# except RuntimeError as e:
|
| 843 |
+
# print(f"设置函数工具错误: {e}")
|
| 844 |
+
|
| 845 |
+
# --- 示例:清除KV缓存(带范围参数) ---
|
| 846 |
+
# print("使用范围参数清除KV缓存...")
|
| 847 |
+
# try:
|
| 848 |
+
# # 清除位置10到20的缓存
|
| 849 |
+
# start_positions = [10] # 批次0的起始位置
|
| 850 |
+
# end_positions = [20] # 批次0的结束位置
|
| 851 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True, start_pos=start_positions, end_pos=end_positions)
|
| 852 |
+
# print("范围KV缓存清除完成。")
|
| 853 |
+
# except RuntimeError as e:
|
| 854 |
+
# print(f"清除范围KV缓存错误: {e}")
|
| 855 |
+
|
| 856 |
+
except OSError as e:
|
| 857 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 858 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 859 |
+
except Exception as e:
|
| 860 |
+
print(f"An unexpected error occurred: {e}")
|
| 861 |
+
finally:
|
| 862 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 863 |
+
print("Destroying LLM instance...")
|
| 864 |
+
rk_llm.destroy()
|
| 865 |
+
print("LLM instance destroyed.")
|
| 866 |
+
|
| 867 |
+
print("Example finished.")
|
run.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
| 3 |
+
from rkllm_binding import *
|
| 4 |
+
|
| 5 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 6 |
+
"""
|
| 7 |
+
回调函数,用于处理LLM的输出结果。
|
| 8 |
+
这个函数会以流式的方式逐字打印模型的响应。
|
| 9 |
+
"""
|
| 10 |
+
state = LLMCallState(state_enum)
|
| 11 |
+
result = result_ptr.contents
|
| 12 |
+
|
| 13 |
+
if result.text:
|
| 14 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 15 |
+
print(current_text, end='', flush=True)
|
| 16 |
+
|
| 17 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 18 |
+
# 在响应结束后打印一个换行符,保持格式整洁
|
| 19 |
+
print()
|
| 20 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 21 |
+
print("\n推理过程中发生错误。")
|
| 22 |
+
|
| 23 |
+
# 返回0继续推理,返回1暂停推理
|
| 24 |
+
return 0
|
| 25 |
+
|
| 26 |
+
# --- Attempt to use the wrapper ---
|
| 27 |
+
try:
|
| 28 |
+
print("Initializing RKLLMRuntime...")
|
| 29 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 30 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 31 |
+
rk_llm = RKLLMRuntime()
|
| 32 |
+
|
| 33 |
+
print("Creating default parameters...")
|
| 34 |
+
params = rk_llm.create_default_param()
|
| 35 |
+
|
| 36 |
+
# --- Configure parameters ---
|
| 37 |
+
model_file = "language_model.rkllm"
|
| 38 |
+
if not os.path.exists(model_file):
|
| 39 |
+
raise FileNotFoundError(f"Model file '{model_file}' does not exist.")
|
| 40 |
+
|
| 41 |
+
params.model_path = model_file.encode('utf-8')
|
| 42 |
+
params.max_context_len = 4096
|
| 43 |
+
params.max_new_tokens = 1024
|
| 44 |
+
# params.top_k = 1 # Greedy
|
| 45 |
+
params.temperature = 0.7
|
| 46 |
+
params.repeat_penalty = 1.1
|
| 47 |
+
# ... set other params as needed
|
| 48 |
+
|
| 49 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 50 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 51 |
+
try:
|
| 52 |
+
rk_llm.init(params, my_python_callback)
|
| 53 |
+
print("LLM Initialized.")
|
| 54 |
+
except RuntimeError as e:
|
| 55 |
+
print(f"Error during LLM initialization: {e}")
|
| 56 |
+
exit()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# --- 进入交互式对话循环 ---
|
| 60 |
+
print("\n进入多轮对话模式。输入 'exit' 或 'quit' 退出。")
|
| 61 |
+
|
| 62 |
+
# 准备推理参数 (这些参数在对话中保持不变)
|
| 63 |
+
infer_params = RKLLMInferParam()
|
| 64 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 65 |
+
infer_params.keep_history = 1 # 保持对话历史
|
| 66 |
+
|
| 67 |
+
while True:
|
| 68 |
+
try:
|
| 69 |
+
prompt_text = input("You: ")
|
| 70 |
+
if prompt_text.lower() in ["exit", "quit"]:
|
| 71 |
+
break
|
| 72 |
+
|
| 73 |
+
print("Assistant: ", end='', flush=True)
|
| 74 |
+
|
| 75 |
+
# 准备输入
|
| 76 |
+
rk_input = RKLLMInput()
|
| 77 |
+
rk_input.role = b"user"
|
| 78 |
+
rk_input.enable_thinking = False
|
| 79 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 80 |
+
|
| 81 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 82 |
+
rk_input._union_data.prompt_input = c_prompt
|
| 83 |
+
|
| 84 |
+
# 运行推理
|
| 85 |
+
rk_llm.run(rk_input, infer_params)
|
| 86 |
+
|
| 87 |
+
except KeyboardInterrupt:
|
| 88 |
+
print("\n\n对话中断。")
|
| 89 |
+
break
|
| 90 |
+
except RuntimeError as e:
|
| 91 |
+
print(f"\n运行时发生错误: {e}")
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
except OSError as e:
|
| 96 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 97 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"An unexpected error occurred: {e}")
|
| 100 |
+
finally:
|
| 101 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 102 |
+
print("Destroying LLM instance...")
|
| 103 |
+
rk_llm.destroy()
|
| 104 |
+
print("LLM instance destroyed.")
|
| 105 |
+
|
| 106 |
+
print("Example finished.")
|