| | """ |
| | This module provides the implementation of an Audio Projection Model, which is designed for |
| | audio processing tasks. The model takes audio embeddings as input and outputs context tokens |
| | that can be used for various downstream applications, such as audio analysis or synthesis. |
| | |
| | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which |
| | provides a foundation for building custom models. This implementation includes multiple linear |
| | layers with ReLU activation functions and a LayerNorm for normalization. |
| | |
| | Key Features: |
| | - Audio embedding input with flexible sequence length and block structure. |
| | - Multiple linear layers for feature transformation. |
| | - ReLU activation for non-linear transformation. |
| | - LayerNorm for stabilizing and speeding up training. |
| | - Rearrangement of input embeddings to match the model's expected input shape. |
| | - Customizable number of blocks, channels, and context tokens for adaptability. |
| | |
| | The module is structured to be easily integrated into larger systems or used as a standalone |
| | component for audio feature extraction and processing. |
| | |
| | Classes: |
| | - AudioProjModel: A class representing the audio projection model with configurable parameters. |
| | |
| | Functions: |
| | - (none) |
| | |
| | Dependencies: |
| | - torch: For tensor operations and neural network components. |
| | - diffusers: For the ModelMixin base class. |
| | - einops: For tensor rearrangement operations. |
| | |
| | """ |
| |
|
| | import torch |
| | from diffusers import ModelMixin |
| | from einops import rearrange |
| |
|
| | import math |
| | import torch.nn as nn |
| |
|
| | class AudioProjNet2(ModelMixin): |
| | """Audio Projection Model |
| | |
| | This class defines an audio projection model that takes audio embeddings as input |
| | and produces context tokens as output. The model is based on the ModelMixin class |
| | and consists of multiple linear layers and activation functions. It can be used |
| | for various audio processing tasks. |
| | |
| | Attributes: |
| | seq_len (int): The length of the audio sequence. |
| | blocks (int): The number of blocks in the audio projection model. |
| | channels (int): The number of channels in the audio projection model. |
| | intermediate_dim (int): The intermediate dimension of the model. |
| | context_tokens (int): The number of context tokens in the output. |
| | output_dim (int): The output dimension of the context tokens. |
| | |
| | Methods: |
| | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): |
| | Initializes the AudioProjModel with the given parameters. |
| | forward(self, audio_embeds): |
| | Defines the forward pass for the AudioProjModel. |
| | Parameters: |
| | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). |
| | Returns: |
| | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | seq_len=5, |
| | blocks=12, |
| | channels=768, |
| | intermediate_dim=512, |
| | output_dim=768, |
| | context_tokens=4, |
| | ): |
| | super().__init__() |
| |
|
| | self.seq_len = seq_len |
| | self.blocks = blocks |
| | self.channels = channels |
| | self.input_dim = ( |
| | seq_len * blocks * channels |
| | ) |
| | self.intermediate_dim = intermediate_dim |
| | self.context_tokens = context_tokens |
| | self.output_dim = output_dim |
| |
|
| | |
| | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) |
| | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) |
| | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) |
| |
|
| | self.norm = nn.LayerNorm(output_dim) |
| |
|
| |
|
| | def forward(self, audio_embeds): |
| | |
| | video_length = audio_embeds.shape[1] |
| | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") |
| | batch_size, window_size, blocks, channels = audio_embeds.shape |
| | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) |
| |
|
| | audio_embeds = torch.relu(self.proj1(audio_embeds)) |
| | audio_embeds = torch.relu(self.proj2(audio_embeds)) |
| |
|
| | context_tokens = self.proj3(audio_embeds).reshape( |
| | batch_size, self.context_tokens, self.output_dim |
| | ) |
| | context_tokens = self.norm(context_tokens) |
| | out_all = rearrange( |
| | context_tokens, "(bz f) m c -> bz f m c", f=video_length |
| | ) |
| |
|
| | return out_all |
| |
|
| |
|
| | def reshape_tensor(x, heads): |
| | bs, length, width = x.shape |
| | |
| | x = x.view(bs, length, heads, -1) |
| | |
| | x = x.transpose(1, 2) |
| | |
| | x = x.reshape(bs, heads, length, -1) |
| | return x |
| |
|
| |
|
| | class PerceiverAttentionCA(nn.Module): |
| | def __init__(self, *, dim=3072, dim_head=1024, heads=33): |
| | super().__init__() |
| | self.scale = dim_head ** -0.5 |
| | self.dim_head = dim_head |
| | self.heads = heads |
| | inner_dim = dim_head |
| |
|
| | self.norm1 = nn.LayerNorm(dim) |
| | self.norm2 = nn.LayerNorm(dim) |
| |
|
| | self.to_q = nn.Linear(dim, inner_dim, bias=False) |
| | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
| | self.to_out = nn.Linear(inner_dim, dim, bias=False) |
| |
|
| | import torch.nn.init as init |
| | init.zeros_(self.to_out.weight) |
| | if self.to_out.bias is not None: |
| | init.zeros_(self.to_out.bias) |
| |
|
| | def forward(self, x, latents): |
| | """ |
| | Args: |
| | x (torch.Tensor): image features |
| | shape (b, t, aa, D) |
| | latent (torch.Tensor): latent features |
| | shape (b, t, hw, D) |
| | """ |
| | x = self.norm1(x) |
| | latents = self.norm2(latents) |
| | |
| | |
| | q = self.to_q(latents) |
| | k, v = self.to_kv(x).chunk(2, dim=-1) |
| |
|
| |
|
| | |
| | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) |
| | weight = (q * scale) @ (k * scale).transpose(-2, -1) |
| | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) |
| | out = weight @ v |
| | |
| | |
| | return self.to_out(out) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|