happyme531's picture
Upload 7 files
b2c3325 verified
import os
os.environ["RKLLM_LOG_LEVEL"] = "1"
from rkllm_binding import *
def my_python_callback(result_ptr, userdata_ptr, state_enum):
"""
回调函数,用于处理LLM的输出结果。
这个函数会以流式的方式逐字打印模型的响应。
"""
state = LLMCallState(state_enum)
result = result_ptr.contents
if result.text:
current_text = result.text.decode('utf-8', errors='ignore')
print(current_text, end='', flush=True)
if state == LLMCallState.RKLLM_RUN_FINISH:
# 在响应结束后打印一个换行符,保持格式整洁
print()
elif state == LLMCallState.RKLLM_RUN_ERROR:
print("\n推理过程中发生错误。")
# 返回0继续推理,返回1暂停推理
return 0
# --- Attempt to use the wrapper ---
try:
print("Initializing RKLLMRuntime...")
# Adjust library_path if librkllmrt.so is not in default search paths
# e.g., library_path="./path/to/librkllmrt.so"
rk_llm = RKLLMRuntime()
print("Creating default parameters...")
params = rk_llm.create_default_param()
# --- Configure parameters ---
model_file = "language_model.rkllm"
if not os.path.exists(model_file):
raise FileNotFoundError(f"Model file '{model_file}' does not exist.")
params.model_path = model_file.encode('utf-8')
params.max_context_len = 4096
params.max_new_tokens = 1024
# params.top_k = 1 # Greedy
params.temperature = 0.7
params.repeat_penalty = 1.1
# ... set other params as needed
print(f"Initializing LLM with model: {params.model_path.decode()}...")
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
try:
rk_llm.init(params, my_python_callback)
print("LLM Initialized.")
except RuntimeError as e:
print(f"Error during LLM initialization: {e}")
exit()
# --- 进入交互式对话循环 ---
print("\n进入多轮对话模式。输入 'exit' 或 'quit' 退出。")
# 准备推理参数 (这些参数在对话中保持不变)
infer_params = RKLLMInferParam()
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
infer_params.keep_history = 1 # 保持对话历史
while True:
try:
prompt_text = input("You: ")
if prompt_text.lower() in ["exit", "quit"]:
break
print("Assistant: ", end='', flush=True)
# 准备输入
rk_input = RKLLMInput()
rk_input.role = b"user"
rk_input.enable_thinking = False
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
c_prompt = prompt_text.encode('utf-8')
rk_input._union_data.prompt_input = c_prompt
# 运行推理
rk_llm.run(rk_input, infer_params)
except KeyboardInterrupt:
print("\n\n对话中断。")
break
except RuntimeError as e:
print(f"\n运行时发生错误: {e}")
break
except OSError as e:
print(f"OSError: {e}. Could not load the RKLLM library.")
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
finally:
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
print("Destroying LLM instance...")
rk_llm.destroy()
print("LLM instance destroyed.")
print("Example finished.")