import torch import numpy as np import os import math import argparse import torch.nn.functional as F from transformers import AutoModel class minicpm_v_2_6_vision(torch.nn.Module): def __init__(self, vlm, batch_size, in_h, in_w): super(minicpm_v_2_6_vision, self).__init__() self.vpm = vlm.vpm self.resampler = vlm.resampler patch_size = vlm.config.patch_size num_patches_per_side = vlm.vpm.embeddings.num_patches_per_side tgt_sizes = torch.Tensor([[(in_h // patch_size), math.ceil(in_w / patch_size)]]).type(torch.int32) patch_attention_mask = torch.ones( size=(batch_size, in_h // patch_size, in_w // patch_size), dtype=torch.bool, device=vlm.device, ) max_im_h, max_im_w = in_h, in_w max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side) position_ids = torch.full( size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): if tgt_sizes is not None: nb_patches_h = tgt_sizes[batch_idx][0] nb_patches_w = tgt_sizes[batch_idx][1] else: nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(vlm.device) self.position_ids = position_ids patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] max_patch_len = torch.max(patch_len) key_padding_mask = torch.zeros((batch_size, max_patch_len), dtype=torch.bool, device=vlm.device) pos_embed = [] for i in range(batch_size): tgt_h, tgt_w = tgt_sizes[i] pos_embed.append(self.resampler.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(torch.float32)) # patches * D key_padding_mask[i, patch_len[i]:] = True self.pos_embed = torch.nn.utils.rnn.pad_sequence( pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D def forward(self, pixel_values): batch_size = pixel_values.size(0) # patch embedding patch_embeds = self.vpm.embeddings.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) hidden_states = embeddings + self.vpm.embeddings.position_embedding(self.position_ids) # encoder encoder_outputs = self.vpm.encoder(inputs_embeds=hidden_states) last_hidden_state = encoder_outputs[0] last_hidden_state = self.vpm.post_layernorm(last_hidden_state) # resampler x = self.resampler.kv_proj(last_hidden_state) # B * L * D x = self.resampler.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.resampler.ln_q(self.resampler.query) # Q * D out = self.resampler.attn( self.resampler._repeat(q, batch_size), # Q * B * D x + self.pos_embed, # L * B * D + L * B * D x)[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D x = self.resampler.ln_post(x) x = x @ self.resampler.proj return x class qwen2_5_vl_3b_vision(torch.nn.Module): def __init__(self, vlm, batch_size): super(qwen2_5_vl_3b_vision, self).__init__() self.merge_size = 2 self.temporal_patch_size = 2 self.patch_size = 14 self.channel = 3 self.vpm = vlm.visual self.batch_size = batch_size def forward(self, pixel_value, grid_thw): if self.batch_size == 1: patches = pixel_value.repeat(self.temporal_patch_size, 1, 1, 1) elif self.batch_size % self.temporal_patch_size == 1: repeat_image = pixel_value[-1:, ...].repeat(2, 1, 1, 1) patches = torch.cat((pixel_value, repeat_image), dim=0) else: patches = pixel_value grid_t, grid_h, grid_w = grid_thw[0][0], grid_thw[0][1], grid_thw[0][2] patches = patches.reshape(grid_t, self.temporal_patch_size, self.channel, grid_h//self.merge_size, self.merge_size, self.patch_size, grid_w//self.merge_size, self.merge_size, self.patch_size) patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape(grid_t * grid_h * grid_w, self.channel * self.temporal_patch_size * self.patch_size * self.patch_size) return self.vpm(flatten_patches, grid_thw) class smolvlm_vision(torch.nn.Module): def __init__(self, vlm): super(smolvlm_vision, self).__init__() self.vpm = vlm.model.vision_model self.connector = vlm.model.connector def forward(self, pixel_values): # Get sequence from the vision encoder image_hidden_states = self.vpm(pixel_values).last_hidden_state # Modality projection & resampling image_hidden_states = self.connector(image_hidden_states) print("image_features:", image_hidden_states.shape) return image_hidden_states class vila1_5_3b_vision(torch.nn.Module): def __init__(self, vlm): super(vila1_5_3b_vision, self).__init__() self.vlm = vlm def forward(self, pixel_values): # Get sequence from the vision encoder out = self.vlm.encode_images(pixel_values) return out if __name__ == "__main__": argparse = argparse.ArgumentParser() argparse.add_argument('--path', type=str, default='CKPT/MiniCPM-V-2_6', help='model path', required=False) argparse.add_argument('--model_name', type=str, default='minicpm-v-2_6', help='model name', required=False) argparse.add_argument('--batch_size', type=int, default=1, help='batch size', required=False) argparse.add_argument('--height', type=int, default=448, help='image height', required=False) argparse.add_argument('--width', type=int, default=448, help='image width', required=False) argparse.add_argument('--device', type=str, default="cpu", help='cpu or cuda', required=False) args = argparse.parse_args() path = args.path model_name = args.model_name savepath = os.path.join("./onnx", model_name + "_vision.onnx") device_type = args.device os.makedirs(os.path.dirname(savepath), exist_ok=True) if model_name == 'minicpm-v-2_6': model = AutoModel.from_pretrained( path, trust_remote_code=True, torch_dtype=torch.float32, ) model = model.to(device=device_type, dtype=torch.float32) model.eval() model = minicpm_v_2_6_vision(model, args.batch_size, args.height, args.width) pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) out = model(pixel_values) print("Output shape:", out.shape) torch.onnx.export(model, pixel_values, savepath, input_names=['pixel'], opset_version=15) elif model_name == 'qwen2_5-vl-3b': from transformers import Qwen2_5_VLForConditionalGeneration model = Qwen2_5_VLForConditionalGeneration.from_pretrained( path, low_cpu_mem_usage=True, _attn_implementation="eager", trust_remote_code=True ) model = model.to(device=device_type, dtype=torch.float32).eval() model = qwen2_5_vl_3b_vision(model, args.batch_size) def get_window_index_static(self, grid_thw): # grid_thw: [1, T, H, W] (int64, static) device = grid_thw.device T, H, W = grid_thw[0] total = T * H * W # window_index: [total] window_index = torch.arange(total, device=device) # cu_window_seqlens: [0, total] cu_window_seqlens = torch.tensor([0, total], device=device) return window_index, cu_window_seqlens # 🔥 APPLY PATCH HERE model.visual.get_window_index = get_window_index_static.__get__( model.visual, type(model.visual) ) print(model.vpm.get_window_index) pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) #grid_thw = torch.tensor([[args.batch_size // 2 if args.batch_size% 2 == 0 else args.batch_size // 2 + 1, args.height//14, args.width//14]], dtype=torch.int64) # model.eval() out = model(pixel_values, grid_thw) print("Output shape:", out.shape) # FIXED grid grid_thw = torch.tensor([[2, 32, 32]], dtype=torch.int64) # example torch.onnx.export( model, (pixel_values, grid_thw), savepath, input_names=["pixel", "grid_thw"], opset_version=18, #dynamic_axes=None, # 🚨 important ) # torch.onnx.export(model, # (pixel_values, grid_thw), # savepath, # input_names=['pixel', 'grid_thw'], # dynamic_axes={'pixel': {2: 'height', 3: 'width'}}, # opset_version=18) elif model_name == 'smolvlm': from transformers import SmolVLMForConditionalGeneration model = SmolVLMForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float32, _attn_implementation="eager", ).to(device_type) pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) print("pixel_values:", pixel_values.shape) model = smolvlm_vision(model) model = model.to(torch.float32).eval() out = model(pixel_values) torch.onnx.export(model, pixel_values, savepath, input_names=['pixel'], dynamic_axes={'pixel': {2: 'height', 3: 'width'}}, opset_version=15) elif model_name == 'internvl3-1b': model = AutoModel.from_pretrained( path, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True).eval().to(device_type) pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) model.forward = model.extract_feature model = model.to(torch.float32).eval() torch.onnx.export(model, pixel_values, savepath) else: raise ValueError(f"Unsupported model name: {model_name}") exit(1) print(f"Exported to {savepath}")