|
|
import torch |
|
|
import numpy as np |
|
|
import os |
|
|
import math |
|
|
import argparse |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel |
|
|
|
|
|
class minicpm_v_2_6_vision(torch.nn.Module): |
|
|
def __init__(self, vlm, batch_size, in_h, in_w): |
|
|
super(minicpm_v_2_6_vision, self).__init__() |
|
|
self.vpm = vlm.vpm |
|
|
self.resampler = vlm.resampler |
|
|
patch_size = vlm.config.patch_size |
|
|
num_patches_per_side = vlm.vpm.embeddings.num_patches_per_side |
|
|
tgt_sizes = torch.Tensor([[(in_h // patch_size), math.ceil(in_w / patch_size)]]).type(torch.int32) |
|
|
patch_attention_mask = torch.ones( |
|
|
size=(batch_size, in_h // patch_size, in_w // patch_size), |
|
|
dtype=torch.bool, device=vlm.device, |
|
|
) |
|
|
max_im_h, max_im_w = in_h, in_w |
|
|
max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size |
|
|
boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side) |
|
|
position_ids = torch.full( |
|
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), |
|
|
fill_value=0, |
|
|
) |
|
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask): |
|
|
if tgt_sizes is not None: |
|
|
nb_patches_h = tgt_sizes[batch_idx][0] |
|
|
nb_patches_w = tgt_sizes[batch_idx][1] |
|
|
else: |
|
|
nb_patches_h = p_attn_mask[:, 0].sum() |
|
|
nb_patches_w = p_attn_mask[0].sum() |
|
|
|
|
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) |
|
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) |
|
|
|
|
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) |
|
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) |
|
|
|
|
|
pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten() |
|
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids |
|
|
|
|
|
position_ids = position_ids.to(vlm.device) |
|
|
self.position_ids = position_ids |
|
|
|
|
|
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] |
|
|
max_patch_len = torch.max(patch_len) |
|
|
key_padding_mask = torch.zeros((batch_size, max_patch_len), dtype=torch.bool, device=vlm.device) |
|
|
pos_embed = [] |
|
|
for i in range(batch_size): |
|
|
tgt_h, tgt_w = tgt_sizes[i] |
|
|
pos_embed.append(self.resampler.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(torch.float32)) |
|
|
key_padding_mask[i, patch_len[i]:] = True |
|
|
|
|
|
self.pos_embed = torch.nn.utils.rnn.pad_sequence( |
|
|
pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) |
|
|
|
|
|
def forward(self, pixel_values): |
|
|
batch_size = pixel_values.size(0) |
|
|
|
|
|
patch_embeds = self.vpm.embeddings.patch_embedding(pixel_values) |
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2) |
|
|
hidden_states = embeddings + self.vpm.embeddings.position_embedding(self.position_ids) |
|
|
|
|
|
encoder_outputs = self.vpm.encoder(inputs_embeds=hidden_states) |
|
|
last_hidden_state = encoder_outputs[0] |
|
|
last_hidden_state = self.vpm.post_layernorm(last_hidden_state) |
|
|
|
|
|
x = self.resampler.kv_proj(last_hidden_state) |
|
|
x = self.resampler.ln_kv(x).permute(1, 0, 2) |
|
|
|
|
|
q = self.resampler.ln_q(self.resampler.query) |
|
|
|
|
|
out = self.resampler.attn( |
|
|
self.resampler._repeat(q, batch_size), |
|
|
x + self.pos_embed, |
|
|
x)[0] |
|
|
|
|
|
x = out.permute(1, 0, 2) |
|
|
|
|
|
x = self.resampler.ln_post(x) |
|
|
x = x @ self.resampler.proj |
|
|
return x |
|
|
|
|
|
class qwen2_5_vl_3b_vision(torch.nn.Module): |
|
|
def __init__(self, vlm, batch_size): |
|
|
super(qwen2_5_vl_3b_vision, self).__init__() |
|
|
self.merge_size = 2 |
|
|
self.temporal_patch_size = 2 |
|
|
self.patch_size = 14 |
|
|
self.channel = 3 |
|
|
self.vpm = vlm.visual |
|
|
self.batch_size = batch_size |
|
|
|
|
|
def forward(self, pixel_value, grid_thw): |
|
|
if self.batch_size == 1: |
|
|
patches = pixel_value.repeat(self.temporal_patch_size, 1, 1, 1) |
|
|
elif self.batch_size % self.temporal_patch_size == 1: |
|
|
repeat_image = pixel_value[-1:, ...].repeat(2, 1, 1, 1) |
|
|
patches = torch.cat((pixel_value, repeat_image), dim=0) |
|
|
else: |
|
|
patches = pixel_value |
|
|
grid_t, grid_h, grid_w = grid_thw[0][0], grid_thw[0][1], grid_thw[0][2] |
|
|
patches = patches.reshape(grid_t, self.temporal_patch_size, self.channel, |
|
|
grid_h//self.merge_size, self.merge_size, self.patch_size, grid_w//self.merge_size, self.merge_size, self.patch_size) |
|
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) |
|
|
flatten_patches = patches.reshape(grid_t * grid_h * grid_w, self.channel * self.temporal_patch_size * self.patch_size * self.patch_size) |
|
|
|
|
|
return self.vpm(flatten_patches, grid_thw) |
|
|
|
|
|
class smolvlm_vision(torch.nn.Module): |
|
|
def __init__(self, vlm): |
|
|
super(smolvlm_vision, self).__init__() |
|
|
self.vpm = vlm.model.vision_model |
|
|
self.connector = vlm.model.connector |
|
|
|
|
|
def forward(self, pixel_values): |
|
|
|
|
|
image_hidden_states = self.vpm(pixel_values).last_hidden_state |
|
|
|
|
|
image_hidden_states = self.connector(image_hidden_states) |
|
|
print("image_features:", image_hidden_states.shape) |
|
|
return image_hidden_states |
|
|
|
|
|
class vila1_5_3b_vision(torch.nn.Module): |
|
|
def __init__(self, vlm): |
|
|
super(vila1_5_3b_vision, self).__init__() |
|
|
self.vlm = vlm |
|
|
|
|
|
def forward(self, pixel_values): |
|
|
|
|
|
out = self.vlm.encode_images(pixel_values) |
|
|
return out |
|
|
|
|
|
if __name__ == "__main__": |
|
|
argparse = argparse.ArgumentParser() |
|
|
argparse.add_argument('--path', type=str, default='CKPT/MiniCPM-V-2_6', help='model path', required=False) |
|
|
argparse.add_argument('--model_name', type=str, default='minicpm-v-2_6', help='model name', required=False) |
|
|
argparse.add_argument('--batch_size', type=int, default=1, help='batch size', required=False) |
|
|
argparse.add_argument('--height', type=int, default=448, help='image height', required=False) |
|
|
argparse.add_argument('--width', type=int, default=448, help='image width', required=False) |
|
|
argparse.add_argument('--device', type=str, default="cpu", help='cpu or cuda', required=False) |
|
|
args = argparse.parse_args() |
|
|
|
|
|
path = args.path |
|
|
model_name = args.model_name |
|
|
savepath = os.path.join("./onnx", model_name + "_vision.onnx") |
|
|
device_type = args.device |
|
|
os.makedirs(os.path.dirname(savepath), exist_ok=True) |
|
|
if model_name == 'minicpm-v-2_6': |
|
|
model = AutoModel.from_pretrained( |
|
|
path, trust_remote_code=True, torch_dtype=torch.float32, |
|
|
) |
|
|
model = model.to(device=device_type, dtype=torch.float32) |
|
|
model.eval() |
|
|
model = minicpm_v_2_6_vision(model, args.batch_size, args.height, args.width) |
|
|
pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) |
|
|
out = model(pixel_values) |
|
|
print("Output shape:", out.shape) |
|
|
torch.onnx.export(model, |
|
|
pixel_values, |
|
|
savepath, |
|
|
input_names=['pixel'], |
|
|
opset_version=15) |
|
|
elif model_name == 'qwen2_5-vl-3b': |
|
|
from transformers import Qwen2_5_VLForConditionalGeneration |
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
path, |
|
|
low_cpu_mem_usage=True, |
|
|
_attn_implementation="eager", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model = model.to(device=device_type, dtype=torch.float32).eval() |
|
|
|
|
|
model = qwen2_5_vl_3b_vision(model, args.batch_size) |
|
|
|
|
|
|
|
|
def get_window_index_static(self, grid_thw): |
|
|
|
|
|
device = grid_thw.device |
|
|
T, H, W = grid_thw[0] |
|
|
|
|
|
total = T * H * W |
|
|
|
|
|
|
|
|
window_index = torch.arange(total, device=device) |
|
|
|
|
|
|
|
|
cu_window_seqlens = torch.tensor([0, total], device=device) |
|
|
|
|
|
return window_index, cu_window_seqlens |
|
|
|
|
|
|
|
|
|
|
|
model.visual.get_window_index = get_window_index_static.__get__( |
|
|
model.visual, type(model.visual) |
|
|
) |
|
|
|
|
|
print(model.vpm.get_window_index) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
|
|
out = model(pixel_values, grid_thw) |
|
|
print("Output shape:", out.shape) |
|
|
|
|
|
grid_thw = torch.tensor([[2, 32, 32]], dtype=torch.int64) |
|
|
|
|
|
torch.onnx.export( |
|
|
model, |
|
|
(pixel_values, grid_thw), |
|
|
savepath, |
|
|
input_names=["pixel", "grid_thw"], |
|
|
opset_version=18, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif model_name == 'smolvlm': |
|
|
from transformers import SmolVLMForConditionalGeneration |
|
|
model = SmolVLMForConditionalGeneration.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float32, |
|
|
_attn_implementation="eager", |
|
|
).to(device_type) |
|
|
pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) |
|
|
print("pixel_values:", pixel_values.shape) |
|
|
model = smolvlm_vision(model) |
|
|
model = model.to(torch.float32).eval() |
|
|
out = model(pixel_values) |
|
|
torch.onnx.export(model, |
|
|
pixel_values, |
|
|
savepath, |
|
|
input_names=['pixel'], |
|
|
dynamic_axes={'pixel': {2: 'height', 3: 'width'}}, |
|
|
opset_version=15) |
|
|
elif model_name == 'internvl3-1b': |
|
|
model = AutoModel.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True).eval().to(device_type) |
|
|
pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32) |
|
|
model.forward = model.extract_feature |
|
|
model = model.to(torch.float32).eval() |
|
|
torch.onnx.export(model, pixel_values, savepath) |
|
|
else: |
|
|
raise ValueError(f"Unsupported model name: {model_name}") |
|
|
exit(1) |
|
|
|
|
|
print(f"Exported to {savepath}") |
|
|
|