import os import argparse from PIL import Image import numpy as np from utils.llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN from utils.llava_qwen import tokenizer_image_token, expand2square from utils.infer_func import InferManager from transformers import AutoTokenizer, AutoConfig, CLIPImageProcessor import axengine as ax from ml_dtypes import bfloat16 import argparse def load_model_and_tokenizer(model_path): tokenizer = AutoTokenizer.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path) mm_use_im_start_end = getattr(config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) return config, tokenizer def vision_encoder(image_path, ax_session, args): image_processor = CLIPImageProcessor(size={"shortest_edge": int(args.input_size)}, # CLIP 支持 336x336 crop_size={"height": int(args.input_size), "width": int(args.input_size)}, image_mean=[0, 0, 0], image_std=[1/255, 1/255, 1/255] ) image = Image.open(image_path).convert('RGB') image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) input_image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] input_image = input_image.unsqueeze(0) # add batch dimension input_image = input_image.numpy().astype(np.uint8).transpose((0, 2, 3, 1)) # NHWC to NCHW vit_output = ax_session.run(None, {"images": input_image})[0] return vit_output def llm_infer(image_features, llm_path, config, tokenizer, imer, get_input, token_length, embeds): prompt = "<|im_start|>system\nYou are a helpful assistant, created by apple company.<|im_end|>\n" question = get_input prompt += "<|im_start|>user\n" + question if image_features is not None: # # for idx in range(len(image_features)): #prompt += "\n" + ""*token_length + "\n" prompt += "\n" + ""*token_length + "\n" prompt += "<|im_end|>\n<|im_start|>assistant\n" # token_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX) token_ids = tokenizer.encode(prompt) # 图像理解 prefill_data = np.take(embeds, token_ids, axis=0) prefill_data = prefill_data.astype(bfloat16) token_len = len(token_ids) if image_features is not None: image_start_index = np.where(np.array(token_ids) == 151646)[0][0] # tag 151646 image_insert_index = image_start_index + 1 prefill_data[image_insert_index : image_insert_index + token_length] = image_features[0, :, :] eos_token_id = None if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1: eos_token_id = config.eos_token_id slice_len = 128 # prefill_max_len = 640 max_seq_len = 1024 # prefill + decode max length # imer = InferManager(config, llm_path, max_seq_len=max_seq_len) # prefill + decode max length token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len) imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id) print("\n") if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument("--vision_model", "-v", type=str, default="./fastvlm_ax650_context_1k_prefill_640/image_encoder_1024x1024.axmodel", help="Path to the vision axmodel.") args.add_argument("--model_path", "-m", type=str, default="./fastvlm_ax650_context_1k_prefill_640", help="Path to the llm axmodel.") args.add_argument("--tokenizer_path", "-t", type=str, default="./fastvlm_tokenizer", help="Path to the tokenizer.") args.add_argument("--input_size", "-i", type=str, default="1024", help="Input size of the vision encoder model.") # args.add_argument("--question", type=str, default="介绍一下你自己", help="The question to ask the model.") args = args.parse_args() token_len_map = {"2048": 1280, "1024": 256, "768": 144, "512": 64, "256": 16} token_length = token_len_map[args.input_size] print("Loading config, tokenizer and init model.") config, tokenizer = load_model_and_tokenizer(model_path=args.tokenizer_path) slice_len = 128 # prefill_max_len = 640 max_seq_len = 1024 # prefill + decode max length imer = InferManager(config, args.model_path, max_seq_len=max_seq_len) # prefill + decode max length ax_session = ax.InferenceSession(args.vision_model) embeds = np.load(os.path.join(args.model_path, "model.embed_tokens.weight.npy")) print(embeds.shape) print(f"[INFO]: 输入文本进行对话,或者输入图片路径进行图片理解, 或者输入q退出对话。") while True: prompt = input("prompt<<") if prompt.strip() == "q": print(f"[INFO]: 对话结束,再见。") break else: get_input = prompt.strip() if get_input.lower().endswith(("jpg", "jpeg", "png")): if not os.path.isfile(get_input): print("[INFO]: 输入错误,请检查图片输入路径。") continue image_features = vision_encoder(get_input, ax_session, args) get_input = "Describe the image in detail." llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds) else: image_features = None llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)