| | import math |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| | from .blocks import _basic_init, DiTBlock |
| | from .modules import RMSNorm |
| | from .positional_embedding import get_1d_sincos_pos_embed |
| |
|
| | |
| | |
| | |
| |
|
| | class TimestepEmbedder(nn.Module): |
| | """ |
| | Embeds scalar timesteps into vector representations. |
| | """ |
| | def __init__(self, hidden_size: int, frequency_embedding_size: int=256, dtype=None, device=None): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), |
| | ) |
| | self.frequency_embedding_size = frequency_embedding_size |
| |
|
| | def initialize_weights(self): |
| | self.apply(_basic_init) |
| | |
| | for l in [0, 2]: |
| | nn.init.normal_(self.mlp[l].weight, std=0.02) |
| |
|
| | @staticmethod |
| | def timestep_embedding(t, dim, max_period=10000): |
| | """ |
| | Create sinusoidal timestep embeddings. |
| | :param t: a 1-D Tensor of N indices, one per batch element. |
| | These may be fractional. |
| | :param dim: the dimension of the output. |
| | :param max_period: controls the minimum frequency of the embeddings. |
| | :return: an (N, D) Tensor of positional embeddings. |
| | """ |
| | |
| | half = dim // 2 |
| | freqs = torch.exp( |
| | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
| | ).to(device=t.device) |
| | args = t[:, None].float() * freqs[None] |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | if dim % 2: |
| | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| | if torch.is_floating_point(t): |
| | embedding = embedding.to(dtype=t.dtype) |
| | return embedding |
| |
|
| | def forward(self, t): |
| | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| | t_emb = self.mlp(t_freq) |
| | return t_emb |
| |
|
| |
|
| | class LabelEmbedder(nn.Module): |
| | """ |
| | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
| | """ |
| | def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float, dtype=None, device=None): |
| | super().__init__() |
| | use_cfg_embedding = dropout_prob > 0 |
| | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=None, device=None) |
| | self.num_classes = num_classes |
| | self.dropout_prob = dropout_prob |
| | |
| | def initialize_weights(self): |
| | |
| | nn.init.normal_(self.embedding_table.weight, std=0.02) |
| |
|
| | def token_drop(self, labels, force_drop_ids=None): |
| | """ |
| | Drops labels to enable classifier-free guidance. |
| | """ |
| | if force_drop_ids is None: |
| | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob |
| | else: |
| | drop_ids = force_drop_ids == 1 |
| | labels = torch.where(drop_ids, self.num_classes, labels) |
| | return labels |
| |
|
| | def forward(self, labels, train, force_drop_ids=None): |
| | use_dropout = self.dropout_prob > 0 |
| | if (train and use_dropout) or (force_drop_ids is not None): |
| | labels = self.token_drop(labels, force_drop_ids) |
| | embeddings = self.embedding_table(labels) |
| | return embeddings |
| |
|
| |
|
| | class MotionEmbedder(nn.Module): |
| | """ |
| | Embeds motion into vector representations, Motion shape B x L x D |
| | """ |
| | def __init__(self, motion_dim: int, hidden_size: int, dtype=None, device=None): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(motion_dim, hidden_size, bias=True, dtype=None, device=None), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True, dtype=None, device=None), |
| | ) |
| | |
| | def initialize_weights(self): |
| | self.apply(_basic_init) |
| | |
| | for l in [0, 2]: |
| | w = self.mlp[l].weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.mlp[l].bias, 0) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.mlp(x) |
| |
|
| |
|
| | class AudioEmbedder(ModelMixin): |
| | """Audio Projection Model |
| | |
| | This class defines an audio projection model that takes audio embeddings as input |
| | and produces context tokens as output. The model is based on the ModelMixin class |
| | and consists of multiple linear layers and activation functions. It can be used |
| | for various audio processing tasks. |
| | |
| | Attributes: |
| | seq_len (int): The length of the audio sequence. |
| | blocks (int): The number of blocks in the audio projection model. |
| | channels (int): The number of channels in the audio projection model. |
| | intermediate_dim (int): The intermediate dimension of the model. |
| | context_tokens (int): The number of context tokens in the output. |
| | output_dim (int): The output dimension of the context tokens. |
| | |
| | Methods: |
| | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): |
| | Initializes the AudioProjModel with the given parameters. |
| | forward(self, audio_embeds): |
| | Defines the forward pass for the AudioProjModel. |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | seq_len=5, |
| | blocks=12, |
| | channels=768, |
| | intermediate_dim=512, |
| | output_dim=768, |
| | context_tokens=32, |
| | input_len = 80, |
| | condition_dim = 63, |
| | norm_type="rms_norm", |
| | qk_norm="rms_norm" |
| | ): |
| | super().__init__() |
| | input_dim = ( |
| | seq_len * blocks * channels |
| | ) |
| | self.context_tokens = context_tokens |
| | self.output_dim = output_dim |
| |
|
| | |
| | self.proj1 = nn.Linear(input_dim, intermediate_dim) |
| | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) |
| | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) |
| |
|
| | self.norm = nn.LayerNorm(output_dim) if norm_type == "layer_norm" else RMSNorm(output_dim) |
| |
|
| | def initialize_weights(self): |
| | self.apply(_basic_init) |
| | |
| | w = self.proj1.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.proj1.bias, 0) |
| |
|
| | w = self.proj2.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.proj2.bias, 0) |
| | |
| | w = self.proj3.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.proj3.bias, 0) |
| |
|
| | def forward(self, audio_embeds, conditions=None, emo=None): |
| | """ |
| | Defines the forward pass for the AudioProjModel. |
| | |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | conditions (torch.Tensor): optional other conditions with shape (batch_size, video_length, channels) or (batch_size, channels) |
| | emo (torch.Tensor): optional emotion embedding with shape (batch_size, channels) |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | """ |
| | |
| | video_length = audio_embeds.shape[1] |
| | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") |
| | batch_size, window_size, blocks, channels = audio_embeds.shape |
| | audio_embeds = audio_embeds.reshape(batch_size, window_size * blocks * channels) |
| |
|
| | audio_embeds = torch.relu(self.proj1(audio_embeds)) |
| | audio_embeds = torch.relu(self.proj2(audio_embeds)) |
| |
|
| | context_tokens = self.proj3(audio_embeds).reshape( |
| | batch_size, self.context_tokens, self.output_dim |
| | ) |
| |
|
| | context_tokens = self.norm(context_tokens) |
| | context_tokens = rearrange( |
| | context_tokens, "(bz f) m c -> bz f m c", f=video_length |
| | ) |
| |
|
| | return context_tokens |
| |
|
| |
|
| | class ConditionAudioEmbedder(ModelMixin): |
| | """Audio Projection Model with conditions |
| | |
| | This class defines an audio projection model that takes audio embeddings as input |
| | and produces context tokens as output. The model is based on the ModelMixin class |
| | and consists of multiple linear layers and activation functions. It can be used |
| | for various audio processing tasks. |
| | |
| | Attributes: |
| | seq_len (int): The length of the audio sequence. |
| | blocks (int): The number of blocks in the audio projection model. |
| | channels (int): The number of channels in the audio projection model. |
| | intermediate_dim (int): The intermediate dimension of the model. |
| | context_tokens (int): The number of context tokens in the output. |
| | output_dim (int): The output dimension of the context tokens. |
| | |
| | Methods: |
| | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): |
| | Initializes the AudioProjModel with the given parameters. |
| | forward(self, audio_embeds): |
| | Defines the forward pass for the AudioProjModel. |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | seq_len=5, |
| | blocks=12, |
| | channels=768, |
| | intermediate_dim=512, |
| | output_dim=768, |
| | context_tokens=32, |
| | input_len = 80, |
| | condition_dim=63, |
| | norm_type="rms_norm", |
| | qk_norm="rms_norm" |
| | ): |
| | super().__init__() |
| | self.input_dim = ( |
| | seq_len * blocks * channels + condition_dim |
| | ) |
| | self.context_tokens = context_tokens |
| | self.output_dim = output_dim |
| |
|
| | |
| | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) |
| | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) |
| | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) |
| |
|
| | self.norm = nn.LayerNorm(output_dim) if norm_type == "layer_norm" else RMSNorm(output_dim) |
| |
|
| | def initialize_weights(self): |
| | self.apply(_basic_init) |
| | |
| | w = self.proj1.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.proj1.bias, 0) |
| |
|
| | w = self.proj2.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.proj2.bias, 0) |
| | |
| | w = self.proj3.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.proj3.bias, 0) |
| |
|
| | def forward(self, audio_embeds, conditions, emo=None): |
| | """ |
| | Defines the forward pass for the AudioProjModel. |
| | |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | conditions (torch.Tensor): other conditions with shape (batch_size, video_length, channels) |
| | emo (torch.Tensor): optional emotion embedding with shape (batch_size, channels) |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | """ |
| | |
| | video_length = audio_embeds.shape[1] |
| | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") |
| | batch_size, window_size, blocks, channels = audio_embeds.shape |
| | audio_embeds = audio_embeds.reshape(batch_size, window_size * blocks * channels) |
| | |
| | conditions = rearrange(conditions, "bz f c -> (bz f) c") |
| | audio_embeds = torch.cat([audio_embeds, conditions], dim=1) |
| |
|
| | |
| | audio_embeds = torch.relu(self.proj1(audio_embeds)) |
| | audio_embeds = torch.relu(self.proj2(audio_embeds)) |
| |
|
| | context_tokens = self.proj3(audio_embeds).reshape( |
| | batch_size, self.context_tokens, self.output_dim |
| | ) |
| |
|
| | context_tokens = self.norm(context_tokens) |
| | context_tokens = rearrange( |
| | context_tokens, "(bz f) m c -> bz f m c", f=video_length |
| | ) |
| |
|
| | return context_tokens |
| |
|
| |
|
| | class SimpleAudioEmbedder(ModelMixin): |
| | """Simplfied Audio Projection Model |
| | |
| | This class defines an audio projection model that takes audio embeddings as input |
| | and produces context tokens as output. The model is based on the ModelMixin class |
| | and consists of multiple linear layers and activation functions. It can be used |
| | for various audio processing tasks. |
| | |
| | Attributes: |
| | seq_len (int): The length of the audio sequence. |
| | blocks (int): The number of blocks in the audio projection model. |
| | channels (int): The number of channels in the audio projection model. |
| | intermediate_dim (int): The intermediate dimension of the model. |
| | context_tokens (int): The number of context tokens in the output. |
| | output_dim (int): The output dimension of the context tokens. |
| | |
| | Methods: |
| | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): |
| | Initializes the AudioProjModel with the given parameters. |
| | forward(self, audio_embeds): |
| | Defines the forward pass for the AudioProjModel. |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | seq_len=5, |
| | blocks=12, |
| | channels=768, |
| | intermediate_dim=512, |
| | output_dim=768, |
| | context_tokens=32, |
| | input_len = 80, |
| | condition_dim = 63, |
| | norm_type="rms_norm", |
| | qk_norm="rms_norm", |
| | n_blocks = 4, |
| | n_heads = 4, |
| | mlp_ratio = 4 |
| | ): |
| | super().__init__() |
| | self.input_dim = ( |
| | seq_len * blocks * channels |
| | ) |
| | self.context_tokens = context_tokens |
| | self.output_dim = output_dim |
| | self.condition_dim=condition_dim |
| | |
| | |
| | self.input_layer = nn.Sequential( |
| | nn.Linear(self.input_dim, intermediate_dim, bias=True, dtype=None, device=None), |
| | nn.SiLU(), |
| | nn.Linear(intermediate_dim, condition_dim+2*intermediate_dim, bias=True, dtype=None, device=None), |
| | ) |
| |
|
| | self.condition2_layer = nn.Linear(condition_dim, condition_dim) |
| | self.emo_layer =nn.Linear(intermediate_dim, intermediate_dim) |
| | |
| | self.use_condition = True |
| | self.condition_layer = nn.Linear(condition_dim+intermediate_dim, intermediate_dim) |
| | |
| | self.pos_embed = nn.Parameter(torch.zeros(1, input_len, intermediate_dim), requires_grad=False) |
| |
|
| | |
| | self.mid_blocks = nn.ModuleList([ |
| | DiTBlock( |
| | intermediate_dim, n_heads, |
| | mlp_ratio=mlp_ratio, |
| | norm_type=norm_type, |
| | qk_norm=qk_norm |
| | ) for _ in range(n_blocks) |
| | ]) |
| | |
| | self.output_layer = nn.Linear(intermediate_dim, context_tokens * output_dim) |
| | self.output_layer2 = nn.Linear(condition_dim+condition_dim, context_tokens * output_dim) |
| | self.output_layer3 = nn.Linear(intermediate_dim+intermediate_dim, context_tokens * output_dim) |
| | self.norm = nn.LayerNorm(output_dim) if norm_type == "layer_norm" else RMSNorm(output_dim) |
| | self.norm2= nn.LayerNorm(output_dim) if norm_type == "layer_norm" else RMSNorm(output_dim) |
| | self.norm3= nn.LayerNorm(output_dim) if norm_type == "layer_norm" else RMSNorm(output_dim) |
| | def initialize_weights(self): |
| | |
| | for l in [0, 2]: |
| | w = self.input_layer[l].weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.input_layer[l].bias, 0) |
| | w = self.emo_layer.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.emo_layer.bias, 0) |
| | |
| | |
| | |
| | |
| | pos_embed = get_1d_sincos_pos_embed(self.pos_embed.shape[-1], self.pos_embed.shape[-2]) |
| | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| | |
| | nn.init.normal_(self.condition_layer.weight, std=0.02) |
| | nn.init.constant_(self.condition_layer.bias, 0) |
| | nn.init.normal_(self.condition2_layer.weight, std=0.02) |
| | nn.init.constant_(self.condition2_layer.bias, 0) |
| | |
| | |
| | |
| | |
| | w = self.output_layer.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.output_layer.bias, 0) |
| |
|
| | w = self.output_layer2.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.output_layer2.bias, 0) |
| |
|
| | w = self.output_layer3.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.output_layer3.bias, 0) |
| |
|
| | def forward(self, audio_embeds, conditions, emo_embeds,mask=None,freqs_cis=None): |
| | """ |
| | Defines the forward pass for the AudioProjModel. |
| | |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | conditions (torch.Tensor): other conditions with shape (batch_size, video_length, channels) or (batch_size, channels) |
| | emo_embeds (torch.Tensor): optional emotion embedding with shape (batch_size, channels) |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | """ |
| | |
| | condition2=self.condition2_layer(conditions) |
| | emo2=self.emo_layer(emo_embeds) |
| |
|
| | video_length = audio_embeds.shape[1] |
| | emo_embeds=emo_embeds.unsqueeze(1).repeat(1,video_length,1) |
| | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") |
| |
|
| | batch_size, window_size, blocks, channels = audio_embeds.shape |
| | audio_embeds = audio_embeds.reshape(batch_size, window_size * blocks * channels) |
| | |
| | |
| | audio_embeds = self.input_layer(audio_embeds) |
| | audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", f=video_length) |
| | |
| | audio_kp=audio_embeds[:,:,:self.condition_dim] |
| | audio_xs,audio_emo=audio_embeds[:,:,self.condition_dim:].chunk(2, dim=-1) |
| | |
| | audio_enc_kp=torch.cat([audio_kp,conditions], dim=-1) |
| | audio_enc_emo=torch.cat([audio_emo,emo_embeds], dim=-1) |
| | audio_enc_kp=rearrange(audio_enc_kp, "bz f c -> (bz f) c") |
| | audio_enc_emo=rearrange(audio_enc_emo, "bz f c -> (bz f) c") |
| | kp_context = self.output_layer2(audio_enc_kp).reshape( |
| | batch_size, self.context_tokens, self.output_dim |
| | ) |
| | kp_context=kp_context |
| | kp_context=self.norm2(kp_context) |
| | emo_context = self.output_layer3(audio_enc_emo).reshape( |
| | batch_size, self.context_tokens, self.output_dim |
| | ) |
| | emo_context=self.norm3(emo_context) |
| | |
| | if self.use_condition: |
| | audio_xs = self.condition_layer(torch.cat([audio_xs, condition2], dim=-1)) |
| | |
| | |
| | audio_xs=audio_xs+self.pos_embed |
| | |
| | for block in self.mid_blocks: |
| | audio_xs = block(audio_xs, emo2,mask=mask,freqs_cis=None) |
| | |
| | audio_xs = rearrange(audio_xs, "bz f c -> (bz f) c") |
| | audio_xs = self.output_layer(audio_xs).reshape( |
| | batch_size, self.context_tokens, self.output_dim |
| | ) |
| | audio_xs = self.norm(audio_xs) |
| |
|
| | kp_context=rearrange(kp_context, "(bz f) m c -> bz f m c", f=video_length) |
| | emo_context=rearrange(emo_context, "(bz f) m c -> bz f m c", f=video_length) |
| | audio_xs=rearrange(audio_xs, "(bz f) m c -> bz f m c", f=video_length) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return kp_context,emo_context,audio_xs,audio_kp,audio_emo,conditions,emo_embeds |
| |
|
| |
|
| | class ConditionEmbedder(nn.Module): |
| | def __init__( |
| | self, |
| | input_dim=768, |
| | intermediate_dim=1024, |
| | output_dim=2048, |
| | input_len = 80, |
| | norm_type="rms_norm", |
| | qk_norm="rms_norm", |
| | n_blocks = 4, |
| | n_heads = 4, |
| | mlp_ratio = 4 |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.output_dim = output_dim |
| |
|
| | |
| | self.input_layer = nn.Sequential( |
| | nn.Linear(self.input_dim, intermediate_dim, bias=True, dtype=None, device=None), |
| | nn.SiLU(), |
| | nn.Linear(intermediate_dim, intermediate_dim, bias=True, dtype=None, device=None), |
| | ) |
| | |
| | self.pos_embed = nn.Parameter(torch.zeros(1, input_len, intermediate_dim), requires_grad=False) |
| |
|
| | |
| | self.mid_blocks = nn.ModuleList([ |
| | DiTBlock( |
| | intermediate_dim, n_heads, |
| | mlp_ratio=mlp_ratio, |
| | norm_type=norm_type, |
| | qk_norm=qk_norm |
| | ) for _ in range(n_blocks) |
| | ]) |
| | |
| | self.output_layer = nn.Linear(intermediate_dim, output_dim) |
| | self.norm = nn.LayerNorm(output_dim) if norm_type == "layer_norm" else RMSNorm(output_dim) |
| |
|
| | def initialize_weights(self): |
| | |
| | for l in [0, 2]: |
| | w = self.input_layer[l].weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.input_layer[l].bias, 0) |
| |
|
| | |
| | pos_embed = get_1d_sincos_pos_embed(self.pos_embed.shape[-1], self.pos_embed.shape[-2]) |
| | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| | |
| | for block in self.mid_blocks: |
| | block.initialize_weights() |
| | |
| | w = self.output_layer.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.output_layer.bias, 0) |
| |
|
| | def forward(self, cond_embeds, emo_embeds): |
| | |
| | |
| | |
| | |
| | cond_embeds = self.input_layer(cond_embeds) |
| | |
| | |
| | cond_embeds = cond_embeds + self.pos_embed |
| | |
| | for block in self.mid_blocks: |
| | cond_embeds = block(cond_embeds, emo_embeds) |
| | |
| | |
| | context_tokens = self.output_layer(cond_embeds) |
| | context_tokens = self.norm(context_tokens) |
| |
|
| | return context_tokens |
| |
|
| |
|
| | class VectorEmbedder(nn.Module): |
| | """Embeds a flat vector of dimension input_dim""" |
| |
|
| | def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.mlp(x) |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | """2D Image to Patch Embedding""" |
| |
|
| | def __init__( |
| | self, |
| | img_size: Optional[int] = 224, |
| | patch_size: int = 16, |
| | in_chans: int = 3, |
| | embed_dim: int = 768, |
| | flatten: bool = True, |
| | bias: bool = True, |
| | strict_img_size: bool = True, |
| | dynamic_img_pad: bool = False, |
| | dtype=None, |
| | device=None, |
| | ): |
| | super().__init__() |
| | self.patch_size = (patch_size, patch_size) |
| | if img_size is not None: |
| | self.img_size = (img_size, img_size) |
| | self.grid_size = tuple( |
| | [s // p for s, p in zip(self.img_size, self.patch_size)] |
| | ) |
| | self.num_patches = self.grid_size[0] * self.grid_size[1] |
| | else: |
| | self.img_size = None |
| | self.grid_size = None |
| | self.num_patches = None |
| |
|
| | |
| | self.flatten = flatten |
| | self.strict_img_size = strict_img_size |
| | self.dynamic_img_pad = dynamic_img_pad |
| |
|
| | self.proj = nn.Conv2d( |
| | in_chans, |
| | embed_dim, |
| | kernel_size=patch_size, |
| | stride=patch_size, |
| | bias=bias, |
| | dtype=dtype, |
| | device=device, |
| | ) |
| |
|
| | def forward(self, x): |
| | B, C, H, W = x.shape |
| | x = self.proj(x) |
| | if self.flatten: |
| | x = x.flatten(2).transpose(1, 2) |
| | return x |
| |
|
| |
|