from transformers import AutoProcessor from PIL import Image import numpy as np import onnxruntime as ort import time import argparse import random # Use RKNN for some models import ztu_somemodelruntime_rknnlite2 as rknnort # Uncomment this to use ONNXRuntime for some models # import onnxruntime as rknnort # set current working directory to the directory of this file import os os.chdir(os.path.dirname(os.path.abspath(__file__))) def run(image_path, prompt, max_new_tokens, output_image_path, temperature, seed): # set seed for reproducibility if seed is not None: random.seed(seed) np.random.seed(seed) # 初始化总时间计数器 total_time = 0 # Initialize RKNNLite instances vision_encoder = rknnort.InferenceSession( "vision_encoder.onnx", providers=["CPUExecutionProvider"] ) encoder = rknnort.InferenceSession( "encoder_model.onnx", providers=["CPUExecutionProvider"] ) decoder_prefill = rknnort.InferenceSession( "decoder_model.onnx", providers=["CPUExecutionProvider"] ) text_embed = ort.InferenceSession( "embed_tokens.onnx", providers=["CPUExecutionProvider"] ) decoder_decode = ort.InferenceSession( "decoder_model_merged.onnx", providers=["CPUExecutionProvider"] ) # 1. prepare inputs processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) # 2. prepare image image = Image.open(image_path).convert("RGB") original_image = image.copy() original_size = image.size # resize image to 64x64 image = image.resize((64, 64)) # 3. prepare text inputs = processor( text=prompt, images=image, return_tensors="np", do_resize=False ) # , padding="max_length", max_length=pad_to + 577, truncation=True) for k, v in inputs.items(): print(k, v.shape) # print(inputs) # 4. run vision encoder using RKNN start_time = time.time() image_features = vision_encoder.run(None, {"pixel_values": inputs["pixel_values"]})[ 0 ] end_time = time.time() vision_encoder_time = (end_time - start_time) * 1000 total_time += vision_encoder_time print(f"Vision encoder time: {vision_encoder_time:.2f} ms") print(image_features.shape) # np.save("image_features.npy", image_features) # 5. run text embed using RKNN start_time = time.time() inputs_embeds = text_embed.run(None, {"input_ids": inputs["input_ids"]})[0] end_time = time.time() text_embed_time = (end_time - start_time) * 1000 total_time += text_embed_time print(f"Text embed time: {text_embed_time:.2f} ms") print(inputs_embeds.shape) # print(inputs_embeds) # 6. concat image features and text embed batch_size, image_token_length = image_features.shape[:-1] image_attention_mask = np.ones((batch_size, image_token_length)) task_prefix_embeds = inputs_embeds task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1])) # task_prefix_attention_mask = inputs["attention_mask"] if len(task_prefix_attention_mask.shape) == 3: task_prefix_attention_mask = task_prefix_attention_mask[:, 0] inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1) attention_mask = np.concatenate( [image_attention_mask, task_prefix_attention_mask], axis=1 ) # 6. run encoder using RKNN start_time = time.time() encoder_out = encoder.run( None, { "inputs_embeds": inputs_embeds, "attention_mask": attention_mask.astype(np.int64), }, ) end_time = time.time() encoder_time = (end_time - start_time) * 1000 total_time += encoder_time print(f"Encoder time: {encoder_time:.2f} ms") encoder_hidden_states = encoder_out[0] print(encoder_hidden_states.shape) # 7. run decoder prefill stage using RKNN start_time = time.time() next_token = processor.tokenizer.bos_token_id next_input_embeds = text_embed.run(None, { "input_ids": np.array([[next_token]], dtype=np.int64) })[0] decoder_outs = decoder_prefill.run( None, { "inputs_embeds": next_input_embeds, "encoder_hidden_states": encoder_hidden_states, # "encoder_attention_mask": attention_mask.astype(np.int64) }, ) end_time = time.time() decoder_prefill_time = (end_time - start_time) * 1000 total_time += decoder_prefill_time print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms") # for output in decoder_outs: # print(output.shape) encoder_kv = decoder_outs[1:] # 8. run decoder decode stage(autoregressive) (using onnxruntime) generated_tokens = [] decoder_decode_total_time = 0 while generated_tokens.__len__() < max_new_tokens: # 获取上一步的输出 logits = decoder_outs[0] decoder_kv = decoder_outs[1:] # 选择最后一个token的logits next_token_logits = logits[:, -1, :] if temperature == 0: # Greedy decoding next_token = np.argmax(next_token_logits, axis=-1)[0] else: # Temperature sampling # 应用温度 next_token_logits /= temperature # 从logits中减去最大值以提高数值稳定性 next_token_logits -= np.max(next_token_logits) # 计算softmax probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits)) # 从概率分布中采样 next_token = np.random.choice(len(probs[0]), p=probs[0]) print("next_token: ", processor.decode([next_token])) # 将新生成的token添加到结果中 generated_tokens.append(next_token) # 如果生成了结束符,则停止生成 if next_token == 2: # break # 准备下一步的输入 start_time = time.time() next_input_embeds = text_embed.run( None, {"input_ids": np.array([[next_token]], dtype=np.int64)} )[0] end_time = time.time() text_embed_time = (end_time - start_time) * 1000 decoder_decode_total_time += text_embed_time # 运行decoder的decode阶段 start_time = time.time() decoder_outs = decoder_decode.run( None, { "use_cache_branch": np.array([True], dtype=np.bool_), "inputs_embeds": next_input_embeds, "encoder_hidden_states": encoder_hidden_states, # "encoder_attention_mask": attention_mask.astype(np.int64), "past_key_values.0.decoder.key": decoder_kv[0], "past_key_values.0.decoder.value": decoder_kv[1], "past_key_values.0.encoder.key": encoder_kv[2], "past_key_values.0.encoder.value": encoder_kv[3], "past_key_values.1.decoder.key": decoder_kv[4], "past_key_values.1.decoder.value": decoder_kv[5], "past_key_values.1.encoder.key": encoder_kv[6], "past_key_values.1.encoder.value": encoder_kv[7], "past_key_values.2.decoder.key": decoder_kv[8], "past_key_values.2.decoder.value": decoder_kv[9], "past_key_values.2.encoder.key": encoder_kv[10], "past_key_values.2.encoder.value": encoder_kv[11], "past_key_values.3.decoder.key": decoder_kv[12], "past_key_values.3.decoder.value": decoder_kv[13], "past_key_values.3.encoder.key": encoder_kv[14], "past_key_values.3.encoder.value": encoder_kv[15], "past_key_values.4.decoder.key": decoder_kv[16], "past_key_values.4.decoder.value": decoder_kv[17], "past_key_values.4.encoder.key": encoder_kv[18], "past_key_values.4.encoder.value": encoder_kv[19], "past_key_values.5.decoder.key": decoder_kv[20], "past_key_values.5.decoder.value": decoder_kv[21], "past_key_values.5.encoder.key": encoder_kv[22], "past_key_values.5.encoder.value": encoder_kv[23], }, ) end_time = time.time() decoder_decode_time = (end_time - start_time) * 1000 decoder_decode_total_time += decoder_decode_time total_time += decoder_decode_total_time print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms") # 将生成的tokens转换为文本 print("generated_tokens: ", generated_tokens) generated_text = processor.batch_decode( [generated_tokens], skip_special_tokens=False )[0] print("Generated Text:", generated_text) parsed_answer = processor.post_process_generation( generated_text, task=prompt.split(">")[0].strip() + ">", image_size=original_size, ) print("Parsed Answer:", parsed_answer) print(f"Total inference time: {total_time:.2f} ms") if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("image_path", type=str, help="Path to the input image.") parser.add_argument( "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens to generate.", ) parser.add_argument( "--output_image_path", type=str, default="result_image.jpg", help="Path to save the output image with visualizations.", ) parser.add_argument( "--temperature", type=float, default=0, help="Temperature for sampling. Set to 0 for greedy decoding.", ) parser.add_argument( "--seed", type=int, default=None, help="Random seed for reproducibility." ) args = parser.parse_args() run( args.image_path, "", args.max_new_tokens, args.output_image_path, args.temperature, args.seed, )