Spaces:
Configuration error
Configuration error
| """ | |
| image_proj_model.py | |
| This module defines the ImageProjModel class, which is responsible for | |
| projecting image embeddings into a different dimensional space. The model | |
| leverages a linear transformation followed by a layer normalization to | |
| reshape and normalize the input image embeddings for further processing in | |
| cross-attention mechanisms or other downstream tasks. | |
| Classes: | |
| ImageProjModel | |
| Dependencies: | |
| torch | |
| diffusers.ModelMixin | |
| """ | |
| import torch | |
| from diffusers import ModelMixin | |
| class ImageProjModel(ModelMixin): | |
| """ | |
| ImageProjModel is a class that projects image embeddings into a different | |
| dimensional space. It inherits from ModelMixin, providing additional functionalities | |
| specific to image projection. | |
| Attributes: | |
| cross_attention_dim (int): The dimension of the cross attention. | |
| clip_embeddings_dim (int): The dimension of the CLIP embeddings. | |
| clip_extra_context_tokens (int): The number of extra context tokens in CLIP. | |
| Methods: | |
| forward(image_embeds): Forward pass of the ImageProjModel, which takes in image | |
| embeddings and returns the projected tokens. | |
| """ | |
| def __init__( | |
| self, | |
| cross_attention_dim=1024, | |
| clip_embeddings_dim=1024, | |
| clip_extra_context_tokens=4, | |
| ): | |
| super().__init__() | |
| self.generator = None | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_extra_context_tokens = clip_extra_context_tokens | |
| self.proj = torch.nn.Linear( | |
| clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim | |
| ) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| def forward(self, image_embeds): | |
| """ | |
| Forward pass of the ImageProjModel, which takes in image embeddings and returns the | |
| projected tokens after reshaping and normalization. | |
| Args: | |
| image_embeds (torch.Tensor): The input image embeddings, with shape | |
| batch_size x num_image_tokens x clip_embeddings_dim. | |
| Returns: | |
| clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping | |
| and normalization, with shape batch_size x (clip_extra_context_tokens * | |
| cross_attention_dim). | |
| """ | |
| embeds = image_embeds | |
| clip_extra_context_tokens = self.proj(embeds).reshape( | |
| -1, self.clip_extra_context_tokens, self.cross_attention_dim | |
| ) | |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
| return clip_extra_context_tokens | |