| import argparse |
| import ctypes |
| import enum |
| import os |
| import threading |
| from typing import Optional, Sequence, Tuple |
|
|
| import numpy as np |
|
|
| |
| CPU0 = (1 << 0) |
| CPU1 = (1 << 1) |
| CPU2 = (1 << 2) |
| CPU3 = (1 << 3) |
| CPU4 = (1 << 4) |
| CPU5 = (1 << 5) |
| CPU6 = (1 << 6) |
| CPU7 = (1 << 7) |
|
|
| |
| class LLMCallState(enum.IntEnum): |
| RKLLM_RUN_NORMAL = 0 |
| RKLLM_RUN_WAITING = 1 |
| RKLLM_RUN_FINISH = 2 |
| RKLLM_RUN_ERROR = 3 |
|
|
| class RKLLMInputType(enum.IntEnum): |
| RKLLM_INPUT_PROMPT = 0 |
| RKLLM_INPUT_TOKEN = 1 |
| RKLLM_INPUT_EMBED = 2 |
| RKLLM_INPUT_MULTIMODAL = 3 |
|
|
| class RKLLMInferMode(enum.IntEnum): |
| RKLLM_INFER_GENERATE = 0 |
| RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 |
| RKLLM_INFER_GET_LOGITS = 2 |
|
|
| |
| class RKLLMExtendParam(ctypes.Structure): |
| base_domain_id: ctypes.c_int32 |
| embed_flash: ctypes.c_int8 |
| enabled_cpus_num: ctypes.c_int8 |
| enabled_cpus_mask: ctypes.c_uint32 |
| n_batch: ctypes.c_uint8 |
| use_cross_attn: ctypes.c_int8 |
| reserved: ctypes.c_uint8 * 104 |
|
|
| _fields_ = [ |
| ("base_domain_id", ctypes.c_int32), |
| ("embed_flash", ctypes.c_int8), |
| ("enabled_cpus_num", ctypes.c_int8), |
| ("enabled_cpus_mask", ctypes.c_uint32), |
| ("n_batch", ctypes.c_uint8), |
| ("use_cross_attn", ctypes.c_int8), |
| ("reserved", ctypes.c_uint8 * 104) |
| ] |
|
|
| class RKLLMParam(ctypes.Structure): |
| model_path: ctypes.c_char_p |
| max_context_len: ctypes.c_int32 |
| max_new_tokens: ctypes.c_int32 |
| top_k: ctypes.c_int32 |
| n_keep: ctypes.c_int32 |
| top_p: ctypes.c_float |
| temperature: ctypes.c_float |
| repeat_penalty: ctypes.c_float |
| frequency_penalty: ctypes.c_float |
| presence_penalty: ctypes.c_float |
| mirostat: ctypes.c_int32 |
| mirostat_tau: ctypes.c_float |
| mirostat_eta: ctypes.c_float |
| skip_special_token: ctypes.c_bool |
| is_async: ctypes.c_bool |
| img_start: ctypes.c_char_p |
| img_end: ctypes.c_char_p |
| img_content: ctypes.c_char_p |
| extend_param: RKLLMExtendParam |
|
|
| _fields_ = [ |
| ("model_path", ctypes.c_char_p), |
| ("max_context_len", ctypes.c_int32), |
| ("max_new_tokens", ctypes.c_int32), |
| ("top_k", ctypes.c_int32), |
| ("n_keep", ctypes.c_int32), |
| ("top_p", ctypes.c_float), |
| ("temperature", ctypes.c_float), |
| ("repeat_penalty", ctypes.c_float), |
| ("frequency_penalty", ctypes.c_float), |
| ("presence_penalty", ctypes.c_float), |
| ("mirostat", ctypes.c_int32), |
| ("mirostat_tau", ctypes.c_float), |
| ("mirostat_eta", ctypes.c_float), |
| ("skip_special_token", ctypes.c_bool), |
| ("is_async", ctypes.c_bool), |
| ("img_start", ctypes.c_char_p), |
| ("img_end", ctypes.c_char_p), |
| ("img_content", ctypes.c_char_p), |
| ("extend_param", RKLLMExtendParam) |
| ] |
|
|
| class RKLLMLoraAdapter(ctypes.Structure): |
| lora_adapter_path: ctypes.c_char_p |
| lora_adapter_name: ctypes.c_char_p |
| scale: ctypes.c_float |
|
|
| _fields_ = [ |
| ("lora_adapter_path", ctypes.c_char_p), |
| ("lora_adapter_name", ctypes.c_char_p), |
| ("scale", ctypes.c_float) |
| ] |
|
|
| class RKLLMEmbedInput(ctypes.Structure): |
| embed: ctypes.POINTER(ctypes.c_float) |
| n_tokens: ctypes.c_size_t |
|
|
| _fields_ = [ |
| ("embed", ctypes.POINTER(ctypes.c_float)), |
| ("n_tokens", ctypes.c_size_t) |
| ] |
|
|
| class RKLLMTokenInput(ctypes.Structure): |
| input_ids: ctypes.POINTER(ctypes.c_int32) |
| n_tokens: ctypes.c_size_t |
|
|
| _fields_ = [ |
| ("input_ids", ctypes.POINTER(ctypes.c_int32)), |
| ("n_tokens", ctypes.c_size_t) |
| ] |
|
|
| class RKLLMMultiModelInput(ctypes.Structure): |
| prompt: ctypes.c_char_p |
| image_embed: ctypes.POINTER(ctypes.c_float) |
| n_image_tokens: ctypes.c_size_t |
| n_image: ctypes.c_size_t |
| image_width: ctypes.c_size_t |
| image_height: ctypes.c_size_t |
|
|
| _fields_ = [ |
| ("prompt", ctypes.c_char_p), |
| ("image_embed", ctypes.POINTER(ctypes.c_float)), |
| ("n_image_tokens", ctypes.c_size_t), |
| ("n_image", ctypes.c_size_t), |
| ("image_width", ctypes.c_size_t), |
| ("image_height", ctypes.c_size_t) |
| ] |
|
|
| class RKLLMCrossAttnParam(ctypes.Structure): |
| """ |
| 交叉注意力参数结构体 |
| |
| 该结构体用于在解码器中执行交叉注意力时使用。 |
| 它提供编码器输出(键/值缓存)、位置索引和注意力掩码。 |
| |
| - encoder_k_cache必须存储在连续内存中,布局为: |
| [num_layers][num_tokens][num_kv_heads][head_dim] |
| - encoder_v_cache必须存储在连续内存中,布局为: |
| [num_layers][num_kv_heads][head_dim][num_tokens] |
| """ |
| encoder_k_cache: ctypes.POINTER(ctypes.c_float) |
| encoder_v_cache: ctypes.POINTER(ctypes.c_float) |
| encoder_mask: ctypes.POINTER(ctypes.c_float) |
| encoder_pos: ctypes.POINTER(ctypes.c_int32) |
| num_tokens: ctypes.c_int |
|
|
| _fields_ = [ |
| ("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), |
| ("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), |
| ("encoder_mask", ctypes.POINTER(ctypes.c_float)), |
| ("encoder_pos", ctypes.POINTER(ctypes.c_int32)), |
| ("num_tokens", ctypes.c_int) |
| ] |
|
|
| class RKLLMPerfStat(ctypes.Structure): |
| """ |
| 性能统计结构体 |
| |
| 用于保存预填充和生成阶段的性能统计信息。 |
| """ |
| prefill_time_ms: ctypes.c_float |
| prefill_tokens: ctypes.c_int |
| generate_time_ms: ctypes.c_float |
| generate_tokens: ctypes.c_int |
| memory_usage_mb: ctypes.c_float |
|
|
| _fields_ = [ |
| ("prefill_time_ms", ctypes.c_float), |
| ("prefill_tokens", ctypes.c_int), |
| ("generate_time_ms", ctypes.c_float), |
| ("generate_tokens", ctypes.c_int), |
| ("memory_usage_mb", ctypes.c_float) |
| ] |
|
|
| class _RKLLMInputUnion(ctypes.Union): |
| prompt_input: ctypes.c_char_p |
| embed_input: RKLLMEmbedInput |
| token_input: RKLLMTokenInput |
| multimodal_input: RKLLMMultiModelInput |
|
|
| _fields_ = [ |
| ("prompt_input", ctypes.c_char_p), |
| ("embed_input", RKLLMEmbedInput), |
| ("token_input", RKLLMTokenInput), |
| ("multimodal_input", RKLLMMultiModelInput) |
| ] |
|
|
| class RKLLMInput(ctypes.Structure): |
| """ |
| LLM输入结构体 |
| |
| 通过联合体表示不同类型的LLM输入。 |
| """ |
| role: ctypes.c_char_p |
| enable_thinking: ctypes.c_bool |
| input_type: ctypes.c_int |
| _union_data: _RKLLMInputUnion |
|
|
| _fields_ = [ |
| ("role", ctypes.c_char_p), |
| ("enable_thinking", ctypes.c_bool), |
| ("input_type", ctypes.c_int), |
| ("_union_data", _RKLLMInputUnion) |
| ] |
| |
| @property |
| def prompt_input(self) -> bytes: |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: |
| return self._union_data.prompt_input |
| raise AttributeError("Not a prompt input") |
| @prompt_input.setter |
| def prompt_input(self, value: bytes): |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: |
| self._union_data.prompt_input = value |
| else: |
| raise AttributeError("Not a prompt input") |
| @property |
| def embed_input(self) -> RKLLMEmbedInput: |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED: |
| return self._union_data.embed_input |
| raise AttributeError("Not an embed input") |
| @embed_input.setter |
| def embed_input(self, value: RKLLMEmbedInput): |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED: |
| self._union_data.embed_input = value |
| else: |
| raise AttributeError("Not an embed input") |
|
|
| @property |
| def token_input(self) -> RKLLMTokenInput: |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN: |
| return self._union_data.token_input |
| raise AttributeError("Not a token input") |
| @token_input.setter |
| def token_input(self, value: RKLLMTokenInput): |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN: |
| self._union_data.token_input = value |
| else: |
| raise AttributeError("Not a token input") |
|
|
| @property |
| def multimodal_input(self) -> RKLLMMultiModelInput: |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL: |
| return self._union_data.multimodal_input |
| raise AttributeError("Not a multimodal input") |
| @multimodal_input.setter |
| def multimodal_input(self, value: RKLLMMultiModelInput): |
| if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL: |
| self._union_data.multimodal_input = value |
| else: |
| raise AttributeError("Not a multimodal input") |
|
|
| class RKLLMLoraParam(ctypes.Structure): |
| lora_adapter_name: ctypes.c_char_p |
|
|
| _fields_ = [ |
| ("lora_adapter_name", ctypes.c_char_p) |
| ] |
|
|
| class RKLLMPromptCacheParam(ctypes.Structure): |
| save_prompt_cache: ctypes.c_int |
| prompt_cache_path: ctypes.c_char_p |
|
|
| _fields_ = [ |
| ("save_prompt_cache", ctypes.c_int), |
| ("prompt_cache_path", ctypes.c_char_p) |
| ] |
|
|
| class RKLLMInferParam(ctypes.Structure): |
| mode: ctypes.c_int |
| lora_params: ctypes.POINTER(RKLLMLoraParam) |
| prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam) |
| keep_history: ctypes.c_int |
|
|
| _fields_ = [ |
| ("mode", ctypes.c_int), |
| ("lora_params", ctypes.POINTER(RKLLMLoraParam)), |
| ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)), |
| ("keep_history", ctypes.c_int) |
| ] |
|
|
| class RKLLMResultLastHiddenLayer(ctypes.Structure): |
| hidden_states: ctypes.POINTER(ctypes.c_float) |
| embd_size: ctypes.c_int |
| num_tokens: ctypes.c_int |
|
|
| _fields_ = [ |
| ("hidden_states", ctypes.POINTER(ctypes.c_float)), |
| ("embd_size", ctypes.c_int), |
| ("num_tokens", ctypes.c_int) |
| ] |
|
|
| class RKLLMResultLogits(ctypes.Structure): |
| logits: ctypes.POINTER(ctypes.c_float) |
| vocab_size: ctypes.c_int |
| num_tokens: ctypes.c_int |
|
|
| _fields_ = [ |
| ("logits", ctypes.POINTER(ctypes.c_float)), |
| ("vocab_size", ctypes.c_int), |
| ("num_tokens", ctypes.c_int) |
| ] |
|
|
| class RKLLMResult(ctypes.Structure): |
| """ |
| LLM推理结果结构体 |
| |
| 表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。 |
| """ |
| text: ctypes.c_char_p |
| token_id: ctypes.c_int32 |
| last_hidden_layer: RKLLMResultLastHiddenLayer |
| logits: RKLLMResultLogits |
| perf: RKLLMPerfStat |
|
|
| _fields_ = [ |
| ("text", ctypes.c_char_p), |
| ("token_id", ctypes.c_int32), |
| ("last_hidden_layer", RKLLMResultLastHiddenLayer), |
| ("logits", RKLLMResultLogits), |
| ("perf", RKLLMPerfStat) |
| ] |
|
|
| |
| LLMHandle = ctypes.c_void_p |
|
|
| |
| LLMResultCallback = ctypes.CFUNCTYPE( |
| ctypes.c_int, |
| ctypes.POINTER(RKLLMResult), |
| ctypes.c_void_p, |
| ctypes.c_int |
| ) |
| """ |
| 回调函数类型定义 |
| |
| 用于处理LLM结果的回调函数。 |
| |
| 参数: |
| - result: 指向LLM结果的指针 |
| - userdata: 回调的用户数据指针 |
| - state: LLM调用状态(例如:完成、错误) |
| |
| 返回值: |
| - 0: 正常继续推理 |
| - 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示), |
| 返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。 |
| """ |
|
|
| class RKLLMRuntime: |
| def __init__(self, library_path="./librkllmrt.so"): |
| try: |
| self.lib = ctypes.CDLL(library_path) |
| except OSError as e: |
| raise OSError(f"Failed to load RKLLM library from {library_path}. " |
| f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}") |
| self._setup_functions() |
| self.llm_handle = LLMHandle() |
| self._c_callback = None |
| self._user_callback = None |
|
|
| def _setup_functions(self): |
| |
| self.lib.rkllm_createDefaultParam.restype = RKLLMParam |
| self.lib.rkllm_createDefaultParam.argtypes = [] |
|
|
| |
| self.lib.rkllm_init.restype = ctypes.c_int |
| self.lib.rkllm_init.argtypes = [ |
| ctypes.POINTER(LLMHandle), |
| ctypes.POINTER(RKLLMParam), |
| LLMResultCallback |
| ] |
|
|
| |
| self.lib.rkllm_load_lora.restype = ctypes.c_int |
| self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)] |
|
|
| |
| self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int |
| self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p] |
|
|
| |
| self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int |
| self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle] |
|
|
| |
| self.lib.rkllm_destroy.restype = ctypes.c_int |
| self.lib.rkllm_destroy.argtypes = [LLMHandle] |
|
|
| |
| self.lib.rkllm_run.restype = ctypes.c_int |
| self.lib.rkllm_run.argtypes = [ |
| LLMHandle, |
| ctypes.POINTER(RKLLMInput), |
| ctypes.POINTER(RKLLMInferParam), |
| ctypes.c_void_p |
| ] |
|
|
| |
| |
| self.lib.rkllm_run_async.restype = ctypes.c_int |
| self.lib.rkllm_run_async.argtypes = [ |
| LLMHandle, |
| ctypes.POINTER(RKLLMInput), |
| ctypes.POINTER(RKLLMInferParam), |
| ctypes.c_void_p |
| ] |
|
|
| |
| self.lib.rkllm_abort.restype = ctypes.c_int |
| self.lib.rkllm_abort.argtypes = [LLMHandle] |
|
|
| |
| self.lib.rkllm_is_running.restype = ctypes.c_int |
| self.lib.rkllm_is_running.argtypes = [LLMHandle] |
|
|
| |
| self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int |
| self.lib.rkllm_clear_kv_cache.argtypes = [ |
| LLMHandle, |
| ctypes.c_int, |
| ctypes.POINTER(ctypes.c_int), |
| ctypes.POINTER(ctypes.c_int) |
| ] |
|
|
| |
| self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int |
| self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)] |
|
|
| |
| self.lib.rkllm_set_chat_template.restype = ctypes.c_int |
| self.lib.rkllm_set_chat_template.argtypes = [ |
| LLMHandle, |
| ctypes.c_char_p, |
| ctypes.c_char_p, |
| ctypes.c_char_p |
| ] |
|
|
| |
| self.lib.rkllm_set_function_tools.restype = ctypes.c_int |
| self.lib.rkllm_set_function_tools.argtypes = [ |
| LLMHandle, |
| ctypes.c_char_p, |
| ctypes.c_char_p, |
| ctypes.c_char_p |
| ] |
|
|
| |
| self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int |
| self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)] |
|
|
| def create_default_param(self) -> RKLLMParam: |
| """Creates a default RKLLMParam structure.""" |
| return self.lib.rkllm_createDefaultParam() |
|
|
| def init(self, param: RKLLMParam, callback_func) -> int: |
| """ |
| Initializes the LLM. |
| :param param: RKLLMParam structure. |
| :param callback_func: A Python function that matches the signature: |
| def my_callback(result_ptr, userdata_ptr, state_enum): |
| result = result_ptr.contents # RKLLMResult |
| # Process result |
| # userdata can be retrieved if passed during run, or ignored |
| # state = LLMCallState(state_enum) |
| :return: 0 for success, non-zero for failure. |
| """ |
| if not callable(callback_func): |
| raise ValueError("callback_func must be a callable Python function.") |
|
|
| self._user_callback = callback_func |
|
|
| |
| |
| self._c_callback = LLMResultCallback(self._callback_trampoline) |
|
|
| ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_init failed with error code {ret}") |
| return ret |
|
|
| def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int: |
| """Loads a Lora adapter.""" |
| ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter)) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_load_lora failed with error code {ret}") |
| return ret |
|
|
| def load_prompt_cache(self, prompt_cache_path: str) -> int: |
| """Loads a prompt cache from a file.""" |
| c_path = prompt_cache_path.encode('utf-8') |
| ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}") |
| return ret |
|
|
| def release_prompt_cache(self) -> int: |
| """Releases the prompt cache from memory.""" |
| ret = self.lib.rkllm_release_prompt_cache(self.llm_handle) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}") |
| return ret |
|
|
| def destroy(self) -> int: |
| """Destroys the LLM instance and releases resources.""" |
| if self.llm_handle and self.llm_handle.value: |
| ret = self.lib.rkllm_destroy(self.llm_handle) |
| self.llm_handle = LLMHandle() |
| if ret != 0: |
| |
| print(f"Warning: rkllm_destroy failed with error code {ret}") |
| return ret |
| return 0 |
|
|
| def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: |
| """Runs an LLM inference task synchronously.""" |
| |
| |
| if userdata is not None: |
| |
| self._userdata_ref = userdata |
| c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p) |
| else: |
| c_userdata = None |
| ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_run failed with error code {ret}") |
| return ret |
|
|
| def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: |
| """Runs an LLM inference task asynchronously.""" |
| if userdata is not None: |
| |
| self._userdata_ref = userdata |
| c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p) |
| else: |
| c_userdata = None |
| ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_run_async failed with error code {ret}") |
| return ret |
|
|
| def abort(self) -> int: |
| """Aborts an ongoing LLM task.""" |
| ret = self.lib.rkllm_abort(self.llm_handle) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_abort failed with error code {ret}") |
| return ret |
|
|
| def is_running(self) -> bool: |
| """Checks if an LLM task is currently running. Returns True if running.""" |
| |
| |
| return self.lib.rkllm_is_running(self.llm_handle) == 0 |
|
|
| def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int: |
| """ |
| 清除键值缓存 |
| |
| 此函数用于清除部分或全部KV缓存。 |
| |
| 参数: |
| - keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除) |
| 如果提供了特定范围[start_pos, end_pos),此标志将被忽略 |
| - start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个 |
| - end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个 |
| 如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效 |
| 如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略 |
| |
| 注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效 |
| |
| 返回:0表示缓存清除成功,非零表示失败 |
| """ |
| |
| c_start_pos = None |
| c_end_pos = None |
| |
| if start_pos is not None and end_pos is not None: |
| if len(start_pos) != len(end_pos): |
| raise ValueError("start_pos和end_pos数组长度必须相同") |
| |
| |
| c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos) |
| c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos) |
| |
| ret = self.lib.rkllm_clear_kv_cache( |
| self.llm_handle, |
| ctypes.c_int(1 if keep_system_prompt else 0), |
| c_start_pos, |
| c_end_pos |
| ) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}") |
| return ret |
|
|
| def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int: |
| """Sets the chat template for the LLM.""" |
| c_system = system_prompt.encode('utf-8') if system_prompt else b"" |
| c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b"" |
| c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b"" |
| |
| ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}") |
| return ret |
|
|
| def get_kv_cache_size(self, n_batch: int) -> list: |
| """ |
| 获取给定LLM句柄的键值缓存当前大小 |
| |
| 此函数返回当前存储在模型KV缓存中的位置总数。 |
| |
| 参数: |
| - n_batch: 批次数量,用于确定返回数组的大小 |
| |
| 返回: |
| - list: 每个批次的缓存大小列表 |
| """ |
| |
| cache_sizes = (ctypes.c_int * n_batch)() |
| |
| ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}") |
| |
| |
| return [cache_sizes[i] for i in range(n_batch)] |
|
|
| def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int: |
| """ |
| 为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token |
| |
| 参数: |
| - system_prompt: 定义语言模型上下文或行为的系统提示 |
| - tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数 |
| - tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签, |
| 允许分词器将工具输出与正常对话轮次分开识别 |
| |
| 返回:0表示配置设置成功,非零表示错误 |
| """ |
| c_system = system_prompt.encode('utf-8') if system_prompt else b"" |
| c_tools = tools.encode('utf-8') if tools else b"" |
| c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b"" |
| |
| ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}") |
| return ret |
|
|
| def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int: |
| """ |
| 为LLM解码器设置交叉注意力参数 |
| |
| 参数: |
| - cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体 |
| (详见RKLLMCrossAttnParam说明) |
| |
| 返回:0表示参数设置成功,非零表示错误 |
| """ |
| ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params)) |
| if ret != 0: |
| raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}") |
| return ret |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.destroy() |
|
|
| def __del__(self): |
| self.destroy() |
|
|
| def _callback_trampoline(self, result_ptr, userdata_ptr, state_enum): |
| """ |
| Bridge callback that forwards to the currently active Python handler. |
| This keeps the C callback pointer stable while allowing per-call overrides. |
| """ |
| handler = self._user_callback |
| if handler is None: |
| return 0 |
| try: |
| return handler(result_ptr, userdata_ptr, state_enum) |
| except Exception as exc: |
| |
| print(f"[rkllm_binding] Callback raised an exception: {exc}") |
| return 0 |
|
|
| def forward_embed( |
| self, |
| embeds: np.ndarray, |
| *, |
| keep_history: bool = False, |
| timeout: Optional[float] = None, |
| return_last_only: bool = False, |
| ) -> np.ndarray: |
| """ |
| Run a single forward pass with embedding input and return the last hidden layer. |
| |
| Args: |
| embeds: Float32 embeddings shaped (T, H) or (1, T, H). Batch>1 is not supported. |
| keep_history: When False, KV cache will be cleared after the call. When True, |
| cache is kept; call clear_kv_cache() manually if needed. |
| timeout: Optional timeout (seconds) for waiting on the callback. |
| return_last_only: If True, return the last token vector shape (H,). |
| |
| Returns: |
| np.ndarray containing hidden states (T, H) or the last token (H,). |
| """ |
| if embeds is None: |
| raise ValueError("embeds must not be None.") |
|
|
| np_embeds = np.asarray(embeds, dtype=np.float32) |
| if np_embeds.ndim == 3: |
| if np_embeds.shape[0] != 1: |
| raise ValueError("Only batch size 1 is supported for forward_embed.") |
| num_tokens = np_embeds.shape[1] |
| flat = np_embeds.reshape(-1) |
| elif np_embeds.ndim == 2: |
| num_tokens = np_embeds.shape[0] |
| flat = np_embeds.reshape(-1) |
| else: |
| raise ValueError("embeds must have shape (T, H) or (1, T, H).") |
|
|
| flat = np.ascontiguousarray(flat, dtype=np.float32) |
| embed_buffer = (ctypes.c_float * flat.size)(*flat) |
|
|
| rk_input = RKLLMInput() |
| rk_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED |
| embed_input = RKLLMEmbedInput() |
| embed_input.embed = embed_buffer |
| embed_input.n_tokens = num_tokens |
| rk_input._union_data.embed_input = embed_input |
|
|
| infer_params = RKLLMInferParam() |
| infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER |
| infer_params.keep_history = 1 if keep_history else 0 |
| infer_params.lora_params = None |
| infer_params.prompt_cache_params = None |
|
|
| done = threading.Event() |
| result_holder = {"hidden": None, "error": None} |
|
|
| def _capture_hidden(result_ptr, userdata_ptr, state_enum): |
| state = LLMCallState(state_enum) |
| if state == LLMCallState.RKLLM_RUN_ERROR: |
| result_holder["error"] = "RKLLM reported an error state." |
| done.set() |
| return 0 |
|
|
| if not result_ptr: |
| result_holder["error"] = "Empty result pointer received." |
| done.set() |
| return 0 |
|
|
| result = result_ptr.contents |
| if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: |
| hidden = np.ctypeslib.as_array( |
| result.last_hidden_layer.hidden_states, |
| shape=(1, result.last_hidden_layer.num_tokens, result.last_hidden_layer.embd_size), |
| ).copy() |
| result_holder["hidden"] = hidden[-1].copy() if return_last_only else hidden |
| done.set() |
| return 1 |
|
|
| if state == LLMCallState.RKLLM_RUN_FINISH: |
| done.set() |
| return 0 |
|
|
| previous_callback = self._user_callback |
| self._user_callback = _capture_hidden |
| try: |
| self.run(rk_input, infer_params) |
| if not done.wait(timeout): |
| raise TimeoutError("forward_embed timed out waiting for hidden states.") |
| finally: |
| self._user_callback = previous_callback |
|
|
| if result_holder["error"]: |
| raise RuntimeError(result_holder["error"]) |
| if result_holder["hidden"] is None: |
| raise RuntimeError("forward_embed did not receive hidden states.") |
|
|
| try: |
| if not keep_history: |
| self.clear_kv_cache(True) |
| except Exception: |
| |
| pass |
|
|
| return result_holder["hidden"] |
|
|
| |
| def _cli_parse_arguments() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Demo application showcasing rkllm_binding usage." |
| ) |
| parser.add_argument( |
| "model", |
| help="Path to the .rkllm model file used for inference." |
| ) |
| parser.add_argument( |
| "--lib", |
| default="./librkllmrt.so", |
| help="Path to librkllmrt.so. Defaults to ./librkllmrt.so." |
| ) |
|
|
| |
| parser.add_argument("--max-context-len", type=int, default=512, help="Maximum context length.") |
| parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate.") |
| parser.add_argument("--top-k", type=int, default=1, help="Top-K sampling parameter.") |
| parser.add_argument("--top-p", type=float, default=0.0, help="Top-P (nucleus) sampling parameter.") |
| parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.") |
| parser.add_argument("--repeat-penalty", type=float, default=1.1, help="Penalty applied to repeated tokens.") |
| parser.add_argument("--n-keep", type=int, default=0, help="Number of tokens to keep when context slides.") |
| parser.add_argument("--mirostat", type=int, default=0, help="Enable Mirostat sampling (0 disables).") |
| parser.add_argument("--mirostat-tau", type=float, default=5.0, help="Mirostat tau parameter.") |
| parser.add_argument("--mirostat-eta", type=float, default=0.1, help="Mirostat eta parameter.") |
| parser.add_argument( |
| "--skip-special-token", |
| action="store_true", |
| help="Skip special tokens when generating output." |
| ) |
|
|
| |
| parser.add_argument( |
| "--input-type", |
| choices=("prompt", "token", "multimodal"), |
| default="prompt", |
| help="Select prompt, raw token, or multimodal (image + prompt) input." |
| ) |
| parser.add_argument("--prompt", help="Prompt text to send to the model.") |
| parser.add_argument("--prompt-file", help="Path to a UTF-8 text file containing the prompt.") |
| parser.add_argument( |
| "--token-ids", |
| type=int, |
| nargs="+", |
| help="Raw token IDs (space separated). Only valid when --input-type token." |
| ) |
| parser.add_argument("--role", default="user", help="Role metadata for the input message (e.g., user/system).") |
| parser.add_argument( |
| "--enable-thinking", |
| action="store_true", |
| help="Enable thinking mode for supported models." |
| ) |
| parser.add_argument("--image", help="Path to an image file used when --input-type multimodal.") |
| parser.add_argument("--vision-encoder", help="Path to the ONNX vision encoder model.") |
| parser.add_argument( |
| "--encoder-provider", |
| help="Comma separated ONNX Runtime providers (e.g., 'CPUExecutionProvider')." |
| ) |
| parser.add_argument( |
| "--encoder-threads", |
| type=int, |
| help="Thread count hint for ONNX Runtime session." |
| ) |
| parser.add_argument( |
| "--encoder-input-shape", |
| help="Override encoder input spatial size as HxW or H,W (e.g., 392x392)." |
| ) |
| parser.add_argument( |
| "--norm", |
| choices=("imagenet", "divide_255", "divide_128_sub_1"), |
| default="imagenet", |
| help="Image normalization preset." |
| ) |
| parser.add_argument( |
| "--norm-mean", |
| type=float, |
| nargs=3, |
| metavar=("R", "G", "B"), |
| help="Override normalization mean (RGB order)." |
| ) |
| parser.add_argument( |
| "--norm-std", |
| type=float, |
| nargs=3, |
| metavar=("R", "G", "B"), |
| help="Override normalization std (RGB order)." |
| ) |
| parser.add_argument( |
| "--image-background", |
| type=int, |
| nargs=3, |
| metavar=("R", "G", "B"), |
| default=(128, 128, 128), |
| help="Background color used when padding image to target size." |
| ) |
| parser.add_argument("--img-start-token", help="Override image start token string passed to the model.") |
| parser.add_argument("--img-end-token", help="Override image end token string passed to the model.") |
| parser.add_argument("--img-content-token", help="Override image content token string passed to the model.") |
|
|
| |
| parser.add_argument( |
| "--mode", |
| choices=("generate", "hidden", "logits"), |
| default="generate", |
| help="Inference mode: generate tokens, return last hidden layer, or logits." |
| ) |
| parser.add_argument( |
| "--no-keep-history", |
| action="store_true", |
| help="Do not keep dialogue history on the device." |
| ) |
|
|
| |
| parser.add_argument( |
| "--stream", |
| action="store_true", |
| default=True, |
| help="Stream tokens to stdout as they arrive from the callback." |
| ) |
| parser.add_argument( |
| "--hide-stats", |
| action="store_true", |
| help="Suppress performance statistics after inference." |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.prompt and args.prompt_file: |
| parser.error("Arguments --prompt and --prompt-file cannot be used together.") |
|
|
| if args.input_type == "prompt": |
| if not args.prompt and not args.prompt_file: |
| parser.error("Provide --prompt or --prompt-file when --input-type is prompt.") |
| if args.token_ids: |
| parser.error("--token-ids is only valid when --input-type token.") |
| elif args.input_type == "token": |
| if not args.token_ids: |
| parser.error("--token-ids is required when --input-type token.") |
| if args.prompt or args.prompt_file: |
| parser.error("--prompt/--prompt-file cannot be combined with --input-type token.") |
| else: |
| if args.token_ids: |
| parser.error("--token-ids cannot be used with --input-type multimodal.") |
| if not args.prompt and not args.prompt_file: |
| parser.error("Provide --prompt or --prompt-file when --input-type is multimodal.") |
| if not args.image: |
| parser.error("--image is required when --input-type multimodal.") |
| if not args.vision_encoder: |
| parser.error("--vision-encoder is required when --input-type multimodal.") |
|
|
| if args.image_background: |
| for component in args.image_background: |
| if component < 0 or component > 255: |
| parser.error("--image-background values must be in the range [0, 255].") |
|
|
| return args |
|
|
|
|
| def _load_prompt_from_args(args: argparse.Namespace) -> str: |
| if args.prompt: |
| return args.prompt |
| if args.prompt_file: |
| try: |
| with open(args.prompt_file, "r", encoding="utf-8") as fp: |
| return fp.read() |
| except OSError as exc: |
| raise RuntimeError(f"Failed to read prompt file '{args.prompt_file}': {exc}") from exc |
| raise RuntimeError("Prompt text is required but not provided.") |
|
|
|
|
| def _mode_to_enum(mode: str) -> int: |
| mapping = { |
| "generate": RKLLMInferMode.RKLLM_INFER_GENERATE, |
| "hidden": RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER, |
| "logits": RKLLMInferMode.RKLLM_INFER_GET_LOGITS, |
| } |
| return mapping[mode] |
|
|
|
|
| def _parse_hw_string(value: str) -> Tuple[int, int]: |
| separators = ("x", "X", ",", " ") |
| token = value.strip() |
| for sep in separators: |
| if sep in token: |
| parts = [p for p in token.split(sep) if p] |
| break |
| else: |
| parts = [token] |
| if len(parts) != 2: |
| raise ValueError(f"Unable to parse height/width from '{value}'. Expected format like 392x392.") |
| try: |
| height = int(parts[0]) |
| width = int(parts[1]) |
| except ValueError as exc: |
| raise ValueError(f"Height/width must be integers, got '{value}'.") from exc |
| if height <= 0 or width <= 0: |
| raise ValueError("Height and width must be positive integers.") |
| return height, width |
|
|
|
|
| def _infer_hw_from_onnx_shape(shape: Sequence) -> Tuple[Optional[int], Optional[int]]: |
| if shape is None or len(shape) < 4: |
| return None, None |
| height = shape[-2] |
| width = shape[-1] |
| if isinstance(height, str) or height is None: |
| height = None |
| if isinstance(width, str) or width is None: |
| width = None |
| return height, width |
|
|
|
|
| def _parse_providers(provider_str: Optional[str]) -> Optional[list]: |
| if not provider_str: |
| return None |
| providers = [item.strip() for item in provider_str.split(",") if item.strip()] |
| return providers or None |
|
|
|
|
| def _load_vision_encoder_session(encoder_path: str, providers: Optional[list], threads: Optional[int]): |
| try: |
| import onnxruntime as ort |
| except ImportError as exc: |
| raise RuntimeError("onnxruntime is required for multimodal inference. Please install onnxruntime.") from exc |
|
|
| sess_options = ort.SessionOptions() |
| if threads and threads > 0: |
| sess_options.intra_op_num_threads = threads |
| try: |
| if providers: |
| session = ort.InferenceSession(encoder_path, sess_options=sess_options, providers=providers) |
| else: |
| session = ort.InferenceSession(encoder_path, sess_options=sess_options) |
| except Exception as exc: |
| raise RuntimeError(f"Failed to load vision encoder '{encoder_path}': {exc}") from exc |
| return session |
|
|
|
|
| def _letterbox_resize(image, target_hw: Tuple[int, int], background_color: Sequence[int]): |
| try: |
| import cv2 |
| import numpy as np |
| except ImportError as exc: |
| raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc |
|
|
| target_h, target_w = target_hw |
| if image.ndim != 3 or image.shape[2] != 3: |
| raise RuntimeError("Expected RGB image with 3 channels.") |
|
|
| src_h, src_w = image.shape[:2] |
| if src_h == 0 or src_w == 0: |
| raise RuntimeError("Loaded image has invalid dimensions.") |
|
|
| scale = min(target_w / src_w, target_h / src_h) |
| resized_w = max(1, int(round(src_w * scale))) |
| resized_h = max(1, int(round(src_h * scale))) |
| resized = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR) |
|
|
| canvas = np.full((target_h, target_w, 3), background_color, dtype=resized.dtype) |
| top = (target_h - resized_h) // 2 |
| left = (target_w - resized_w) // 2 |
| canvas[top:top + resized_h, left:left + resized_w] = resized |
| return canvas, resized_h, resized_w |
|
|
|
|
| def _normalize_image(image, method: str, mean: Optional[Sequence[float]], std: Optional[Sequence[float]]): |
| import numpy as np |
|
|
| img = image.astype(np.float32) |
| mean_arr = np.array(mean, dtype=np.float32) if mean else None |
| std_arr = np.array(std, dtype=np.float32) if std else None |
|
|
| if method == "imagenet": |
| img = img / 255.0 |
| if mean_arr is None: |
| mean_arr = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) |
| if std_arr is None: |
| std_arr = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) |
| img = (img - mean_arr) / std_arr |
| elif method == "divide_255": |
| img = img / 255.0 |
| if mean_arr is not None: |
| img = img - mean_arr |
| if std_arr is not None: |
| img = img / std_arr |
| elif method == "divide_128_sub_1": |
| img = img / 128.0 - 1.0 |
| if mean_arr is not None: |
| img = img - mean_arr |
| if std_arr is not None: |
| img = img / std_arr |
| else: |
| raise RuntimeError(f"Unsupported normalization method '{method}'.") |
|
|
| return img |
|
|
|
|
| def _encode_image_to_embedding( |
| session, |
| image_path: str, |
| input_name: str, |
| output_name: str, |
| target_hw: Tuple[int, int], |
| background_color: Sequence[int], |
| norm_method: str, |
| norm_mean: Optional[Sequence[float]], |
| norm_std: Optional[Sequence[float]] |
| ): |
| try: |
| import cv2 |
| import numpy as np |
| except ImportError as exc: |
| raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc |
|
|
| image = cv2.imread(image_path, cv2.IMREAD_COLOR) |
| if image is None: |
| raise RuntimeError(f"Failed to read image from '{image_path}'.") |
|
|
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| padded, resized_h, resized_w = _letterbox_resize(image, target_hw, background_color) |
|
|
| normalized = _normalize_image(padded, norm_method, norm_mean, norm_std) |
| tensor = np.transpose(normalized, (2, 0, 1)) |
| tensor = np.expand_dims(tensor, axis=0) |
| tensor = np.ascontiguousarray(tensor, dtype=np.float32) |
|
|
| try: |
| output_list = session.run([output_name], {input_name: tensor}) |
| except Exception as exc: |
| raise RuntimeError(f"Vision encoder inference failed: {exc}") from exc |
|
|
| if not output_list: |
| raise RuntimeError("Vision encoder returned no outputs.") |
|
|
| embedding = output_list[0] |
| if embedding.ndim == 3: |
| if embedding.shape[0] != 1: |
| raise RuntimeError("Vision encoder output batch dimension must be 1 for a single image.") |
| n_tokens = embedding.shape[1] |
| elif embedding.ndim == 2: |
| n_tokens = embedding.shape[0] |
| else: |
| raise RuntimeError(f"Unsupported vision encoder output shape {embedding.shape}.") |
|
|
| flat_embedding = embedding.reshape(-1).astype(np.float32, copy=False) |
| flat_embedding = np.ascontiguousarray(flat_embedding) |
|
|
| return flat_embedding, n_tokens, target_hw |
|
|
| if __name__ == "__main__": |
| import os |
| os.environ["RKLLM_LOG_LEVEL"] = "1" |
| args = _cli_parse_arguments() |
|
|
| prompt_text = None |
| if args.input_type == "prompt": |
| prompt_text = _load_prompt_from_args(args) |
|
|
| token_id_array = None |
| token_input_struct = None |
|
|
| generated_chunks = [] |
| perf_snapshot = { |
| "prefill_tokens": 0, |
| "prefill_time_ms": 0.0, |
| "generate_tokens": 0, |
| "generate_time_ms": 0.0, |
| "memory_usage_mb": 0.0, |
| } |
|
|
| def demo_callback(result_ptr, userdata_ptr, state_enum): |
| state = LLMCallState(state_enum) |
| result = result_ptr.contents |
|
|
| current_text = "" |
| if result.text: |
| current_text = result.text.decode("utf-8", errors="ignore") |
| generated_chunks.append(current_text) |
| if args.stream and current_text: |
| print(current_text, end="", flush=True) |
|
|
| perf_snapshot.update( |
| prefill_tokens=result.perf.prefill_tokens, |
| prefill_time_ms=result.perf.prefill_time_ms, |
| generate_tokens=result.perf.generate_tokens, |
| generate_time_ms=result.perf.generate_time_ms, |
| memory_usage_mb=result.perf.memory_usage_mb, |
| ) |
|
|
| if state == LLMCallState.RKLLM_RUN_ERROR: |
| print("\n[Callback] 推理过程中出现错误。") |
|
|
| return 0 |
|
|
| try: |
| with RKLLMRuntime(library_path=args.lib) as rk_llm: |
| params = rk_llm.create_default_param() |
| params.model_path = os.path.abspath(args.model).encode("utf-8") |
| params.max_context_len = args.max_context_len |
| params.max_new_tokens = args.max_new_tokens |
| params.top_k = args.top_k |
| params.top_p = float(args.top_p) |
| params.temperature = float(args.temperature) |
| params.repeat_penalty = float(args.repeat_penalty) |
| params.n_keep = args.n_keep |
| params.mirostat = args.mirostat |
| params.mirostat_tau = float(args.mirostat_tau) |
| params.mirostat_eta = float(args.mirostat_eta) |
| params.skip_special_token = bool(args.skip_special_token) |
| params.is_async = False |
|
|
| rk_llm.init(params, demo_callback) |
|
|
| rk_input = RKLLMInput() |
| rk_input.role = args.role.encode("utf-8") |
| rk_input.enable_thinking = bool(args.enable_thinking) |
|
|
| if args.input_type == "prompt": |
| rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
| rk_input._union_data.prompt_input = prompt_text.encode("utf-8") |
| else: |
| rk_input.input_type = RKLLMInputType.RKLLM_INPUT_TOKEN |
| token_id_array = (ctypes.c_int32 * len(args.token_ids))(*args.token_ids) |
| token_input_struct = RKLLMTokenInput() |
| token_input_struct.input_ids = token_id_array |
| token_input_struct.n_tokens = len(args.token_ids) |
| rk_input._union_data.token_input = token_input_struct |
|
|
| infer_params = RKLLMInferParam() |
| infer_params.mode = _mode_to_enum(args.mode) |
| infer_params.keep_history = 0 if args.no_keep_history else 1 |
| infer_params.lora_params = None |
| infer_params.prompt_cache_params = None |
|
|
| if args.stream: |
| print("=== Streaming Output ===") |
|
|
| rk_llm.run(rk_input, infer_params) |
|
|
| except OSError as exc: |
| print(f"无法加载 RKLLM 运行时库:{exc}") |
| except RuntimeError as exc: |
| print(f"推理失败:{exc}") |
| except Exception as exc: |
| print(f"发生未预期的错误:{exc}") |
| else: |
| if args.stream: |
| print() |
|
|
| final_text = "".join(generated_chunks) |
| if final_text: |
| print("=== 生成结果 ===") |
| print(final_text) |
| else: |
| print("未收到生成文本。") |
|
|
| if not args.hide_stats: |
| print("=== 性能统计 ===") |
| print( |
| f"预填充: {perf_snapshot['prefill_tokens']} tokens / {perf_snapshot['prefill_time_ms']:.2f} ms" |
| ) |
| print( |
| f"生成: {perf_snapshot['generate_tokens']} tokens / {perf_snapshot['generate_time_ms']:.2f} ms" |
| ) |
| print(f"最大常驻内存: {perf_snapshot['memory_usage_mb']:.2f} MB") |
|
|