File size: 9,319 Bytes
7fc4eb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import faulthandler
faulthandler.enable()
import sys
import os
os.environ["RKLLM_LOG_LEVEL"] = "1"
import ctypes
import argparse
import cv2
import numpy as np
import ztu_somemodelruntime_rknnlite2 as ort
from rkllm_binding import (
    RKLLMRuntime,
    RKLLMParam,
    RKLLMInput,
    RKLLMInferParam,
    LLMCallState,
    RKLLMInputType,
    RKLLMInferMode,
    RKLLMResult
)

# Constants aligned with InternVL config
IMAGE_HEIGHT = 448
IMAGE_WIDTH = 448
IMAGE_SEQ_LENGTH = 256
MULTIMODAL_HIDDEN_DIM = 2048
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def expand2square(img, background_color):
    """
    Expand the image into a square and fill it with the specified background color.
    """
    height, width, _ = img.shape
    if width == height:
        return img.copy()

    size = max(width, height)
    square_img = np.full((size, size, 3), background_color, dtype=np.uint8)

    x_offset = (size - width) // 2
    y_offset = (size - height) // 2

    square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img
    return square_img

def llm_callback(result_ptr, userdata_ptr, state_enum):
    """
    Callback function to handle LLM results.
    """
    state = LLMCallState(state_enum)
    result = result_ptr.contents

    if state == LLMCallState.RKLLM_RUN_NORMAL:
        if result.text:
            print(result.text.decode('utf-8', errors='ignore'), end='', flush=True)
    elif state == LLMCallState.RKLLM_RUN_FINISH:
        print("\n", flush=True)
    elif state == LLMCallState.RKLLM_RUN_ERROR:
        print("\nrun error", flush=True)
    
    return 0

def main():
    parser = argparse.ArgumentParser(
        description="Run RKLLM visual language model inference based on the C++ example."
    )
    parser.add_argument("image_path", type=str, help="Path to the input image.")
    parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.")
    parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.")
    parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.")
    parser.add_argument("max_context_len", type=int, help="Maximum context length.")
    # The rknn_core_num is not directly used by onnxruntime in the same way,
    # but we keep it for API consistency with the C++ example.
    # ONNX Runtime will manage its own threading and execution providers.
    parser.add_argument("rknn_core_num", type=int, help="Sets the number of npu cores used in vision encoder.")

    args = parser.parse_args()

    # --- 1. Initialize Image Encoder (ONNX Runtime) ---
    print("Initializing ONNX Runtime for vision encoder...")
    try:
        sess_options = ort.SessionOptions()
        sess_options.intra_op_num_threads = args.rknn_core_num
        ort_session = ort.InferenceSession(args.encoder_model_path, sess_options=sess_options)
    except Exception as e:
        print(f"Failed to load ONNX model: {e}")
        sys.exit(1)
    print("Vision encoder loaded successfully.")
    
    input_name = ort_session.get_inputs()[0].name
    output_name = ort_session.get_outputs()[0].name
    print(f"ONNX Input: {input_name}, ONNX Output: {output_name}")

    # --- 2. Initialize LLM ---
    print("Initializing RKLLM Runtime...")
    rk_llm = RKLLMRuntime()
    param = rk_llm.create_default_param()

    param.model_path = args.llm_model_path.encode('utf-8')
    param.top_k = 1
    param.max_new_tokens = args.max_new_tokens
    param.max_context_len = args.max_context_len
    param.skip_special_token = True
    param.img_start = b"<img>"
    param.img_end = b"</img>\n"
    param.img_content = b""
    param.extend_param.base_domain_id = 1

    try:
        rk_llm.init(param, llm_callback)
        print("RKLLM initialized successfully.")
    except RuntimeError as e:
        print(f"RKLLM init failed: {e}")
        sys.exit(1)

    # --- 3. Image Preprocessing ---
    print("Preprocessing image...")
    img = cv2.imread(args.image_path)
    if img is None:
        print(f"Failed to read image from {args.image_path}")
        sys.exit(1)
        
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    background_color = (127.5, 127.5, 127.5) # Keep close to official preprocessing
    square_img = expand2square(img, background_color)
    resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR)
    
    # Normalize and prepare for ONNX model
    input_tensor = resized_img.astype(np.float32)
    # Normalize using InternVL vision config statistics
    input_tensor = (input_tensor / 255.0 - IMAGENET_MEAN) / IMAGENET_STD
    # Convert to NCHW format
    input_tensor = np.transpose(input_tensor, (2, 0, 1))  # HWC -> CHW
    input_tensor = np.expand_dims(input_tensor, axis=0)  # Add batch dimension -> (1, 3, 448, 448)

    # --- 4. Run Image Encoder ---
    print("Running vision encoder...")
    import time
    start_time = time.time()
    try:
        img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0]
        if img_vec_output.ndim != 3:
            raise RuntimeError(f"Unexpected encoder output shape {img_vec_output.shape}, expected (batch, tokens, hidden)")
        if img_vec_output.shape[-1] != MULTIMODAL_HIDDEN_DIM:
            print(f"Warning: hidden dim {img_vec_output.shape[-1]} differs from expected {MULTIMODAL_HIDDEN_DIM}")
        if img_vec_output.shape[1] != IMAGE_SEQ_LENGTH:
            print(f"Warning: token count {img_vec_output.shape[1]} differs from expected {IMAGE_SEQ_LENGTH}")
        elapsed_time = time.time() - start_time
        print(f"视觉编码器推理耗时: {elapsed_time:.4f} 秒")
        # The output from C++ is a flat float array. Let's flatten the ONNX output.
        img_vec = img_vec_output.flatten().astype(np.float32)

    except Exception as e:
        print(f"Failed to run vision encoder inference: {e}")
        rk_llm.destroy()
        sys.exit(1)
    
    print("Image encoded successfully.")

    # --- 5. Interactive Chat Loop ---
    rkllm_infer_params = RKLLMInferParam()
    rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
    rkllm_infer_params.keep_history = 1

    # Set chat template
    # Looks the default template parsed by RKLLM gives better result than this one, don't know why.

    # rk_llm.set_chat_template(
    #     system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
    #     prompt_prefix="<|im_start|>user\n",
    #     prompt_postfix="<|im_end|>\n<|im_start|>assistant\n"
    # )

    pre_input = [
        "<image>What is in the image?",
        "<image>这张图片中有什么?"
    ]
    print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n")
    for i, p in enumerate(pre_input):
        print(f"[{i}] {p}")
    print("\n*************************************************************************\n")

    try:
        while True:
            print("\nuser: ", end="", flush=True)
            input_str = sys.stdin.readline().strip()

            if not input_str:
                continue
            if input_str == "exit":
                break
            if input_str == "clear":
                try:
                    rk_llm.clear_kv_cache(keep_system_prompt=True)
                    print("KV cache cleared.")
                except RuntimeError as e:
                    print(f"Failed to clear KV cache: {e}")
                continue

            try:
                idx = int(input_str)
                if 0 <= idx < len(pre_input):
                    input_str = pre_input[idx]
                    print(input_str)
            except (ValueError, IndexError):
                pass # Use the raw string if not a valid index

            rkllm_input = RKLLMInput()
            rkllm_input.role = b"user"
            
            print("robot: ", end="", flush=True)

            if "<image>" in input_str:
                rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
                
                # Setup multimodal input
                rkllm_input.multimodal_input.prompt = input_str.encode('utf-8')
                rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[1]
                print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens)
                rkllm_input.multimodal_input.n_image = 1
                rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT
                rkllm_input.multimodal_input.image_width = IMAGE_WIDTH
            else:
                rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
                rkllm_input.prompt_input = input_str.encode('utf-8')

            try:
                rk_llm.run(rkllm_input, rkllm_infer_params)
            except RuntimeError as e:
                print(f"\nError during rkllm_run: {e}")

    except KeyboardInterrupt:
        print("\nExiting...")
    finally:
        print("Releasing resources...")
        rk_llm.destroy()
        print("RKLLM instance destroyed.")

if __name__ == "__main__":
    main()