File size: 6,101 Bytes
e61f34a 7eab1cc e61f34a 7eab1cc e61f34a 49c9770 e61f34a 49c9770 e61f34a 49c9770 e61f34a 49c9770 e61f34a 7eab1cc e61f34a 7eab1cc e61f34a 7eab1cc e61f34a 49c9770 e61f34a 7eab1cc e61f34a 49c9770 e61f34a 49c9770 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import os
import argparse
from PIL import Image
import numpy as np
from utils.llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from utils.llava_qwen import tokenizer_image_token, expand2square
from utils.infer_func import InferManager
from transformers import AutoTokenizer, AutoConfig, CLIPImageProcessor
import axengine as ax
from ml_dtypes import bfloat16
import argparse
def load_model_and_tokenizer(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path)
mm_use_im_start_end = getattr(config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
return config, tokenizer
def vision_encoder(image_path, ax_session, args):
image_processor = CLIPImageProcessor(size={"shortest_edge": int(args.input_size)}, # CLIP 支持 336x336
crop_size={"height": int(args.input_size), "width": int(args.input_size)},
image_mean=[0, 0, 0],
image_std=[1/255, 1/255, 1/255]
)
image = Image.open(image_path).convert('RGB')
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
input_image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
input_image = input_image.unsqueeze(0) # add batch dimension
input_image = input_image.numpy().astype(np.uint8).transpose((0, 2, 3, 1)) # NHWC to NCHW
vit_output = ax_session.run(None, {"images": input_image})[0]
return vit_output
def llm_infer(image_features, llm_path, config, tokenizer, imer, get_input, token_length, embeds):
prompt = "<|im_start|>system\nYou are a helpful assistant, created by apple company.<|im_end|>\n"
question = get_input
prompt += "<|im_start|>user\n" + question
if image_features is not None:
# # for idx in range(len(image_features)):
#prompt += "\n<img>" + "<image>"*token_length + "</img>\n"
prompt += "\n" + "<image>"*token_length + "\n"
prompt += "<|im_end|>\n<|im_start|>assistant\n"
# token_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX)
token_ids = tokenizer.encode(prompt)
# 图像理解
prefill_data = np.take(embeds, token_ids, axis=0)
prefill_data = prefill_data.astype(bfloat16)
token_len = len(token_ids)
if image_features is not None:
image_start_index = np.where(np.array(token_ids) == 151646)[0][0] # <image> tag 151646
image_insert_index = image_start_index + 1
prefill_data[image_insert_index : image_insert_index + token_length] = image_features[0, :, :]
eos_token_id = None
if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1:
eos_token_id = config.eos_token_id
slice_len = 128
# prefill_max_len = 640
max_seq_len = 1024 # prefill + decode max length
# imer = InferManager(config, llm_path, max_seq_len=max_seq_len) # prefill + decode max length
token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len)
imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id)
print("\n")
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--vision_model", "-v", type=str, default="./fastvlm_ax650_context_1k_prefill_640/image_encoder_1024x1024.axmodel", help="Path to the vision axmodel.")
args.add_argument("--model_path", "-m", type=str, default="./fastvlm_ax650_context_1k_prefill_640", help="Path to the llm axmodel.")
args.add_argument("--tokenizer_path", "-t", type=str, default="./fastvlm_tokenizer", help="Path to the tokenizer.")
args.add_argument("--input_size", "-i", type=str, default="1024", help="Input size of the vision encoder model.")
# args.add_argument("--question", type=str, default="介绍一下你自己", help="The question to ask the model.")
args = args.parse_args()
token_len_map = {"2048": 1280,
"1024": 256,
"768": 144,
"512": 64,
"256": 16}
token_length = token_len_map[args.input_size]
print("Loading config, tokenizer and init model.")
config, tokenizer = load_model_and_tokenizer(model_path=args.tokenizer_path)
slice_len = 128
# prefill_max_len = 640
max_seq_len = 1024 # prefill + decode max length
imer = InferManager(config, args.model_path, max_seq_len=max_seq_len) # prefill + decode max length
ax_session = ax.InferenceSession(args.vision_model)
embeds = np.load(os.path.join(args.model_path, "model.embed_tokens.weight.npy"))
print(embeds.shape)
print(f"[INFO]: 输入文本进行对话,或者输入图片路径进行图片理解, 或者输入q退出对话。")
while True:
prompt = input("prompt<<")
if prompt.strip() == "q":
print(f"[INFO]: 对话结束,再见。")
break
else:
get_input = prompt.strip()
if get_input.lower().endswith(("jpg", "jpeg", "png")):
if not os.path.isfile(get_input):
print("[INFO]: 输入错误,请检查图片输入路径。")
continue
image_features = vision_encoder(get_input, ax_session, args)
get_input = "Describe the image in detail."
llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)
else:
image_features = None
llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)
|