Janus-Pro-1B-RKLLM / run_rkllm.py
happyme531's picture
Upload 3 files
f4851b2 verified
import faulthandler
faulthandler.enable()
import os
os.environ["RKLLM_LOG_LEVEL"] = "1"
import numpy as np
import onnxruntime as real_ort
import ztu_somemodelruntime_rknnlite2 as ort
from tokenizers import Tokenizer
import cv2
import tqdm
import time
import ctypes
from rkllm_binding import *
model_path = "."
onnx_model_path = f"{model_path}"
tokenizer = Tokenizer.from_file(f"{model_path}/tokenizer.json")
# np.random.seed(0)
image = None
prompt = "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair"
mode = "t2i" # 文本生成图片, 反过来是it2t -> 图片/文本生成文本
# image = "./test.jpg"
# prompt = "仔细描述这张图片。"
# mode = "it2t"
tempature = 0.7
# 全局变量用于存储 rkllm 推理结果
rkllm_result_data = {
'hidden_states': None,
'finished': False,
'error': False
}
def rkllm_callback(result_ptr, userdata_ptr, state_enum):
"""RKLLM 推理回调函数"""
global rkllm_result_data
try:
state = LLMCallState(state_enum)
# print(f"回调状态: {state.name}")
if state == LLMCallState.RKLLM_RUN_FINISH:
rkllm_result_data['finished'] = True
print("RKLLM 推理完成")
return
elif state == LLMCallState.RKLLM_RUN_ERROR:
rkllm_result_data['error'] = True
rkllm_result_data['error_msg'] = "RKLLM 推理出错"
rkllm_result_data['finished'] = True
print("错误: RKLLM 推理出错")
# 检查 result_ptr 是否为空
if not result_ptr:
print("警告: result_ptr 为空指针")
return
result = result_ptr.contents
# print(result.perf)
if state == LLMCallState.RKLLM_RUN_NORMAL:
# 获取 last hidden layer 结果
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
# 将 C 数组转换为 numpy 数组
hidden_size = result.last_hidden_layer.embd_size
num_tokens = result.last_hidden_layer.num_tokens
# print(f"Hidden layer info: num_tokens={num_tokens}, embd_size={hidden_size}")
# 创建 numpy 数组从 C 指针
hidden_array = np.ctypeslib.as_array(
result.last_hidden_layer.hidden_states,
shape=(num_tokens, hidden_size)
).copy() # 复制数据以避免内存问题
rkllm_result_data['hidden_states'] = hidden_array
# print(f"成功获取 hidden states,形状: {hidden_array.shape}")
rkllm_result_data['finished'] = True
return 1
else:
print("警告: 没有获取到有效的 hidden states")
return 1
except Exception as e:
print(f"回调函数异常: {e}")
rkllm_result_data['error'] = True
rkllm_result_data['error_msg'] = str(e)
rkllm_result_data['finished'] = True
# 1. 加载模型
# 视觉编码器
# <- pixel_values: float32[batch_size,num_images,3,384,384]
# -> inputs_embeds: float32[batch_size*num_images,576,2048]
vision_encoder = ort.InferenceSession(f"{onnx_model_path}/vision_encoder.rknn")
# 初始化 RKLLM 语言模型
print("初始化 RKLLM 语言模型...")
rkllm_runtime = RKLLMRuntime()
rkllm_params = rkllm_runtime.create_default_param()
rkllm_params.model_path = f"{model_path}/language_model.rkllm".encode('utf-8')
rkllm_params.max_context_len = 1024
rkllm_params.max_new_tokens = 5
# rkllm_params.temperature = tempature
rkllm_params.skip_special_token = 0
rkllm_params.extend_param.base_domain_id = 1
rkllm_runtime.init(rkllm_params, rkllm_callback)
# LM Head
# <- hidden_states: float32[batch_size,sequence_length,2048]
# -> logits: float32[batch_size,sequence_length,102400]
lm_head = ort.InferenceSession(f"{onnx_model_path}/lm_head.onnx")
# 图片生成Head
# <- hidden_states: float32[batch_size,sequence_length,2048]
# -> logits: float32[batch_size,sequence_length,16384]
gen_head = ort.InferenceSession(f"{onnx_model_path}/gen_head.onnx")
# 图片生成Embedding
# <- image_ids: int64[batch_size,sequence_length]
# -> inputs_embeds: float32[batch_size,sequence_length,2048]
gen_img_embeds = ort.InferenceSession(f"{onnx_model_path}/gen_img_embeds.onnx")
# 文本Embedding
# <- input_ids: int64[batch_size,sequence_length]
# -> inputs_embeds: float32[batch_size,sequence_length,2048]
text_embeds = real_ort.InferenceSession(f"{onnx_model_path}/embed_tokens.onnx")
# VQVAE 解码器 (576个token变成一个384x384的图片)
# <- generated_tokens: int64[batch_size,sequence_length]
# -> decoded_image: float32[batch_size,3,384,384]
image_decode = ort.InferenceSession(f"{onnx_model_path}/image_decode.onnx")
# 2. 预处理输入
# tokenizer会在最开始加<|begin▁of▁sentence|>, 这里不要加!
if mode == "t2i":
input_str = f"""<|User|>: {prompt}
<|Assistant|>:<begin_of_image>"""
else:
input_str = f"""You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
<|User|>: <image_placeholder>
{prompt}
<|Assistant|>:"""
# 3. 生成Embedding
# 把<image_placeholder>替换为576个<image_placeholder>
input_str = input_str.replace("<image_placeholder>", "<image_placeholder>" * 576)
input = tokenizer.encode(input_str)
input_ids = np.array([input.ids], dtype=np.int64)
input_len = len(input.ids)
attention_mask = np.array([input.attention_mask], dtype=np.int64)
images_seq_mask = np.array([[1 if id == 100581 else 0 for id in input.ids]], dtype=np.bool_) # 为什么<image_placeholder>有两个id?
position_ids = np.expand_dims(np.arange(input_len), axis=0)
#图片预处理
if image:
img = cv2.imread(image)
if img is None:
raise ValueError(f"无法读取图片: {image}")
# 将BGR转换为RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 调整尺寸为目标大小:384x384
target_size = 384
img = cv2.resize(img, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
# 转换数据类型为float32,并使用rescale_factor归一化像素值到[0,1]
img = img.astype(np.float32) * 0.00392156862745098 # 0.00392156 = 1/255
# 根据配置文件归一化: (img - image_mean) / image_std,其中 image_mean = [0.5, 0.5, 0.5], image_std = [0.5, 0.5, 0.5]
img = (img - np.array([0.5, 0.5, 0.5], dtype=np.float32)) / np.array([0.5, 0.5, 0.5], dtype=np.float32)
# 如果图像尺寸不是正方形,还可以用background_color填充,这里由于直接resize为正方形,故忽略这部分
# 转换图像维度为 [batch_size, num_images, channels, height, width]
# 先将HWC格式转为CHW
img = img.transpose(2, 0, 1) # 得到 [3, 384, 384]
pixel_values = np.expand_dims(np.expand_dims(img, axis=0), axis=1) # [1, 1, 3, 384, 384]
images_emb_mask = np.ones((1, 1, 576), dtype=np.bool_)
else:
pixel_values = np.zeros((0, 0, 3, 384, 384), dtype=np.float32)
images_emb_mask = np.zeros((1, 0, 576), dtype=np.bool_)
# 手动处理输入embeddings
# 1. 先获取文本embeddings
text_inputs_embeds = text_embeds.run(None, {"input_ids": input_ids})[0] # [1, input_len, 2048]
# 2. 如果有图片,获取视觉embeddings并插入到文本中
if image:
# 运行视觉编码器
vision_embeds = vision_encoder.run(None, {"pixel_values": pixel_values})[0] # [1, 576, 2048]
# 找到所有 <image_placeholder> token 的位置(images_seq_mask中为True的位置)
image_token_positions = np.where(images_seq_mask[0])[0] # 获取所有为True的索引
# 将视觉embeddings插入到对应位置
# 遍历每个图片token位置,替换为对应的视觉embedding
for idx, pos in enumerate(image_token_positions):
if idx < vision_embeds.shape[1]: # 确保不超出vision_embeds的范围
text_inputs_embeds[0, pos, :] = vision_embeds[0, idx, :]
inputs_embeds = text_inputs_embeds
# 4. 语言模型推理(使用 RKLLM)
# 用于保存生成的图片 token(这里仅保留条件分支生成的 token)
generated_tokens = []
# 初始化可复用的对象
rkllm_input = RKLLMInput()
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED
embed_input = RKLLMEmbedInput()
infer_params = RKLLMInferParam()
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER
infer_params.keep_history = 1
def run_rkllm_inference(inputs_embeds):
"""使用 RKLLM 进行推理,输入 embedding,输出 hidden states"""
global rkllm_result_data
# 重置结果
rkllm_result_data = {
'hidden_states': None,
'finished': False,
'error': False
}
# 更新embedding输入数据
embed_flat = inputs_embeds.flatten().astype(np.float32)
embed_c_array = (ctypes.c_float * len(embed_flat))(*embed_flat)
embed_input.embed = embed_c_array
embed_input.n_tokens = inputs_embeds.shape[1] # sequence length
rkllm_input._union_data.embed_input = embed_input
# 运行推理
rkllm_runtime.run(rkllm_input, infer_params)
# 等待结果
while not rkllm_result_data['finished']:
time.sleep(0.001) # 短暂等待
if rkllm_result_data['error']:
raise RuntimeError("RKLLM 推理出错")
return rkllm_result_data['hidden_states']
# 循环生成576个图片 token
with tqdm.tqdm(range(576)) as pbar:
for i in pbar:
# 使用 RKLLM 进行推理
hidden_states = run_rkllm_inference(inputs_embeds)
if hidden_states is None:
raise RuntimeError("RKLLM 未返回有效的 hidden states")
# 重新整形为期望的格式 [batch_size, sequence_length, hidden_size]
if len(hidden_states.shape) == 2:
# 如果是 [num_tokens, hidden_size],添加 batch 维度
hidden_states = hidden_states.reshape(1, hidden_states.shape[0], hidden_states.shape[1])
# 取最后一个 token 的 hidden state
hs = hidden_states[:, -1:, :] # shape: [1, 1, 2048]
# 使用 Head 得到当前步输出的 logits
logits = (gen_head if mode == "t2i" else lm_head).run(None, {"hidden_states": hs})[0]
logits = logits[:, -1, :] # shape: [1, vocab_size]
# 温度采样,调整 logits 分布并随机采样 (不能用贪婪采样)
logits = logits / tempature
# 计算 softmax
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))[0]
probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
# 多项式采样
probs = probs.astype(np.float64)
probs /= probs.sum()
next_token = int(np.random.multinomial(1, probs).argmax())
pbar.set_postfix(next_token=tokenizer.decode([next_token]))
generated_tokens.append(next_token)
if next_token == 100001: # eos
break
# 将生成的 token 转换为 embedding 并与之前的 embedding 拼接
if mode == "t2i":
new_embed = gen_img_embeds.run(None, {"image_ids": np.array([[next_token]], dtype=np.int64)})[0]
else:
new_embed = text_embeds.run(None, {"input_ids": np.array([[next_token]], dtype=np.int64)})[0]
# 将新生成的 embedding 拼接到 inputs_embeds 后面
# inputs_embeds = np.concatenate([inputs_embeds, new_embed], axis=1)
inputs_embeds = new_embed
rkllm_runtime.clear_kv_cache(False)
# 5. 图片或者文本解码
if mode == "t2i":
# 将生成的576个图片token拼接为数组,输入到VQVAE解码器进行图像解码
generated_tokens_array = np.array([generated_tokens], dtype=np.int64) # shape: [1, 576]
decoded_image = image_decode.run(None, {"generated_tokens": generated_tokens_array})[0] # 输出形状: [1, 3, 384, 384]
decoded_image = np.clip((decoded_image + 1) / 2 * 255, 0, 255)
# 后处理图像:将图像从CHW转换为HWC,并利用cv2保存为png格式
decoded_image = np.squeeze(decoded_image, axis=0) # [3, 384, 384]
decoded_image = np.transpose(decoded_image, (1, 2, 0)) # [384, 384, 3]
cv2.imwrite("generated.png", cv2.cvtColor(decoded_image, cv2.COLOR_RGB2BGR))
print("(generated.png)")
else:
decoded_text = tokenizer.decode(generated_tokens)
print(f"{decoded_text}")
# 清理资源
print("清理 RKLLM 资源...")
rkllm_runtime.destroy()