# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # import torch import torch.nn as nn import torch.nn.functional as F class MLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0. ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., use_sdpa=True ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop_prob = proj_drop self.proj_drop = nn.Dropout(proj_drop) self.use_sdpa = use_sdpa def forward(self, x, mask=None): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) attn = None else: attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, grid_size=None, grid_depth=None, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, return_attention=False, mask=None): y, attn = self.attn(self.norm1(x), mask=mask) if return_attention: return attn x = x + y x = x + self.mlp(self.norm2(x)) return x class CrossAttention(nn.Module): def __init__( self, dim, num_heads=12, qkv_bias=False, use_sdpa=True ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_sdpa = use_sdpa def forward(self, q, x): B, n, C = q.shape q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) B, N, C = x.shape # Batch is the batch size. N is the number of tokens (spatial/temporal tokens). D is the embedding dimension (from the encoder). kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): q = F.scaled_dot_product_attention(q, k, v) else: xattn = (q @ k.transpose(-2, -1)) * self.scale xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) q = (xattn @ v) q = q.transpose(1, 2).reshape(B, n, C) q = self.proj(q) return q class CrossAttentionBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.norm1 = norm_layer(dim) self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) def forward(self, q, x): y = self.xattn(q, self.norm1(x)) q = q + y q = q + self.mlp(self.norm2(q)) return q