|
|
import torch |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
|
|
|
class PoolerOutputWrapper(torch.nn.Module): |
|
|
def __init__(self, model, model_part): |
|
|
super(PoolerOutputWrapper, self).__init__() |
|
|
if model_part == 'vision': |
|
|
self.model = model.vision_model |
|
|
elif model_part == 'text': |
|
|
self.model = model.text_model |
|
|
else: |
|
|
raise ValueError("model_part must be either 'vision' or 'text'") |
|
|
|
|
|
def forward(self, x): |
|
|
outputs = self.model(x) |
|
|
return outputs.pooler_output |
|
|
|
|
|
|
|
|
ckpt = "google/siglip2-base-patch16-224" |
|
|
model = AutoModel.from_pretrained(ckpt, device_map="auto").eval().to("cpu") |
|
|
processor = AutoProcessor.from_pretrained(ckpt) |
|
|
|
|
|
|
|
|
dummy_img = torch.randn(1, 3, 224, 224) |
|
|
dummy_ids = torch.randint(1, 1000, (1, 64)) |
|
|
|
|
|
|
|
|
vision_wrapper = PoolerOutputWrapper(model, 'vision') |
|
|
torch.onnx.export(vision_wrapper, |
|
|
dummy_img, |
|
|
f"./onnx/siglip2-base-patch16-224_vision.onnx", |
|
|
input_names=['image'], |
|
|
output_names=['pooler_output'], |
|
|
export_params=True, |
|
|
opset_version=14) |
|
|
|
|
|
|
|
|
text_wrapper = PoolerOutputWrapper(model, 'text') |
|
|
torch.onnx.export(text_wrapper, |
|
|
dummy_ids, |
|
|
f"./onnx/siglip2-base-patch16-224_text.onnx", |
|
|
input_names=['text'], |
|
|
output_names=['pooler_output'], |
|
|
export_params=True, |
|
|
opset_version=14) |
|
|
|
|
|
|