import faulthandler faulthandler.enable() import sys import os os.environ["RKLLM_LOG_LEVEL"] = "1" import ctypes import argparse import cv2 import numpy as np import ztu_somemodelruntime_rknnlite2 as ort from rkllm_binding import ( RKLLMRuntime, RKLLMParam, RKLLMInput, RKLLMInferParam, LLMCallState, RKLLMInputType, RKLLMInferMode, RKLLMResult ) # Constants aligned with InternVL config IMAGE_HEIGHT = 448 IMAGE_WIDTH = 448 IMAGE_SEQ_LENGTH = 256 MULTIMODAL_HIDDEN_DIM = 2048 IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) def expand2square(img, background_color): """ Expand the image into a square and fill it with the specified background color. """ height, width, _ = img.shape if width == height: return img.copy() size = max(width, height) square_img = np.full((size, size, 3), background_color, dtype=np.uint8) x_offset = (size - width) // 2 y_offset = (size - height) // 2 square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img return square_img def llm_callback(result_ptr, userdata_ptr, state_enum): """ Callback function to handle LLM results. """ state = LLMCallState(state_enum) result = result_ptr.contents if state == LLMCallState.RKLLM_RUN_NORMAL: if result.text: print(result.text.decode('utf-8', errors='ignore'), end='', flush=True) elif state == LLMCallState.RKLLM_RUN_FINISH: print("\n", flush=True) elif state == LLMCallState.RKLLM_RUN_ERROR: print("\nrun error", flush=True) return 0 def main(): parser = argparse.ArgumentParser( description="Run RKLLM visual language model inference based on the C++ example." ) parser.add_argument("image_path", type=str, help="Path to the input image.") parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.") parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.") parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.") parser.add_argument("max_context_len", type=int, help="Maximum context length.") # The rknn_core_num is not directly used by onnxruntime in the same way, # but we keep it for API consistency with the C++ example. # ONNX Runtime will manage its own threading and execution providers. parser.add_argument("rknn_core_num", type=int, help="Sets the number of npu cores used in vision encoder.") args = parser.parse_args() # --- 1. Initialize Image Encoder (ONNX Runtime) --- print("Initializing ONNX Runtime for vision encoder...") try: sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = args.rknn_core_num ort_session = ort.InferenceSession(args.encoder_model_path, sess_options=sess_options) except Exception as e: print(f"Failed to load ONNX model: {e}") sys.exit(1) print("Vision encoder loaded successfully.") input_name = ort_session.get_inputs()[0].name output_name = ort_session.get_outputs()[0].name print(f"ONNX Input: {input_name}, ONNX Output: {output_name}") # --- 2. Initialize LLM --- print("Initializing RKLLM Runtime...") rk_llm = RKLLMRuntime() param = rk_llm.create_default_param() param.model_path = args.llm_model_path.encode('utf-8') param.top_k = 1 param.max_new_tokens = args.max_new_tokens param.max_context_len = args.max_context_len param.skip_special_token = True param.img_start = b"" param.img_end = b"\n" param.img_content = b"" param.extend_param.base_domain_id = 1 try: rk_llm.init(param, llm_callback) print("RKLLM initialized successfully.") except RuntimeError as e: print(f"RKLLM init failed: {e}") sys.exit(1) # --- 3. Image Preprocessing --- print("Preprocessing image...") img = cv2.imread(args.image_path) if img is None: print(f"Failed to read image from {args.image_path}") sys.exit(1) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) background_color = (127.5, 127.5, 127.5) # Keep close to official preprocessing square_img = expand2square(img, background_color) resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR) # Normalize and prepare for ONNX model input_tensor = resized_img.astype(np.float32) # Normalize using InternVL vision config statistics input_tensor = (input_tensor / 255.0 - IMAGENET_MEAN) / IMAGENET_STD # Convert to NCHW format input_tensor = np.transpose(input_tensor, (2, 0, 1)) # HWC -> CHW input_tensor = np.expand_dims(input_tensor, axis=0) # Add batch dimension -> (1, 3, 448, 448) # --- 4. Run Image Encoder --- print("Running vision encoder...") import time start_time = time.time() try: img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0] if img_vec_output.ndim != 3: raise RuntimeError(f"Unexpected encoder output shape {img_vec_output.shape}, expected (batch, tokens, hidden)") if img_vec_output.shape[-1] != MULTIMODAL_HIDDEN_DIM: print(f"Warning: hidden dim {img_vec_output.shape[-1]} differs from expected {MULTIMODAL_HIDDEN_DIM}") if img_vec_output.shape[1] != IMAGE_SEQ_LENGTH: print(f"Warning: token count {img_vec_output.shape[1]} differs from expected {IMAGE_SEQ_LENGTH}") elapsed_time = time.time() - start_time print(f"视觉编码器推理耗时: {elapsed_time:.4f} 秒") # The output from C++ is a flat float array. Let's flatten the ONNX output. img_vec = img_vec_output.flatten().astype(np.float32) except Exception as e: print(f"Failed to run vision encoder inference: {e}") rk_llm.destroy() sys.exit(1) print("Image encoded successfully.") # --- 5. Interactive Chat Loop --- rkllm_infer_params = RKLLMInferParam() rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE rkllm_infer_params.keep_history = 1 # Set chat template # Looks the default template parsed by RKLLM gives better result than this one, don't know why. # rk_llm.set_chat_template( # system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", # prompt_prefix="<|im_start|>user\n", # prompt_postfix="<|im_end|>\n<|im_start|>assistant\n" # ) pre_input = [ "What is in the image?", "这张图片中有什么?" ] print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n") for i, p in enumerate(pre_input): print(f"[{i}] {p}") print("\n*************************************************************************\n") try: while True: print("\nuser: ", end="", flush=True) input_str = sys.stdin.readline().strip() if not input_str: continue if input_str == "exit": break if input_str == "clear": try: rk_llm.clear_kv_cache(keep_system_prompt=True) print("KV cache cleared.") except RuntimeError as e: print(f"Failed to clear KV cache: {e}") continue try: idx = int(input_str) if 0 <= idx < len(pre_input): input_str = pre_input[idx] print(input_str) except (ValueError, IndexError): pass # Use the raw string if not a valid index rkllm_input = RKLLMInput() rkllm_input.role = b"user" print("robot: ", end="", flush=True) if "" in input_str: rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL # Setup multimodal input rkllm_input.multimodal_input.prompt = input_str.encode('utf-8') rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[1] print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens) rkllm_input.multimodal_input.n_image = 1 rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT rkllm_input.multimodal_input.image_width = IMAGE_WIDTH else: rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT rkllm_input.prompt_input = input_str.encode('utf-8') try: rk_llm.run(rkllm_input, rkllm_infer_params) except RuntimeError as e: print(f"\nError during rkllm_run: {e}") except KeyboardInterrupt: print("\nExiting...") finally: print("Releasing resources...") rk_llm.destroy() print("RKLLM instance destroyed.") if __name__ == "__main__": main()