| import torch |
| from diffusers import ConfigMixin, ModelMixin |
|
|
|
|
| class ImageProjModel(ModelMixin, ConfigMixin): |
| def __init__( |
| self, |
| cross_attention_dim=768, |
| clip_embeddings_dim=512, |
| clip_extra_context_tokens=4, |
| ): |
| super().__init__() |
|
|
| self.generator = None |
| self.cross_attention_dim = cross_attention_dim |
| self.clip_extra_context_tokens = clip_extra_context_tokens |
| self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) |
|
|
| def forward(self, image_embeds): |
| embeds = image_embeds |
| clip_extra_context_tokens = self.proj(embeds).reshape( |
| -1, self.clip_extra_context_tokens, self.cross_attention_dim |
| ) |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) |
| return clip_extra_context_tokens |
|
|