File size: 6,101 Bytes
e61f34a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eab1cc
e61f34a
7eab1cc
 
e61f34a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c9770
e61f34a
 
 
 
 
 
 
49c9770
 
e61f34a
 
49c9770
 
e61f34a
 
 
 
 
 
 
49c9770
e61f34a
7eab1cc
e61f34a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eab1cc
e61f34a
 
 
 
7eab1cc
 
 
 
 
 
 
 
e61f34a
 
 
 
 
 
 
 
 
 
49c9770
 
 
e61f34a
 
 
 
 
 
 
 
 
 
 
 
7eab1cc
e61f34a
49c9770
e61f34a
 
49c9770
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import argparse

from PIL import Image
import numpy as np
from utils.llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from utils.llava_qwen import tokenizer_image_token, expand2square
from utils.infer_func import InferManager
from transformers import AutoTokenizer, AutoConfig, CLIPImageProcessor
import axengine as ax
from ml_dtypes import bfloat16
import argparse

def load_model_and_tokenizer(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    config = AutoConfig.from_pretrained(model_path)

    mm_use_im_start_end = getattr(config, "mm_use_im_start_end", False)
    mm_use_im_patch_token = getattr(config, "mm_use_im_patch_token", True)
    if mm_use_im_patch_token:
        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

    return config, tokenizer

def vision_encoder(image_path, ax_session, args):

    image_processor = CLIPImageProcessor(size={"shortest_edge": int(args.input_size)},        # CLIP 支持 336x336
                                         crop_size={"height": int(args.input_size), "width": int(args.input_size)},
                                         image_mean=[0, 0, 0],
                                         image_std=[1/255, 1/255, 1/255]
                                         )

    image = Image.open(image_path).convert('RGB')

    image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
    input_image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
    input_image = input_image.unsqueeze(0) # add batch dimension

    input_image = input_image.numpy().astype(np.uint8).transpose((0, 2, 3, 1))  # NHWC to NCHW
    vit_output = ax_session.run(None, {"images": input_image})[0]

    return vit_output

def llm_infer(image_features, llm_path, config, tokenizer, imer, get_input, token_length, embeds):

    prompt = "<|im_start|>system\nYou are a helpful assistant, created by apple company.<|im_end|>\n"
    question = get_input
    prompt += "<|im_start|>user\n" + question

    if image_features is not None:
    #     # for idx in range(len(image_features)):
        #prompt += "\n<img>" + "<image>"*token_length + "</img>\n"
        prompt += "\n" + "<image>"*token_length + "\n"
    prompt += "<|im_end|>\n<|im_start|>assistant\n"

    # token_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX)
    token_ids = tokenizer.encode(prompt)

    # 图像理解
    prefill_data = np.take(embeds, token_ids, axis=0)
    prefill_data = prefill_data.astype(bfloat16)
    token_len = len(token_ids)

    if image_features is not None:
        image_start_index = np.where(np.array(token_ids) == 151646)[0][0] # <image> tag 151646
        image_insert_index = image_start_index + 1
        prefill_data[image_insert_index : image_insert_index + token_length] = image_features[0, :, :]

    eos_token_id = None
    if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1:
        eos_token_id = config.eos_token_id

    slice_len = 128
    # prefill_max_len = 640
    max_seq_len = 1024  # prefill + decode max length

    # imer = InferManager(config, llm_path, max_seq_len=max_seq_len) # prefill + decode max length

    token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len)
    imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id)
    print("\n")

if __name__ == "__main__":

    args = argparse.ArgumentParser()
    args.add_argument("--vision_model", "-v", type=str, default="./fastvlm_ax650_context_1k_prefill_640/image_encoder_1024x1024.axmodel", help="Path to the vision axmodel.")
    args.add_argument("--model_path", "-m", type=str, default="./fastvlm_ax650_context_1k_prefill_640", help="Path to the llm axmodel.")
    args.add_argument("--tokenizer_path", "-t", type=str, default="./fastvlm_tokenizer", help="Path to the tokenizer.")
    args.add_argument("--input_size", "-i", type=str, default="1024", help="Input size of the vision encoder model.")
    # args.add_argument("--question", type=str, default="介绍一下你自己", help="The question to ask the model.")
    
    args = args.parse_args()

    token_len_map = {"2048": 1280,
                     "1024": 256,
                     "768": 144,
                     "512": 64,
                     "256": 16}
    
    token_length = token_len_map[args.input_size]

    print("Loading config, tokenizer and init model.")
    config, tokenizer = load_model_and_tokenizer(model_path=args.tokenizer_path)

    slice_len = 128
    # prefill_max_len = 640
    max_seq_len = 1024  # prefill + decode max length

    imer = InferManager(config, args.model_path, max_seq_len=max_seq_len) # prefill + decode max length
    ax_session = ax.InferenceSession(args.vision_model)

    embeds = np.load(os.path.join(args.model_path, "model.embed_tokens.weight.npy"))
    print(embeds.shape)

    print(f"[INFO]: 输入文本进行对话,或者输入图片路径进行图片理解, 或者输入q退出对话。")
    while True:
        prompt = input("prompt<<")
        if prompt.strip() == "q":
            print(f"[INFO]: 对话结束,再见。")
            break
        else:
            get_input = prompt.strip()
            if get_input.lower().endswith(("jpg", "jpeg", "png")):
                if not os.path.isfile(get_input):
                    print("[INFO]: 输入错误,请检查图片输入路径。")
                    continue
                image_features = vision_encoder(get_input, ax_session, args)
                get_input = "Describe the image in detail."
                llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)
            else:
                image_features = None
                llm_infer(image_features, args.model_path, config, tokenizer, imer, get_input, token_length, embeds)