File size: 12,585 Bytes
f5d111e
 
 
 
 
 
 
f4851b2
f5d111e
 
 
 
 
 
 
 
 
 
 
 
 
f4851b2
 
 
f5d111e
f4851b2
 
 
f5d111e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import faulthandler
faulthandler.enable()

import os 
os.environ["RKLLM_LOG_LEVEL"] = "1"
import numpy as np
import onnxruntime as real_ort
import ztu_somemodelruntime_rknnlite2 as ort
from tokenizers import Tokenizer
import cv2
import tqdm
import time
import ctypes

from rkllm_binding import *

model_path = "."
onnx_model_path = f"{model_path}"
tokenizer = Tokenizer.from_file(f"{model_path}/tokenizer.json")
# np.random.seed(0)

image = None
prompt = "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair"
mode = "t2i" # 文本生成图片, 反过来是it2t -> 图片/文本生成文本

# image = "./test.jpg"
# prompt = "仔细描述这张图片。"
# mode = "it2t"

tempature = 0.7

# 全局变量用于存储 rkllm 推理结果
rkllm_result_data = {
    'hidden_states': None,
    'finished': False,
    'error': False
}

def rkllm_callback(result_ptr, userdata_ptr, state_enum):
    """RKLLM 推理回调函数"""
    global rkllm_result_data
    
    try:
        state = LLMCallState(state_enum)
        # print(f"回调状态: {state.name}")

        if state == LLMCallState.RKLLM_RUN_FINISH:
            rkllm_result_data['finished'] = True
            print("RKLLM 推理完成") 
            return
        elif state == LLMCallState.RKLLM_RUN_ERROR:
            rkllm_result_data['error'] = True
            rkllm_result_data['error_msg'] = "RKLLM 推理出错"
            rkllm_result_data['finished'] = True
            print("错误: RKLLM 推理出错")

        # 检查 result_ptr 是否为空
        if not result_ptr:
            print("警告: result_ptr 为空指针")
            return
        
        result = result_ptr.contents
        # print(result.perf)
        if state == LLMCallState.RKLLM_RUN_NORMAL:
            # 获取 last hidden layer 结果
            if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
                # 将 C 数组转换为 numpy 数组
                hidden_size = result.last_hidden_layer.embd_size
                num_tokens = result.last_hidden_layer.num_tokens
                
                # print(f"Hidden layer info: num_tokens={num_tokens}, embd_size={hidden_size}")
                
                # 创建 numpy 数组从 C 指针
                hidden_array = np.ctypeslib.as_array(
                    result.last_hidden_layer.hidden_states, 
                    shape=(num_tokens, hidden_size)
                ).copy()  # 复制数据以避免内存问题
                
                rkllm_result_data['hidden_states'] = hidden_array
                # print(f"成功获取 hidden states,形状: {hidden_array.shape}")
                rkllm_result_data['finished'] = True
                return 1
            else:
                print("警告: 没有获取到有效的 hidden states")

        return 1    
    except Exception as e:
        print(f"回调函数异常: {e}")
        rkllm_result_data['error'] = True
        rkllm_result_data['error_msg'] = str(e)
        rkllm_result_data['finished'] = True

# 1. 加载模型

# 视觉编码器
# <- pixel_values: float32[batch_size,num_images,3,384,384]
# -> inputs_embeds: float32[batch_size*num_images,576,2048]
vision_encoder = ort.InferenceSession(f"{onnx_model_path}/vision_encoder.rknn")

# 初始化 RKLLM 语言模型
print("初始化 RKLLM 语言模型...")
rkllm_runtime = RKLLMRuntime()
rkllm_params = rkllm_runtime.create_default_param()
rkllm_params.model_path = f"{model_path}/language_model.rkllm".encode('utf-8')
rkllm_params.max_context_len = 1024
rkllm_params.max_new_tokens = 5
# rkllm_params.temperature = tempature
rkllm_params.skip_special_token = 0
rkllm_params.extend_param.base_domain_id = 1
rkllm_runtime.init(rkllm_params, rkllm_callback)

# LM Head
# <- hidden_states: float32[batch_size,sequence_length,2048]
# -> logits: float32[batch_size,sequence_length,102400]
lm_head = ort.InferenceSession(f"{onnx_model_path}/lm_head.onnx")
# 图片生成Head
# <- hidden_states: float32[batch_size,sequence_length,2048]
# -> logits: float32[batch_size,sequence_length,16384]
gen_head = ort.InferenceSession(f"{onnx_model_path}/gen_head.onnx")
# 图片生成Embedding
# <- image_ids: int64[batch_size,sequence_length]
# -> inputs_embeds: float32[batch_size,sequence_length,2048]
gen_img_embeds = ort.InferenceSession(f"{onnx_model_path}/gen_img_embeds.onnx")
# 文本Embedding
# <- input_ids: int64[batch_size,sequence_length]
# -> inputs_embeds: float32[batch_size,sequence_length,2048]
text_embeds = real_ort.InferenceSession(f"{onnx_model_path}/embed_tokens.onnx")
# VQVAE 解码器 (576个token变成一个384x384的图片)
# <- generated_tokens: int64[batch_size,sequence_length]
# -> decoded_image: float32[batch_size,3,384,384]
image_decode = ort.InferenceSession(f"{onnx_model_path}/image_decode.onnx")

# 2. 预处理输入
# tokenizer会在最开始加<|begin▁of▁sentence|>, 这里不要加!
if mode == "t2i":
    input_str = f"""<|User|>: {prompt}

<|Assistant|>:<begin_of_image>"""
else:
    input_str = f"""You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.

<|User|>: <image_placeholder>
{prompt}

<|Assistant|>:"""

# 3. 生成Embedding

# 把<image_placeholder>替换为576个<image_placeholder>
input_str = input_str.replace("<image_placeholder>", "<image_placeholder>" * 576)
input = tokenizer.encode(input_str)
input_ids = np.array([input.ids], dtype=np.int64)
input_len = len(input.ids)
attention_mask = np.array([input.attention_mask], dtype=np.int64)
images_seq_mask = np.array([[1 if id == 100581 else 0 for id in input.ids]], dtype=np.bool_)  # 为什么<image_placeholder>有两个id?
position_ids = np.expand_dims(np.arange(input_len), axis=0)
#图片预处理
if image:
    img = cv2.imread(image)
    if img is None:
        raise ValueError(f"无法读取图片: {image}")
    # 将BGR转换为RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # 调整尺寸为目标大小:384x384
    target_size = 384
    img = cv2.resize(img, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
    # 转换数据类型为float32,并使用rescale_factor归一化像素值到[0,1]
    img = img.astype(np.float32) * 0.00392156862745098  # 0.00392156 = 1/255
    # 根据配置文件归一化: (img - image_mean) / image_std,其中 image_mean = [0.5, 0.5, 0.5], image_std = [0.5, 0.5, 0.5]
    img = (img - np.array([0.5, 0.5, 0.5], dtype=np.float32)) / np.array([0.5, 0.5, 0.5], dtype=np.float32)
    # 如果图像尺寸不是正方形,还可以用background_color填充,这里由于直接resize为正方形,故忽略这部分
    # 转换图像维度为 [batch_size, num_images, channels, height, width]
    # 先将HWC格式转为CHW
    img = img.transpose(2, 0, 1)  # 得到 [3, 384, 384]
    pixel_values = np.expand_dims(np.expand_dims(img, axis=0), axis=1)  # [1, 1, 3, 384, 384]
    images_emb_mask = np.ones((1, 1, 576), dtype=np.bool_)
else:
    pixel_values = np.zeros((0, 0, 3, 384, 384), dtype=np.float32)
    images_emb_mask = np.zeros((1, 0, 576), dtype=np.bool_)

# 手动处理输入embeddings
# 1. 先获取文本embeddings
text_inputs_embeds = text_embeds.run(None, {"input_ids": input_ids})[0]  # [1, input_len, 2048]

# 2. 如果有图片,获取视觉embeddings并插入到文本中
if image:
    # 运行视觉编码器
    vision_embeds = vision_encoder.run(None, {"pixel_values": pixel_values})[0]  # [1, 576, 2048]
    
    # 找到所有 <image_placeholder> token 的位置(images_seq_mask中为True的位置)
    image_token_positions = np.where(images_seq_mask[0])[0]  # 获取所有为True的索引
    
    # 将视觉embeddings插入到对应位置
    # 遍历每个图片token位置,替换为对应的视觉embedding
    for idx, pos in enumerate(image_token_positions):
        if idx < vision_embeds.shape[1]:  # 确保不超出vision_embeds的范围
            text_inputs_embeds[0, pos, :] = vision_embeds[0, idx, :]

inputs_embeds = text_inputs_embeds

# 4. 语言模型推理(使用 RKLLM)
# 用于保存生成的图片 token(这里仅保留条件分支生成的 token)
generated_tokens = []

# 初始化可复用的对象
rkllm_input = RKLLMInput()
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED
embed_input = RKLLMEmbedInput()
infer_params = RKLLMInferParam()
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER
infer_params.keep_history = 1

def run_rkllm_inference(inputs_embeds):
    """使用 RKLLM 进行推理,输入 embedding,输出 hidden states"""
    global rkllm_result_data
    
    # 重置结果
    rkllm_result_data = {
        'hidden_states': None,
        'finished': False,
        'error': False
    }
    
    # 更新embedding输入数据
    embed_flat = inputs_embeds.flatten().astype(np.float32)
    embed_c_array = (ctypes.c_float * len(embed_flat))(*embed_flat)
    embed_input.embed = embed_c_array
    embed_input.n_tokens = inputs_embeds.shape[1]  # sequence length
    
    rkllm_input._union_data.embed_input = embed_input
    
    # 运行推理
    rkllm_runtime.run(rkllm_input, infer_params)
    
    # 等待结果
    while not rkllm_result_data['finished']:
        time.sleep(0.001)  # 短暂等待
    
    if rkllm_result_data['error']:
        raise RuntimeError("RKLLM 推理出错")
    
    return rkllm_result_data['hidden_states']

# 循环生成576个图片 token
with tqdm.tqdm(range(576)) as pbar:
    for i in pbar:
        # 使用 RKLLM 进行推理
        hidden_states = run_rkllm_inference(inputs_embeds)
        
        if hidden_states is None:
            raise RuntimeError("RKLLM 未返回有效的 hidden states")
        
        # 重新整形为期望的格式 [batch_size, sequence_length, hidden_size]
        if len(hidden_states.shape) == 2:
            # 如果是 [num_tokens, hidden_size],添加 batch 维度
            hidden_states = hidden_states.reshape(1, hidden_states.shape[0], hidden_states.shape[1])
        
        # 取最后一个 token 的 hidden state
        hs = hidden_states[:, -1:, :]  # shape: [1, 1, 2048]
        
        # 使用 Head 得到当前步输出的 logits
        logits = (gen_head if mode == "t2i" else lm_head).run(None, {"hidden_states": hs})[0]
        logits = logits[:, -1, :]  # shape: [1, vocab_size]
        
        # 温度采样,调整 logits 分布并随机采样 (不能用贪婪采样)
        logits = logits / tempature
        # 计算 softmax
        exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))[0]
        probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
        # 多项式采样
        probs = probs.astype(np.float64)
        probs /= probs.sum()
        next_token = int(np.random.multinomial(1, probs).argmax())
        pbar.set_postfix(next_token=tokenizer.decode([next_token]))
        generated_tokens.append(next_token)
        if next_token == 100001:  # eos
            break
        
        # 将生成的 token 转换为 embedding 并与之前的 embedding 拼接
        if mode == "t2i":
            new_embed = gen_img_embeds.run(None, {"image_ids": np.array([[next_token]], dtype=np.int64)})[0]
        else:
            new_embed = text_embeds.run(None, {"input_ids": np.array([[next_token]], dtype=np.int64)})[0]
        
        # 将新生成的 embedding 拼接到 inputs_embeds 后面
        # inputs_embeds = np.concatenate([inputs_embeds, new_embed], axis=1)
        inputs_embeds = new_embed

rkllm_runtime.clear_kv_cache(False)

# 5. 图片或者文本解码
if mode == "t2i":
    # 将生成的576个图片token拼接为数组,输入到VQVAE解码器进行图像解码
    generated_tokens_array = np.array([generated_tokens], dtype=np.int64)  # shape: [1, 576]
    decoded_image = image_decode.run(None, {"generated_tokens": generated_tokens_array})[0]  # 输出形状: [1, 3, 384, 384]
    decoded_image = np.clip((decoded_image + 1) / 2 * 255, 0, 255)
    # 后处理图像:将图像从CHW转换为HWC,并利用cv2保存为png格式
    decoded_image = np.squeeze(decoded_image, axis=0)  # [3, 384, 384]
    decoded_image = np.transpose(decoded_image, (1, 2, 0))  # [384, 384, 3]
    cv2.imwrite("generated.png", cv2.cvtColor(decoded_image, cv2.COLOR_RGB2BGR))
    print("(generated.png)")
else:
    decoded_text = tokenizer.decode(generated_tokens)
    print(f"{decoded_text}")

# 清理资源
print("清理 RKLLM 资源...")
rkllm_runtime.destroy()