Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from diffusers import ConfigMixin, ModelMixin | |
| class ImageProjModel(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| cross_attention_dim=768, | |
| clip_embeddings_dim=512, | |
| clip_extra_context_tokens=4, | |
| ): | |
| super().__init__() | |
| self.generator = None | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_extra_context_tokens = clip_extra_context_tokens | |
| self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| def forward(self, image_embeds): | |
| embeds = image_embeds | |
| clip_extra_context_tokens = self.proj(embeds).reshape( | |
| -1, self.clip_extra_context_tokens, self.cross_attention_dim | |
| ) | |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
| return clip_extra_context_tokens | |