happyme531's picture
Upload 12 files
69499b6 verified
from transformers import AutoProcessor
from PIL import Image
import numpy as np
import onnxruntime as ort
import time
import argparse
import random
# Use RKNN for some models
import ztu_somemodelruntime_rknnlite2 as rknnort
# Uncomment this to use ONNXRuntime for some models
# import onnxruntime as rknnort
# set current working directory to the directory of this file
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def run(image_path, prompt, max_new_tokens, output_image_path, temperature, seed):
# set seed for reproducibility
if seed is not None:
random.seed(seed)
np.random.seed(seed)
# 初始化总时间计数器
total_time = 0
# Initialize RKNNLite instances
vision_encoder = rknnort.InferenceSession(
"vision_encoder.onnx", providers=["CPUExecutionProvider"]
)
encoder = rknnort.InferenceSession(
"encoder_model.onnx", providers=["CPUExecutionProvider"]
)
decoder_prefill = rknnort.InferenceSession(
"decoder_model.onnx", providers=["CPUExecutionProvider"]
)
text_embed = ort.InferenceSession(
"embed_tokens.onnx", providers=["CPUExecutionProvider"]
)
decoder_decode = ort.InferenceSession(
"decoder_model_merged.onnx", providers=["CPUExecutionProvider"]
)
# 1. prepare inputs
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
# 2. prepare image
image = Image.open(image_path).convert("RGB")
original_image = image.copy()
original_size = image.size
# resize image to 64x64
image = image.resize((64, 64))
# 3. prepare text
inputs = processor(
text=prompt, images=image, return_tensors="np", do_resize=False
) # , padding="max_length", max_length=pad_to + 577, truncation=True)
for k, v in inputs.items():
print(k, v.shape)
# print(inputs)
# 4. run vision encoder using RKNN
start_time = time.time()
image_features = vision_encoder.run(None, {"pixel_values": inputs["pixel_values"]})[
0
]
end_time = time.time()
vision_encoder_time = (end_time - start_time) * 1000
total_time += vision_encoder_time
print(f"Vision encoder time: {vision_encoder_time:.2f} ms")
print(image_features.shape)
# np.save("image_features.npy", image_features)
# 5. run text embed using RKNN
start_time = time.time()
inputs_embeds = text_embed.run(None, {"input_ids": inputs["input_ids"]})[0]
end_time = time.time()
text_embed_time = (end_time - start_time) * 1000
total_time += text_embed_time
print(f"Text embed time: {text_embed_time:.2f} ms")
print(inputs_embeds.shape)
# print(inputs_embeds)
# 6. concat image features and text embed
batch_size, image_token_length = image_features.shape[:-1]
image_attention_mask = np.ones((batch_size, image_token_length))
task_prefix_embeds = inputs_embeds
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
# task_prefix_attention_mask = inputs["attention_mask"]
if len(task_prefix_attention_mask.shape) == 3:
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
attention_mask = np.concatenate(
[image_attention_mask, task_prefix_attention_mask], axis=1
)
# 6. run encoder using RKNN
start_time = time.time()
encoder_out = encoder.run(
None,
{
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask.astype(np.int64),
},
)
end_time = time.time()
encoder_time = (end_time - start_time) * 1000
total_time += encoder_time
print(f"Encoder time: {encoder_time:.2f} ms")
encoder_hidden_states = encoder_out[0]
print(encoder_hidden_states.shape)
# 7. run decoder prefill stage using RKNN
start_time = time.time()
next_token = processor.tokenizer.bos_token_id
next_input_embeds = text_embed.run(None, {
"input_ids": np.array([[next_token]], dtype=np.int64)
})[0]
decoder_outs = decoder_prefill.run(
None,
{
"inputs_embeds": next_input_embeds,
"encoder_hidden_states": encoder_hidden_states,
# "encoder_attention_mask": attention_mask.astype(np.int64)
},
)
end_time = time.time()
decoder_prefill_time = (end_time - start_time) * 1000
total_time += decoder_prefill_time
print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms")
# for output in decoder_outs:
# print(output.shape)
encoder_kv = decoder_outs[1:]
# 8. run decoder decode stage(autoregressive) (using onnxruntime)
generated_tokens = []
decoder_decode_total_time = 0
while generated_tokens.__len__() < max_new_tokens:
# 获取上一步的输出
logits = decoder_outs[0]
decoder_kv = decoder_outs[1:]
# 选择最后一个token的logits
next_token_logits = logits[:, -1, :]
if temperature == 0:
# Greedy decoding
next_token = np.argmax(next_token_logits, axis=-1)[0]
else:
# Temperature sampling
# 应用温度
next_token_logits /= temperature
# 从logits中减去最大值以提高数值稳定性
next_token_logits -= np.max(next_token_logits)
# 计算softmax
probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
# 从概率分布中采样
next_token = np.random.choice(len(probs[0]), p=probs[0])
print("next_token: ", processor.decode([next_token]))
# 将新生成的token添加到结果中
generated_tokens.append(next_token)
# 如果生成了结束符,则停止生成
if next_token == 2: # </s>
break
# 准备下一步的输入
start_time = time.time()
next_input_embeds = text_embed.run(
None, {"input_ids": np.array([[next_token]], dtype=np.int64)}
)[0]
end_time = time.time()
text_embed_time = (end_time - start_time) * 1000
decoder_decode_total_time += text_embed_time
# 运行decoder的decode阶段
start_time = time.time()
decoder_outs = decoder_decode.run(
None,
{
"use_cache_branch": np.array([True], dtype=np.bool_),
"inputs_embeds": next_input_embeds,
"encoder_hidden_states": encoder_hidden_states,
# "encoder_attention_mask": attention_mask.astype(np.int64),
"past_key_values.0.decoder.key": decoder_kv[0],
"past_key_values.0.decoder.value": decoder_kv[1],
"past_key_values.0.encoder.key": encoder_kv[2],
"past_key_values.0.encoder.value": encoder_kv[3],
"past_key_values.1.decoder.key": decoder_kv[4],
"past_key_values.1.decoder.value": decoder_kv[5],
"past_key_values.1.encoder.key": encoder_kv[6],
"past_key_values.1.encoder.value": encoder_kv[7],
"past_key_values.2.decoder.key": decoder_kv[8],
"past_key_values.2.decoder.value": decoder_kv[9],
"past_key_values.2.encoder.key": encoder_kv[10],
"past_key_values.2.encoder.value": encoder_kv[11],
"past_key_values.3.decoder.key": decoder_kv[12],
"past_key_values.3.decoder.value": decoder_kv[13],
"past_key_values.3.encoder.key": encoder_kv[14],
"past_key_values.3.encoder.value": encoder_kv[15],
"past_key_values.4.decoder.key": decoder_kv[16],
"past_key_values.4.decoder.value": decoder_kv[17],
"past_key_values.4.encoder.key": encoder_kv[18],
"past_key_values.4.encoder.value": encoder_kv[19],
"past_key_values.5.decoder.key": decoder_kv[20],
"past_key_values.5.decoder.value": decoder_kv[21],
"past_key_values.5.encoder.key": encoder_kv[22],
"past_key_values.5.encoder.value": encoder_kv[23],
},
)
end_time = time.time()
decoder_decode_time = (end_time - start_time) * 1000
decoder_decode_total_time += decoder_decode_time
total_time += decoder_decode_total_time
print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
# 将生成的tokens转换为文本
print("generated_tokens: ", generated_tokens)
generated_text = processor.batch_decode(
[generated_tokens], skip_special_tokens=False
)[0]
print("Generated Text:", generated_text)
parsed_answer = processor.post_process_generation(
generated_text,
task=prompt.split(">")[0].strip() + ">",
image_size=original_size,
)
print("Parsed Answer:", parsed_answer)
print(f"Total inference time: {total_time:.2f} ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("image_path", type=str, help="Path to the input image.")
parser.add_argument(
"--max_new_tokens",
type=int,
default=512,
help="Maximum number of new tokens to generate.",
)
parser.add_argument(
"--output_image_path",
type=str,
default="result_image.jpg",
help="Path to save the output image with visualizations.",
)
parser.add_argument(
"--temperature",
type=float,
default=0,
help="Temperature for sampling. Set to 0 for greedy decoding.",
)
parser.add_argument(
"--seed", type=int, default=None, help="Random seed for reproducibility."
)
args = parser.parse_args()
run(
args.image_path,
"<CAPTION>",
args.max_new_tokens,
args.output_image_path,
args.temperature,
args.seed,
)