File size: 1,398 Bytes
1038086 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
from transformers import AutoModel, AutoProcessor
class PoolerOutputWrapper(torch.nn.Module):
def __init__(self, model, model_part):
super(PoolerOutputWrapper, self).__init__()
if model_part == 'vision':
self.model = model.vision_model
elif model_part == 'text':
self.model = model.text_model
else:
raise ValueError("model_part must be either 'vision' or 'text'")
def forward(self, x):
outputs = self.model(x)
return outputs.pooler_output
# load the model and processor
ckpt = "google/siglip2-base-patch16-224"
model = AutoModel.from_pretrained(ckpt, device_map="auto").eval().to("cpu")
processor = AutoProcessor.from_pretrained(ckpt)
dummy_img = torch.randn(1, 3, 224, 224)
dummy_ids = torch.randint(1, 1000, (1, 64))
# export image onnx
vision_wrapper = PoolerOutputWrapper(model, 'vision')
torch.onnx.export(vision_wrapper,
dummy_img,
f"./onnx/siglip2-base-patch16-224_vision.onnx",
input_names=['image'],
output_names=['pooler_output'],
export_params=True,
opset_version=14)
# export text onnx
text_wrapper = PoolerOutputWrapper(model, 'text')
torch.onnx.export(text_wrapper,
dummy_ids,
f"./onnx/siglip2-base-patch16-224_text.onnx",
input_names=['text'],
output_names=['pooler_output'],
export_params=True,
opset_version=14)
|