File size: 7,552 Bytes
7fc4eb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import numpy as np
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from transformers.modeling_utils import PreTrainedModel
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# 加载本地模型
path = '.'
save_path = 'vision_encoder.onnx'
image_file = 'test.jpg'
def export_vision_InternVL(model_path: str, save_path: str):
"""
Export the vision encoder and projector of Janus-Pro-1B model to ONNX format
"""
# 设置默认数据类型为 float32
torch.set_default_dtype(torch.float32)
vl_gpt = AutoModel.from_pretrained(model_path,torch_dtype = torch.float32,trust_remote_code=True)
# Move model to CPU and convert to float32
vl_gpt = vl_gpt.cpu().eval().float() # 确保模型是 float32
# Create a wrapper class for vision encoder + projector
class VisionWrapper(nn.Module):
def __init__(self, model: PreTrainedModel):
super().__init__()
self.vision_model = model
def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
# Delegate to the built-in helper so we stay consistent with Transformers' implementation.
return self.vision_model.get_image_features(pixel_values=pixel_values)
# Create wrapper instance and convert to float32
vision_wrapper = VisionWrapper(vl_gpt)
vision_wrapper.eval().float() # 确保包装器也是 float32
# Create dummy input with float32
batch_size = 1
num_channels = 3
height = 448 # InternVL2 default image size
width = 448
# dummy_input = load_image(image_file=image_file, max_num=12).to(torch.float32).cpu()
dummy_input = torch.randn(batch_size, num_channels, height, width, dtype=torch.float32)
# Export to ONNX with higher opset version
torch.onnx.export(
vision_wrapper,
dummy_input,
save_path,
export_params=True,
opset_version=17, # 使用高版本 opset 以支持 scaled_dot_product_attention
do_constant_folding=True,
input_names=['pixel_values'],
output_names=['projected_features'],
dynamic_axes={
'pixel_values': {0: 'batch_size'},
'projected_features': {0: 'batch_size'}
},
# 添加额外的配置
# operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
# training=torch.onnx.TrainingMode.EVAL,
dynamo=True,
verbose=False
)
print(f"Successfully exported vision components to {save_path}")
# Verify the exported model
import onnxruntime
# Create inference session
ort_session = onnxruntime.InferenceSession(save_path)
# Run inference with dummy input
ort_inputs = {
'pixel_values': dummy_input.numpy()
}
ort_outputs = ort_session.run(None, ort_inputs)
# Compare with PyTorch output
torch_output = vision_wrapper(dummy_input)
# Check numerical accuracy with更宽松的容忍度
import numpy as np
np.testing.assert_allclose(
torch_output.detach().numpy(),
ort_outputs[0],
rtol=1e-1, # 放宽相对误差容忍度
atol=1e-2 # 放宽绝对误差容忍度
)
print("ONNX model verification successful!")
# 打印一些统计信息
torch_output_np = torch_output.detach().numpy()
onnx_output_np = ort_outputs[0]
abs_diff = np.abs(torch_output_np - onnx_output_np)
rel_diff = np.abs((torch_output_np - onnx_output_np) / (torch_output_np + 1e-7))
print(f"\nValidation Statistics:")
print(f"Max absolute difference: {np.max(abs_diff):.6f}")
print(f"Mean absolute difference: {np.mean(abs_diff):.6f}")
print(f"Max relative difference: {np.max(rel_diff):.6f}")
print(f"Mean relative difference: {np.mean(rel_diff):.6f}")
if __name__ == "__main__":
try:
import onnx
try:
onnx_version = onnx.__version__
except AttributeError:
try:
onnx_version = onnx.version.version
except AttributeError:
onnx_version = "Unknown"
print(f"ONNX version: {onnx_version}")
except ImportError:
print("ONNX not installed")
import onnxruntime
print(f"ONNX Runtime version: {onnxruntime.__version__}")
export_vision_InternVL(path, save_path)
|