import os import torch import argparse import gc from copy import deepcopy from ml_dtypes import bfloat16 import axengine as ort from PIL import Image from tqdm import tqdm import numpy as np from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoConfig def post_process(data, topk=1, topp=0.9, temperature=0.6): def top_p(l: np.ndarray, p: float) -> np.ndarray: index = np.argsort(l) res = l.copy() sum_p = 0 for i in index[::-1]: if sum_p >= p: res[i] = 0 sum_p += res[i] return res / sum_p def softmax(l: np.ndarray) -> np.ndarray: l_max = l - l.max() l_exp = np.exp(l_max) res = l_exp / np.sum(l_exp) return res.astype(np.float64) r = data.astype(np.float32) r = r.flatten() # topk candidate_index = np.argpartition(r, -topk)[-topk:] candidate_value = r[candidate_index] # temperature candidate_value /= temperature # softmax candidate_soft = softmax(candidate_value) # topp candidate_soft = top_p(candidate_soft, topp) candidate_soft = candidate_soft.astype(np.float64) / candidate_soft.sum() pos = np.random.multinomial(1, candidate_soft).argmax() next_token = candidate_index[pos] return np.array(next_token), candidate_index, candidate_soft class MiniCPMV: def __init__(self, siglip_onnx_path, resampler_onnx_path, embed_token_path, llm_axmodel_path, config) -> None: self.config = config self.vpm = ort.InferenceSession(siglip_onnx_path) self.resampler = ort.InferenceSession(resampler_onnx_path) self.embed_tokens = torch.load(embed_token_path, weights_only=False) # llm embedding self.prefill_slice_len=320 self.kv_cache_len=1023 self.prefill_decoder_sessions = [] for i in tqdm(range(self.config.num_hidden_layers), desc="Init InferenceSession"): session = ort.InferenceSession( f"{llm_axmodel_path}/llama_p{self.prefill_slice_len}_l{i}_together.axmodel" ) self.prefill_decoder_sessions.append(session) self.post_process_session = ort.InferenceSession( f"{llm_axmodel_path}/llama_post.axmodel" ) print("model load done!") self.kv_dim = 256 # self.config.hidden_size // self.config.num_attention_heads * self.config.num_key_value_heads self.terminators = ['<|im_end|>', ''] def get_position_ids(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: torch.IntTensor=None): batch_size = pixel_values.size(0) max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) max_nb_patches_h, max_nb_patches_w = max_im_h // self.config.vision_config.patch_size, max_im_w // self.config.vision_config.patch_size num_patches_per_side = self.config.vision_config.image_size // self.config.vision_config.patch_size boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side) position_ids = torch.full( size=( batch_size, max_nb_patches_h * max_nb_patches_w, ), fill_value=0, ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): if tgt_sizes is not None: nb_patches_h = tgt_sizes[batch_idx][0] nb_patches_w = tgt_sizes[batch_idx][1] else: nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids # position_ids[batch_idx] = torch.where(p_attn_mask.view(-1).cpu(), pos_ids, position_ids[batch_idx]) return position_ids @torch.no_grad() def get_vllm_embedding(self, data): if 'vision_hidden_states' not in data: dtype = torch.float32 device = "cpu" tgt_sizes = data['tgt_sizes'] pixel_values_list = data['pixel_values'] vision_hidden_states = [] all_pixel_values = [] img_cnt = [] for pixel_values in pixel_values_list: img_cnt.append(len(pixel_values)) all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # exist image if all_pixel_values: tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_batch_size = self.config.vision_batch_size all_pixel_values = all_pixel_values.type(dtype).to(device=device) if B > vision_batch_size: hs = [] for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: # 走这里 position_ids = self.get_position_ids(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes) siglip_inputs = { "all_pixel_values": all_pixel_values.numpy(), "position_ids": position_ids.numpy().astype(np.int32), } vision_embedding = self.vpm.run(None, input_feed=siglip_inputs)[0] resampler_inputs = { "vision_embedding": vision_embedding, # "tgt_sizes": tgt_sizes.type(torch.int64).numpy(), } vision_embedding = self.resampler.run(None, input_feed=resampler_inputs)[0] vision_embedding = torch.from_numpy(vision_embedding) start = 0 for pixel_values in pixel_values_list: img_cnt = len(pixel_values) if img_cnt > 0: vision_hidden_states.append(vision_embedding[start: start + img_cnt]) start += img_cnt else: vision_hidden_states.append([]) else: # no image if self.training: dummy_image = torch.zeros( (1, 3, 224, 224), device=device, dtype=dtype ) tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32) dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) else: dummy_feature = [] for _ in range(len(pixel_values_list)): vision_hidden_states.append(dummy_feature) else: vision_hidden_states = data['vision_hidden_states'] vllm_embedding = self.embed_tokens(data['input_ids']) vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( i, torch.Tensor) else i for i in vision_hidden_states] bs = len(data['input_ids']) device = vllm_embedding.device embed_dim = vllm_embedding.shape[-1] new_vllm_embeddings = [] for i in range(bs): cur_vs_hs = vision_hidden_states[i] cur_vllm_emb = vllm_embedding[i] if len(cur_vs_hs) == 0: new_vllm_embeddings.append(cur_vllm_emb) continue cur_image_bound = data['image_bound'][i] if len(cur_image_bound) > 0: image_indices = torch.stack([ torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound ], dim=0).flatten().to(device) indices_expanded = image_indices.view(-1, 1).expand(-1, embed_dim) vision_features = cur_vs_hs.view(-1, embed_dim) updated_emb = cur_vllm_emb.scatter(0, indices_expanded, vision_features) new_vllm_embeddings.append(updated_emb) elif self.training: dummy_term = cur_vs_hs[0].sum() * 0 new_vllm_embeddings.append(cur_vllm_emb + dummy_term) else: new_vllm_embeddings.append(cur_vllm_emb) vllm_embedding = torch.stack(new_vllm_embeddings, dim=0) return vllm_embedding, vision_hidden_states def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs): token_ids = model_inputs["input_ids"].tolist()[0] terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] k_caches = [ np.zeros((1, self.kv_cache_len, self.kv_dim), dtype=bfloat16) for _ in range(self.config.num_hidden_layers) ] v_caches = [ np.zeros((1, self.kv_cache_len, self.kv_dim), dtype=bfloat16) for _ in range(self.config.num_hidden_layers) ] token_len = inputs_embeds.shape[1] # (B, L, D) """ prefill """ prefill_slice_len = self.prefill_slice_len # slice_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8] slice_indexs = [ e for e in range(token_len // prefill_slice_len + 1) ] prefill_len = prefill_slice_len * slice_indexs[-1] if slice_indexs[-1] != 0 else prefill_slice_len if prefill_len > 0: for slice_index in tqdm(slice_indexs, desc="prefill"): indices = np.array( list( range( slice_index * prefill_slice_len, (slice_index + 1) * prefill_slice_len, ) ), np.uint32, ).reshape((1, prefill_slice_len)) mask = ( np.zeros((1, prefill_slice_len, prefill_slice_len * (slice_index + 1))) - 65536 ) data = np.zeros((1, prefill_slice_len, self.config.hidden_size)).astype(bfloat16) for i, t in enumerate( range( slice_index * prefill_slice_len, (slice_index + 1) * prefill_slice_len, ) ): if t < token_len: mask[:, i, : slice_index * prefill_slice_len + i + 1] = 0 data[:, i : i + 1, :] = ( inputs_embeds[0][t] .reshape((1, 1, self.config.hidden_size)) .astype(bfloat16) ) if slice_index == slice_indexs[-1]: remain_len = token_len - slice_index * prefill_slice_len else: remain_len = prefill_slice_len mask = mask.astype(bfloat16) for i in range(self.config.num_hidden_layers): input_feed = { "K_cache": ( k_caches[i][:, 0 : prefill_slice_len * slice_index, :] if slice_index else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16) ), "V_cache": ( v_caches[i][:, 0 : prefill_slice_len * slice_index, :] if slice_index else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16) ), "indices": indices, "input": data, "mask": mask, } outputs = self.prefill_decoder_sessions[i].run(None, input_feed, shape_group=slice_index + 1) k_caches[i][ :, slice_index * prefill_slice_len : slice_index * prefill_slice_len + remain_len, :, ] = outputs[0][:, :remain_len, :] v_caches[i][ :, slice_index * prefill_slice_len : slice_index * prefill_slice_len + remain_len, :, ] = outputs[1][:, :remain_len, :] data = outputs[2] post_out = self.post_process_session.run( None, { "input": data[ :, token_len - (len(slice_indexs) - 1) * prefill_slice_len - 1, None, : ] } )[0] next_token, posssible_tokens, possible_soft = post_process(post_out) posibles = [tokenizer.decode([t]) for t in posssible_tokens] posible_soft = [str((t, s)) for t, s in zip(posibles, possible_soft)] token_ids.append(next_token) # set to decoder token_ids_cached = [] token_ids_cached.append(next_token) mask = np.zeros((1, 1, self.kv_cache_len + 1), dtype=np.float32).astype(bfloat16) mask[:, :, :self.kv_cache_len + 1] -= 65536 if prefill_len > 0: mask[:, :, :token_len + 1] = 0 for start_indice in range(self.kv_cache_len): if prefill_len > 0 and start_indice < token_len: continue next_token = token_ids[start_indice] indices = np.array([start_indice], np.uint32).reshape((1, 1)) data = self.embed_tokens(torch.from_numpy(next_token)).reshape((1, 1, self.config.hidden_size)).detach().numpy().astype(bfloat16) for i in range(self.config.num_hidden_layers): input_feed = { "K_cache": k_caches[i], "V_cache": v_caches[i], "indices": indices, "input": data, "mask": mask, } outputs = self.prefill_decoder_sessions[i].run(None, input_feed, shape_group=0) k_caches[i][:, start_indice, :] = outputs[0][:, :, :] v_caches[i][:, start_indice, :] = outputs[1][:, :, :] data = outputs[2] mask[..., start_indice + 1] = 0 if start_indice < token_len - 1: pass else: post_out = self.post_process_session.run(None, {"input": data})[0] next_token, posssible_tokens, possible_soft = post_process(post_out) token_ids.append(next_token) if next_token in terminators: if len(token_ids_cached) > 0: msg = tokenizer.decode(token_ids_cached) token_ids_cached.clear() if "\ufffd" in msg: msg = msg.replace("\ufffd", "") print(msg, end='\n<|im_end|>\n', flush=True) return token_ids_cached.append(next_token) if len(token_ids_cached) >= 10: msg = tokenizer.decode(token_ids_cached) token_ids_cached.clear() if "\ufffd" in msg: msg = msg.replace("\ufffd", "") print(msg, end=" ", flush=True) return if __name__ == '__main__': parser = argparse.ArgumentParser(description="MiniCPM-v4 axmodel demo") parser.add_argument("--hf_model_path", type=str, default="../hf_cache/MiniCPM-V-4", help="Path to HuggingFace model") parser.add_argument("--siglip_axmodel", type=str, default="./siglip.axmodel") parser.add_argument("--resampler_axmodel", type=str, default="./resampler.axmodel") parser.add_argument("--embed_token_path", type=str, default="./embed_tokens.pth") parser.add_argument("--minicpm_axmodel", type=str, default="./minicpm-v-4_axmodel") parser.add_argument("-i", "--image", type=str, default="./show_demo.jpg", help="Path to the test image.") parser.add_argument("-q", "--question", type=str, default="What is the landform in the picture?", help="Your question that you want to ask the model.") args = parser.parse_args() hf_model_path = args.hf_model_path img_path = args.image image = Image.open(img_path).convert('RGB').resize((448, 448)) question = args.question msgs = [{'role': 'user', 'content': [image, question]}] resampler_axmodel = args.resampler_axmodel siglip_axmodel = args.siglip_axmodel embed_token_path = args.embed_token_path llm_axmodel_path = args.minicpm_axmodel processor = AutoProcessor.from_pretrained(hf_model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(hf_model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True) processor.image_processor.slice_mode = False # 不对图像做切分操作 minicpm_axmodel = MiniCPMV(siglip_axmodel, resampler_axmodel, embed_token_path, llm_axmodel_path, config) msgs_list = [msgs] prompts_lists = [] input_images_lists = [] for msgs in msgs_list: copy_msgs = deepcopy(msgs) images = [] for i, msg in enumerate(copy_msgs): role = msg["role"] content = msg["content"] assert role in ["user", "assistant"] if i == 0: assert role == "user", "The role of first msg should be user" if isinstance(content, str): content = [content] cur_msgs = [] for c in content: if isinstance(c, Image.Image): images.append(c) cur_msgs.append("(./)") elif isinstance(c, str): cur_msgs.append(c) msg["content"] = "\n".join(cur_msgs) prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)) input_images_lists.append(images) inputs = processor( prompts_lists, input_images_lists, max_slice_nums=None, use_image_id=None, return_tensors="pt", max_length=32768 ) generation_config = { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05 } inputs.pop("image_sizes") model_inputs = { "input_ids": inputs.input_ids, "image_bound": inputs.image_bound, } model_inputs["pixel_values"] = inputs.pixel_values model_inputs['tgt_sizes'] = inputs.tgt_sizes model_inputs["inputs_embeds"], vision_hidden_states = minicpm_axmodel.get_vllm_embedding(model_inputs) del minicpm_axmodel.vpm, minicpm_axmodel.resampler, vision_hidden_states gc.collect() result = minicpm_axmodel._decode(model_inputs["inputs_embeds"].detach().numpy(), tokenizer, inputs.attention_mask, decode_text=True)