FastVLM-1.5B / infer_axmodel.py
wli1995's picture
Upload c++ demo
49c9770 verified
import os
import argparse
from PIL import Image
import numpy as np
from utils.llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from utils.llava_qwen import tokenizer_image_token, expand2square
from utils.infer_func import InferManager
from transformers import AutoTokenizer, AutoConfig, CLIPImageProcessor
import axengine as ax
from ml_dtypes import bfloat16
import argparse
def load_model_and_tokenizer(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path)
mm_use_im_start_end = getattr(config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
return config, tokenizer
def vision_encoder(image_path, ax_session, args):
image_processor = CLIPImageProcessor(size={"shortest_edge": int(args.input_size)}, # CLIP 支持 336x336
crop_size={"height": int(args.input_size), "width": int(args.input_size)},
image_mean=[0, 0, 0],
image_std=[1/255, 1/255, 1/255]
)
image = Image.open(image_path).convert('RGB')
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
input_image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
input_image = input_image.unsqueeze(0) # add batch dimension
input_image = input_image.numpy().astype(np.uint8).transpose((0, 2, 3, 1)) # NHWC to NCHW
vit_output = ax_session.run(None, {"images": input_image})[0]
return vit_output
def llm_infer(image_features, llm_path, config, tokenizer, imer, get_input, token_length, embeds):
prompt = "<|im_start|>system\nYou are a helpful assistant, created by apple company.<|im_end|>\n"
question = get_input
prompt += "<|im_start|>user\n" + question
if image_features is not None:
# # for idx in range(len(image_features)):
#prompt += "\n<img>" + "<image>"*token_length + "</img>\n"
prompt += "\n" + "<image>"*token_length + "\n"
prompt += "<|im_end|>\n<|im_start|>assistant\n"
# token_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX)
token_ids = tokenizer.encode(prompt)
# 图像理解
prefill_data = np.take(embeds, token_ids, axis=0)
prefill_data = prefill_data.astype(bfloat16)
token_len = len(token_ids)
if image_features is not None:
image_start_index = np.where(np.array(token_ids) == 151646)[0][0] # <image> tag 151646
image_insert_index = image_start_index + 1
prefill_data[image_insert_index : image_insert_index + token_length] = image_features[0, :, :]
eos_token_id = None
if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1:
eos_token_id = config.eos_token_id
slice_len = 128
# prefill_max_len = 640
max_seq_len = 1024 # prefill + decode max length
# imer = InferManager(config, llm_path, max_seq_len=max_seq_len) # prefill + decode max length
token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len)
imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id)
print("\n")
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--vision_model", "-v", type=str, default="./fastvlm_ax650_context_1k_prefill_640/image_encoder_1024x1024.axmodel", help="Path to the vision axmodel.")
args.add_argument("--model_path", "-m", type=str, default="./fastvlm_ax650_context_1k_prefill_640", help="Path to the llm axmodel.")
args.add_argument("--tokenizer_path", "-t", type=str, default="./fastvlm_tokenizer", help="Path to the tokenizer.")
args.add_argument("--input_size", "-i", type=str, default="1024", help="Input size of the vision encoder model.")
# args.add_argument("--question", type=str, default="介绍一下你自己", help="The question to ask the model.")
args = args.parse_args()
token_len_map = {"2048": 1280,
"1024": 256,
"768": 144,
"512": 64,
"256": 16}
token_length = token_len_map[args.input_size]
print("Loading config, tokenizer and init model.")
config, tokenizer = load_model_and_tokenizer(model_path=args.tokenizer_path)
slice_len = 128
# prefill_max_len = 640
max_seq_len = 1024 # prefill + decode max length
imer = InferManager(config, args.model_path, max_seq_len=max_seq_len) # prefill + decode max length
ax_session = ax.InferenceSession(args.vision_model)
embeds = np.load(os.path.join(args.model_path, "model.embed_tokens.weight.npy"))
print(embeds.shape)
print(f"[INFO]: 输入文本进行对话,或者输入图片路径进行图片理解, 或者输入q退出对话。")
while True:
prompt = input("prompt<<")
if prompt.strip() == "q":
print(f"[INFO]: 对话结束,再见。")
break
else:
get_input = prompt.strip()
if get_input.lower().endswith(("jpg", "jpeg", "png")):
if not os.path.isfile(get_input):
print("[INFO]: 输入错误,请检查图片输入路径。")
continue
image_features = vision_encoder(get_input, ax_session, args)
get_input = "Describe the image in detail."
llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)
else:
image_features = None
llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)