|
|
import os |
|
|
import argparse |
|
|
|
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from utils.llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN |
|
|
from utils.llava_qwen import tokenizer_image_token, expand2square |
|
|
from utils.infer_func import InferManager |
|
|
from transformers import AutoTokenizer, AutoConfig, CLIPImageProcessor |
|
|
import axengine as ax |
|
|
from ml_dtypes import bfloat16 |
|
|
import argparse |
|
|
|
|
|
def load_model_and_tokenizer(model_path): |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
config = AutoConfig.from_pretrained(model_path) |
|
|
|
|
|
mm_use_im_start_end = getattr(config, "mm_use_im_start_end", False) |
|
|
mm_use_im_patch_token = getattr(config, "mm_use_im_patch_token", True) |
|
|
if mm_use_im_patch_token: |
|
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
|
if mm_use_im_start_end: |
|
|
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
|
|
|
|
|
return config, tokenizer |
|
|
|
|
|
def vision_encoder(image_path, ax_session, args): |
|
|
|
|
|
image_processor = CLIPImageProcessor(size={"shortest_edge": int(args.input_size)}, |
|
|
crop_size={"height": int(args.input_size), "width": int(args.input_size)}, |
|
|
image_mean=[0, 0, 0], |
|
|
image_std=[1/255, 1/255, 1/255] |
|
|
) |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) |
|
|
input_image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
|
input_image = input_image.unsqueeze(0) |
|
|
|
|
|
input_image = input_image.numpy().astype(np.uint8).transpose((0, 2, 3, 1)) |
|
|
vit_output = ax_session.run(None, {"images": input_image})[0] |
|
|
|
|
|
return vit_output |
|
|
|
|
|
def llm_infer(image_features, llm_path, config, tokenizer, imer, get_input, token_length, embeds): |
|
|
|
|
|
prompt = "<|im_start|>system\nYou are a helpful assistant, created by apple company.<|im_end|>\n" |
|
|
question = get_input |
|
|
prompt += "<|im_start|>user\n" + question |
|
|
|
|
|
if image_features is not None: |
|
|
|
|
|
|
|
|
prompt += "\n" + "<image>"*token_length + "\n" |
|
|
prompt += "<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
|
|
|
token_ids = tokenizer.encode(prompt) |
|
|
|
|
|
|
|
|
prefill_data = np.take(embeds, token_ids, axis=0) |
|
|
prefill_data = prefill_data.astype(bfloat16) |
|
|
token_len = len(token_ids) |
|
|
|
|
|
if image_features is not None: |
|
|
image_start_index = np.where(np.array(token_ids) == 151646)[0][0] |
|
|
image_insert_index = image_start_index + 1 |
|
|
prefill_data[image_insert_index : image_insert_index + token_length] = image_features[0, :, :] |
|
|
|
|
|
eos_token_id = None |
|
|
if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1: |
|
|
eos_token_id = config.eos_token_id |
|
|
|
|
|
slice_len = 128 |
|
|
|
|
|
max_seq_len = 1024 |
|
|
|
|
|
|
|
|
|
|
|
token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len) |
|
|
imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id) |
|
|
print("\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
args = argparse.ArgumentParser() |
|
|
args.add_argument("--vision_model", "-v", type=str, default="./fastvlm_ax650_context_1k_prefill_640/image_encoder_1024x1024.axmodel", help="Path to the vision axmodel.") |
|
|
args.add_argument("--model_path", "-m", type=str, default="./fastvlm_ax650_context_1k_prefill_640", help="Path to the llm axmodel.") |
|
|
args.add_argument("--tokenizer_path", "-t", type=str, default="./fastvlm_tokenizer", help="Path to the tokenizer.") |
|
|
args.add_argument("--input_size", "-i", type=str, default="1024", help="Input size of the vision encoder model.") |
|
|
|
|
|
|
|
|
args = args.parse_args() |
|
|
|
|
|
token_len_map = {"2048": 1280, |
|
|
"1024": 256, |
|
|
"768": 144, |
|
|
"512": 64, |
|
|
"256": 16} |
|
|
|
|
|
token_length = token_len_map[args.input_size] |
|
|
|
|
|
print("Loading config, tokenizer and init model.") |
|
|
config, tokenizer = load_model_and_tokenizer(model_path=args.tokenizer_path) |
|
|
|
|
|
slice_len = 128 |
|
|
|
|
|
max_seq_len = 1024 |
|
|
|
|
|
imer = InferManager(config, args.model_path, max_seq_len=max_seq_len) |
|
|
ax_session = ax.InferenceSession(args.vision_model) |
|
|
|
|
|
embeds = np.load(os.path.join(args.model_path, "model.embed_tokens.weight.npy")) |
|
|
print(embeds.shape) |
|
|
|
|
|
print(f"[INFO]: 输入文本进行对话,或者输入图片路径进行图片理解, 或者输入q退出对话。") |
|
|
while True: |
|
|
prompt = input("prompt<<") |
|
|
if prompt.strip() == "q": |
|
|
print(f"[INFO]: 对话结束,再见。") |
|
|
break |
|
|
else: |
|
|
get_input = prompt.strip() |
|
|
if get_input.lower().endswith(("jpg", "jpeg", "png")): |
|
|
if not os.path.isfile(get_input): |
|
|
print("[INFO]: 输入错误,请检查图片输入路径。") |
|
|
continue |
|
|
image_features = vision_encoder(get_input, ax_session, args) |
|
|
get_input = "Describe the image in detail." |
|
|
llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds) |
|
|
else: |
|
|
image_features = None |
|
|
llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds) |
|
|
|