| | """ Attention Pool 2D |
| | |
| | Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. |
| | |
| | Based on idea in CLIP by OpenAI, licensed Apache 2.0 |
| | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py |
| | |
| | Hacked together by / Copyright 2021 Ross Wightman |
| | """ |
| | import math |
| | from typing import List, Union, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .helpers import to_2tuple |
| | from .weight_init import trunc_normal_ |
| |
|
| |
|
| | def rot(x): |
| | return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) |
| |
|
| |
|
| | def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): |
| | return x * cos_emb + rot(x) * sin_emb |
| |
|
| |
|
| | def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): |
| | if isinstance(x, torch.Tensor): |
| | x = [x] |
| | return [t * cos_emb + rot(t) * sin_emb for t in x] |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """ Rotary position embedding |
| | |
| | NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not |
| | been well tested, and will likely change. It will be moved to its own file. |
| | |
| | The following impl/resources were referenced for this impl: |
| | * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py |
| | * https://blog.eleuther.ai/rotary-embeddings/ |
| | """ |
| | def __init__(self, dim, max_freq=4): |
| | super().__init__() |
| | self.dim = dim |
| | self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False) |
| |
|
| | def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None): |
| | """ |
| | NOTE: shape arg should include spatial dim only |
| | """ |
| | device = device or self.bands.device |
| | dtype = dtype or self.bands.dtype |
| | if not isinstance(shape, torch.Size): |
| | shape = torch.Size(shape) |
| | N = shape.numel() |
| | grid = torch.stack(torch.meshgrid( |
| | [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1) |
| | emb = grid * math.pi * self.bands |
| | sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1) |
| | cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1) |
| | return sin, cos |
| |
|
| | def forward(self, x): |
| | |
| | sin_emb, cos_emb = self.get_embed(x.shape[2:]) |
| | return apply_rot_embed(x, sin_emb, cos_emb) |
| |
|
| |
|
| | class RotAttentionPool2d(nn.Module): |
| | """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. |
| | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. |
| | |
| | Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. |
| | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py |
| | |
| | NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from |
| | train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW |
| | """ |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int = None, |
| | embed_dim: int = None, |
| | num_heads: int = 4, |
| | qkv_bias: bool = True, |
| | ): |
| | super().__init__() |
| | embed_dim = embed_dim or in_features |
| | out_features = out_features or in_features |
| | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) |
| | self.proj = nn.Linear(embed_dim, out_features) |
| | self.num_heads = num_heads |
| | assert embed_dim % num_heads == 0 |
| | self.head_dim = embed_dim // num_heads |
| | self.scale = self.head_dim ** -0.5 |
| | self.pos_embed = RotaryEmbedding(self.head_dim) |
| |
|
| | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) |
| | nn.init.zeros_(self.qkv.bias) |
| |
|
| | def forward(self, x): |
| | B, _, H, W = x.shape |
| | N = H * W |
| | sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:]) |
| | x = x.reshape(B, -1, N).permute(0, 2, 1) |
| |
|
| | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) |
| |
|
| | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| | q, k, v = x[0], x[1], x[2] |
| |
|
| | qc, q = q[:, :, :1], q[:, :, 1:] |
| | q = apply_rot_embed(q, sin_emb, cos_emb) |
| | q = torch.cat([qc, q], dim=2) |
| |
|
| | kc, k = k[:, :, :1], k[:, :, 1:] |
| | k = apply_rot_embed(k, sin_emb, cos_emb) |
| | k = torch.cat([kc, k], dim=2) |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| | attn = attn.softmax(dim=-1) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) |
| | x = self.proj(x) |
| | return x[:, 0] |
| |
|
| |
|
| | class AttentionPool2d(nn.Module): |
| | """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. |
| | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. |
| | |
| | It was based on impl in CLIP by OpenAI |
| | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py |
| | |
| | NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. |
| | """ |
| | def __init__( |
| | self, |
| | in_features: int, |
| | feat_size: Union[int, Tuple[int, int]], |
| | out_features: int = None, |
| | embed_dim: int = None, |
| | num_heads: int = 4, |
| | qkv_bias: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | embed_dim = embed_dim or in_features |
| | out_features = out_features or in_features |
| | assert embed_dim % num_heads == 0 |
| | self.feat_size = to_2tuple(feat_size) |
| | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) |
| | self.proj = nn.Linear(embed_dim, out_features) |
| | self.num_heads = num_heads |
| | self.head_dim = embed_dim // num_heads |
| | self.scale = self.head_dim ** -0.5 |
| |
|
| | spatial_dim = self.feat_size[0] * self.feat_size[1] |
| | self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) |
| | trunc_normal_(self.pos_embed, std=in_features ** -0.5) |
| | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) |
| | nn.init.zeros_(self.qkv.bias) |
| |
|
| | def forward(self, x): |
| | B, _, H, W = x.shape |
| | N = H * W |
| | assert self.feat_size[0] == H |
| | assert self.feat_size[1] == W |
| | x = x.reshape(B, -1, N).permute(0, 2, 1) |
| | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) |
| | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) |
| |
|
| | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| | q, k, v = x[0], x[1], x[2] |
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| | attn = attn.softmax(dim=-1) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) |
| | x = self.proj(x) |
| | return x[:, 0] |
| |
|