|
|
import os |
|
|
os.environ["RKLLM_LOG_LEVEL"] = "1" |
|
|
from rkllm_binding import * |
|
|
|
|
|
def my_python_callback(result_ptr, userdata_ptr, state_enum): |
|
|
""" |
|
|
回调函数,用于处理LLM的输出结果。 |
|
|
这个函数会以流式的方式逐字打印模型的响应。 |
|
|
""" |
|
|
state = LLMCallState(state_enum) |
|
|
result = result_ptr.contents |
|
|
|
|
|
if result.text: |
|
|
current_text = result.text.decode('utf-8', errors='ignore') |
|
|
print(current_text, end='', flush=True) |
|
|
|
|
|
if state == LLMCallState.RKLLM_RUN_FINISH: |
|
|
|
|
|
print() |
|
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
|
print("\n推理过程中发生错误。") |
|
|
|
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
try: |
|
|
print("Initializing RKLLMRuntime...") |
|
|
|
|
|
|
|
|
rk_llm = RKLLMRuntime() |
|
|
|
|
|
print("Creating default parameters...") |
|
|
params = rk_llm.create_default_param() |
|
|
|
|
|
|
|
|
model_file = "language_model.rkllm" |
|
|
if not os.path.exists(model_file): |
|
|
raise FileNotFoundError(f"Model file '{model_file}' does not exist.") |
|
|
|
|
|
params.model_path = model_file.encode('utf-8') |
|
|
params.max_context_len = 4096 |
|
|
params.max_new_tokens = 1024 |
|
|
|
|
|
params.temperature = 0.7 |
|
|
params.repeat_penalty = 1.1 |
|
|
|
|
|
|
|
|
print(f"Initializing LLM with model: {params.model_path.decode()}...") |
|
|
|
|
|
try: |
|
|
rk_llm.init(params, my_python_callback) |
|
|
print("LLM Initialized.") |
|
|
except RuntimeError as e: |
|
|
print(f"Error during LLM initialization: {e}") |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
print("\n进入多轮对话模式。输入 'exit' 或 'quit' 退出。") |
|
|
|
|
|
|
|
|
infer_params = RKLLMInferParam() |
|
|
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE |
|
|
infer_params.keep_history = 1 |
|
|
|
|
|
while True: |
|
|
try: |
|
|
prompt_text = input("You: ") |
|
|
if prompt_text.lower() in ["exit", "quit"]: |
|
|
break |
|
|
|
|
|
print("Assistant: ", end='', flush=True) |
|
|
|
|
|
|
|
|
rk_input = RKLLMInput() |
|
|
rk_input.role = b"user" |
|
|
rk_input.enable_thinking = False |
|
|
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
|
|
|
|
|
c_prompt = prompt_text.encode('utf-8') |
|
|
rk_input._union_data.prompt_input = c_prompt |
|
|
|
|
|
|
|
|
rk_llm.run(rk_input, infer_params) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\n对话中断。") |
|
|
break |
|
|
except RuntimeError as e: |
|
|
print(f"\n运行时发生错误: {e}") |
|
|
break |
|
|
|
|
|
|
|
|
except OSError as e: |
|
|
print(f"OSError: {e}. Could not load the RKLLM library.") |
|
|
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.") |
|
|
except Exception as e: |
|
|
print(f"An unexpected error occurred: {e}") |
|
|
finally: |
|
|
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value: |
|
|
print("Destroying LLM instance...") |
|
|
rk_llm.destroy() |
|
|
print("LLM instance destroyed.") |
|
|
|
|
|
print("Example finished.") |
|
|
|