|
|
import faulthandler |
|
|
faulthandler.enable() |
|
|
|
|
|
import os |
|
|
os.environ["RKLLM_LOG_LEVEL"] = "1" |
|
|
import numpy as np |
|
|
import onnxruntime as real_ort |
|
|
import ztu_somemodelruntime_rknnlite2 as ort |
|
|
from tokenizers import Tokenizer |
|
|
import cv2 |
|
|
import tqdm |
|
|
import time |
|
|
import ctypes |
|
|
|
|
|
from rkllm_binding import * |
|
|
|
|
|
model_path = "." |
|
|
onnx_model_path = f"{model_path}" |
|
|
tokenizer = Tokenizer.from_file(f"{model_path}/tokenizer.json") |
|
|
|
|
|
|
|
|
image = None |
|
|
prompt = "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair" |
|
|
mode = "t2i" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tempature = 0.7 |
|
|
|
|
|
|
|
|
rkllm_result_data = { |
|
|
'hidden_states': None, |
|
|
'finished': False, |
|
|
'error': False |
|
|
} |
|
|
|
|
|
def rkllm_callback(result_ptr, userdata_ptr, state_enum): |
|
|
"""RKLLM 推理回调函数""" |
|
|
global rkllm_result_data |
|
|
|
|
|
try: |
|
|
state = LLMCallState(state_enum) |
|
|
|
|
|
|
|
|
if state == LLMCallState.RKLLM_RUN_FINISH: |
|
|
rkllm_result_data['finished'] = True |
|
|
print("RKLLM 推理完成") |
|
|
return |
|
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
|
rkllm_result_data['error'] = True |
|
|
rkllm_result_data['error_msg'] = "RKLLM 推理出错" |
|
|
rkllm_result_data['finished'] = True |
|
|
print("错误: RKLLM 推理出错") |
|
|
|
|
|
|
|
|
if not result_ptr: |
|
|
print("警告: result_ptr 为空指针") |
|
|
return |
|
|
|
|
|
result = result_ptr.contents |
|
|
|
|
|
if state == LLMCallState.RKLLM_RUN_NORMAL: |
|
|
|
|
|
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: |
|
|
|
|
|
hidden_size = result.last_hidden_layer.embd_size |
|
|
num_tokens = result.last_hidden_layer.num_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_array = np.ctypeslib.as_array( |
|
|
result.last_hidden_layer.hidden_states, |
|
|
shape=(num_tokens, hidden_size) |
|
|
).copy() |
|
|
|
|
|
rkllm_result_data['hidden_states'] = hidden_array |
|
|
|
|
|
rkllm_result_data['finished'] = True |
|
|
return 1 |
|
|
else: |
|
|
print("警告: 没有获取到有效的 hidden states") |
|
|
|
|
|
return 1 |
|
|
except Exception as e: |
|
|
print(f"回调函数异常: {e}") |
|
|
rkllm_result_data['error'] = True |
|
|
rkllm_result_data['error_msg'] = str(e) |
|
|
rkllm_result_data['finished'] = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vision_encoder = ort.InferenceSession(f"{onnx_model_path}/vision_encoder.rknn") |
|
|
|
|
|
|
|
|
print("初始化 RKLLM 语言模型...") |
|
|
rkllm_runtime = RKLLMRuntime() |
|
|
rkllm_params = rkllm_runtime.create_default_param() |
|
|
rkllm_params.model_path = f"{model_path}/language_model.rkllm".encode('utf-8') |
|
|
rkllm_params.max_context_len = 1024 |
|
|
rkllm_params.max_new_tokens = 5 |
|
|
|
|
|
rkllm_params.skip_special_token = 0 |
|
|
rkllm_params.extend_param.base_domain_id = 1 |
|
|
rkllm_runtime.init(rkllm_params, rkllm_callback) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm_head = ort.InferenceSession(f"{onnx_model_path}/lm_head.onnx") |
|
|
|
|
|
|
|
|
|
|
|
gen_head = ort.InferenceSession(f"{onnx_model_path}/gen_head.onnx") |
|
|
|
|
|
|
|
|
|
|
|
gen_img_embeds = ort.InferenceSession(f"{onnx_model_path}/gen_img_embeds.onnx") |
|
|
|
|
|
|
|
|
|
|
|
text_embeds = real_ort.InferenceSession(f"{onnx_model_path}/embed_tokens.onnx") |
|
|
|
|
|
|
|
|
|
|
|
image_decode = ort.InferenceSession(f"{onnx_model_path}/image_decode.onnx") |
|
|
|
|
|
|
|
|
|
|
|
if mode == "t2i": |
|
|
input_str = f"""<|User|>: {prompt} |
|
|
|
|
|
<|Assistant|>:<begin_of_image>""" |
|
|
else: |
|
|
input_str = f"""You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. |
|
|
|
|
|
<|User|>: <image_placeholder> |
|
|
{prompt} |
|
|
|
|
|
<|Assistant|>:""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_str = input_str.replace("<image_placeholder>", "<image_placeholder>" * 576) |
|
|
input = tokenizer.encode(input_str) |
|
|
input_ids = np.array([input.ids], dtype=np.int64) |
|
|
input_len = len(input.ids) |
|
|
attention_mask = np.array([input.attention_mask], dtype=np.int64) |
|
|
images_seq_mask = np.array([[1 if id == 100581 else 0 for id in input.ids]], dtype=np.bool_) |
|
|
position_ids = np.expand_dims(np.arange(input_len), axis=0) |
|
|
|
|
|
if image: |
|
|
img = cv2.imread(image) |
|
|
if img is None: |
|
|
raise ValueError(f"无法读取图片: {image}") |
|
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
target_size = 384 |
|
|
img = cv2.resize(img, (target_size, target_size), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
img = img.astype(np.float32) * 0.00392156862745098 |
|
|
|
|
|
img = (img - np.array([0.5, 0.5, 0.5], dtype=np.float32)) / np.array([0.5, 0.5, 0.5], dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
img = img.transpose(2, 0, 1) |
|
|
pixel_values = np.expand_dims(np.expand_dims(img, axis=0), axis=1) |
|
|
images_emb_mask = np.ones((1, 1, 576), dtype=np.bool_) |
|
|
else: |
|
|
pixel_values = np.zeros((0, 0, 3, 384, 384), dtype=np.float32) |
|
|
images_emb_mask = np.zeros((1, 0, 576), dtype=np.bool_) |
|
|
|
|
|
|
|
|
|
|
|
text_inputs_embeds = text_embeds.run(None, {"input_ids": input_ids})[0] |
|
|
|
|
|
|
|
|
if image: |
|
|
|
|
|
vision_embeds = vision_encoder.run(None, {"pixel_values": pixel_values})[0] |
|
|
|
|
|
|
|
|
image_token_positions = np.where(images_seq_mask[0])[0] |
|
|
|
|
|
|
|
|
|
|
|
for idx, pos in enumerate(image_token_positions): |
|
|
if idx < vision_embeds.shape[1]: |
|
|
text_inputs_embeds[0, pos, :] = vision_embeds[0, idx, :] |
|
|
|
|
|
inputs_embeds = text_inputs_embeds |
|
|
|
|
|
|
|
|
|
|
|
generated_tokens = [] |
|
|
|
|
|
|
|
|
rkllm_input = RKLLMInput() |
|
|
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED |
|
|
embed_input = RKLLMEmbedInput() |
|
|
infer_params = RKLLMInferParam() |
|
|
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER |
|
|
infer_params.keep_history = 1 |
|
|
|
|
|
def run_rkllm_inference(inputs_embeds): |
|
|
"""使用 RKLLM 进行推理,输入 embedding,输出 hidden states""" |
|
|
global rkllm_result_data |
|
|
|
|
|
|
|
|
rkllm_result_data = { |
|
|
'hidden_states': None, |
|
|
'finished': False, |
|
|
'error': False |
|
|
} |
|
|
|
|
|
|
|
|
embed_flat = inputs_embeds.flatten().astype(np.float32) |
|
|
embed_c_array = (ctypes.c_float * len(embed_flat))(*embed_flat) |
|
|
embed_input.embed = embed_c_array |
|
|
embed_input.n_tokens = inputs_embeds.shape[1] |
|
|
|
|
|
rkllm_input._union_data.embed_input = embed_input |
|
|
|
|
|
|
|
|
rkllm_runtime.run(rkllm_input, infer_params) |
|
|
|
|
|
|
|
|
while not rkllm_result_data['finished']: |
|
|
time.sleep(0.001) |
|
|
|
|
|
if rkllm_result_data['error']: |
|
|
raise RuntimeError("RKLLM 推理出错") |
|
|
|
|
|
return rkllm_result_data['hidden_states'] |
|
|
|
|
|
|
|
|
with tqdm.tqdm(range(576)) as pbar: |
|
|
for i in pbar: |
|
|
|
|
|
hidden_states = run_rkllm_inference(inputs_embeds) |
|
|
|
|
|
if hidden_states is None: |
|
|
raise RuntimeError("RKLLM 未返回有效的 hidden states") |
|
|
|
|
|
|
|
|
if len(hidden_states.shape) == 2: |
|
|
|
|
|
hidden_states = hidden_states.reshape(1, hidden_states.shape[0], hidden_states.shape[1]) |
|
|
|
|
|
|
|
|
hs = hidden_states[:, -1:, :] |
|
|
|
|
|
|
|
|
logits = (gen_head if mode == "t2i" else lm_head).run(None, {"hidden_states": hs})[0] |
|
|
logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
logits = logits / tempature |
|
|
|
|
|
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))[0] |
|
|
probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) |
|
|
|
|
|
probs = probs.astype(np.float64) |
|
|
probs /= probs.sum() |
|
|
next_token = int(np.random.multinomial(1, probs).argmax()) |
|
|
pbar.set_postfix(next_token=tokenizer.decode([next_token])) |
|
|
generated_tokens.append(next_token) |
|
|
if next_token == 100001: |
|
|
break |
|
|
|
|
|
|
|
|
if mode == "t2i": |
|
|
new_embed = gen_img_embeds.run(None, {"image_ids": np.array([[next_token]], dtype=np.int64)})[0] |
|
|
else: |
|
|
new_embed = text_embeds.run(None, {"input_ids": np.array([[next_token]], dtype=np.int64)})[0] |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds = new_embed |
|
|
|
|
|
rkllm_runtime.clear_kv_cache(False) |
|
|
|
|
|
|
|
|
if mode == "t2i": |
|
|
|
|
|
generated_tokens_array = np.array([generated_tokens], dtype=np.int64) |
|
|
decoded_image = image_decode.run(None, {"generated_tokens": generated_tokens_array})[0] |
|
|
decoded_image = np.clip((decoded_image + 1) / 2 * 255, 0, 255) |
|
|
|
|
|
decoded_image = np.squeeze(decoded_image, axis=0) |
|
|
decoded_image = np.transpose(decoded_image, (1, 2, 0)) |
|
|
cv2.imwrite("generated.png", cv2.cvtColor(decoded_image, cv2.COLOR_RGB2BGR)) |
|
|
print("(generated.png)") |
|
|
else: |
|
|
decoded_text = tokenizer.decode(generated_tokens) |
|
|
print(f"{decoded_text}") |
|
|
|
|
|
|
|
|
print("清理 RKLLM 资源...") |
|
|
rkllm_runtime.destroy() |
|
|
|