VoxCPM-0.5B-RKNN2 / rkllm_binding.py
happyme531's picture
Upload 34 files
621e4aa verified
import argparse
import ctypes
import enum
import os
import threading
from typing import Optional, Sequence, Tuple
import numpy as np
# Define constants from the header
CPU0 = (1 << 0) # 0x01
CPU1 = (1 << 1) # 0x02
CPU2 = (1 << 2) # 0x04
CPU3 = (1 << 3) # 0x08
CPU4 = (1 << 4) # 0x10
CPU5 = (1 << 5) # 0x20
CPU6 = (1 << 6) # 0x40
CPU7 = (1 << 7) # 0x80
# --- Enums ---
class LLMCallState(enum.IntEnum):
RKLLM_RUN_NORMAL = 0
RKLLM_RUN_WAITING = 1
RKLLM_RUN_FINISH = 2
RKLLM_RUN_ERROR = 3
class RKLLMInputType(enum.IntEnum):
RKLLM_INPUT_PROMPT = 0
RKLLM_INPUT_TOKEN = 1
RKLLM_INPUT_EMBED = 2
RKLLM_INPUT_MULTIMODAL = 3
class RKLLMInferMode(enum.IntEnum):
RKLLM_INFER_GENERATE = 0
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
RKLLM_INFER_GET_LOGITS = 2
# --- Structures ---
class RKLLMExtendParam(ctypes.Structure):
base_domain_id: ctypes.c_int32
embed_flash: ctypes.c_int8
enabled_cpus_num: ctypes.c_int8
enabled_cpus_mask: ctypes.c_uint32
n_batch: ctypes.c_uint8
use_cross_attn: ctypes.c_int8
reserved: ctypes.c_uint8 * 104
_fields_ = [
("base_domain_id", ctypes.c_int32), # 基础域ID
("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
("reserved", ctypes.c_uint8 * 104) # 保留字段
]
class RKLLMParam(ctypes.Structure):
model_path: ctypes.c_char_p
max_context_len: ctypes.c_int32
max_new_tokens: ctypes.c_int32
top_k: ctypes.c_int32
n_keep: ctypes.c_int32
top_p: ctypes.c_float
temperature: ctypes.c_float
repeat_penalty: ctypes.c_float
frequency_penalty: ctypes.c_float
presence_penalty: ctypes.c_float
mirostat: ctypes.c_int32
mirostat_tau: ctypes.c_float
mirostat_eta: ctypes.c_float
skip_special_token: ctypes.c_bool
is_async: ctypes.c_bool
img_start: ctypes.c_char_p
img_end: ctypes.c_char_p
img_content: ctypes.c_char_p
extend_param: RKLLMExtendParam
_fields_ = [
("model_path", ctypes.c_char_p), # 模型文件路径
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
("top_k", ctypes.c_int32), # Top-K采样参数
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
("repeat_penalty", ctypes.c_float), # 重复token惩罚
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
("is_async", ctypes.c_bool), # 是否异步推理
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
("img_content", ctypes.c_char_p), # 图像内容指针
("extend_param", RKLLMExtendParam) # 扩展参数
]
class RKLLMLoraAdapter(ctypes.Structure):
lora_adapter_path: ctypes.c_char_p
lora_adapter_name: ctypes.c_char_p
scale: ctypes.c_float
_fields_ = [
("lora_adapter_path", ctypes.c_char_p),
("lora_adapter_name", ctypes.c_char_p),
("scale", ctypes.c_float)
]
class RKLLMEmbedInput(ctypes.Structure):
embed: ctypes.POINTER(ctypes.c_float)
n_tokens: ctypes.c_size_t
_fields_ = [
("embed", ctypes.POINTER(ctypes.c_float)),
("n_tokens", ctypes.c_size_t)
]
class RKLLMTokenInput(ctypes.Structure):
input_ids: ctypes.POINTER(ctypes.c_int32)
n_tokens: ctypes.c_size_t
_fields_ = [
("input_ids", ctypes.POINTER(ctypes.c_int32)),
("n_tokens", ctypes.c_size_t)
]
class RKLLMMultiModelInput(ctypes.Structure):
prompt: ctypes.c_char_p
image_embed: ctypes.POINTER(ctypes.c_float)
n_image_tokens: ctypes.c_size_t
n_image: ctypes.c_size_t
image_width: ctypes.c_size_t
image_height: ctypes.c_size_t
_fields_ = [
("prompt", ctypes.c_char_p),
("image_embed", ctypes.POINTER(ctypes.c_float)),
("n_image_tokens", ctypes.c_size_t),
("n_image", ctypes.c_size_t),
("image_width", ctypes.c_size_t),
("image_height", ctypes.c_size_t)
]
class RKLLMCrossAttnParam(ctypes.Structure):
"""
交叉注意力参数结构体
该结构体用于在解码器中执行交叉注意力时使用。
它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
- encoder_k_cache必须存储在连续内存中,布局为:
[num_layers][num_tokens][num_kv_heads][head_dim]
- encoder_v_cache必须存储在连续内存中,布局为:
[num_layers][num_kv_heads][head_dim][num_tokens]
"""
encoder_k_cache: ctypes.POINTER(ctypes.c_float)
encoder_v_cache: ctypes.POINTER(ctypes.c_float)
encoder_mask: ctypes.POINTER(ctypes.c_float)
encoder_pos: ctypes.POINTER(ctypes.c_int32)
num_tokens: ctypes.c_int
_fields_ = [
("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
("num_tokens", ctypes.c_int) # 编码器序列中的token数量
]
class RKLLMPerfStat(ctypes.Structure):
"""
性能统计结构体
用于保存预填充和生成阶段的性能统计信息。
"""
prefill_time_ms: ctypes.c_float
prefill_tokens: ctypes.c_int
generate_time_ms: ctypes.c_float
generate_tokens: ctypes.c_int
memory_usage_mb: ctypes.c_float
_fields_ = [
("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
]
class _RKLLMInputUnion(ctypes.Union):
prompt_input: ctypes.c_char_p
embed_input: RKLLMEmbedInput
token_input: RKLLMTokenInput
multimodal_input: RKLLMMultiModelInput
_fields_ = [
("prompt_input", ctypes.c_char_p),
("embed_input", RKLLMEmbedInput),
("token_input", RKLLMTokenInput),
("multimodal_input", RKLLMMultiModelInput)
]
class RKLLMInput(ctypes.Structure):
"""
LLM输入结构体
通过联合体表示不同类型的LLM输入。
"""
role: ctypes.c_char_p
enable_thinking: ctypes.c_bool
input_type: ctypes.c_int
_union_data: _RKLLMInputUnion
_fields_ = [
("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
("_union_data", _RKLLMInputUnion) # 联合体数据
]
# Properties to make accessing union members easier
@property
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
return self._union_data.prompt_input
raise AttributeError("Not a prompt input")
@prompt_input.setter
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
self._union_data.prompt_input = value
else:
raise AttributeError("Not a prompt input")
@property
def embed_input(self) -> RKLLMEmbedInput:
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
return self._union_data.embed_input
raise AttributeError("Not an embed input")
@embed_input.setter
def embed_input(self, value: RKLLMEmbedInput):
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
self._union_data.embed_input = value
else:
raise AttributeError("Not an embed input")
@property
def token_input(self) -> RKLLMTokenInput:
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
return self._union_data.token_input
raise AttributeError("Not a token input")
@token_input.setter
def token_input(self, value: RKLLMTokenInput):
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
self._union_data.token_input = value
else:
raise AttributeError("Not a token input")
@property
def multimodal_input(self) -> RKLLMMultiModelInput:
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
return self._union_data.multimodal_input
raise AttributeError("Not a multimodal input")
@multimodal_input.setter
def multimodal_input(self, value: RKLLMMultiModelInput):
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
self._union_data.multimodal_input = value
else:
raise AttributeError("Not a multimodal input")
class RKLLMLoraParam(ctypes.Structure): # For inference
lora_adapter_name: ctypes.c_char_p
_fields_ = [
("lora_adapter_name", ctypes.c_char_p)
]
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
save_prompt_cache: ctypes.c_int # bool-like
prompt_cache_path: ctypes.c_char_p
_fields_ = [
("save_prompt_cache", ctypes.c_int), # bool-like
("prompt_cache_path", ctypes.c_char_p)
]
class RKLLMInferParam(ctypes.Structure):
mode: ctypes.c_int
lora_params: ctypes.POINTER(RKLLMLoraParam)
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
keep_history: ctypes.c_int # bool-like
_fields_ = [
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
("keep_history", ctypes.c_int) # bool-like
]
class RKLLMResultLastHiddenLayer(ctypes.Structure):
hidden_states: ctypes.POINTER(ctypes.c_float)
embd_size: ctypes.c_int
num_tokens: ctypes.c_int
_fields_ = [
("hidden_states", ctypes.POINTER(ctypes.c_float)),
("embd_size", ctypes.c_int),
("num_tokens", ctypes.c_int)
]
class RKLLMResultLogits(ctypes.Structure):
logits: ctypes.POINTER(ctypes.c_float)
vocab_size: ctypes.c_int
num_tokens: ctypes.c_int
_fields_ = [
("logits", ctypes.POINTER(ctypes.c_float)),
("vocab_size", ctypes.c_int),
("num_tokens", ctypes.c_int)
]
class RKLLMResult(ctypes.Structure):
"""
LLM推理结果结构体
表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
"""
text: ctypes.c_char_p
token_id: ctypes.c_int32
last_hidden_layer: RKLLMResultLastHiddenLayer
logits: RKLLMResultLogits
perf: RKLLMPerfStat
_fields_ = [
("text", ctypes.c_char_p), # 生成的文本结果
("token_id", ctypes.c_int32), # 生成的token ID
("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
("logits", RKLLMResultLogits), # 模型输出的logits
("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
]
# --- Typedefs ---
LLMHandle = ctypes.c_void_p
# --- Callback Function Type ---
LLMResultCallback = ctypes.CFUNCTYPE(
ctypes.c_int, # 返回类型:int,表示处理状态
ctypes.POINTER(RKLLMResult), # LLM结果指针
ctypes.c_void_p, # 用户数据指针
ctypes.c_int # LLM调用状态(LLMCallState枚举值)
)
"""
回调函数类型定义
用于处理LLM结果的回调函数。
参数:
- result: 指向LLM结果的指针
- userdata: 回调的用户数据指针
- state: LLM调用状态(例如:完成、错误)
返回值:
- 0: 正常继续推理
- 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
"""
class RKLLMRuntime:
def __init__(self, library_path="./librkllmrt.so"):
try:
self.lib = ctypes.CDLL(library_path)
except OSError as e:
raise OSError(f"Failed to load RKLLM library from {library_path}. "
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
self._setup_functions()
self.llm_handle = LLMHandle()
self._c_callback = None # To keep the callback object alive
self._user_callback = None
def _setup_functions(self):
# RKLLMParam rkllm_createDefaultParam();
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
self.lib.rkllm_createDefaultParam.argtypes = []
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
self.lib.rkllm_init.restype = ctypes.c_int
self.lib.rkllm_init.argtypes = [
ctypes.POINTER(LLMHandle),
ctypes.POINTER(RKLLMParam),
LLMResultCallback
]
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
self.lib.rkllm_load_lora.restype = ctypes.c_int
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
# int rkllm_release_prompt_cache(LLMHandle handle);
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
# int rkllm_destroy(LLMHandle handle);
self.lib.rkllm_destroy.restype = ctypes.c_int
self.lib.rkllm_destroy.argtypes = [LLMHandle]
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
self.lib.rkllm_run.restype = ctypes.c_int
self.lib.rkllm_run.argtypes = [
LLMHandle,
ctypes.POINTER(RKLLMInput),
ctypes.POINTER(RKLLMInferParam),
ctypes.c_void_p # userdata
]
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
# Assuming async also takes userdata for the callback context
self.lib.rkllm_run_async.restype = ctypes.c_int
self.lib.rkllm_run_async.argtypes = [
LLMHandle,
ctypes.POINTER(RKLLMInput),
ctypes.POINTER(RKLLMInferParam),
ctypes.c_void_p # userdata
]
# int rkllm_abort(LLMHandle handle);
self.lib.rkllm_abort.restype = ctypes.c_int
self.lib.rkllm_abort.argtypes = [LLMHandle]
# int rkllm_is_running(LLMHandle handle);
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
self.lib.rkllm_is_running.argtypes = [LLMHandle]
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
self.lib.rkllm_clear_kv_cache.argtypes = [
LLMHandle,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int), # start_pos
ctypes.POINTER(ctypes.c_int) # end_pos
]
# int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
self.lib.rkllm_set_chat_template.argtypes = [
LLMHandle,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_char_p
]
# int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
self.lib.rkllm_set_function_tools.restype = ctypes.c_int
self.lib.rkllm_set_function_tools.argtypes = [
LLMHandle,
ctypes.c_char_p, # system_prompt
ctypes.c_char_p, # tools
ctypes.c_char_p # tool_response_str
]
# int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
def create_default_param(self) -> RKLLMParam:
"""Creates a default RKLLMParam structure."""
return self.lib.rkllm_createDefaultParam()
def init(self, param: RKLLMParam, callback_func) -> int:
"""
Initializes the LLM.
:param param: RKLLMParam structure.
:param callback_func: A Python function that matches the signature:
def my_callback(result_ptr, userdata_ptr, state_enum):
result = result_ptr.contents # RKLLMResult
# Process result
# userdata can be retrieved if passed during run, or ignored
# state = LLMCallState(state_enum)
:return: 0 for success, non-zero for failure.
"""
if not callable(callback_func):
raise ValueError("callback_func must be a callable Python function.")
self._user_callback = callback_func
# Keep a reference to the ctypes callback object to prevent it from being garbage collected.
# Always register a trampoline so we can swap the Python-level handler when needed.
self._c_callback = LLMResultCallback(self._callback_trampoline)
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
if ret != 0:
raise RuntimeError(f"rkllm_init failed with error code {ret}")
return ret
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
"""Loads a Lora adapter."""
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
if ret != 0:
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
return ret
def load_prompt_cache(self, prompt_cache_path: str) -> int:
"""Loads a prompt cache from a file."""
c_path = prompt_cache_path.encode('utf-8')
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
if ret != 0:
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
return ret
def release_prompt_cache(self) -> int:
"""Releases the prompt cache from memory."""
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
if ret != 0:
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
return ret
def destroy(self) -> int:
"""Destroys the LLM instance and releases resources."""
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
ret = self.lib.rkllm_destroy(self.llm_handle)
self.llm_handle = LLMHandle() # Reset handle
if ret != 0:
# Don't raise here as it might be called in __del__
print(f"Warning: rkllm_destroy failed with error code {ret}")
return ret
return 0 # Already destroyed or not initialized
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
"""Runs an LLM inference task synchronously."""
# userdata can be a ctypes.py_object if you want to pass Python objects,
# then cast to c_void_p. Or simply None.
if userdata is not None:
# Store the userdata object to keep it alive during the call
self._userdata_ref = userdata
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
else:
c_userdata = None
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
if ret != 0:
raise RuntimeError(f"rkllm_run failed with error code {ret}")
return ret
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
"""Runs an LLM inference task asynchronously."""
if userdata is not None:
# Store the userdata object to keep it alive during the call
self._userdata_ref = userdata
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
else:
c_userdata = None
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
if ret != 0:
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
return ret
def abort(self) -> int:
"""Aborts an ongoing LLM task."""
ret = self.lib.rkllm_abort(self.llm_handle)
if ret != 0:
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
return ret
def is_running(self) -> bool:
"""Checks if an LLM task is currently running. Returns True if running."""
# The C API returns 0 if running, non-zero otherwise.
# This is a bit counter-intuitive for a boolean "is_running".
return self.lib.rkllm_is_running(self.llm_handle) == 0
def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
"""
清除键值缓存
此函数用于清除部分或全部KV缓存。
参数:
- keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
如果提供了特定范围[start_pos, end_pos),此标志将被忽略
- start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
- end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
返回:0表示缓存清除成功,非零表示失败
"""
# 准备C数组参数
c_start_pos = None
c_end_pos = None
if start_pos is not None and end_pos is not None:
if len(start_pos) != len(end_pos):
raise ValueError("start_pos和end_pos数组长度必须相同")
# 创建C数组
c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
ret = self.lib.rkllm_clear_kv_cache(
self.llm_handle,
ctypes.c_int(1 if keep_system_prompt else 0),
c_start_pos,
c_end_pos
)
if ret != 0:
raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
return ret
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
"""Sets the chat template for the LLM."""
c_system = system_prompt.encode('utf-8') if system_prompt else b""
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
if ret != 0:
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
return ret
def get_kv_cache_size(self, n_batch: int) -> list:
"""
获取给定LLM句柄的键值缓存当前大小
此函数返回当前存储在模型KV缓存中的位置总数。
参数:
- n_batch: 批次数量,用于确定返回数组的大小
返回:
- list: 每个批次的缓存大小列表
"""
# 预分配数组以存储每个批次的缓存大小
cache_sizes = (ctypes.c_int * n_batch)()
ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
if ret != 0:
raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
# 转换为Python列表
return [cache_sizes[i] for i in range(n_batch)]
def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
"""
为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
参数:
- system_prompt: 定义语言模型上下文或行为的系统提示
- tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
- tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
允许分词器将工具输出与正常对话轮次分开识别
返回:0表示配置设置成功,非零表示错误
"""
c_system = system_prompt.encode('utf-8') if system_prompt else b""
c_tools = tools.encode('utf-8') if tools else b""
c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
if ret != 0:
raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
return ret
def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
"""
为LLM解码器设置交叉注意力参数
参数:
- cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
(详见RKLLMCrossAttnParam说明)
返回:0表示参数设置成功,非零表示错误
"""
ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
if ret != 0:
raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
return ret
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.destroy()
def __del__(self):
self.destroy() # Ensure resources are freed if object is garbage collected
def _callback_trampoline(self, result_ptr, userdata_ptr, state_enum):
"""
Bridge callback that forwards to the currently active Python handler.
This keeps the C callback pointer stable while allowing per-call overrides.
"""
handler = self._user_callback
if handler is None:
return 0
try:
return handler(result_ptr, userdata_ptr, state_enum)
except Exception as exc:
# Avoid propagating exceptions through the C callback boundary.
print(f"[rkllm_binding] Callback raised an exception: {exc}")
return 0
def forward_embed(
self,
embeds: np.ndarray,
*,
keep_history: bool = False,
timeout: Optional[float] = None,
return_last_only: bool = False,
) -> np.ndarray:
"""
Run a single forward pass with embedding input and return the last hidden layer.
Args:
embeds: Float32 embeddings shaped (T, H) or (1, T, H). Batch>1 is not supported.
keep_history: When False, KV cache will be cleared after the call. When True,
cache is kept; call clear_kv_cache() manually if needed.
timeout: Optional timeout (seconds) for waiting on the callback.
return_last_only: If True, return the last token vector shape (H,).
Returns:
np.ndarray containing hidden states (T, H) or the last token (H,).
"""
if embeds is None:
raise ValueError("embeds must not be None.")
np_embeds = np.asarray(embeds, dtype=np.float32)
if np_embeds.ndim == 3:
if np_embeds.shape[0] != 1:
raise ValueError("Only batch size 1 is supported for forward_embed.")
num_tokens = np_embeds.shape[1]
flat = np_embeds.reshape(-1)
elif np_embeds.ndim == 2:
num_tokens = np_embeds.shape[0]
flat = np_embeds.reshape(-1)
else:
raise ValueError("embeds must have shape (T, H) or (1, T, H).")
flat = np.ascontiguousarray(flat, dtype=np.float32)
embed_buffer = (ctypes.c_float * flat.size)(*flat)
rk_input = RKLLMInput()
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED
embed_input = RKLLMEmbedInput()
embed_input.embed = embed_buffer
embed_input.n_tokens = num_tokens
rk_input._union_data.embed_input = embed_input
infer_params = RKLLMInferParam()
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER
infer_params.keep_history = 1 if keep_history else 0
infer_params.lora_params = None
infer_params.prompt_cache_params = None
done = threading.Event()
result_holder = {"hidden": None, "error": None}
def _capture_hidden(result_ptr, userdata_ptr, state_enum):
state = LLMCallState(state_enum)
if state == LLMCallState.RKLLM_RUN_ERROR:
result_holder["error"] = "RKLLM reported an error state."
done.set()
return 0
if not result_ptr:
result_holder["error"] = "Empty result pointer received."
done.set()
return 0
result = result_ptr.contents
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
hidden = np.ctypeslib.as_array(
result.last_hidden_layer.hidden_states,
shape=(1, result.last_hidden_layer.num_tokens, result.last_hidden_layer.embd_size),
).copy()
result_holder["hidden"] = hidden[-1].copy() if return_last_only else hidden
done.set()
return 1 # Pause further work; we already have the hidden states.
if state == LLMCallState.RKLLM_RUN_FINISH:
done.set()
return 0
previous_callback = self._user_callback
self._user_callback = _capture_hidden
try:
self.run(rk_input, infer_params)
if not done.wait(timeout):
raise TimeoutError("forward_embed timed out waiting for hidden states.")
finally:
self._user_callback = previous_callback
if result_holder["error"]:
raise RuntimeError(result_holder["error"])
if result_holder["hidden"] is None:
raise RuntimeError("forward_embed did not receive hidden states.")
try:
if not keep_history:
self.clear_kv_cache(True)
except Exception:
# Cache clearing best-effort; keep the forward result usable even if clearing fails.
pass
return result_holder["hidden"]
# --- Demo CLI ---
def _cli_parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Demo application showcasing rkllm_binding usage."
)
parser.add_argument(
"model",
help="Path to the .rkllm model file used for inference."
)
parser.add_argument(
"--lib",
default="./librkllmrt.so",
help="Path to librkllmrt.so. Defaults to ./librkllmrt.so."
)
# Core generation parameters
parser.add_argument("--max-context-len", type=int, default=512, help="Maximum context length.")
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate.")
parser.add_argument("--top-k", type=int, default=1, help="Top-K sampling parameter.")
parser.add_argument("--top-p", type=float, default=0.0, help="Top-P (nucleus) sampling parameter.")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
parser.add_argument("--repeat-penalty", type=float, default=1.1, help="Penalty applied to repeated tokens.")
parser.add_argument("--n-keep", type=int, default=0, help="Number of tokens to keep when context slides.")
parser.add_argument("--mirostat", type=int, default=0, help="Enable Mirostat sampling (0 disables).")
parser.add_argument("--mirostat-tau", type=float, default=5.0, help="Mirostat tau parameter.")
parser.add_argument("--mirostat-eta", type=float, default=0.1, help="Mirostat eta parameter.")
parser.add_argument(
"--skip-special-token",
action="store_true",
help="Skip special tokens when generating output."
)
# Input management
parser.add_argument(
"--input-type",
choices=("prompt", "token", "multimodal"),
default="prompt",
help="Select prompt, raw token, or multimodal (image + prompt) input."
)
parser.add_argument("--prompt", help="Prompt text to send to the model.")
parser.add_argument("--prompt-file", help="Path to a UTF-8 text file containing the prompt.")
parser.add_argument(
"--token-ids",
type=int,
nargs="+",
help="Raw token IDs (space separated). Only valid when --input-type token."
)
parser.add_argument("--role", default="user", help="Role metadata for the input message (e.g., user/system).")
parser.add_argument(
"--enable-thinking",
action="store_true",
help="Enable thinking mode for supported models."
)
parser.add_argument("--image", help="Path to an image file used when --input-type multimodal.")
parser.add_argument("--vision-encoder", help="Path to the ONNX vision encoder model.")
parser.add_argument(
"--encoder-provider",
help="Comma separated ONNX Runtime providers (e.g., 'CPUExecutionProvider')."
)
parser.add_argument(
"--encoder-threads",
type=int,
help="Thread count hint for ONNX Runtime session."
)
parser.add_argument(
"--encoder-input-shape",
help="Override encoder input spatial size as HxW or H,W (e.g., 392x392)."
)
parser.add_argument(
"--norm",
choices=("imagenet", "divide_255", "divide_128_sub_1"),
default="imagenet",
help="Image normalization preset."
)
parser.add_argument(
"--norm-mean",
type=float,
nargs=3,
metavar=("R", "G", "B"),
help="Override normalization mean (RGB order)."
)
parser.add_argument(
"--norm-std",
type=float,
nargs=3,
metavar=("R", "G", "B"),
help="Override normalization std (RGB order)."
)
parser.add_argument(
"--image-background",
type=int,
nargs=3,
metavar=("R", "G", "B"),
default=(128, 128, 128),
help="Background color used when padding image to target size."
)
parser.add_argument("--img-start-token", help="Override image start token string passed to the model.")
parser.add_argument("--img-end-token", help="Override image end token string passed to the model.")
parser.add_argument("--img-content-token", help="Override image content token string passed to the model.")
# Inference options
parser.add_argument(
"--mode",
choices=("generate", "hidden", "logits"),
default="generate",
help="Inference mode: generate tokens, return last hidden layer, or logits."
)
parser.add_argument(
"--no-keep-history",
action="store_true",
help="Do not keep dialogue history on the device."
)
# Output options
parser.add_argument(
"--stream",
action="store_true",
default=True,
help="Stream tokens to stdout as they arrive from the callback."
)
parser.add_argument(
"--hide-stats",
action="store_true",
help="Suppress performance statistics after inference."
)
args = parser.parse_args()
if args.prompt and args.prompt_file:
parser.error("Arguments --prompt and --prompt-file cannot be used together.")
if args.input_type == "prompt":
if not args.prompt and not args.prompt_file:
parser.error("Provide --prompt or --prompt-file when --input-type is prompt.")
if args.token_ids:
parser.error("--token-ids is only valid when --input-type token.")
elif args.input_type == "token":
if not args.token_ids:
parser.error("--token-ids is required when --input-type token.")
if args.prompt or args.prompt_file:
parser.error("--prompt/--prompt-file cannot be combined with --input-type token.")
else: # multimodal
if args.token_ids:
parser.error("--token-ids cannot be used with --input-type multimodal.")
if not args.prompt and not args.prompt_file:
parser.error("Provide --prompt or --prompt-file when --input-type is multimodal.")
if not args.image:
parser.error("--image is required when --input-type multimodal.")
if not args.vision_encoder:
parser.error("--vision-encoder is required when --input-type multimodal.")
if args.image_background:
for component in args.image_background:
if component < 0 or component > 255:
parser.error("--image-background values must be in the range [0, 255].")
return args
def _load_prompt_from_args(args: argparse.Namespace) -> str:
if args.prompt:
return args.prompt
if args.prompt_file:
try:
with open(args.prompt_file, "r", encoding="utf-8") as fp:
return fp.read()
except OSError as exc:
raise RuntimeError(f"Failed to read prompt file '{args.prompt_file}': {exc}") from exc
raise RuntimeError("Prompt text is required but not provided.")
def _mode_to_enum(mode: str) -> int:
mapping = {
"generate": RKLLMInferMode.RKLLM_INFER_GENERATE,
"hidden": RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER,
"logits": RKLLMInferMode.RKLLM_INFER_GET_LOGITS,
}
return mapping[mode]
def _parse_hw_string(value: str) -> Tuple[int, int]:
separators = ("x", "X", ",", " ")
token = value.strip()
for sep in separators:
if sep in token:
parts = [p for p in token.split(sep) if p]
break
else:
parts = [token]
if len(parts) != 2:
raise ValueError(f"Unable to parse height/width from '{value}'. Expected format like 392x392.")
try:
height = int(parts[0])
width = int(parts[1])
except ValueError as exc:
raise ValueError(f"Height/width must be integers, got '{value}'.") from exc
if height <= 0 or width <= 0:
raise ValueError("Height and width must be positive integers.")
return height, width
def _infer_hw_from_onnx_shape(shape: Sequence) -> Tuple[Optional[int], Optional[int]]:
if shape is None or len(shape) < 4:
return None, None
height = shape[-2]
width = shape[-1]
if isinstance(height, str) or height is None:
height = None
if isinstance(width, str) or width is None:
width = None
return height, width
def _parse_providers(provider_str: Optional[str]) -> Optional[list]:
if not provider_str:
return None
providers = [item.strip() for item in provider_str.split(",") if item.strip()]
return providers or None
def _load_vision_encoder_session(encoder_path: str, providers: Optional[list], threads: Optional[int]):
try:
import onnxruntime as ort
except ImportError as exc:
raise RuntimeError("onnxruntime is required for multimodal inference. Please install onnxruntime.") from exc
sess_options = ort.SessionOptions()
if threads and threads > 0:
sess_options.intra_op_num_threads = threads
try:
if providers:
session = ort.InferenceSession(encoder_path, sess_options=sess_options, providers=providers)
else:
session = ort.InferenceSession(encoder_path, sess_options=sess_options)
except Exception as exc:
raise RuntimeError(f"Failed to load vision encoder '{encoder_path}': {exc}") from exc
return session
def _letterbox_resize(image, target_hw: Tuple[int, int], background_color: Sequence[int]):
try:
import cv2
import numpy as np
except ImportError as exc:
raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc
target_h, target_w = target_hw
if image.ndim != 3 or image.shape[2] != 3:
raise RuntimeError("Expected RGB image with 3 channels.")
src_h, src_w = image.shape[:2]
if src_h == 0 or src_w == 0:
raise RuntimeError("Loaded image has invalid dimensions.")
scale = min(target_w / src_w, target_h / src_h)
resized_w = max(1, int(round(src_w * scale)))
resized_h = max(1, int(round(src_h * scale)))
resized = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)
canvas = np.full((target_h, target_w, 3), background_color, dtype=resized.dtype)
top = (target_h - resized_h) // 2
left = (target_w - resized_w) // 2
canvas[top:top + resized_h, left:left + resized_w] = resized
return canvas, resized_h, resized_w
def _normalize_image(image, method: str, mean: Optional[Sequence[float]], std: Optional[Sequence[float]]):
import numpy as np
img = image.astype(np.float32)
mean_arr = np.array(mean, dtype=np.float32) if mean else None
std_arr = np.array(std, dtype=np.float32) if std else None
if method == "imagenet":
img = img / 255.0
if mean_arr is None:
mean_arr = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
if std_arr is None:
std_arr = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
img = (img - mean_arr) / std_arr
elif method == "divide_255":
img = img / 255.0
if mean_arr is not None:
img = img - mean_arr
if std_arr is not None:
img = img / std_arr
elif method == "divide_128_sub_1":
img = img / 128.0 - 1.0
if mean_arr is not None:
img = img - mean_arr
if std_arr is not None:
img = img / std_arr
else:
raise RuntimeError(f"Unsupported normalization method '{method}'.")
return img
def _encode_image_to_embedding(
session,
image_path: str,
input_name: str,
output_name: str,
target_hw: Tuple[int, int],
background_color: Sequence[int],
norm_method: str,
norm_mean: Optional[Sequence[float]],
norm_std: Optional[Sequence[float]]
):
try:
import cv2
import numpy as np
except ImportError as exc:
raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if image is None:
raise RuntimeError(f"Failed to read image from '{image_path}'.")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
padded, resized_h, resized_w = _letterbox_resize(image, target_hw, background_color)
normalized = _normalize_image(padded, norm_method, norm_mean, norm_std)
tensor = np.transpose(normalized, (2, 0, 1)) # HWC -> CHW
tensor = np.expand_dims(tensor, axis=0) # Add batch dimension
tensor = np.ascontiguousarray(tensor, dtype=np.float32)
try:
output_list = session.run([output_name], {input_name: tensor})
except Exception as exc:
raise RuntimeError(f"Vision encoder inference failed: {exc}") from exc
if not output_list:
raise RuntimeError("Vision encoder returned no outputs.")
embedding = output_list[0]
if embedding.ndim == 3:
if embedding.shape[0] != 1:
raise RuntimeError("Vision encoder output batch dimension must be 1 for a single image.")
n_tokens = embedding.shape[1]
elif embedding.ndim == 2:
n_tokens = embedding.shape[0]
else:
raise RuntimeError(f"Unsupported vision encoder output shape {embedding.shape}.")
flat_embedding = embedding.reshape(-1).astype(np.float32, copy=False)
flat_embedding = np.ascontiguousarray(flat_embedding)
return flat_embedding, n_tokens, target_hw
if __name__ == "__main__":
import os
os.environ["RKLLM_LOG_LEVEL"] = "1"
args = _cli_parse_arguments()
prompt_text = None
if args.input_type == "prompt":
prompt_text = _load_prompt_from_args(args)
token_id_array = None
token_input_struct = None
generated_chunks = []
perf_snapshot = {
"prefill_tokens": 0,
"prefill_time_ms": 0.0,
"generate_tokens": 0,
"generate_time_ms": 0.0,
"memory_usage_mb": 0.0,
}
def demo_callback(result_ptr, userdata_ptr, state_enum):
state = LLMCallState(state_enum)
result = result_ptr.contents
current_text = ""
if result.text:
current_text = result.text.decode("utf-8", errors="ignore")
generated_chunks.append(current_text)
if args.stream and current_text:
print(current_text, end="", flush=True)
perf_snapshot.update(
prefill_tokens=result.perf.prefill_tokens,
prefill_time_ms=result.perf.prefill_time_ms,
generate_tokens=result.perf.generate_tokens,
generate_time_ms=result.perf.generate_time_ms,
memory_usage_mb=result.perf.memory_usage_mb,
)
if state == LLMCallState.RKLLM_RUN_ERROR:
print("\n[Callback] 推理过程中出现错误。")
return 0
try:
with RKLLMRuntime(library_path=args.lib) as rk_llm:
params = rk_llm.create_default_param()
params.model_path = os.path.abspath(args.model).encode("utf-8")
params.max_context_len = args.max_context_len
params.max_new_tokens = args.max_new_tokens
params.top_k = args.top_k
params.top_p = float(args.top_p)
params.temperature = float(args.temperature)
params.repeat_penalty = float(args.repeat_penalty)
params.n_keep = args.n_keep
params.mirostat = args.mirostat
params.mirostat_tau = float(args.mirostat_tau)
params.mirostat_eta = float(args.mirostat_eta)
params.skip_special_token = bool(args.skip_special_token)
params.is_async = False
rk_llm.init(params, demo_callback)
rk_input = RKLLMInput()
rk_input.role = args.role.encode("utf-8")
rk_input.enable_thinking = bool(args.enable_thinking)
if args.input_type == "prompt":
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
rk_input._union_data.prompt_input = prompt_text.encode("utf-8")
else:
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_TOKEN
token_id_array = (ctypes.c_int32 * len(args.token_ids))(*args.token_ids)
token_input_struct = RKLLMTokenInput()
token_input_struct.input_ids = token_id_array
token_input_struct.n_tokens = len(args.token_ids)
rk_input._union_data.token_input = token_input_struct
infer_params = RKLLMInferParam()
infer_params.mode = _mode_to_enum(args.mode)
infer_params.keep_history = 0 if args.no_keep_history else 1
infer_params.lora_params = None
infer_params.prompt_cache_params = None
if args.stream:
print("=== Streaming Output ===")
rk_llm.run(rk_input, infer_params)
except OSError as exc:
print(f"无法加载 RKLLM 运行时库:{exc}")
except RuntimeError as exc:
print(f"推理失败:{exc}")
except Exception as exc:
print(f"发生未预期的错误:{exc}")
else:
if args.stream:
print() # Ensure newline after streaming output
final_text = "".join(generated_chunks)
if final_text:
print("=== 生成结果 ===")
print(final_text)
else:
print("未收到生成文本。")
if not args.hide_stats:
print("=== 性能统计 ===")
print(
f"预填充: {perf_snapshot['prefill_tokens']} tokens / {perf_snapshot['prefill_time_ms']:.2f} ms"
)
print(
f"生成: {perf_snapshot['generate_tokens']} tokens / {perf_snapshot['generate_time_ms']:.2f} ms"
)
print(f"最大常驻内存: {perf_snapshot['memory_usage_mb']:.2f} MB")