File size: 22,014 Bytes
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de14119
 
 
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de14119
 
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69fb30a
 
432d085
 
 
 
 
 
 
69fb30a
432d085
 
 
 
 
 
 
 
 
 
 
 
 
69fb30a
432d085
69fb30a
 
 
 
 
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69fb30a
 
 
 
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
import math
import random
from typing import Optional, Tuple, List, Dict, Any
import os
import numpy as np
import torch
from PIL import Image
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import onnxruntime

# 如果你用 transformers 的 AutoTokenizer(推荐)
from transformers import AutoTokenizer

# ---------------------------
# Config / 默认参数
# ---------------------------

DEFAULT_IMAGE_SIZE = 224
DEFAULT_MAX_ROWS = 2
DEFAULT_MAX_COLS = 2
MIN_BLOCK_SIZE = 4


# ---------------------------
# Tokenizer / Preprocess
# ---------------------------

def load_tokenizer(tokenizer_path: str):
    """加载 tokenizer(AutoTokenizer)"""
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    if tokenizer.chat_template is None:
        # 新版本transformer可自动加载,训练环境版本:4.51.3支持
        with open(os.path.join(tokenizer_path, "chat_template.jinja"), "r") as f:
            tokenizer.chat_template = f.read()
    return tokenizer


def build_image_preprocess(image_size: int = DEFAULT_IMAGE_SIZE):
    """返回 torchvision.transforms.Compose 的预处理 callable"""
    return Compose([
        Resize(size=image_size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=True),
        CenterCrop(size=(image_size, image_size)),
        lambda img: img.convert("RGB"),
        ToTensor(),
        Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                  std=(0.26862954, 0.26130258, 0.27577711))
    ])


# ---------------------------
# RoPE频率预计算(precompute_freqs_cis)
# ---------------------------

def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
                         rope_scaling: Optional[dict] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    计算 RoPE 的 cos 和 sin 表

    返回:
      freqs_cos: (end, dim)
      freqs_sin: (end, dim)
    """
    freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    if rope_scaling is not None:
        orig_max = rope_scaling.get("original_max_position_embeddings", 2048)
        factor = rope_scaling.get("factor", 4)
        beta_fast = rope_scaling.get("beta_fast", 4.0)
        beta_slow = rope_scaling.get("beta_slow", 1.0)

        if end / orig_max > 1.0:
            corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
            power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
            beta = beta_slow + (beta_fast - beta_slow) * power
            scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim,
                                (beta * factor - beta + 1) / (beta * factor),
                                1.0 / factor)
            freqs = freqs * scale

    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()  # (end, dim/2)
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
    return freqs_cos, freqs_sin


# ---------------------------
# 图像自适应切分(adaptive_square_split)
# ---------------------------

def calculate_optimal_split_with_fixed_max(width: int, height: int, max_rows: int, max_cols: int) -> Tuple[int, int, int]:
    """
    计算最佳切分(返回 rows, cols, block_size)
    block_size 会向下取整到 16 的倍数,最小为 MIN_BLOCK_SIZE。
    """
    best_rows = 1
    best_cols = 1
    best_block_size = 0
    best_coverage = 0.0

    # 方案1: 固定行数为 max_rows,自适应列数
    rows_fixed = max_rows
    for cols in range(1, max_cols + 1):
        block_width = width // cols
        block_height = height // rows_fixed
        square_size = min(block_width, block_height)
        if square_size <= 0:
            continue
        coverage = (cols * square_size) * (rows_fixed * square_size) / (width * height)
        if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
            best_rows, best_cols, best_block_size, best_coverage = rows_fixed, cols, square_size, coverage

    # 方案2: 固定列数为 max_cols,自适应行数
    cols_fixed = max_cols
    for rows in range(1, max_rows + 1):
        block_width = width // cols_fixed
        block_height = height // rows
        square_size = min(block_width, block_height)
        if square_size <= 0:
            continue
        coverage = (cols_fixed * square_size) * (rows * square_size) / (width * height)
        if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
            best_rows, best_cols, best_block_size, best_coverage = rows, cols_fixed, square_size, coverage

    # 方案3: 两者都达到最大
    block_width = width // max_cols
    block_height = height // max_rows
    square_size = min(block_width, block_height)
    if square_size > 0:
        coverage = (max_cols * square_size) * (max_rows * square_size) / (width * height)
        if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
            best_rows, best_cols, best_block_size, best_coverage = max_rows, max_cols, square_size, coverage

    # 对齐到 16 的倍数并保证最小值
    best_block_size = max(MIN_BLOCK_SIZE, (best_block_size // 16) * 16)
    return best_rows, best_cols, best_block_size


def adaptive_square_split(image: Image.Image, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
                          ) -> Tuple[List[Image.Image], int, int, int]:
    """
    将 PIL Image 自适应切分为正方形块,返回 (blocks_list, rows, cols, block_size)
    blocks_list 是按行主序的块列表(可能少于 max_rows*max_cols)
    """
    width, height = image.size
    rows, cols, block_size = calculate_optimal_split_with_fixed_max(width, height, max_rows, max_cols)

    blocks = []
    for r in range(rows):
        for c in range(cols):
            left = c * block_size
            upper = r * block_size
            right = left + block_size
            lower = upper + block_size
            blocks.append(image.crop((left, upper, right, lower)))

    return blocks, rows, cols, block_size


# ---------------------------
# 特殊 token 准备
# ---------------------------

def prepare_special_tokens(tokenizer, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS) -> Dict[str, int]:
    """
    返回特殊 token id 的 dict,包含 <global-img>, <fake_token_around_image>, <image>, 以及 <row_i_col_j>
    """
    special = {
        "<global-img>": tokenizer.convert_tokens_to_ids("<global-img>"),
        "<fake_token_around_image>": tokenizer.convert_tokens_to_ids("<fake_token_around_image>"),
        "<image>": tokenizer.convert_tokens_to_ids("<image>"),
    }
    for i in range(max_rows):
        for j in range(max_cols):
            special[f"<row_{i + 1}_col_{j + 1}>"] = tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>")
    return special


# ---------------------------
# 将图像切块、填充、stack 为模型输入张量
# ---------------------------

def prepare_image_patches(image: Image.Image, preprocess_fn, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
                          ) -> Tuple[torch.Tensor, List[int]]:
    """
    image: PIL.Image
    preprocess_fn: callable that maps PIL.Image -> tensor(C,H,W)
    返回:
      pixel_values: torch.Tensor, shape (num_patches + 1, C, H, W)  -- 最后一个是 full image 原图
      mask_token_ids: list[int]  -- 当某个位置为空时,对应的 row_col token id 列表(未去重)
    """
    blocks, rows, cols, block_size = adaptive_square_split(image, max_rows=max_rows, max_cols=max_cols)
    patch_num = len(blocks)
    pad_num = max_rows * max_cols - patch_num
    mask_token_id_list = []
    patch_tensors = []

    if pad_num > 0:
        # 以行主序填充: 若某个位置超出 rows 或 cols,则用零张量并记录对应的 row_col token id(由调用者映射)
        for i in range(max_rows):
            for j in range(max_cols):
                if i >= rows or j >= cols:
                    patch_tensors.append(torch.zeros_like(preprocess_fn(image)))
                    # mask token id 由调用者生成/映射,这里只记录一个占位(具体 id 值需外部映射)
                    # 返回时,调用者会在文本中找到对应的 special token 的位置并进行 attention mask 操作
                    mask_token_id_list.append((i, j))
                else:
                    patch_tensors.append(preprocess_fn(blocks[i * cols + j]))
    else:
        patch_tensors = [preprocess_fn(b) for b in blocks]

    # 最后附加 full image 的 pixel_values(和你原来逻辑一致)
    full_image_tensor = preprocess_fn(image)
    pixel_values = torch.stack(patch_tensors + [full_image_tensor], dim=0)  # (N_patches+1, C, H, W)

    return pixel_values, mask_token_id_list


# ---------------------------
# 在 token 流中构建 image placeholder(原始的占位 token 串)
# ---------------------------

def construct_image_placeholders(special_tokens: Dict[str, int], max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS,
                                 n_image_tokens_per_patch: int = 49) -> str:
    """
    生成一个示例占位字符串,便于拼接到 prompt 中。
    返回一个包含多个占位符的字符串 (str)。
    """
    image_place_holder = random.choice(["图片如下:", "如下所示的图片:", "请见下面这张图:", "如下图显示:", "参考下方图片:", "图示如下:"])
    for row in range(max_rows):
        for col in range(max_cols):
            image_place_holder += f"<fake_token_around_image><row_{row + 1}_col_{col + 1}>"
            image_place_holder += "<image>" * n_image_tokens_per_patch
    # 全局图像块(最后)
    image_place_holder += f"<fake_token_around_image><global-img>{'<image>' * n_image_tokens_per_patch}<fake_token_around_image>"
    return image_place_holder


# ---------------------------
# 寻找 token 序列中 image 标记出现的位置(用于 attention mask 修改)
# ---------------------------

def find_indices(tokens: torch.Tensor) -> Optional[Dict[int, Dict[int, List[Tuple[int, int]]]]]:
    """
    输入 tokens: shape (B, T) 的 tensor
    返回结构:
      results = { batch_index: { k: [(start_idx, end_idx), ...], ... }, ... }
    其中 k 对应 image token 的索引(函数里预设 image_id 列表),返回的 start_idx/end_idx 为占位段在 tokens 中的 start/end(包含)
    说明:此方法沿用了你原来的匹配模式(匹配 [<fake>, <row_i_col_j>] 以及 [<fake>, <global-img>])
    """
    B, T = tokens.size()
    # 这里使用与原代码一致的 id 序列(如果 tokenizer 中不同,请改这里)
    # image_ids = [[3, i] for i in range(6, 22)] + [[3, 4]]  # 预设 pattern
    image_ids = [[3, 6], [3, 7], [3, 10], [3, 11],[3, 4]]
    image_ids_tensor = torch.tensor(image_ids, device=tokens.device)
    len_image_ids = image_ids_tensor.size(1)
    if len_image_ids > tokens.size(1):
        return None
    tokens_view = tokens.unfold(1, len_image_ids, 1)  # (B, T - len_image_ids +1, len_image_ids)
    matches = []
    for image_id_tensor in image_ids_tensor:
        match = (tokens_view == image_id_tensor).all(dim=2)  # (B, T-len+1)
        matches.append(match)
    results = {}
    for b in range(B):
        batch_res = {}
        for k, m in enumerate(matches):
            idxs = m[b].nonzero(as_tuple=True)[0]
            if len(idxs) > 0:
                batch_res[k] = [(i.item() + 2, i.item() + 50) for i in idxs]
        if batch_res:
            results[b] = batch_res
    return results or None


# ---------------------------
# ONNX Session helpers
# ---------------------------

def create_onnx_session(path: str, intra_threads: int = 1) -> onnxruntime.InferenceSession:
    opts = onnxruntime.SessionOptions()
    opts.intra_op_num_threads = intra_threads
    opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
    return onnxruntime.InferenceSession(path, sess_options=opts)


# ---------------------------
# Prefill 阶段(将视觉嵌入插入并运行一次 LLM)
# ---------------------------

def prefill_llm(vision_session: onnxruntime.InferenceSession,
                embed_tokens_session: onnxruntime.InferenceSession,
                llm_session: onnxruntime.InferenceSession,
                pixel_values: torch.Tensor,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                freqs_cos: torch.Tensor,
                freqs_sin: torch.Tensor,
                special_tokens: Dict[str, int],
                seqlen: int,
                device: str = "cpu") -> Dict[str, Any]:
    """
    完成 prefill 步骤:
      1) 通过 vision_session 获得视觉嵌入 deepstack_embeds
      2) 通过 embed_tokens_session 获得 token embedding(或直接使用输入 hidden states)
      3) 将视觉嵌入插入 hidden stream(替换占位 token 段)
      4) 调用 llm_session.run 一次,得到 logits、hidden_states、present_keys、present_values

    返回 dict:
       {
         "logits": np.ndarray,
         "hidden_states": np.ndarray,
         "present_keys": np.ndarray,
         "present_values": np.ndarray
       }
    """
    # 1) vision embed
    ort_inputs_vis = {"inputs": pixel_values.numpy()}
    deepstack_embeds = vision_session.run(["deepstack_embeds"], ort_inputs_vis)[0]  # e.g. (B, P, L_patch, D)

    # 2) token embed
    ort_inputs_emb = {"input_ids": input_ids.numpy()}
    embed_tokens = embed_tokens_session.run(["embed_tokens"], ort_inputs_emb)[0]  # (B, T, D)

    # 3) 找到 image placeholder 在 token 中的位置并替换
    image_batch_indices = find_indices(input_ids)
    B = input_ids.shape[0]
    seqlen = seqlen
    new_h = []

    for i in range(B):
        h_i = embed_tokens[i]  # np array (T, D)
        image_indices = image_batch_indices.get(i, {}) if image_batch_indices else {}
        # image_indices: {k: [(start,end), ...], ...}
        # deepstack_embeds: assume shape (B, P, L_patch, D), P = number_of_image_patches + global
        for tki, index_list in image_indices.items():
            # tki 对应 deepstack_embeds 第二维索引
            vision_proj_i = deepstack_embeds[i][tki]  # (L_patch, D)
            # 取第一个匹配段
            start_idx, end_idx = index_list[0]
            # 将 h_i 中 start_idx..end_idx 替换为 vision_proj_i,并截断到 seqlen
            # 注意这里我们使用 numpy concat(h_i 是 numpy)
            h_i = np.concatenate((h_i[:start_idx], vision_proj_i, h_i[end_idx + 1:]), axis=0)[:seqlen]
        new_h.append(h_i)

    hidden_states = np.stack(new_h, axis=0)  # (B, seqlen, D)

    # 4) 呼叫 llm.onnx 做一次前向(prefill)
    # past_keys/past_values 用空的 shape(按模型要求)
    # 这里 past keys/values 的 shape 需与模型期望一致,示例用随机的 0 长度数组作为占位
    # 如模型要求具体形状,请在调用方准备
    past_keys = np.zeros([8, 0, 2, 64], dtype=np.float32)
    past_values = np.zeros([8, 0, 2, 64], dtype=np.float32)
    cos_pe = freqs_cos[0: seqlen].numpy()
    sin_pe = freqs_sin[0: seqlen].numpy()

    ort_inputs_llm = {
        "input_ids": hidden_states.astype(np.float32),
        "attention_mask": attention_mask.numpy(),
        "cos_pe": cos_pe.astype(np.float32),
        "sin_pe": sin_pe.astype(np.float32),
        "past_keys": past_keys,
        "past_values": past_values
    }

    logits, hidden_states_out, present_keys, present_values = llm_session.run(
        ["logits", "hidden_states", "present_keys", "present_values"], ort_inputs_llm
    )

    return {
        "logits": logits,
        "hidden_states": hidden_states_out,
        "present_keys": present_keys,
        "present_values": present_values
    }


# ---------------------------
# Next-token 自回归生成(基于 present keys/values)
# ---------------------------

def generate_autoregressive(llm_session: onnxruntime.InferenceSession,
                            embed_tokens_session: onnxruntime.InferenceSession,
                            tokenizer,
                            initial_present: Dict[str, np.ndarray],
                            start_token_id: int,
                            freqs_cos: torch.Tensor,
                            freqs_sin: torch.Tensor,
                            attention_mask: np.ndarray,
                            max_new_tokens: int = 128,
                            eos_token_id: int = 2,
                            start_pos: int = None,
                            output_queue=None):
    """
    基于 prefill 返回的 present_keys/present_values 进行自回归生成。
    每一步:
      - 用 embed_tokens_session 获取新 token 的 embedding
      - 用 llm_session 传入 present keys/values 并得到新的 present keys/values 与 logits
      - 选取最大 logit(argmax)作为下一个 token(你可替换为 sampling 策略)

    如果提供了 output_queue,则将每个生成的文本片段放入队列。
    注意:present keys/values 的名称与 shape 与模型实现相关,确保和模型一致。
    """
    present_keys = initial_present["present_keys"]
    present_values = initial_present["present_values"]
    present_keys = present_keys
    present_values = present_values

    token_id = int(start_token_id)
    if start_pos is None:
        start_pos = attention_mask.shape[1]

    generated_ids = []
    for step in range(max_new_tokens):
        # 解码当前 token
        decoded = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        
        # 如果提供了输出队列,放入生成的文本
        if output_queue is not None:
            output_queue.put(decoded)
        
        # 更新 attention mask
        attention_mask = np.concatenate([attention_mask, np.array([[1]], dtype=np.int64)], axis=1)

        # embed 当前 token
        embed_tokens = embed_tokens_session.run(["embed_tokens"], {"input_ids": np.array([[token_id]], dtype=np.int64)})[0]

        cos_pe = freqs_cos[start_pos: start_pos + 1].numpy()
        sin_pe = freqs_sin[start_pos: start_pos + 1].numpy()

        ort_inputs = {
            "input_ids": embed_tokens.astype(np.float32),
            "attention_mask": attention_mask,
            "cos_pe": cos_pe.astype(np.float32),
            "sin_pe": sin_pe.astype(np.float32),
            "past_keys": present_keys,
            "past_values": present_values
        }

        logits, hidden_states, present_keys, present_values = llm_session.run(
            ["logits", "hidden_states", "present_keys", "present_values"], ort_inputs
        )

        token_id = int(np.argmax(logits[:, -1, :], axis=-1)[0])
        generated_ids.append(token_id)

        if token_id == eos_token_id:
            break

        start_pos += 1
    
    # 发送生成结束信号
    if output_queue is not None:
        output_queue.put(None)

    return generated_ids


def main_example():
    
    tokenizer_path = "./custom_tokenizer"
    tokenizer = load_tokenizer(tokenizer_path)
    preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)

    # image
    image_path = "/Users/hulk/Downloads/coco128/images/train2017/000000000165.jpg"
    image = Image.open(image_path).convert("RGB")

    # special tokens
    special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)

    pixel_values, mask_positions = prepare_image_patches(image, preprocess, max_rows=4, max_cols=4)

    # 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template)
    query = "图片中的人在做什么。"
    messages = [
        {"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
        {"role": "user", "content": query + construct_image_placeholders(special_tokens)}
    ]
    inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # precompute RoPE
    freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)

    # create onnx sessions
    vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
    embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
    llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)

    # prefill
    seqlen = input_ids.shape[1]
    prefill_out = prefill_llm(
        vision_session=vision_session,
        embed_tokens_session=embed_tokens_session,
        llm_session=llm_session,
        pixel_values=pixel_values,
        input_ids=input_ids,
        attention_mask=attention_mask,
        freqs_cos=freqs_cos,
        freqs_sin=freqs_sin,
        special_tokens=special_tokens,
        seqlen=seqlen
    )

    # start token id = argmax last logit
    start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])

    generated = generate_autoregressive(
        llm_session=llm_session,
        embed_tokens_session=embed_tokens_session,
        tokenizer=tokenizer,
        initial_present={"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
        start_token_id=start_token_id,
        freqs_cos=freqs_cos,
        freqs_sin=freqs_sin,
        attention_mask=attention_mask.numpy(),
        max_new_tokens=128,
        eos_token_id=2,
        start_pos=seqlen
    )



if __name__ == "__main__":
    # main_example()
    pass