Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from typing import Tuple, Optional | |
| from einops import rearrange | |
| from diffusers import ModelMixin | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| import torch.cuda.amp as amp | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import ( | |
| get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_sp_group, | |
| ) | |
| from xfuser.core.long_ctx_attention import xFuserLongContextAttention | |
| try: | |
| import flash_attn_interface | |
| FLASH_ATTN_3_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_3_AVAILABLE = False | |
| try: | |
| import flash_attn | |
| FLASH_ATTN_2_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_2_AVAILABLE = False | |
| try: | |
| from sageattention import sageattn | |
| SAGE_ATTN_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| SAGE_ATTN_AVAILABLE = False | |
| def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): | |
| if compatibility_mode: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| elif SAGE_ATTN_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = sageattn(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| elif FLASH_ATTN_3_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) | |
| x = flash_attn_interface.flash_attn_func(q, k, v) | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| elif FLASH_ATTN_2_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) | |
| x = flash_attn.flash_attn_func(q, k, v) | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| else: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| return x | |
| def sinusoidal_embedding_1d(dim, position): | |
| sinusoid = torch.outer(position.type(torch.float64), torch.pow( | |
| 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) | |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |
| return x.to(position.dtype) | |
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): | |
| # 3d rope precompute | |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) | |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| return torch.cat([f_freqs_cis, h_freqs_cis, w_freqs_cis], dim=1) | |
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): | |
| # 1d rope precompute | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) | |
| [: (dim // 2)].double() / dim)) | |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| def pad_freqs(original_tensor, target_len): | |
| seq_len, s1, s2 = original_tensor.shape | |
| pad_size = target_len - seq_len | |
| padding_tensor = torch.ones( | |
| pad_size, | |
| s1, | |
| s2, | |
| dtype=original_tensor.dtype, | |
| device=original_tensor.device) | |
| padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) | |
| return padded_tensor | |
| def rope_apply(x, freqs, grid_sizes, use_usp=False, sp_size=1, sp_rank=0): | |
| """ | |
| x: [B, L, N, C]. | |
| grid_sizes: [B, 3]. | |
| freqs: [M, C // 2]. | |
| """ | |
| s, n, c = x.size(1), x.size(2), x.size(3) // 2 | |
| # split freqs | |
| freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # [[N, head_dim/2], [N, head_dim/2], [N, head_dim/2]] # T H W 极坐标 | |
| # loop over samples | |
| (f, h, w) = grid_sizes | |
| seq_len = f * h * w | |
| # precompute multipliers | |
| x_i = torch.view_as_complex(x[0, :s].to(torch.float64).reshape( | |
| s, n, -1, 2)) # [L, N, C/2] # 极坐标 | |
| freqs_i = torch.cat([ | |
| freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
| ], | |
| dim=-1).reshape(seq_len, 1, -1) # seq_lens, 1, 3 * dim / 2 (T H W) | |
| if use_usp: | |
| # apply rotary embedding | |
| freqs_i = pad_freqs(freqs_i, s * sp_size) | |
| s_per_rank = s | |
| freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * | |
| s_per_rank), :, :] | |
| x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) | |
| x_i = torch.cat([x_i, x[0, s:]]) | |
| else: | |
| x_i = torch.view_as_real(x_i * freqs_i).flatten(2) | |
| x_i = torch.cat([x_i, x[0, seq_len:]]) | |
| return x_i.unsqueeze(0).to(x.dtype) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| dtype = x.dtype | |
| return self.norm(x.float()).to(dtype) * self.weight | |
| class SelfAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.use_usp = dist.is_initialized() | |
| self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1 | |
| self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0 | |
| def forward(self, x, freqs, grid_sizes): | |
| b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim | |
| q = self.norm_q(self.q(x)).view(b, s, n, d) | |
| k = self.norm_k(self.k(x)).view(b, s, n, d) | |
| v = self.v(x) | |
| if self.use_usp: | |
| from yunchang.kernels import AttnType | |
| if SAGE_ATTN_AVAILABLE: | |
| attn_type = AttnType.SAGE_AUTO | |
| else: | |
| attn_type = AttnType.FA | |
| x = xFuserLongContextAttention(attn_type=attn_type)( | |
| None, | |
| query=rope_apply(q, freqs, grid_sizes, self.use_usp, self.sp_size, self.sp_rank), | |
| key=rope_apply(k, freqs, grid_sizes, self.use_usp, self.sp_size, self.sp_rank), | |
| value=v.view(b, s, n, d), | |
| ).flatten(2) | |
| else: | |
| x = flash_attention( | |
| q=rope_apply(q, freqs, grid_sizes).flatten(2), | |
| k=rope_apply(k, freqs, grid_sizes).flatten(2), | |
| v=v, | |
| num_heads=self.num_heads | |
| ) | |
| return self.o(x) | |
| class CrossAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.has_image_input = has_image_input | |
| if has_image_input: | |
| self.k_img = nn.Linear(dim, dim) | |
| self.v_img = nn.Linear(dim, dim) | |
| self.norm_k_img = RMSNorm(dim, eps=eps) | |
| def forward(self, x: torch.Tensor, y: torch.Tensor): | |
| if self.has_image_input: | |
| img = y[:, :257] | |
| ctx = y[:, 257:] | |
| else: | |
| ctx = y | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(ctx)) | |
| v = self.v(ctx) | |
| x = flash_attention(q, k, v, num_heads=self.num_heads) | |
| if self.has_image_input: | |
| k_img = self.norm_k_img(self.k_img(img)) | |
| v_img = self.v_img(img) | |
| y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) | |
| x = x + y | |
| return self.o(x) | |
| class DiTAudioBlock(nn.Module): | |
| def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, i=0, num_layers=0): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.ffn_dim = ffn_dim | |
| self.i = i | |
| self.num_layers = num_layers | |
| self.self_attn = SelfAttention(dim, num_heads, eps) | |
| self.cross_attn = CrossAttention( | |
| dim, num_heads, eps, has_image_input=has_image_input) | |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm3 = nn.LayerNorm(dim, eps=eps) | |
| self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( | |
| approximate='tanh'), nn.Linear(ffn_dim, dim)) | |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| self.use_usp = dist.is_initialized() | |
| self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1 | |
| self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0 | |
| def forward(self, x, context, t_mod, freqs, grid_sizes): | |
| e = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) | |
| y = self.self_attn( | |
| self.norm1(x) * (1 + e[1]) + e[0], freqs, grid_sizes) | |
| x = x + y * e[2] | |
| x_1 = rearrange(self.norm3(x), 'b (f l) c -> (b f) l c', f=context.shape[1]) | |
| context_1 = context.squeeze(0) | |
| if self.use_usp: | |
| context_1 = context_1.unsqueeze(1).repeat(1, self.sp_size, 1, 1).flatten(0,1) | |
| context_1 = torch.chunk(context_1, self.sp_size, dim=0)[self.sp_rank] | |
| x = x + self.cross_attn(x_1, context_1).flatten(0, 1).unsqueeze(0) | |
| y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) | |
| x = x + y * e[5] | |
| return x | |
| class MLP(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim): | |
| super().__init__() | |
| self.proj = torch.nn.Sequential( | |
| nn.LayerNorm(in_dim), | |
| nn.Linear(in_dim, in_dim), | |
| nn.GELU(), | |
| nn.Linear(in_dim, out_dim), | |
| nn.LayerNorm(out_dim) | |
| ) | |
| def forward(self, x): | |
| return self.proj(x) | |
| class Head(nn.Module): | |
| def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): | |
| super().__init__() | |
| self.dim = dim | |
| self.patch_size = patch_size | |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) | |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) | |
| def forward(self, x, t_mod): | |
| r""" | |
| Args: | |
| x(Tensor): Shape [B, L1, C] | |
| t_mod(Tensor): Shape [B*21, C] | |
| """ | |
| B, L, D = x.shape | |
| F = t_mod.shape[0] // B | |
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device).unsqueeze(1) + t_mod.unflatten(dim=0, sizes=(B, t_mod.shape[0]//B)).unsqueeze(2)).chunk(2, dim=2) | |
| x = rearrange(x, 'b (f l) d -> b f l d', f=F) | |
| x = (self.head(self.norm(x) * (1 + scale) + shift)) | |
| x = rearrange(x, 'b f l d -> b (f l) d') | |
| return x | |
| class WanModelAudioProject(ModelMixin, ConfigMixin): | |
| _no_split_modules = ['DiTAudioBlock'] | |
| def __init__( | |
| self, | |
| dim: int, | |
| in_dim: int, | |
| ffn_dim: int, | |
| out_dim: int, | |
| text_dim: int, | |
| freq_dim: int, | |
| eps: float, | |
| vae_stride: Tuple[int, int, int], | |
| patch_size: Tuple[int, int, int], | |
| num_heads: int, | |
| num_layers: int, | |
| has_image_input: bool, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.freq_dim = freq_dim | |
| self.has_image_input = has_image_input | |
| self.patch_size = patch_size | |
| self.patch_embedding = nn.Conv3d( | |
| in_dim, dim, kernel_size=patch_size, stride=patch_size) | |
| self.text_embedding = nn.Sequential( | |
| nn.Linear(text_dim, dim), | |
| nn.GELU(approximate='tanh'), | |
| nn.Linear(dim, dim) | |
| ) | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(freq_dim, dim), | |
| nn.SiLU(), | |
| nn.Linear(dim, dim) | |
| ) | |
| self.time_projection = nn.Sequential( | |
| nn.SiLU(), nn.Linear(dim, dim * 6)) | |
| self.blocks = nn.ModuleList([ | |
| DiTAudioBlock(has_image_input, dim, num_heads, ffn_dim, eps, i, num_layers) | |
| for i in range(num_layers) | |
| ]) | |
| self.head = Head(dim, out_dim, patch_size, eps) | |
| head_dim = dim // num_heads | |
| self.freqs = precompute_freqs_cis_3d(head_dim) | |
| self.audio_emb = MLP(768, dim) | |
| if has_image_input: | |
| self.img_emb = MLP(1280, dim) | |
| # init audio adapter | |
| audio_window = 5 | |
| vae_scale = vae_stride[0] | |
| intermediate_dim = 512 | |
| output_dim = 1536 | |
| context_tokens = 32 | |
| norm_output_audio = True | |
| self.audio_window = audio_window | |
| self.vae_scale = vae_scale | |
| self.audio_proj = AudioProjModel( | |
| seq_len=audio_window, | |
| seq_len_vf=audio_window+vae_scale-1, | |
| intermediate_dim=intermediate_dim, | |
| output_dim=output_dim, | |
| context_tokens=context_tokens, | |
| norm_output_audio=norm_output_audio, | |
| ) | |
| self.use_usp = dist.is_initialized() | |
| self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1 | |
| self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0 | |
| def patchify(self, x: torch.Tensor): | |
| x = self.patch_embedding(x) | |
| grid_size = x.shape[2:] | |
| x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() | |
| return x, grid_size # x, grid_size: (f, h, w) | |
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): | |
| return rearrange( | |
| x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', | |
| f=grid_size[0], h=grid_size[1], w=grid_size[2], | |
| x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] | |
| ) | |
| def forward(self, | |
| x: torch.Tensor, #(1, 16, 9, 64, 64)) | |
| timestep: torch.Tensor, #(9,) | |
| context: torch.Tensor, #(5, 33, 12, 768) | |
| y: Optional[torch.Tensor] = None, #(1, 16, 9, 64, 64) | |
| use_gradient_checkpointing: bool = False, | |
| use_gradient_checkpointing_offload: bool = False, | |
| **kwargs, | |
| ): | |
| if self.freqs.device != x.device: | |
| self.freqs = self.freqs.to(x.device) | |
| x = torch.cat([x, y], dim=1) # (1, 32, 9, 64, 64) | |
| x, grid_sizes = self.patchify(x) | |
| t = self.time_embedding( | |
| sinusoidal_embedding_1d(self.freq_dim, timestep.to(dtype=x.dtype))) | |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (bsz, 6, 1536) | |
| # ==================== 音频条件处理 ==================== | |
| # 输入: context (bsz, 81, 5, 12, 768) | |
| # - 81 帧 = 1 (第一帧) + 80 (后续帧, 每4帧对应VAE压缩后的1帧) | |
| # - 5 是音频窗口大小 (audio_window) | |
| # - 12 是音频特征的 blocks | |
| # - 768 是音频特征维度 | |
| audio_cond = context.to(device=x.device, dtype=x.dtype) | |
| # 1. 第一帧:直接使用完整的5帧音频窗口 | |
| first_frame_audio = audio_cond[:, :1, ...] # (bsz, 1, 5, 12, 768) | |
| # 2. 后续帧:需要根据帧位置选择不同的音频窗口 | |
| # 将 32 帧重排为 (8 个 VAE latent, 每个4帧) | |
| latter_frames_audio = rearrange( | |
| audio_cond[:, 1:, ...], | |
| "b (n_latent n_frame) w s c -> b n_latent n_frame w s c", | |
| n_frame=self.vae_scale # vae_scale=4 | |
| ) # (bsz, 8, 4, 5, 12, 768) | |
| mid_idx = self.audio_window // 2 # 窗口中心索引: 5//2=2 | |
| # 为每个 latent 的4帧选择合适的音频窗口: | |
| # - 第1帧 (帧索引0): 无过去,取前3帧窗口 [:mid_idx+1] = [:3] | |
| # - 中间帧 (帧索引1-2): 取中心1帧 [mid_idx:mid_idx+1] = [2:3] | |
| # - 第4帧 (帧索引3): 无未来,取后3帧窗口 [mid_idx:] = [2:] | |
| first_of_group = latter_frames_audio[:, :, :1, :mid_idx+1, ...] # (bsz, 8, 1, 3, 12, 768) | |
| middle_of_group = latter_frames_audio[:, :, 1:-1, mid_idx:mid_idx+1, ...] # (bsz, 8, 2, 1, 12, 768) | |
| last_of_group = latter_frames_audio[:, :, -1:, mid_idx:, ...] # (bsz, 8, 1, 3, 12, 768) | |
| # 合并并展平窗口维度: (n_frame, window) -> (n_frame * window) | |
| latter_frames_audio_processed = torch.cat([ | |
| rearrange(first_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"), | |
| rearrange(middle_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"), | |
| rearrange(last_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"), | |
| ], dim=2) # (bsz, 8, 1*3 + 2*1 + 1*3, 12, 768) = (bsz, 8, 8, 12, 768) | |
| # 3. 通过 AudioProjModel 投影到 DiT 所需的特征空间 | |
| context = self.audio_proj( | |
| first_frame_audio, | |
| latter_frames_audio_processed | |
| ).to(x.dtype) # (bsz, 9, 32, 1536) | |
| if self.use_usp: | |
| x = torch.chunk(x, self.sp_size, dim=1)[self.sp_rank] | |
| for block in self.blocks: | |
| x = block(x, context, t_mod, self.freqs, grid_sizes) | |
| x = self.head(x, t) # (bsz, 9*32*32, 64) | |
| if self.use_usp: | |
| x = get_sp_group().all_gather(x, dim=1) | |
| x = self.unpatchify(x, grid_sizes) # (bsz, 16, 21, 64, 64) | |
| return x | |
| class AudioProjModel(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| seq_len=5, | |
| seq_len_vf=12, | |
| blocks=12, | |
| channels=768, | |
| intermediate_dim=512, | |
| output_dim=768, | |
| context_tokens=32, | |
| norm_output_audio=False, | |
| ): | |
| super().__init__() | |
| self.seq_len = seq_len | |
| self.blocks = blocks | |
| self.channels = channels | |
| self.input_dim = seq_len * blocks * channels | |
| self.input_dim_vf = seq_len_vf * blocks * channels | |
| self.intermediate_dim = intermediate_dim | |
| self.context_tokens = context_tokens | |
| self.output_dim = output_dim | |
| # define multiple linear layers | |
| self.proj1 = nn.Linear(self.input_dim, intermediate_dim) | |
| self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim) | |
| self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) | |
| self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) | |
| self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity() | |
| def forward(self, audio_embeds, audio_embeds_vf): | |
| video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] | |
| B, _, _, S, C = audio_embeds.shape | |
| # process audio of first frame | |
| audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") | |
| batch_size, window_size, blocks, channels = audio_embeds.shape | |
| audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) | |
| # process audio of latter frame | |
| audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") | |
| batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape | |
| audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) | |
| # first projection | |
| audio_embeds = torch.relu(self.proj1(audio_embeds)) | |
| audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) | |
| audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) | |
| audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) | |
| audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) | |
| batch_size_c, N_t, C_a = audio_embeds_c.shape | |
| audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) | |
| # second projection | |
| audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) | |
| context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim) | |
| # normalization and reshape | |
| with amp.autocast(dtype=torch.float32): | |
| context_tokens = self.norm(context_tokens) | |
| context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) | |
| return context_tokens |