Bethie's picture
Upload code_convert/convert_proj.py with huggingface_hub
c60387a verified
import torch
from huggingface_hub import hf_hub_download
class ImageProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
image_proj_model = ImageProjModel()
model_filename = hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/ip-adapter_sdxl.bin")
state_dict = torch.load(model_filename, map_location="cpu", weights_only=True)
image_proj_model.load_state_dict(state_dict["image_proj"])
clip_image_embeds = torch.rand((1, 1280))
onnx_output_path = 'model.onnx'
torch.onnx.export(
image_proj_model,
clip_image_embeds,
onnx_output_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['clip_image_embeds'],
output_names=['image_prompt_embeds'],
dynamic_axes={
'clip_image_embeds': {0: 'batch_size', 1:'embed_size'},
'image_prompt_embeds': {0: 'batch_size'},
})