| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .config import use_fused_attn |
| | from .mlp import Mlp |
| | from .weight_init import trunc_normal_tf_ |
| |
|
| |
|
| | class AttentionPoolLatent(nn.Module): |
| | """ Attention pooling w/ latent query |
| | """ |
| | fused_attn: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int = None, |
| | embed_dim: int = None, |
| | num_heads: int = 8, |
| | feat_size: Optional[int] = None, |
| | mlp_ratio: float = 4.0, |
| | qkv_bias: bool = True, |
| | qk_norm: bool = False, |
| | latent_len: int = 1, |
| | latent_dim: int = None, |
| | pos_embed: str = '', |
| | pool_type: str = 'token', |
| | norm_layer: Optional[nn.Module] = None, |
| | drop: float = 0.0, |
| | ): |
| | super().__init__() |
| | embed_dim = embed_dim or in_features |
| | out_features = out_features or in_features |
| | assert embed_dim % num_heads == 0 |
| | self.num_heads = num_heads |
| | self.head_dim = embed_dim // num_heads |
| | self.feat_size = feat_size |
| | self.scale = self.head_dim ** -0.5 |
| | self.pool = pool_type |
| | self.fused_attn = use_fused_attn() |
| |
|
| | if pos_embed == 'abs': |
| | assert feat_size is not None |
| | self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) |
| | else: |
| | self.pos_embed = None |
| |
|
| | self.latent_dim = latent_dim or embed_dim |
| | self.latent_len = latent_len |
| | self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) |
| |
|
| | self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) |
| | self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) |
| | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| | self.proj = nn.Linear(embed_dim, embed_dim) |
| | self.proj_drop = nn.Dropout(drop) |
| |
|
| | self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() |
| | self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) |
| |
|
| | self.init_weights() |
| |
|
| | def init_weights(self): |
| | if self.pos_embed is not None: |
| | trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) |
| | trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) |
| |
|
| | def forward(self, x): |
| | B, N, C = x.shape |
| |
|
| | if self.pos_embed is not None: |
| | |
| | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) |
| |
|
| | q_latent = self.latent.expand(B, -1, -1) |
| | q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| | k, v = kv.unbind(0) |
| |
|
| | q, k = self.q_norm(q), self.k_norm(k) |
| |
|
| | if self.fused_attn: |
| | x = F.scaled_dot_product_attention(q, k, v) |
| | else: |
| | q = q * self.scale |
| | attn = q @ k.transpose(-2, -1) |
| | attn = attn.softmax(dim=-1) |
| | x = attn @ v |
| | x = x.transpose(1, 2).reshape(B, self.latent_len, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| |
|
| | x = x + self.mlp(self.norm(x)) |
| |
|
| | |
| | if self.pool == 'token': |
| | x = x[:, 0] |
| | elif self.pool == 'avg': |
| | x = x.mean(1) |
| | return x |