import torch from transformers import AutoModel, AutoProcessor class PoolerOutputWrapper(torch.nn.Module): def __init__(self, model, model_part): super(PoolerOutputWrapper, self).__init__() if model_part == 'vision': self.model = model.vision_model elif model_part == 'text': self.model = model.text_model else: raise ValueError("model_part must be either 'vision' or 'text'") def forward(self, x): outputs = self.model(x) return outputs.pooler_output # load the model and processor ckpt = "google/siglip2-base-patch16-224" model = AutoModel.from_pretrained(ckpt, device_map="auto").eval().to("cpu") processor = AutoProcessor.from_pretrained(ckpt) dummy_img = torch.randn(1, 3, 224, 224) dummy_ids = torch.randint(1, 1000, (1, 64)) # export image onnx vision_wrapper = PoolerOutputWrapper(model, 'vision') torch.onnx.export(vision_wrapper, dummy_img, f"./onnx/siglip2-base-patch16-224_vision.onnx", input_names=['image'], output_names=['pooler_output'], export_params=True, opset_version=14) # export text onnx text_wrapper = PoolerOutputWrapper(model, 'text') torch.onnx.export(text_wrapper, dummy_ids, f"./onnx/siglip2-base-patch16-224_text.onnx", input_names=['text'], output_names=['pooler_output'], export_params=True, opset_version=14)