Spaces:
Sleeping
Sleeping
| import math | |
| import random | |
| from typing import Optional, Tuple, List, Dict, Any | |
| import os | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import torchvision | |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
| import onnxruntime | |
| # 如果你用 transformers 的 AutoTokenizer(推荐) | |
| from transformers import AutoTokenizer | |
| # --------------------------- | |
| # Config / 默认参数 | |
| # --------------------------- | |
| DEFAULT_IMAGE_SIZE = 224 | |
| DEFAULT_MAX_ROWS = 2 | |
| DEFAULT_MAX_COLS = 2 | |
| MIN_BLOCK_SIZE = 4 | |
| # --------------------------- | |
| # Tokenizer / Preprocess | |
| # --------------------------- | |
| def load_tokenizer(tokenizer_path: str): | |
| """加载 tokenizer(AutoTokenizer)""" | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
| if tokenizer.chat_template is None: | |
| # 新版本transformer可自动加载,训练环境版本:4.51.3支持 | |
| with open(os.path.join(tokenizer_path, "chat_template.jinja"), "r") as f: | |
| tokenizer.chat_template = f.read() | |
| return tokenizer | |
| def build_image_preprocess(image_size: int = DEFAULT_IMAGE_SIZE): | |
| """返回 torchvision.transforms.Compose 的预处理 callable""" | |
| return Compose([ | |
| Resize(size=image_size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=True), | |
| CenterCrop(size=(image_size, image_size)), | |
| lambda img: img.convert("RGB"), | |
| ToTensor(), | |
| Normalize(mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711)) | |
| ]) | |
| # --------------------------- | |
| # RoPE频率预计算(precompute_freqs_cis) | |
| # --------------------------- | |
| def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, | |
| rope_scaling: Optional[dict] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| 计算 RoPE 的 cos 和 sin 表 | |
| 返回: | |
| freqs_cos: (end, dim) | |
| freqs_sin: (end, dim) | |
| """ | |
| freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| if rope_scaling is not None: | |
| orig_max = rope_scaling.get("original_max_position_embeddings", 2048) | |
| factor = rope_scaling.get("factor", 4) | |
| beta_fast = rope_scaling.get("beta_fast", 4.0) | |
| beta_slow = rope_scaling.get("beta_slow", 1.0) | |
| if end / orig_max > 1.0: | |
| corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2) | |
| power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1) | |
| beta = beta_slow + (beta_fast - beta_slow) * power | |
| scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, | |
| (beta * factor - beta + 1) / (beta * factor), | |
| 1.0 / factor) | |
| freqs = freqs * scale | |
| t = torch.arange(end, device=freqs.device) | |
| freqs = torch.outer(t, freqs).float() # (end, dim/2) | |
| freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) | |
| freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) | |
| return freqs_cos, freqs_sin | |
| # --------------------------- | |
| # 图像自适应切分(adaptive_square_split) | |
| # --------------------------- | |
| def calculate_optimal_split_with_fixed_max(width: int, height: int, max_rows: int, max_cols: int) -> Tuple[int, int, int]: | |
| """ | |
| 计算最佳切分(返回 rows, cols, block_size) | |
| block_size 会向下取整到 16 的倍数,最小为 MIN_BLOCK_SIZE。 | |
| """ | |
| best_rows = 1 | |
| best_cols = 1 | |
| best_block_size = 0 | |
| best_coverage = 0.0 | |
| # 方案1: 固定行数为 max_rows,自适应列数 | |
| rows_fixed = max_rows | |
| for cols in range(1, max_cols + 1): | |
| block_width = width // cols | |
| block_height = height // rows_fixed | |
| square_size = min(block_width, block_height) | |
| if square_size <= 0: | |
| continue | |
| coverage = (cols * square_size) * (rows_fixed * square_size) / (width * height) | |
| if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size): | |
| best_rows, best_cols, best_block_size, best_coverage = rows_fixed, cols, square_size, coverage | |
| # 方案2: 固定列数为 max_cols,自适应行数 | |
| cols_fixed = max_cols | |
| for rows in range(1, max_rows + 1): | |
| block_width = width // cols_fixed | |
| block_height = height // rows | |
| square_size = min(block_width, block_height) | |
| if square_size <= 0: | |
| continue | |
| coverage = (cols_fixed * square_size) * (rows * square_size) / (width * height) | |
| if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size): | |
| best_rows, best_cols, best_block_size, best_coverage = rows, cols_fixed, square_size, coverage | |
| # 方案3: 两者都达到最大 | |
| block_width = width // max_cols | |
| block_height = height // max_rows | |
| square_size = min(block_width, block_height) | |
| if square_size > 0: | |
| coverage = (max_cols * square_size) * (max_rows * square_size) / (width * height) | |
| if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size): | |
| best_rows, best_cols, best_block_size, best_coverage = max_rows, max_cols, square_size, coverage | |
| # 对齐到 16 的倍数并保证最小值 | |
| best_block_size = max(MIN_BLOCK_SIZE, (best_block_size // 16) * 16) | |
| return best_rows, best_cols, best_block_size | |
| def adaptive_square_split(image: Image.Image, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS | |
| ) -> Tuple[List[Image.Image], int, int, int]: | |
| """ | |
| 将 PIL Image 自适应切分为正方形块,返回 (blocks_list, rows, cols, block_size) | |
| blocks_list 是按行主序的块列表(可能少于 max_rows*max_cols) | |
| """ | |
| width, height = image.size | |
| rows, cols, block_size = calculate_optimal_split_with_fixed_max(width, height, max_rows, max_cols) | |
| blocks = [] | |
| for r in range(rows): | |
| for c in range(cols): | |
| left = c * block_size | |
| upper = r * block_size | |
| right = left + block_size | |
| lower = upper + block_size | |
| blocks.append(image.crop((left, upper, right, lower))) | |
| return blocks, rows, cols, block_size | |
| # --------------------------- | |
| # 特殊 token 准备 | |
| # --------------------------- | |
| def prepare_special_tokens(tokenizer, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS) -> Dict[str, int]: | |
| """ | |
| 返回特殊 token id 的 dict,包含 <global-img>, <fake_token_around_image>, <image>, 以及 <row_i_col_j> | |
| """ | |
| special = { | |
| "<global-img>": tokenizer.convert_tokens_to_ids("<global-img>"), | |
| "<fake_token_around_image>": tokenizer.convert_tokens_to_ids("<fake_token_around_image>"), | |
| "<image>": tokenizer.convert_tokens_to_ids("<image>"), | |
| } | |
| for i in range(max_rows): | |
| for j in range(max_cols): | |
| special[f"<row_{i + 1}_col_{j + 1}>"] = tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>") | |
| return special | |
| # --------------------------- | |
| # 将图像切块、填充、stack 为模型输入张量 | |
| # --------------------------- | |
| def prepare_image_patches(image: Image.Image, preprocess_fn, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS | |
| ) -> Tuple[torch.Tensor, List[int]]: | |
| """ | |
| image: PIL.Image | |
| preprocess_fn: callable that maps PIL.Image -> tensor(C,H,W) | |
| 返回: | |
| pixel_values: torch.Tensor, shape (num_patches + 1, C, H, W) -- 最后一个是 full image 原图 | |
| mask_token_ids: list[int] -- 当某个位置为空时,对应的 row_col token id 列表(未去重) | |
| """ | |
| blocks, rows, cols, block_size = adaptive_square_split(image, max_rows=max_rows, max_cols=max_cols) | |
| patch_num = len(blocks) | |
| pad_num = max_rows * max_cols - patch_num | |
| mask_token_id_list = [] | |
| patch_tensors = [] | |
| if pad_num > 0: | |
| # 以行主序填充: 若某个位置超出 rows 或 cols,则用零张量并记录对应的 row_col token id(由调用者映射) | |
| for i in range(max_rows): | |
| for j in range(max_cols): | |
| if i >= rows or j >= cols: | |
| patch_tensors.append(torch.zeros_like(preprocess_fn(image))) | |
| # mask token id 由调用者生成/映射,这里只记录一个占位(具体 id 值需外部映射) | |
| # 返回时,调用者会在文本中找到对应的 special token 的位置并进行 attention mask 操作 | |
| mask_token_id_list.append((i, j)) | |
| else: | |
| patch_tensors.append(preprocess_fn(blocks[i * cols + j])) | |
| else: | |
| patch_tensors = [preprocess_fn(b) for b in blocks] | |
| # 最后附加 full image 的 pixel_values(和你原来逻辑一致) | |
| full_image_tensor = preprocess_fn(image) | |
| pixel_values = torch.stack(patch_tensors + [full_image_tensor], dim=0) # (N_patches+1, C, H, W) | |
| return pixel_values, mask_token_id_list | |
| # --------------------------- | |
| # 在 token 流中构建 image placeholder(原始的占位 token 串) | |
| # --------------------------- | |
| def construct_image_placeholders(special_tokens: Dict[str, int], max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS, | |
| n_image_tokens_per_patch: int = 49) -> str: | |
| """ | |
| 生成一个示例占位字符串,便于拼接到 prompt 中。 | |
| 返回一个包含多个占位符的字符串 (str)。 | |
| """ | |
| image_place_holder = random.choice(["图片如下:", "如下所示的图片:", "请见下面这张图:", "如下图显示:", "参考下方图片:", "图示如下:"]) | |
| for row in range(max_rows): | |
| for col in range(max_cols): | |
| image_place_holder += f"<fake_token_around_image><row_{row + 1}_col_{col + 1}>" | |
| image_place_holder += "<image>" * n_image_tokens_per_patch | |
| # 全局图像块(最后) | |
| image_place_holder += f"<fake_token_around_image><global-img>{'<image>' * n_image_tokens_per_patch}<fake_token_around_image>" | |
| return image_place_holder | |
| # --------------------------- | |
| # 寻找 token 序列中 image 标记出现的位置(用于 attention mask 修改) | |
| # --------------------------- | |
| def find_indices(tokens: torch.Tensor) -> Optional[Dict[int, Dict[int, List[Tuple[int, int]]]]]: | |
| """ | |
| 输入 tokens: shape (B, T) 的 tensor | |
| 返回结构: | |
| results = { batch_index: { k: [(start_idx, end_idx), ...], ... }, ... } | |
| 其中 k 对应 image token 的索引(函数里预设 image_id 列表),返回的 start_idx/end_idx 为占位段在 tokens 中的 start/end(包含) | |
| 说明:此方法沿用了你原来的匹配模式(匹配 [<fake>, <row_i_col_j>] 以及 [<fake>, <global-img>]) | |
| """ | |
| B, T = tokens.size() | |
| # 这里使用与原代码一致的 id 序列(如果 tokenizer 中不同,请改这里) | |
| # image_ids = [[3, i] for i in range(6, 22)] + [[3, 4]] # 预设 pattern | |
| image_ids = [[3, 6], [3, 7], [3, 10], [3, 11],[3, 4]] | |
| image_ids_tensor = torch.tensor(image_ids, device=tokens.device) | |
| len_image_ids = image_ids_tensor.size(1) | |
| if len_image_ids > tokens.size(1): | |
| return None | |
| tokens_view = tokens.unfold(1, len_image_ids, 1) # (B, T - len_image_ids +1, len_image_ids) | |
| matches = [] | |
| for image_id_tensor in image_ids_tensor: | |
| match = (tokens_view == image_id_tensor).all(dim=2) # (B, T-len+1) | |
| matches.append(match) | |
| results = {} | |
| for b in range(B): | |
| batch_res = {} | |
| for k, m in enumerate(matches): | |
| idxs = m[b].nonzero(as_tuple=True)[0] | |
| if len(idxs) > 0: | |
| batch_res[k] = [(i.item() + 2, i.item() + 50) for i in idxs] | |
| if batch_res: | |
| results[b] = batch_res | |
| return results or None | |
| # --------------------------- | |
| # ONNX Session helpers | |
| # --------------------------- | |
| def create_onnx_session(path: str, intra_threads: int = 1) -> onnxruntime.InferenceSession: | |
| opts = onnxruntime.SessionOptions() | |
| opts.intra_op_num_threads = intra_threads | |
| opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL | |
| return onnxruntime.InferenceSession(path, sess_options=opts) | |
| # --------------------------- | |
| # Prefill 阶段(将视觉嵌入插入并运行一次 LLM) | |
| # --------------------------- | |
| def prefill_llm(vision_session: onnxruntime.InferenceSession, | |
| embed_tokens_session: onnxruntime.InferenceSession, | |
| llm_session: onnxruntime.InferenceSession, | |
| pixel_values: torch.Tensor, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| freqs_cos: torch.Tensor, | |
| freqs_sin: torch.Tensor, | |
| special_tokens: Dict[str, int], | |
| seqlen: int, | |
| device: str = "cpu") -> Dict[str, Any]: | |
| """ | |
| 完成 prefill 步骤: | |
| 1) 通过 vision_session 获得视觉嵌入 deepstack_embeds | |
| 2) 通过 embed_tokens_session 获得 token embedding(或直接使用输入 hidden states) | |
| 3) 将视觉嵌入插入 hidden stream(替换占位 token 段) | |
| 4) 调用 llm_session.run 一次,得到 logits、hidden_states、present_keys、present_values | |
| 返回 dict: | |
| { | |
| "logits": np.ndarray, | |
| "hidden_states": np.ndarray, | |
| "present_keys": np.ndarray, | |
| "present_values": np.ndarray | |
| } | |
| """ | |
| # 1) vision embed | |
| ort_inputs_vis = {"inputs": pixel_values.numpy()} | |
| deepstack_embeds = vision_session.run(["deepstack_embeds"], ort_inputs_vis)[0] # e.g. (B, P, L_patch, D) | |
| # 2) token embed | |
| ort_inputs_emb = {"input_ids": input_ids.numpy()} | |
| embed_tokens = embed_tokens_session.run(["embed_tokens"], ort_inputs_emb)[0] # (B, T, D) | |
| # 3) 找到 image placeholder 在 token 中的位置并替换 | |
| image_batch_indices = find_indices(input_ids) | |
| B = input_ids.shape[0] | |
| seqlen = seqlen | |
| new_h = [] | |
| for i in range(B): | |
| h_i = embed_tokens[i] # np array (T, D) | |
| image_indices = image_batch_indices.get(i, {}) if image_batch_indices else {} | |
| # image_indices: {k: [(start,end), ...], ...} | |
| # deepstack_embeds: assume shape (B, P, L_patch, D), P = number_of_image_patches + global | |
| for tki, index_list in image_indices.items(): | |
| # tki 对应 deepstack_embeds 第二维索引 | |
| vision_proj_i = deepstack_embeds[i][tki] # (L_patch, D) | |
| # 取第一个匹配段 | |
| start_idx, end_idx = index_list[0] | |
| # 将 h_i 中 start_idx..end_idx 替换为 vision_proj_i,并截断到 seqlen | |
| # 注意这里我们使用 numpy concat(h_i 是 numpy) | |
| h_i = np.concatenate((h_i[:start_idx], vision_proj_i, h_i[end_idx + 1:]), axis=0)[:seqlen] | |
| new_h.append(h_i) | |
| hidden_states = np.stack(new_h, axis=0) # (B, seqlen, D) | |
| # 4) 呼叫 llm.onnx 做一次前向(prefill) | |
| # past_keys/past_values 用空的 shape(按模型要求) | |
| # 这里 past keys/values 的 shape 需与模型期望一致,示例用随机的 0 长度数组作为占位 | |
| # 如模型要求具体形状,请在调用方准备 | |
| past_keys = np.zeros([8, 0, 2, 64], dtype=np.float32) | |
| past_values = np.zeros([8, 0, 2, 64], dtype=np.float32) | |
| cos_pe = freqs_cos[0: seqlen].numpy() | |
| sin_pe = freqs_sin[0: seqlen].numpy() | |
| ort_inputs_llm = { | |
| "input_ids": hidden_states.astype(np.float32), | |
| "attention_mask": attention_mask.numpy(), | |
| "cos_pe": cos_pe.astype(np.float32), | |
| "sin_pe": sin_pe.astype(np.float32), | |
| "past_keys": past_keys, | |
| "past_values": past_values | |
| } | |
| logits, hidden_states_out, present_keys, present_values = llm_session.run( | |
| ["logits", "hidden_states", "present_keys", "present_values"], ort_inputs_llm | |
| ) | |
| return { | |
| "logits": logits, | |
| "hidden_states": hidden_states_out, | |
| "present_keys": present_keys, | |
| "present_values": present_values | |
| } | |
| # --------------------------- | |
| # Next-token 自回归生成(基于 present keys/values) | |
| # --------------------------- | |
| def generate_autoregressive(llm_session: onnxruntime.InferenceSession, | |
| embed_tokens_session: onnxruntime.InferenceSession, | |
| tokenizer, | |
| initial_present: Dict[str, np.ndarray], | |
| start_token_id: int, | |
| freqs_cos: torch.Tensor, | |
| freqs_sin: torch.Tensor, | |
| attention_mask: np.ndarray, | |
| max_new_tokens: int = 128, | |
| eos_token_id: int = 2, | |
| start_pos: int = None, | |
| output_queue=None): | |
| """ | |
| 基于 prefill 返回的 present_keys/present_values 进行自回归生成。 | |
| 每一步: | |
| - 用 embed_tokens_session 获取新 token 的 embedding | |
| - 用 llm_session 传入 present keys/values 并得到新的 present keys/values 与 logits | |
| - 选取最大 logit(argmax)作为下一个 token(你可替换为 sampling 策略) | |
| 如果提供了 output_queue,则将每个生成的文本片段放入队列。 | |
| 注意:present keys/values 的名称与 shape 与模型实现相关,确保和模型一致。 | |
| """ | |
| present_keys = initial_present["present_keys"] | |
| present_values = initial_present["present_values"] | |
| present_keys = present_keys | |
| present_values = present_values | |
| token_id = int(start_token_id) | |
| if start_pos is None: | |
| start_pos = attention_mask.shape[1] | |
| generated_ids = [] | |
| for step in range(max_new_tokens): | |
| # 解码当前 token | |
| decoded = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| # 如果提供了输出队列,放入生成的文本 | |
| if output_queue is not None: | |
| output_queue.put(decoded) | |
| # 更新 attention mask | |
| attention_mask = np.concatenate([attention_mask, np.array([[1]], dtype=np.int64)], axis=1) | |
| # embed 当前 token | |
| embed_tokens = embed_tokens_session.run(["embed_tokens"], {"input_ids": np.array([[token_id]], dtype=np.int64)})[0] | |
| cos_pe = freqs_cos[start_pos: start_pos + 1].numpy() | |
| sin_pe = freqs_sin[start_pos: start_pos + 1].numpy() | |
| ort_inputs = { | |
| "input_ids": embed_tokens.astype(np.float32), | |
| "attention_mask": attention_mask, | |
| "cos_pe": cos_pe.astype(np.float32), | |
| "sin_pe": sin_pe.astype(np.float32), | |
| "past_keys": present_keys, | |
| "past_values": present_values | |
| } | |
| logits, hidden_states, present_keys, present_values = llm_session.run( | |
| ["logits", "hidden_states", "present_keys", "present_values"], ort_inputs | |
| ) | |
| token_id = int(np.argmax(logits[:, -1, :], axis=-1)[0]) | |
| generated_ids.append(token_id) | |
| if token_id == eos_token_id: | |
| break | |
| start_pos += 1 | |
| # 发送生成结束信号 | |
| if output_queue is not None: | |
| output_queue.put(None) | |
| return generated_ids | |
| def main_example(): | |
| tokenizer_path = "./custom_tokenizer" | |
| tokenizer = load_tokenizer(tokenizer_path) | |
| preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE) | |
| # image | |
| image_path = "/Users/hulk/Downloads/coco128/images/train2017/000000000165.jpg" | |
| image = Image.open(image_path).convert("RGB") | |
| # special tokens | |
| special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4) | |
| pixel_values, mask_positions = prepare_image_patches(image, preprocess, max_rows=4, max_cols=4) | |
| # 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template) | |
| query = "图片中的人在做什么。" | |
| messages = [ | |
| {"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."}, | |
| {"role": "user", "content": query + construct_image_placeholders(special_tokens)} | |
| ] | |
| inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| # precompute RoPE | |
| freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6) | |
| # create onnx sessions | |
| vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2) | |
| embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2) | |
| llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2) | |
| # prefill | |
| seqlen = input_ids.shape[1] | |
| prefill_out = prefill_llm( | |
| vision_session=vision_session, | |
| embed_tokens_session=embed_tokens_session, | |
| llm_session=llm_session, | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| freqs_cos=freqs_cos, | |
| freqs_sin=freqs_sin, | |
| special_tokens=special_tokens, | |
| seqlen=seqlen | |
| ) | |
| # start token id = argmax last logit | |
| start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0]) | |
| generated = generate_autoregressive( | |
| llm_session=llm_session, | |
| embed_tokens_session=embed_tokens_session, | |
| tokenizer=tokenizer, | |
| initial_present={"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]}, | |
| start_token_id=start_token_id, | |
| freqs_cos=freqs_cos, | |
| freqs_sin=freqs_sin, | |
| attention_mask=attention_mask.numpy(), | |
| max_new_tokens=128, | |
| eos_token_id=2, | |
| start_pos=seqlen | |
| ) | |
| if __name__ == "__main__": | |
| # main_example() | |
| pass |