tinymind-90M-onnx / tinymind.py
TalkUHulk's picture
Update tinymind.py
de14119 verified
import math
import random
from typing import Optional, Tuple, List, Dict, Any
import os
import numpy as np
import torch
from PIL import Image
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import onnxruntime
# 如果你用 transformers 的 AutoTokenizer(推荐)
from transformers import AutoTokenizer
# ---------------------------
# Config / 默认参数
# ---------------------------
DEFAULT_IMAGE_SIZE = 224
DEFAULT_MAX_ROWS = 2
DEFAULT_MAX_COLS = 2
MIN_BLOCK_SIZE = 4
# ---------------------------
# Tokenizer / Preprocess
# ---------------------------
def load_tokenizer(tokenizer_path: str):
"""加载 tokenizer(AutoTokenizer)"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
if tokenizer.chat_template is None:
# 新版本transformer可自动加载,训练环境版本:4.51.3支持
with open(os.path.join(tokenizer_path, "chat_template.jinja"), "r") as f:
tokenizer.chat_template = f.read()
return tokenizer
def build_image_preprocess(image_size: int = DEFAULT_IMAGE_SIZE):
"""返回 torchvision.transforms.Compose 的预处理 callable"""
return Compose([
Resize(size=image_size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=True),
CenterCrop(size=(image_size, image_size)),
lambda img: img.convert("RGB"),
ToTensor(),
Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
])
# ---------------------------
# RoPE频率预计算(precompute_freqs_cis)
# ---------------------------
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
rope_scaling: Optional[dict] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
计算 RoPE 的 cos 和 sin 表
返回:
freqs_cos: (end, dim)
freqs_sin: (end, dim)
"""
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
if rope_scaling is not None:
orig_max = rope_scaling.get("original_max_position_embeddings", 2048)
factor = rope_scaling.get("factor", 4)
beta_fast = rope_scaling.get("beta_fast", 4.0)
beta_slow = rope_scaling.get("beta_slow", 1.0)
if end / orig_max > 1.0:
corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
beta = beta_slow + (beta_fast - beta_slow) * power
scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim,
(beta * factor - beta + 1) / (beta * factor),
1.0 / factor)
freqs = freqs * scale
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float() # (end, dim/2)
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
return freqs_cos, freqs_sin
# ---------------------------
# 图像自适应切分(adaptive_square_split)
# ---------------------------
def calculate_optimal_split_with_fixed_max(width: int, height: int, max_rows: int, max_cols: int) -> Tuple[int, int, int]:
"""
计算最佳切分(返回 rows, cols, block_size)
block_size 会向下取整到 16 的倍数,最小为 MIN_BLOCK_SIZE。
"""
best_rows = 1
best_cols = 1
best_block_size = 0
best_coverage = 0.0
# 方案1: 固定行数为 max_rows,自适应列数
rows_fixed = max_rows
for cols in range(1, max_cols + 1):
block_width = width // cols
block_height = height // rows_fixed
square_size = min(block_width, block_height)
if square_size <= 0:
continue
coverage = (cols * square_size) * (rows_fixed * square_size) / (width * height)
if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
best_rows, best_cols, best_block_size, best_coverage = rows_fixed, cols, square_size, coverage
# 方案2: 固定列数为 max_cols,自适应行数
cols_fixed = max_cols
for rows in range(1, max_rows + 1):
block_width = width // cols_fixed
block_height = height // rows
square_size = min(block_width, block_height)
if square_size <= 0:
continue
coverage = (cols_fixed * square_size) * (rows * square_size) / (width * height)
if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
best_rows, best_cols, best_block_size, best_coverage = rows, cols_fixed, square_size, coverage
# 方案3: 两者都达到最大
block_width = width // max_cols
block_height = height // max_rows
square_size = min(block_width, block_height)
if square_size > 0:
coverage = (max_cols * square_size) * (max_rows * square_size) / (width * height)
if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
best_rows, best_cols, best_block_size, best_coverage = max_rows, max_cols, square_size, coverage
# 对齐到 16 的倍数并保证最小值
best_block_size = max(MIN_BLOCK_SIZE, (best_block_size // 16) * 16)
return best_rows, best_cols, best_block_size
def adaptive_square_split(image: Image.Image, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
) -> Tuple[List[Image.Image], int, int, int]:
"""
将 PIL Image 自适应切分为正方形块,返回 (blocks_list, rows, cols, block_size)
blocks_list 是按行主序的块列表(可能少于 max_rows*max_cols)
"""
width, height = image.size
rows, cols, block_size = calculate_optimal_split_with_fixed_max(width, height, max_rows, max_cols)
blocks = []
for r in range(rows):
for c in range(cols):
left = c * block_size
upper = r * block_size
right = left + block_size
lower = upper + block_size
blocks.append(image.crop((left, upper, right, lower)))
return blocks, rows, cols, block_size
# ---------------------------
# 特殊 token 准备
# ---------------------------
def prepare_special_tokens(tokenizer, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS) -> Dict[str, int]:
"""
返回特殊 token id 的 dict,包含 <global-img>, <fake_token_around_image>, <image>, 以及 <row_i_col_j>
"""
special = {
"<global-img>": tokenizer.convert_tokens_to_ids("<global-img>"),
"<fake_token_around_image>": tokenizer.convert_tokens_to_ids("<fake_token_around_image>"),
"<image>": tokenizer.convert_tokens_to_ids("<image>"),
}
for i in range(max_rows):
for j in range(max_cols):
special[f"<row_{i + 1}_col_{j + 1}>"] = tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>")
return special
# ---------------------------
# 将图像切块、填充、stack 为模型输入张量
# ---------------------------
def prepare_image_patches(image: Image.Image, preprocess_fn, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
) -> Tuple[torch.Tensor, List[int]]:
"""
image: PIL.Image
preprocess_fn: callable that maps PIL.Image -> tensor(C,H,W)
返回:
pixel_values: torch.Tensor, shape (num_patches + 1, C, H, W) -- 最后一个是 full image 原图
mask_token_ids: list[int] -- 当某个位置为空时,对应的 row_col token id 列表(未去重)
"""
blocks, rows, cols, block_size = adaptive_square_split(image, max_rows=max_rows, max_cols=max_cols)
patch_num = len(blocks)
pad_num = max_rows * max_cols - patch_num
mask_token_id_list = []
patch_tensors = []
if pad_num > 0:
# 以行主序填充: 若某个位置超出 rows 或 cols,则用零张量并记录对应的 row_col token id(由调用者映射)
for i in range(max_rows):
for j in range(max_cols):
if i >= rows or j >= cols:
patch_tensors.append(torch.zeros_like(preprocess_fn(image)))
# mask token id 由调用者生成/映射,这里只记录一个占位(具体 id 值需外部映射)
# 返回时,调用者会在文本中找到对应的 special token 的位置并进行 attention mask 操作
mask_token_id_list.append((i, j))
else:
patch_tensors.append(preprocess_fn(blocks[i * cols + j]))
else:
patch_tensors = [preprocess_fn(b) for b in blocks]
# 最后附加 full image 的 pixel_values(和你原来逻辑一致)
full_image_tensor = preprocess_fn(image)
pixel_values = torch.stack(patch_tensors + [full_image_tensor], dim=0) # (N_patches+1, C, H, W)
return pixel_values, mask_token_id_list
# ---------------------------
# 在 token 流中构建 image placeholder(原始的占位 token 串)
# ---------------------------
def construct_image_placeholders(special_tokens: Dict[str, int], max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS,
n_image_tokens_per_patch: int = 49) -> str:
"""
生成一个示例占位字符串,便于拼接到 prompt 中。
返回一个包含多个占位符的字符串 (str)。
"""
image_place_holder = random.choice(["图片如下:", "如下所示的图片:", "请见下面这张图:", "如下图显示:", "参考下方图片:", "图示如下:"])
for row in range(max_rows):
for col in range(max_cols):
image_place_holder += f"<fake_token_around_image><row_{row + 1}_col_{col + 1}>"
image_place_holder += "<image>" * n_image_tokens_per_patch
# 全局图像块(最后)
image_place_holder += f"<fake_token_around_image><global-img>{'<image>' * n_image_tokens_per_patch}<fake_token_around_image>"
return image_place_holder
# ---------------------------
# 寻找 token 序列中 image 标记出现的位置(用于 attention mask 修改)
# ---------------------------
def find_indices(tokens: torch.Tensor) -> Optional[Dict[int, Dict[int, List[Tuple[int, int]]]]]:
"""
输入 tokens: shape (B, T) 的 tensor
返回结构:
results = { batch_index: { k: [(start_idx, end_idx), ...], ... }, ... }
其中 k 对应 image token 的索引(函数里预设 image_id 列表),返回的 start_idx/end_idx 为占位段在 tokens 中的 start/end(包含)
说明:此方法沿用了你原来的匹配模式(匹配 [<fake>, <row_i_col_j>] 以及 [<fake>, <global-img>])
"""
B, T = tokens.size()
# 这里使用与原代码一致的 id 序列(如果 tokenizer 中不同,请改这里)
# image_ids = [[3, i] for i in range(6, 22)] + [[3, 4]] # 预设 pattern
image_ids = [[3, 6], [3, 7], [3, 10], [3, 11],[3, 4]]
image_ids_tensor = torch.tensor(image_ids, device=tokens.device)
len_image_ids = image_ids_tensor.size(1)
if len_image_ids > tokens.size(1):
return None
tokens_view = tokens.unfold(1, len_image_ids, 1) # (B, T - len_image_ids +1, len_image_ids)
matches = []
for image_id_tensor in image_ids_tensor:
match = (tokens_view == image_id_tensor).all(dim=2) # (B, T-len+1)
matches.append(match)
results = {}
for b in range(B):
batch_res = {}
for k, m in enumerate(matches):
idxs = m[b].nonzero(as_tuple=True)[0]
if len(idxs) > 0:
batch_res[k] = [(i.item() + 2, i.item() + 50) for i in idxs]
if batch_res:
results[b] = batch_res
return results or None
# ---------------------------
# ONNX Session helpers
# ---------------------------
def create_onnx_session(path: str, intra_threads: int = 1) -> onnxruntime.InferenceSession:
opts = onnxruntime.SessionOptions()
opts.intra_op_num_threads = intra_threads
opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
return onnxruntime.InferenceSession(path, sess_options=opts)
# ---------------------------
# Prefill 阶段(将视觉嵌入插入并运行一次 LLM)
# ---------------------------
def prefill_llm(vision_session: onnxruntime.InferenceSession,
embed_tokens_session: onnxruntime.InferenceSession,
llm_session: onnxruntime.InferenceSession,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
special_tokens: Dict[str, int],
seqlen: int,
device: str = "cpu") -> Dict[str, Any]:
"""
完成 prefill 步骤:
1) 通过 vision_session 获得视觉嵌入 deepstack_embeds
2) 通过 embed_tokens_session 获得 token embedding(或直接使用输入 hidden states)
3) 将视觉嵌入插入 hidden stream(替换占位 token 段)
4) 调用 llm_session.run 一次,得到 logits、hidden_states、present_keys、present_values
返回 dict:
{
"logits": np.ndarray,
"hidden_states": np.ndarray,
"present_keys": np.ndarray,
"present_values": np.ndarray
}
"""
# 1) vision embed
ort_inputs_vis = {"inputs": pixel_values.numpy()}
deepstack_embeds = vision_session.run(["deepstack_embeds"], ort_inputs_vis)[0] # e.g. (B, P, L_patch, D)
# 2) token embed
ort_inputs_emb = {"input_ids": input_ids.numpy()}
embed_tokens = embed_tokens_session.run(["embed_tokens"], ort_inputs_emb)[0] # (B, T, D)
# 3) 找到 image placeholder 在 token 中的位置并替换
image_batch_indices = find_indices(input_ids)
B = input_ids.shape[0]
seqlen = seqlen
new_h = []
for i in range(B):
h_i = embed_tokens[i] # np array (T, D)
image_indices = image_batch_indices.get(i, {}) if image_batch_indices else {}
# image_indices: {k: [(start,end), ...], ...}
# deepstack_embeds: assume shape (B, P, L_patch, D), P = number_of_image_patches + global
for tki, index_list in image_indices.items():
# tki 对应 deepstack_embeds 第二维索引
vision_proj_i = deepstack_embeds[i][tki] # (L_patch, D)
# 取第一个匹配段
start_idx, end_idx = index_list[0]
# 将 h_i 中 start_idx..end_idx 替换为 vision_proj_i,并截断到 seqlen
# 注意这里我们使用 numpy concat(h_i 是 numpy)
h_i = np.concatenate((h_i[:start_idx], vision_proj_i, h_i[end_idx + 1:]), axis=0)[:seqlen]
new_h.append(h_i)
hidden_states = np.stack(new_h, axis=0) # (B, seqlen, D)
# 4) 呼叫 llm.onnx 做一次前向(prefill)
# past_keys/past_values 用空的 shape(按模型要求)
# 这里 past keys/values 的 shape 需与模型期望一致,示例用随机的 0 长度数组作为占位
# 如模型要求具体形状,请在调用方准备
past_keys = np.zeros([8, 0, 2, 64], dtype=np.float32)
past_values = np.zeros([8, 0, 2, 64], dtype=np.float32)
cos_pe = freqs_cos[0: seqlen].numpy()
sin_pe = freqs_sin[0: seqlen].numpy()
ort_inputs_llm = {
"input_ids": hidden_states.astype(np.float32),
"attention_mask": attention_mask.numpy(),
"cos_pe": cos_pe.astype(np.float32),
"sin_pe": sin_pe.astype(np.float32),
"past_keys": past_keys,
"past_values": past_values
}
logits, hidden_states_out, present_keys, present_values = llm_session.run(
["logits", "hidden_states", "present_keys", "present_values"], ort_inputs_llm
)
return {
"logits": logits,
"hidden_states": hidden_states_out,
"present_keys": present_keys,
"present_values": present_values
}
# ---------------------------
# Next-token 自回归生成(基于 present keys/values)
# ---------------------------
def generate_autoregressive(llm_session: onnxruntime.InferenceSession,
embed_tokens_session: onnxruntime.InferenceSession,
tokenizer,
initial_present: Dict[str, np.ndarray],
start_token_id: int,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
attention_mask: np.ndarray,
max_new_tokens: int = 128,
eos_token_id: int = 2,
start_pos: int = None,
output_queue=None):
"""
基于 prefill 返回的 present_keys/present_values 进行自回归生成。
每一步:
- 用 embed_tokens_session 获取新 token 的 embedding
- 用 llm_session 传入 present keys/values 并得到新的 present keys/values 与 logits
- 选取最大 logit(argmax)作为下一个 token(你可替换为 sampling 策略)
如果提供了 output_queue,则将每个生成的文本片段放入队列。
注意:present keys/values 的名称与 shape 与模型实现相关,确保和模型一致。
"""
present_keys = initial_present["present_keys"]
present_values = initial_present["present_values"]
present_keys = present_keys
present_values = present_values
token_id = int(start_token_id)
if start_pos is None:
start_pos = attention_mask.shape[1]
generated_ids = []
for step in range(max_new_tokens):
# 解码当前 token
decoded = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
# 如果提供了输出队列,放入生成的文本
if output_queue is not None:
output_queue.put(decoded)
# 更新 attention mask
attention_mask = np.concatenate([attention_mask, np.array([[1]], dtype=np.int64)], axis=1)
# embed 当前 token
embed_tokens = embed_tokens_session.run(["embed_tokens"], {"input_ids": np.array([[token_id]], dtype=np.int64)})[0]
cos_pe = freqs_cos[start_pos: start_pos + 1].numpy()
sin_pe = freqs_sin[start_pos: start_pos + 1].numpy()
ort_inputs = {
"input_ids": embed_tokens.astype(np.float32),
"attention_mask": attention_mask,
"cos_pe": cos_pe.astype(np.float32),
"sin_pe": sin_pe.astype(np.float32),
"past_keys": present_keys,
"past_values": present_values
}
logits, hidden_states, present_keys, present_values = llm_session.run(
["logits", "hidden_states", "present_keys", "present_values"], ort_inputs
)
token_id = int(np.argmax(logits[:, -1, :], axis=-1)[0])
generated_ids.append(token_id)
if token_id == eos_token_id:
break
start_pos += 1
# 发送生成结束信号
if output_queue is not None:
output_queue.put(None)
return generated_ids
def main_example():
tokenizer_path = "./custom_tokenizer"
tokenizer = load_tokenizer(tokenizer_path)
preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
# image
image_path = "/Users/hulk/Downloads/coco128/images/train2017/000000000165.jpg"
image = Image.open(image_path).convert("RGB")
# special tokens
special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
pixel_values, mask_positions = prepare_image_patches(image, preprocess, max_rows=4, max_cols=4)
# 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template)
query = "图片中的人在做什么。"
messages = [
{"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
{"role": "user", "content": query + construct_image_placeholders(special_tokens)}
]
inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# precompute RoPE
freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
# create onnx sessions
vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
# prefill
seqlen = input_ids.shape[1]
prefill_out = prefill_llm(
vision_session=vision_session,
embed_tokens_session=embed_tokens_session,
llm_session=llm_session,
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
special_tokens=special_tokens,
seqlen=seqlen
)
# start token id = argmax last logit
start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
generated = generate_autoregressive(
llm_session=llm_session,
embed_tokens_session=embed_tokens_session,
tokenizer=tokenizer,
initial_present={"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
start_token_id=start_token_id,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
attention_mask=attention_mask.numpy(),
max_new_tokens=128,
eos_token_id=2,
start_pos=seqlen
)
if __name__ == "__main__":
# main_example()
pass