happyme531 commited on
Commit
b2c3325
·
verified ·
1 Parent(s): 042e332

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
37
+ language_model.rkllm filter=lfs diff=lfs merge=lfs -text
38
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
language_model.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e458a6dee8ea66c8e166596027adb4e1b1cf30b5e150747f7a56630df1139c5
3
+ size 893228148
language_model_w8a8.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28df7af77cd515f63f07faad96be587618979b2a5f46541ae64f5fdbb080499e
3
+ size 627635884
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7e6f87f07bbb08058cad4871cc74e8069a054fe4f6259b43c29a4738b0affdd
3
+ size 7461896
rkllm-convert-w8a8.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -U transformers==4.50.0 -i https://mirrors.aliyun.com/pypi/simple
2
+
3
+ from rkllm.api import RKLLM
4
+
5
+ modelpath = '.'
6
+ llm = RKLLM()
7
+
8
+ ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
9
+ if ret != 0:
10
+ print('Load model failed!')
11
+ exit(ret)
12
+
13
+ qparams = None
14
+ ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
15
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
16
+
17
+ if ret != 0:
18
+ print('Build model failed!')
19
+ exit(ret)
20
+
21
+ # Export rkllm model
22
+ ret = llm.export_rkllm("./language_model_w8a8.rkllm")
23
+ if ret != 0:
24
+ print('Export model failed!')
25
+ exit(ret)
rkllm-convert.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rkllm.api import RKLLM
2
+
3
+ modelpath = '.'
4
+ llm = RKLLM()
5
+
6
+ ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
7
+ if ret != 0:
8
+ print('Load model failed!')
9
+ exit(ret)
10
+
11
+ qparams = None
12
+ ret = llm.build(do_quantization=False, optimization_level=1, quantized_dtype='w8a8',
13
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
14
+
15
+ if ret != 0:
16
+ print('Build model failed!')
17
+ exit(ret)
18
+
19
+ # Export rkllm model
20
+ ret = llm.export_rkllm("./language_model.rkllm")
21
+ if ret != 0:
22
+ print('Export model failed!')
23
+ exit(ret)
rkllm_binding.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import enum
3
+ import os
4
+
5
+ # Define constants from the header
6
+ CPU0 = (1 << 0) # 0x01
7
+ CPU1 = (1 << 1) # 0x02
8
+ CPU2 = (1 << 2) # 0x04
9
+ CPU3 = (1 << 3) # 0x08
10
+ CPU4 = (1 << 4) # 0x10
11
+ CPU5 = (1 << 5) # 0x20
12
+ CPU6 = (1 << 6) # 0x40
13
+ CPU7 = (1 << 7) # 0x80
14
+
15
+ # --- Enums ---
16
+ class LLMCallState(enum.IntEnum):
17
+ RKLLM_RUN_NORMAL = 0
18
+ RKLLM_RUN_WAITING = 1
19
+ RKLLM_RUN_FINISH = 2
20
+ RKLLM_RUN_ERROR = 3
21
+
22
+ class RKLLMInputType(enum.IntEnum):
23
+ RKLLM_INPUT_PROMPT = 0
24
+ RKLLM_INPUT_TOKEN = 1
25
+ RKLLM_INPUT_EMBED = 2
26
+ RKLLM_INPUT_MULTIMODAL = 3
27
+
28
+ class RKLLMInferMode(enum.IntEnum):
29
+ RKLLM_INFER_GENERATE = 0
30
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
31
+ RKLLM_INFER_GET_LOGITS = 2
32
+
33
+ # --- Structures ---
34
+ class RKLLMExtendParam(ctypes.Structure):
35
+ base_domain_id: ctypes.c_int32
36
+ embed_flash: ctypes.c_int8
37
+ enabled_cpus_num: ctypes.c_int8
38
+ enabled_cpus_mask: ctypes.c_uint32
39
+ n_batch: ctypes.c_uint8
40
+ use_cross_attn: ctypes.c_int8
41
+ reserved: ctypes.c_uint8 * 104
42
+
43
+ _fields_ = [
44
+ ("base_domain_id", ctypes.c_int32), # 基础域ID
45
+ ("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
46
+ ("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
47
+ ("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
48
+ ("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
49
+ ("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
50
+ ("reserved", ctypes.c_uint8 * 104) # 保留字段
51
+ ]
52
+
53
+ class RKLLMParam(ctypes.Structure):
54
+ model_path: ctypes.c_char_p
55
+ max_context_len: ctypes.c_int32
56
+ max_new_tokens: ctypes.c_int32
57
+ top_k: ctypes.c_int32
58
+ n_keep: ctypes.c_int32
59
+ top_p: ctypes.c_float
60
+ temperature: ctypes.c_float
61
+ repeat_penalty: ctypes.c_float
62
+ frequency_penalty: ctypes.c_float
63
+ presence_penalty: ctypes.c_float
64
+ mirostat: ctypes.c_int32
65
+ mirostat_tau: ctypes.c_float
66
+ mirostat_eta: ctypes.c_float
67
+ skip_special_token: ctypes.c_bool
68
+ is_async: ctypes.c_bool
69
+ img_start: ctypes.c_char_p
70
+ img_end: ctypes.c_char_p
71
+ img_content: ctypes.c_char_p
72
+ extend_param: RKLLMExtendParam
73
+
74
+ _fields_ = [
75
+ ("model_path", ctypes.c_char_p), # 模型文件路径
76
+ ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
77
+ ("max_new_tokens", ctypes.c_int32), # 最大生成新token数
78
+ ("top_k", ctypes.c_int32), # Top-K采样参数
79
+ ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
80
+ ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
81
+ ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
82
+ ("repeat_penalty", ctypes.c_float), # 重复token惩罚
83
+ ("frequency_penalty", ctypes.c_float), # 频繁token惩罚
84
+ ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
85
+ ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
86
+ ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
87
+ ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
88
+ ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
89
+ ("is_async", ctypes.c_bool), # 是否异步推理
90
+ ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
91
+ ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
92
+ ("img_content", ctypes.c_char_p), # 图像内容指针
93
+ ("extend_param", RKLLMExtendParam) # 扩展参数
94
+ ]
95
+
96
+ class RKLLMLoraAdapter(ctypes.Structure):
97
+ lora_adapter_path: ctypes.c_char_p
98
+ lora_adapter_name: ctypes.c_char_p
99
+ scale: ctypes.c_float
100
+
101
+ _fields_ = [
102
+ ("lora_adapter_path", ctypes.c_char_p),
103
+ ("lora_adapter_name", ctypes.c_char_p),
104
+ ("scale", ctypes.c_float)
105
+ ]
106
+
107
+ class RKLLMEmbedInput(ctypes.Structure):
108
+ embed: ctypes.POINTER(ctypes.c_float)
109
+ n_tokens: ctypes.c_size_t
110
+
111
+ _fields_ = [
112
+ ("embed", ctypes.POINTER(ctypes.c_float)),
113
+ ("n_tokens", ctypes.c_size_t)
114
+ ]
115
+
116
+ class RKLLMTokenInput(ctypes.Structure):
117
+ input_ids: ctypes.POINTER(ctypes.c_int32)
118
+ n_tokens: ctypes.c_size_t
119
+
120
+ _fields_ = [
121
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
122
+ ("n_tokens", ctypes.c_size_t)
123
+ ]
124
+
125
+ class RKLLMMultiModelInput(ctypes.Structure):
126
+ prompt: ctypes.c_char_p
127
+ image_embed: ctypes.POINTER(ctypes.c_float)
128
+ n_image_tokens: ctypes.c_size_t
129
+ n_image: ctypes.c_size_t
130
+ image_width: ctypes.c_size_t
131
+ image_height: ctypes.c_size_t
132
+
133
+ _fields_ = [
134
+ ("prompt", ctypes.c_char_p),
135
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
136
+ ("n_image_tokens", ctypes.c_size_t),
137
+ ("n_image", ctypes.c_size_t),
138
+ ("image_width", ctypes.c_size_t),
139
+ ("image_height", ctypes.c_size_t)
140
+ ]
141
+
142
+ class RKLLMCrossAttnParam(ctypes.Structure):
143
+ """
144
+ 交叉注意力参数结构体
145
+
146
+ 该结构体用于在解码器中执行交叉注意力时使用。
147
+ 它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
148
+
149
+ - encoder_k_cache必须存储在连续内存中,布局为:
150
+ [num_layers][num_tokens][num_kv_heads][head_dim]
151
+ - encoder_v_cache必须存储在连续内存中,布局为:
152
+ [num_layers][num_kv_heads][head_dim][num_tokens]
153
+ """
154
+ encoder_k_cache: ctypes.POINTER(ctypes.c_float)
155
+ encoder_v_cache: ctypes.POINTER(ctypes.c_float)
156
+ encoder_mask: ctypes.POINTER(ctypes.c_float)
157
+ encoder_pos: ctypes.POINTER(ctypes.c_int32)
158
+ num_tokens: ctypes.c_int
159
+
160
+ _fields_ = [
161
+ ("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
162
+ ("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
163
+ ("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
164
+ ("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
165
+ ("num_tokens", ctypes.c_int) # 编码器序列中的token数量
166
+ ]
167
+
168
+ class RKLLMPerfStat(ctypes.Structure):
169
+ """
170
+ 性能统计结构体
171
+
172
+ 用于保存预填充和生成阶段的性能统计信息。
173
+ """
174
+ prefill_time_ms: ctypes.c_float
175
+ prefill_tokens: ctypes.c_int
176
+ generate_time_ms: ctypes.c_float
177
+ generate_tokens: ctypes.c_int
178
+ memory_usage_mb: ctypes.c_float
179
+
180
+ _fields_ = [
181
+ ("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
182
+ ("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
183
+ ("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
184
+ ("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
185
+ ("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
186
+ ]
187
+
188
+ class _RKLLMInputUnion(ctypes.Union):
189
+ prompt_input: ctypes.c_char_p
190
+ embed_input: RKLLMEmbedInput
191
+ token_input: RKLLMTokenInput
192
+ multimodal_input: RKLLMMultiModelInput
193
+
194
+ _fields_ = [
195
+ ("prompt_input", ctypes.c_char_p),
196
+ ("embed_input", RKLLMEmbedInput),
197
+ ("token_input", RKLLMTokenInput),
198
+ ("multimodal_input", RKLLMMultiModelInput)
199
+ ]
200
+
201
+ class RKLLMInput(ctypes.Structure):
202
+ """
203
+ LLM输入结构体
204
+
205
+ 通过联合体表示不同类型的LLM输入。
206
+ """
207
+ role: ctypes.c_char_p
208
+ enable_thinking: ctypes.c_bool
209
+ input_type: ctypes.c_int
210
+ _union_data: _RKLLMInputUnion
211
+
212
+ _fields_ = [
213
+ ("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
214
+ ("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
215
+ ("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
216
+ ("_union_data", _RKLLMInputUnion) # 联合体数据
217
+ ]
218
+ # Properties to make accessing union members easier
219
+ @property
220
+ def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
221
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
222
+ return self._union_data.prompt_input
223
+ raise AttributeError("Not a prompt input")
224
+ @prompt_input.setter
225
+ def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
226
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
227
+ self._union_data.prompt_input = value
228
+ else:
229
+ raise AttributeError("Not a prompt input")
230
+ @property
231
+ def embed_input(self) -> RKLLMEmbedInput:
232
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
233
+ return self._union_data.embed_input
234
+ raise AttributeError("Not an embed input")
235
+ @embed_input.setter
236
+ def embed_input(self, value: RKLLMEmbedInput):
237
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
238
+ self._union_data.embed_input = value
239
+ else:
240
+ raise AttributeError("Not an embed input")
241
+
242
+ @property
243
+ def token_input(self) -> RKLLMTokenInput:
244
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
245
+ return self._union_data.token_input
246
+ raise AttributeError("Not a token input")
247
+ @token_input.setter
248
+ def token_input(self, value: RKLLMTokenInput):
249
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
250
+ self._union_data.token_input = value
251
+ else:
252
+ raise AttributeError("Not a token input")
253
+
254
+ @property
255
+ def multimodal_input(self) -> RKLLMMultiModelInput:
256
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
257
+ return self._union_data.multimodal_input
258
+ raise AttributeError("Not a multimodal input")
259
+ @multimodal_input.setter
260
+ def multimodal_input(self, value: RKLLMMultiModelInput):
261
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
262
+ self._union_data.multimodal_input = value
263
+ else:
264
+ raise AttributeError("Not a multimodal input")
265
+
266
+ class RKLLMLoraParam(ctypes.Structure): # For inference
267
+ lora_adapter_name: ctypes.c_char_p
268
+
269
+ _fields_ = [
270
+ ("lora_adapter_name", ctypes.c_char_p)
271
+ ]
272
+
273
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
274
+ save_prompt_cache: ctypes.c_int # bool-like
275
+ prompt_cache_path: ctypes.c_char_p
276
+
277
+ _fields_ = [
278
+ ("save_prompt_cache", ctypes.c_int), # bool-like
279
+ ("prompt_cache_path", ctypes.c_char_p)
280
+ ]
281
+
282
+ class RKLLMInferParam(ctypes.Structure):
283
+ mode: ctypes.c_int
284
+ lora_params: ctypes.POINTER(RKLLMLoraParam)
285
+ prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
286
+ keep_history: ctypes.c_int # bool-like
287
+
288
+ _fields_ = [
289
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
290
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
291
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
292
+ ("keep_history", ctypes.c_int) # bool-like
293
+ ]
294
+
295
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
296
+ hidden_states: ctypes.POINTER(ctypes.c_float)
297
+ embd_size: ctypes.c_int
298
+ num_tokens: ctypes.c_int
299
+
300
+ _fields_ = [
301
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
302
+ ("embd_size", ctypes.c_int),
303
+ ("num_tokens", ctypes.c_int)
304
+ ]
305
+
306
+ class RKLLMResultLogits(ctypes.Structure):
307
+ logits: ctypes.POINTER(ctypes.c_float)
308
+ vocab_size: ctypes.c_int
309
+ num_tokens: ctypes.c_int
310
+
311
+ _fields_ = [
312
+ ("logits", ctypes.POINTER(ctypes.c_float)),
313
+ ("vocab_size", ctypes.c_int),
314
+ ("num_tokens", ctypes.c_int)
315
+ ]
316
+
317
+ class RKLLMResult(ctypes.Structure):
318
+ """
319
+ LLM推理结果结构体
320
+
321
+ 表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
322
+ """
323
+ text: ctypes.c_char_p
324
+ token_id: ctypes.c_int32
325
+ last_hidden_layer: RKLLMResultLastHiddenLayer
326
+ logits: RKLLMResultLogits
327
+ perf: RKLLMPerfStat
328
+
329
+ _fields_ = [
330
+ ("text", ctypes.c_char_p), # 生成的文本结果
331
+ ("token_id", ctypes.c_int32), # 生成的token ID
332
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
333
+ ("logits", RKLLMResultLogits), # 模型输出的logits
334
+ ("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
335
+ ]
336
+
337
+ # --- Typedefs ---
338
+ LLMHandle = ctypes.c_void_p
339
+
340
+ # --- Callback Function Type ---
341
+ LLMResultCallback = ctypes.CFUNCTYPE(
342
+ ctypes.c_int, # 返回类型:int,表示处理状态
343
+ ctypes.POINTER(RKLLMResult), # LLM结果指针
344
+ ctypes.c_void_p, # 用户数据指针
345
+ ctypes.c_int # LLM调用状态(LLMCallState枚举值)
346
+ )
347
+ """
348
+ 回调函数类型定义
349
+
350
+ 用于处理LLM结果的回调函数。
351
+
352
+ 参数:
353
+ - result: 指向LLM结果的指针
354
+ - userdata: 回调的用户数据指针
355
+ - state: LLM调用状态(例如:完成、错误)
356
+
357
+ 返回值:
358
+ - 0: 正常继续推理
359
+ - 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
360
+ 返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
361
+ """
362
+
363
+ class RKLLMRuntime:
364
+ def __init__(self, library_path="./librkllmrt.so"):
365
+ try:
366
+ self.lib = ctypes.CDLL(library_path)
367
+ except OSError as e:
368
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
369
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
370
+ self._setup_functions()
371
+ self.llm_handle = LLMHandle()
372
+ self._c_callback = None # To keep the callback object alive
373
+
374
+ def _setup_functions(self):
375
+ # RKLLMParam rkllm_createDefaultParam();
376
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
377
+ self.lib.rkllm_createDefaultParam.argtypes = []
378
+
379
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
380
+ self.lib.rkllm_init.restype = ctypes.c_int
381
+ self.lib.rkllm_init.argtypes = [
382
+ ctypes.POINTER(LLMHandle),
383
+ ctypes.POINTER(RKLLMParam),
384
+ LLMResultCallback
385
+ ]
386
+
387
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
388
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
389
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
390
+
391
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
392
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
393
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
394
+
395
+ # int rkllm_release_prompt_cache(LLMHandle handle);
396
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
397
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
398
+
399
+ # int rkllm_destroy(LLMHandle handle);
400
+ self.lib.rkllm_destroy.restype = ctypes.c_int
401
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
402
+
403
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
404
+ self.lib.rkllm_run.restype = ctypes.c_int
405
+ self.lib.rkllm_run.argtypes = [
406
+ LLMHandle,
407
+ ctypes.POINTER(RKLLMInput),
408
+ ctypes.POINTER(RKLLMInferParam),
409
+ ctypes.c_void_p # userdata
410
+ ]
411
+
412
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
413
+ # Assuming async also takes userdata for the callback context
414
+ self.lib.rkllm_run_async.restype = ctypes.c_int
415
+ self.lib.rkllm_run_async.argtypes = [
416
+ LLMHandle,
417
+ ctypes.POINTER(RKLLMInput),
418
+ ctypes.POINTER(RKLLMInferParam),
419
+ ctypes.c_void_p # userdata
420
+ ]
421
+
422
+ # int rkllm_abort(LLMHandle handle);
423
+ self.lib.rkllm_abort.restype = ctypes.c_int
424
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
425
+
426
+ # int rkllm_is_running(LLMHandle handle);
427
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
428
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
429
+
430
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
431
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
432
+ self.lib.rkllm_clear_kv_cache.argtypes = [
433
+ LLMHandle,
434
+ ctypes.c_int,
435
+ ctypes.POINTER(ctypes.c_int), # start_pos
436
+ ctypes.POINTER(ctypes.c_int) # end_pos
437
+ ]
438
+
439
+ # int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
440
+ self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
441
+ self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
442
+
443
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
444
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
445
+ self.lib.rkllm_set_chat_template.argtypes = [
446
+ LLMHandle,
447
+ ctypes.c_char_p,
448
+ ctypes.c_char_p,
449
+ ctypes.c_char_p
450
+ ]
451
+
452
+ # int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
453
+ self.lib.rkllm_set_function_tools.restype = ctypes.c_int
454
+ self.lib.rkllm_set_function_tools.argtypes = [
455
+ LLMHandle,
456
+ ctypes.c_char_p, # system_prompt
457
+ ctypes.c_char_p, # tools
458
+ ctypes.c_char_p # tool_response_str
459
+ ]
460
+
461
+ # int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
462
+ self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
463
+ self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
464
+
465
+ def create_default_param(self) -> RKLLMParam:
466
+ """Creates a default RKLLMParam structure."""
467
+ return self.lib.rkllm_createDefaultParam()
468
+
469
+ def init(self, param: RKLLMParam, callback_func) -> int:
470
+ """
471
+ Initializes the LLM.
472
+ :param param: RKLLMParam structure.
473
+ :param callback_func: A Python function that matches the signature:
474
+ def my_callback(result_ptr, userdata_ptr, state_enum):
475
+ result = result_ptr.contents # RKLLMResult
476
+ # Process result
477
+ # userdata can be retrieved if passed during run, or ignored
478
+ # state = LLMCallState(state_enum)
479
+ :return: 0 for success, non-zero for failure.
480
+ """
481
+ if not callable(callback_func):
482
+ raise ValueError("callback_func must be a callable Python function.")
483
+
484
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected
485
+ self._c_callback = LLMResultCallback(callback_func)
486
+
487
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
488
+ if ret != 0:
489
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
490
+ return ret
491
+
492
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
493
+ """Loads a Lora adapter."""
494
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
495
+ if ret != 0:
496
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
497
+ return ret
498
+
499
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
500
+ """Loads a prompt cache from a file."""
501
+ c_path = prompt_cache_path.encode('utf-8')
502
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
503
+ if ret != 0:
504
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
505
+ return ret
506
+
507
+ def release_prompt_cache(self) -> int:
508
+ """Releases the prompt cache from memory."""
509
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
510
+ if ret != 0:
511
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
512
+ return ret
513
+
514
+ def destroy(self) -> int:
515
+ """Destroys the LLM instance and releases resources."""
516
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
517
+ ret = self.lib.rkllm_destroy(self.llm_handle)
518
+ self.llm_handle = LLMHandle() # Reset handle
519
+ if ret != 0:
520
+ # Don't raise here as it might be called in __del__
521
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
522
+ return ret
523
+ return 0 # Already destroyed or not initialized
524
+
525
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
526
+ """Runs an LLM inference task synchronously."""
527
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
528
+ # then cast to c_void_p. Or simply None.
529
+ if userdata is not None:
530
+ # Store the userdata object to keep it alive during the call
531
+ self._userdata_ref = userdata
532
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
533
+ else:
534
+ c_userdata = None
535
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
536
+ if ret != 0:
537
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
538
+ return ret
539
+
540
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
541
+ """Runs an LLM inference task asynchronously."""
542
+ if userdata is not None:
543
+ # Store the userdata object to keep it alive during the call
544
+ self._userdata_ref = userdata
545
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
546
+ else:
547
+ c_userdata = None
548
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
549
+ if ret != 0:
550
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
551
+ return ret
552
+
553
+ def abort(self) -> int:
554
+ """Aborts an ongoing LLM task."""
555
+ ret = self.lib.rkllm_abort(self.llm_handle)
556
+ if ret != 0:
557
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
558
+ return ret
559
+
560
+ def is_running(self) -> bool:
561
+ """Checks if an LLM task is currently running. Returns True if running."""
562
+ # The C API returns 0 if running, non-zero otherwise.
563
+ # This is a bit counter-intuitive for a boolean "is_running".
564
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
565
+
566
+ def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
567
+ """
568
+ 清除键值缓存
569
+
570
+ 此函数用于清除部分或全部KV缓存。
571
+
572
+ 参数:
573
+ - keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
574
+ 如果提供了特定范围[start_pos, end_pos),此标志将被忽略
575
+ - start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
576
+ - end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
577
+ 如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
578
+ 如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
579
+
580
+ 注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
581
+
582
+ 返回:0表示缓存清除成功,非零表示失败
583
+ """
584
+ # 准备C数组参数
585
+ c_start_pos = None
586
+ c_end_pos = None
587
+
588
+ if start_pos is not None and end_pos is not None:
589
+ if len(start_pos) != len(end_pos):
590
+ raise ValueError("start_pos和end_pos数组长度必须相同")
591
+
592
+ # 创建C数组
593
+ c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
594
+ c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
595
+
596
+ ret = self.lib.rkllm_clear_kv_cache(
597
+ self.llm_handle,
598
+ ctypes.c_int(1 if keep_system_prompt else 0),
599
+ c_start_pos,
600
+ c_end_pos
601
+ )
602
+ if ret != 0:
603
+ raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
604
+ return ret
605
+
606
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
607
+ """Sets the chat template for the LLM."""
608
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
609
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
610
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
611
+
612
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
613
+ if ret != 0:
614
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
615
+ return ret
616
+
617
+ def get_kv_cache_size(self, n_batch: int) -> list:
618
+ """
619
+ 获取给定LLM句柄的键值缓存当前大小
620
+
621
+ 此函数返回当前存储在模型KV缓存中的位置总数。
622
+
623
+ 参数:
624
+ - n_batch: 批次数量,用于确定返回数组的大小
625
+
626
+ 返回:
627
+ - list: 每个批次的缓存大小列表
628
+ """
629
+ # 预分配数组以存储每个批次的缓存大小
630
+ cache_sizes = (ctypes.c_int * n_batch)()
631
+
632
+ ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
633
+ if ret != 0:
634
+ raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
635
+
636
+ # 转换为Python列表
637
+ return [cache_sizes[i] for i in range(n_batch)]
638
+
639
+ def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
640
+ """
641
+ 为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
642
+
643
+ 参数:
644
+ - system_prompt: 定义语言模型上下文或行为的系统提示
645
+ - tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
646
+ - tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
647
+ 允许分词器将工具输出与正常对话轮次分开识别
648
+
649
+ 返回:0表示配置设置成功,非零表示错误
650
+ """
651
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
652
+ c_tools = tools.encode('utf-8') if tools else b""
653
+ c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
654
+
655
+ ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
656
+ if ret != 0:
657
+ raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
658
+ return ret
659
+
660
+ def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
661
+ """
662
+ 为LLM解码器设置交叉注意力参数
663
+
664
+ 参数:
665
+ - cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
666
+ (详见RKLLMCrossAttnParam说明)
667
+
668
+ 返回:0表示参数设置成功,非零表示错误
669
+ """
670
+ ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
671
+ if ret != 0:
672
+ raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
673
+ return ret
674
+
675
+ def __enter__(self):
676
+ return self
677
+
678
+ def __exit__(self, exc_type, exc_val, exc_tb):
679
+ self.destroy()
680
+
681
+ def __del__(self):
682
+ self.destroy() # Ensure resources are freed if object is garbage collected
683
+
684
+ # --- Example Usage (Illustrative) ---
685
+ if __name__ == "__main__":
686
+ # This is a placeholder for how you might use it.
687
+ # You'll need a valid .rkllm model and librkllmrt.so in your path.
688
+
689
+ # Global list to store results from callback for demonstration
690
+ results_buffer = []
691
+
692
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
693
+ """
694
+ 回调函数,由C库调用来处理LLM结果
695
+
696
+ 参数:
697
+ - result_ptr: 指向LLM结果的指针
698
+ - userdata_ptr: 用户数据指针
699
+ - state_enum: LLM调用状态枚举值
700
+
701
+ 返回:
702
+ - 0: 继续推理
703
+ - 1: 暂停推理
704
+ """
705
+ global results_buffer
706
+ state = LLMCallState(state_enum)
707
+ result = result_ptr.contents
708
+
709
+ current_text = ""
710
+ if result.text: # 检查char_p是否不为NULL
711
+ current_text = result.text.decode('utf-8', errors='ignore')
712
+
713
+ print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
714
+
715
+ # 显示性能统计信息
716
+ if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0:
717
+ print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, "
718
+ f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, "
719
+ f"内存={result.perf.memory_usage_mb:.1f}MB")
720
+
721
+ results_buffer.append(current_text)
722
+
723
+ if state == LLMCallState.RKLLM_RUN_FINISH:
724
+ print("推理完成。")
725
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
726
+ print("推理错误。")
727
+
728
+ # 返回0继续推理,返回1暂停推理
729
+ return 0
730
+
731
+ # --- Attempt to use the wrapper ---
732
+ try:
733
+ print("Initializing RKLLMRuntime...")
734
+ # Adjust library_path if librkllmrt.so is not in default search paths
735
+ # e.g., library_path="./path/to/librkllmrt.so"
736
+ rk_llm = RKLLMRuntime()
737
+
738
+ print("Creating default parameters...")
739
+ params = rk_llm.create_default_param()
740
+
741
+ # --- Configure parameters ---
742
+ # THIS IS CRITICAL: model_path must point to an actual .rkllm file
743
+ # For this example to run, you need a model file.
744
+ # Let's assume a dummy path for now, this will fail at init if not valid.
745
+ model_file = "language_model.rkllm"
746
+ if not os.path.exists(model_file):
747
+ raise FileNotFoundError(f"Model file '{model_file}' does not exist.")
748
+
749
+ params.model_path = model_file.encode('utf-8')
750
+ params.max_context_len = 512
751
+ params.max_new_tokens = 128
752
+ # params.top_k = 1 # Greedy
753
+ params.temperature = 0.7
754
+ params.repeat_penalty = 1.1
755
+ # ... set other params as needed
756
+
757
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
758
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
759
+ try:
760
+ rk_llm.init(params, my_python_callback)
761
+ print("LLM Initialized.")
762
+ except RuntimeError as e:
763
+ print(f"Error during LLM initialization: {e}")
764
+ print("This is expected if 'dummy_model.rkllm' is not a valid model.")
765
+ print("Replace 'dummy_model.rkllm' with a real model path to test further.")
766
+ exit()
767
+
768
+
769
+ # --- Prepare input ---
770
+ print("准备输入...")
771
+ rk_input = RKLLMInput()
772
+ rk_input.role = b"user" # 设置角色为用户输入
773
+ rk_input.enable_thinking = False # 禁用思考模式(适用于Qwen3模型)
774
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
775
+
776
+ prompt_text = "将以下英文文本翻译成中文:'Hello, world!'"
777
+ c_prompt = prompt_text.encode('utf-8')
778
+ rk_input._union_data.prompt_input = c_prompt # 直接访问联合体成员
779
+
780
+ # --- Prepare inference parameters ---
781
+ print("Preparing inference parameters...")
782
+ infer_params = RKLLMInferParam()
783
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
784
+ infer_params.keep_history = 1 # True
785
+ # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
786
+ # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
787
+
788
+ # --- Run inference ---
789
+ print(f"Running inference with prompt: '{prompt_text}'")
790
+ results_buffer.clear()
791
+ try:
792
+ rk_llm.run(rk_input, infer_params) # Userdata is None by default
793
+ print("\n--- Full Response ---")
794
+ print("".join(results_buffer))
795
+ print("---------------------\n")
796
+ except RuntimeError as e:
797
+ print(f"Error during LLM run: {e}")
798
+
799
+
800
+ # --- Example: Set chat template (if model supports it) ---
801
+ # print("Setting chat template...")
802
+ # try:
803
+ # rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
804
+ # print("Chat template set.")
805
+ # except RuntimeError as e:
806
+ # print(f"Error setting chat template: {e}")
807
+
808
+ # --- Example: Clear KV Cache ---
809
+ # print("Clearing KV cache (keeping system prompt if any)...")
810
+ # try:
811
+ # rk_llm.clear_kv_cache(keep_system_prompt=True)
812
+ # print("KV cache cleared.")
813
+ # except RuntimeError as e:
814
+ # print(f"Error clearing KV cache: {e}")
815
+
816
+ # --- 示例:获取KV缓存大小 ---
817
+ # print("获取KV缓存大小...")
818
+ # try:
819
+ # cache_sizes = rk_llm.get_kv_cache_size(n_batch=1) # 假设批次大小为1
820
+ # print(f"当前KV缓存大小: {cache_sizes}")
821
+ # except RuntimeError as e:
822
+ # print(f"获取KV缓存大小错误: {e}")
823
+
824
+ # --- 示例:设置函数工具 ---
825
+ # print("设置函数调用工具...")
826
+ # try:
827
+ # system_prompt = "你是一个有用的助手,可以调用提供的函数来帮助用户。"
828
+ # tools = '''[{
829
+ # "name": "get_weather",
830
+ # "description": "获取指定城市的天气信息",
831
+ # "parameters": {
832
+ # "type": "object",
833
+ # "properties": {
834
+ # "city": {"type": "string", "description": "城市名称"}
835
+ # },
836
+ # "required": ["city"]
837
+ # }
838
+ # }]'''
839
+ # tool_response_str = "<tool_response>"
840
+ # rk_llm.set_function_tools(system_prompt, tools, tool_response_str)
841
+ # print("函数工具设置成功。")
842
+ # except RuntimeError as e:
843
+ # print(f"设置函数工具错误: {e}")
844
+
845
+ # --- 示例:清除KV缓存(带范围参数) ---
846
+ # print("使用范围参数清除KV缓存...")
847
+ # try:
848
+ # # 清除位置10到20的缓存
849
+ # start_positions = [10] # 批次0的起始位置
850
+ # end_positions = [20] # 批次0的结束位置
851
+ # rk_llm.clear_kv_cache(keep_system_prompt=True, start_pos=start_positions, end_pos=end_positions)
852
+ # print("范围KV缓存清除完成。")
853
+ # except RuntimeError as e:
854
+ # print(f"清除范围KV缓存错误: {e}")
855
+
856
+ except OSError as e:
857
+ print(f"OSError: {e}. Could not load the RKLLM library.")
858
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
859
+ except Exception as e:
860
+ print(f"An unexpected error occurred: {e}")
861
+ finally:
862
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
863
+ print("Destroying LLM instance...")
864
+ rk_llm.destroy()
865
+ print("LLM instance destroyed.")
866
+
867
+ print("Example finished.")
run.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
3
+ from rkllm_binding import *
4
+
5
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
6
+ """
7
+ 回调函数,用于处理LLM的输出结果。
8
+ 这个函数会以流式的方式逐字打印模型的响应。
9
+ """
10
+ state = LLMCallState(state_enum)
11
+ result = result_ptr.contents
12
+
13
+ if result.text:
14
+ current_text = result.text.decode('utf-8', errors='ignore')
15
+ print(current_text, end='', flush=True)
16
+
17
+ if state == LLMCallState.RKLLM_RUN_FINISH:
18
+ # 在响应结束后打印一个换行符,保持格式整洁
19
+ print()
20
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
21
+ print("\n推理过程中发生错误。")
22
+
23
+ # 返回0继续推理,返回1暂停推理
24
+ return 0
25
+
26
+ # --- Attempt to use the wrapper ---
27
+ try:
28
+ print("Initializing RKLLMRuntime...")
29
+ # Adjust library_path if librkllmrt.so is not in default search paths
30
+ # e.g., library_path="./path/to/librkllmrt.so"
31
+ rk_llm = RKLLMRuntime()
32
+
33
+ print("Creating default parameters...")
34
+ params = rk_llm.create_default_param()
35
+
36
+ # --- Configure parameters ---
37
+ model_file = "language_model.rkllm"
38
+ if not os.path.exists(model_file):
39
+ raise FileNotFoundError(f"Model file '{model_file}' does not exist.")
40
+
41
+ params.model_path = model_file.encode('utf-8')
42
+ params.max_context_len = 4096
43
+ params.max_new_tokens = 1024
44
+ # params.top_k = 1 # Greedy
45
+ params.temperature = 0.7
46
+ params.repeat_penalty = 1.1
47
+ # ... set other params as needed
48
+
49
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
50
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
51
+ try:
52
+ rk_llm.init(params, my_python_callback)
53
+ print("LLM Initialized.")
54
+ except RuntimeError as e:
55
+ print(f"Error during LLM initialization: {e}")
56
+ exit()
57
+
58
+
59
+ # --- 进入交互式对话循环 ---
60
+ print("\n进入多轮对话模式。输入 'exit' 或 'quit' 退出。")
61
+
62
+ # 准备推理参数 (这些参数在对话中保持不变)
63
+ infer_params = RKLLMInferParam()
64
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
65
+ infer_params.keep_history = 1 # 保持对话历史
66
+
67
+ while True:
68
+ try:
69
+ prompt_text = input("You: ")
70
+ if prompt_text.lower() in ["exit", "quit"]:
71
+ break
72
+
73
+ print("Assistant: ", end='', flush=True)
74
+
75
+ # 准备输入
76
+ rk_input = RKLLMInput()
77
+ rk_input.role = b"user"
78
+ rk_input.enable_thinking = False
79
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
80
+
81
+ c_prompt = prompt_text.encode('utf-8')
82
+ rk_input._union_data.prompt_input = c_prompt
83
+
84
+ # 运行推理
85
+ rk_llm.run(rk_input, infer_params)
86
+
87
+ except KeyboardInterrupt:
88
+ print("\n\n对话中断。")
89
+ break
90
+ except RuntimeError as e:
91
+ print(f"\n运行时发生错误: {e}")
92
+ break
93
+
94
+
95
+ except OSError as e:
96
+ print(f"OSError: {e}. Could not load the RKLLM library.")
97
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
98
+ except Exception as e:
99
+ print(f"An unexpected error occurred: {e}")
100
+ finally:
101
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
102
+ print("Destroying LLM instance...")
103
+ rk_llm.destroy()
104
+ print("LLM instance destroyed.")
105
+
106
+ print("Example finished.")