| | import os |
| | import argparse |
| |
|
| | from PIL import Image |
| | import numpy as np |
| | from utils.llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN |
| | from utils.llava_qwen import tokenizer_image_token, expand2square |
| | from utils.infer_func import InferManager |
| | from transformers import AutoTokenizer, AutoConfig, CLIPImageProcessor |
| | import axengine as ax |
| | from ml_dtypes import bfloat16 |
| | import argparse |
| |
|
| | def load_model_and_tokenizer(model_path): |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | config = AutoConfig.from_pretrained(model_path) |
| |
|
| | mm_use_im_start_end = getattr(config, "mm_use_im_start_end", False) |
| | mm_use_im_patch_token = getattr(config, "mm_use_im_patch_token", True) |
| | if mm_use_im_patch_token: |
| | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
| | if mm_use_im_start_end: |
| | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
| |
|
| | return config, tokenizer |
| |
|
| | def vision_encoder(image_path, ax_session, args): |
| |
|
| | image_processor = CLIPImageProcessor(size={"shortest_edge": int(args.input_size)}, |
| | crop_size={"height": int(args.input_size), "width": int(args.input_size)}, |
| | image_mean=[0, 0, 0], |
| | image_std=[1/255, 1/255, 1/255] |
| | ) |
| |
|
| | image = Image.open(image_path).convert('RGB') |
| |
|
| | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) |
| | input_image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| | input_image = input_image.unsqueeze(0) |
| |
|
| | input_image = input_image.numpy().astype(np.uint8).transpose((0, 2, 3, 1)) |
| | vit_output = ax_session.run(None, {"images": input_image})[0] |
| |
|
| | return vit_output |
| |
|
| | def llm_infer(image_features, llm_path, config, tokenizer, imer, get_input, token_length, embeds): |
| |
|
| | prompt = "<|im_start|>system\nYou are a helpful assistant, created by apple company.<|im_end|>\n" |
| | question = get_input |
| | prompt += "<|im_start|>user\n" + question |
| |
|
| | if image_features is not None: |
| | |
| | prompt += "\n" + "<image>"*token_length + "\n" |
| | prompt += "<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| | token_ids = tokenizer.encode(prompt) |
| |
|
| | |
| | prefill_data = np.take(embeds, token_ids, axis=0) |
| | prefill_data = prefill_data.astype(bfloat16) |
| | token_len = len(token_ids) |
| |
|
| | if image_features is not None: |
| | image_start_index = np.where(np.array(token_ids) == 151646)[0][0] |
| | image_insert_index = image_start_index + 1 |
| | prefill_data[image_insert_index : image_insert_index + token_length] = image_features[0, :, :] |
| |
|
| | eos_token_id = None |
| | if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1: |
| | eos_token_id = config.eos_token_id |
| |
|
| | slice_len = 128 |
| | |
| | max_seq_len = 1024 |
| |
|
| | |
| |
|
| | token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len) |
| | imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id) |
| | print("\n") |
| |
|
| | if __name__ == "__main__": |
| |
|
| | args = argparse.ArgumentParser() |
| | args.add_argument("--vision_model", "-v", type=str, default="./fastvlm_ax650_context_1k_prefill_640/image_encoder_1024x1024.axmodel", help="Path to the vision axmodel.") |
| | args.add_argument("--model_path", "-m", type=str, default="./fastvlm_ax650_context_1k_prefill_640", help="Path to the llm axmodel.") |
| | args.add_argument("--tokenizer_path", "-t", type=str, default="./fastvlm_tokenizer", help="Path to the tokenizer.") |
| | args.add_argument("--input_size", "-i", type=str, default="1024", help="Input size of the vision encoder model.") |
| | |
| | |
| | args = args.parse_args() |
| |
|
| | token_len_map = {"2048": 1280, |
| | "1024": 256, |
| | "768": 144, |
| | "512": 64, |
| | "256": 16} |
| | |
| | token_length = token_len_map[args.input_size] |
| |
|
| | print("Loading config, tokenizer and init model.") |
| | config, tokenizer = load_model_and_tokenizer(model_path=args.tokenizer_path) |
| |
|
| | slice_len = 128 |
| | |
| | max_seq_len = 1024 |
| |
|
| | embeds = np.load(os.path.join("./embeds", "model.embed_tokens.weight.npy"), mmap_mode='r') |
| | |
| | ax_session = ax.InferenceSession(args.vision_model) |
| | imer = InferManager(config, args.model_path, max_seq_len=max_seq_len) |
| |
|
| | print(f"[INFO]: 输入文本进行对话,或者输入图片路径进行图片理解, 或者输入q退出对话。") |
| | while True: |
| | prompt = input("prompt<<") |
| | if prompt.strip() == "q": |
| | print(f"[INFO]: 对话结束,再见。") |
| | break |
| | else: |
| | get_input = prompt.strip() |
| | if get_input.lower().endswith(("jpg", "jpeg", "png")): |
| | if not os.path.isfile(get_input): |
| | print("[INFO]: 输入错误,请检查图片输入路径。") |
| | continue |
| | image_features = vision_encoder(get_input, ax_session, args) |
| | get_input = "Describe the image in detail." |
| | llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds) |
| | else: |
| | image_features = None |
| | llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds) |
| |
|