| import torch | |
| from huggingface_hub import hf_hub_download | |
| class ImageProjModel(torch.nn.Module): | |
| def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): | |
| super().__init__() | |
| self.generator = None | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_extra_context_tokens = clip_extra_context_tokens | |
| self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| def forward(self, image_embeds): | |
| embeds = image_embeds | |
| clip_extra_context_tokens = self.proj(embeds).reshape( | |
| -1, self.clip_extra_context_tokens, self.cross_attention_dim | |
| ) | |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
| return clip_extra_context_tokens | |
| image_proj_model = ImageProjModel() | |
| model_filename = hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/ip-adapter_sdxl.bin") | |
| state_dict = torch.load(model_filename, map_location="cpu", weights_only=True) | |
| image_proj_model.load_state_dict(state_dict["image_proj"]) | |
| clip_image_embeds = torch.rand((1, 1280)) | |
| onnx_output_path = 'model.onnx' | |
| torch.onnx.export( | |
| image_proj_model, | |
| clip_image_embeds, | |
| onnx_output_path, | |
| export_params=True, | |
| opset_version=18, | |
| do_constant_folding=True, | |
| input_names=['clip_image_embeds'], | |
| output_names=['image_prompt_embeds'], | |
| dynamic_axes={ | |
| 'clip_image_embeds': {0: 'batch_size', 1:'embed_size'}, | |
| 'image_prompt_embeds': {0: 'batch_size'}, | |
| }) |