|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
class RelativePositionBias(nn.Module): |
|
|
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): |
|
|
super().__init__() |
|
|
self.scale = scale |
|
|
self.causal = causal |
|
|
self.num_buckets = num_buckets |
|
|
self.max_distance = max_distance |
|
|
self.relative_attention_bias = nn.Embedding(num_buckets, heads) |
|
|
|
|
|
@staticmethod |
|
|
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): |
|
|
ret = 0 |
|
|
n = -relative_position |
|
|
if not causal: |
|
|
num_buckets //= 2 |
|
|
ret += (n < 0).long() * num_buckets |
|
|
n = torch.abs(n) |
|
|
else: |
|
|
n = torch.max(n, torch.zeros_like(n)) |
|
|
|
|
|
max_exact = num_buckets // 2 |
|
|
is_small = n < max_exact |
|
|
|
|
|
val_if_large = max_exact + ( |
|
|
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) |
|
|
).long() |
|
|
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) |
|
|
|
|
|
ret += torch.where(is_small, n, val_if_large) |
|
|
return ret |
|
|
|
|
|
def forward(self, qk_dots): |
|
|
i, j, device = *qk_dots.shape[-2:], qk_dots.device |
|
|
q_pos = torch.arange(i, dtype=torch.long, device=device) |
|
|
k_pos = torch.arange(j, dtype=torch.long, device=device) |
|
|
rel_pos = k_pos[None, :] - q_pos[:, None] |
|
|
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, |
|
|
max_distance=self.max_distance) |
|
|
values = self.relative_attention_bias(rp_bucket) |
|
|
bias = rearrange(values, 'i j h -> () h i j') |
|
|
return qk_dots + (bias * self.scale) |
|
|
|
|
|
|
|
|
class AttentionQKV(nn.Module): |
|
|
def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False): |
|
|
super().__init__() |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = head_dim |
|
|
self.scale = scale if scale is not None else head_dim ** -0.5 |
|
|
self.flash = flash |
|
|
self.dropout_rate = dropout_rate |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
self.flash_config = self.setup_flash_config() if flash else None |
|
|
|
|
|
def setup_flash_config(self): |
|
|
|
|
|
flash_config = { |
|
|
'enable_flash': True, |
|
|
'enable_math': True, |
|
|
'enable_mem_efficient': True |
|
|
} |
|
|
return flash_config |
|
|
|
|
|
def forward(self, q, k, v, mask=None): |
|
|
q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]] |
|
|
if self.flash: |
|
|
out = self.flash_attention(q, k, v, mask=mask) |
|
|
else: |
|
|
out = self.scaled_dot_product_attention(q, k, v, mask=mask) |
|
|
|
|
|
return self.combine_heads(out) |
|
|
|
|
|
def scaled_dot_product_attention(self, q, k, v, mask=None): |
|
|
sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale |
|
|
if mask is not None: |
|
|
sim = sim.masked_fill(mask == 0, float('-inf')) |
|
|
attn = torch.softmax(sim, dim=-1) |
|
|
attn = self.dropout(attn) |
|
|
return torch.einsum("bhts,bhls->bhlt", attn, v) |
|
|
|
|
|
def flash_attention(self, q, k, v, mask=None): |
|
|
config = self.flash_config if self.flash_config else {} |
|
|
with torch.backends.cuda.sdp_kernel(**config): |
|
|
out = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=mask, |
|
|
dropout_p=self.dropout_rate if self.training else 0. |
|
|
) |
|
|
return out |
|
|
|
|
|
def split_heads(self, x): |
|
|
bs, length, _ = x.shape |
|
|
x = x.view(bs, length, self.n_heads, self.head_dim) |
|
|
return x.permute(0, 2, 1, 3) |
|
|
|
|
|
def combine_heads(self, x): |
|
|
bs, _, length, _ = x.shape |
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
return x.view(bs, length, -1) |
|
|
|
|
|
|
|
|
class AttentionBlock2(nn.Module): |
|
|
""" |
|
|
An attention block that allows spatial positions to attend to each other, |
|
|
using AttentionQKV and separate linear transformations for Q, K, and V. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
channels, |
|
|
num_heads=1, |
|
|
num_head_channels=-1, |
|
|
relative_pos_embeddings=False, |
|
|
flash_attention=True, |
|
|
dropout_rate=0.2, |
|
|
scale=None |
|
|
): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
|
|
|
if num_head_channels == -1: |
|
|
self.num_heads = num_heads |
|
|
else: |
|
|
assert ( |
|
|
channels % num_head_channels == 0 |
|
|
), f"channels {channels} is not divisible by num_head_channels {num_head_channels}" |
|
|
self.num_heads = channels // num_head_channels |
|
|
|
|
|
self.norm = nn.LayerNorm(channels) |
|
|
|
|
|
|
|
|
self.to_q = nn.Linear(channels, channels) |
|
|
self.to_k = nn.Linear(channels, channels) |
|
|
self.to_v = nn.Linear(channels, channels) |
|
|
|
|
|
self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale) |
|
|
|
|
|
self.proj_out = nn.Linear(channels, channels) |
|
|
|
|
|
if relative_pos_embeddings: |
|
|
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) |
|
|
else: |
|
|
self.relative_pos_embeddings = None |
|
|
|
|
|
def forward(self, x1, x2, mask=None): |
|
|
b1, c1, *spatial1 = x1.shape |
|
|
b2, c2, *spatial2 = x2.shape |
|
|
|
|
|
x1_norm = self.norm(x1) |
|
|
x2_norm = self.norm(x2) |
|
|
|
|
|
q = self.to_q(x1_norm) |
|
|
k = self.to_k(x2_norm) |
|
|
v = self.to_v(x2_norm) |
|
|
|
|
|
h = self.attention(q, k, v, mask=mask) |
|
|
h = self.proj_out(h) |
|
|
|
|
|
return (x1 + h).reshape(b1, c1, *spatial1) |
|
|
|
|
|
|
|
|
class Perceiver(nn.Module): |
|
|
"""Inspired by https://arxiv.org/abs/2103.03206""" |
|
|
def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4): |
|
|
""" |
|
|
Initialize the perceiver module. |
|
|
|
|
|
:param pre_attention_query_token: Number of query tokens for pre-attention |
|
|
:param pre_attention_query_size: Size of each query token |
|
|
:param embedding_dim: Dimension of the embedding space |
|
|
:param num_attn_heads: Number of attention heads |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.pre_attention_query = torch.nn.Parameter( |
|
|
torch.empty(1, pre_attention_query_token, pre_attention_query_size) |
|
|
) |
|
|
|
|
|
|
|
|
query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token)) |
|
|
|
|
|
|
|
|
self.pre_attention_query.data.uniform_(-query_variance, query_variance) |
|
|
|
|
|
|
|
|
self.attn = AttentionBlock2(embedding_dim, num_attn_heads) |
|
|
|
|
|
def forward(self, h): |
|
|
""" |
|
|
Forward pass of the perceiver module. |
|
|
:param h: Input tensor |
|
|
:return: Output after applying attention mechanisms |
|
|
""" |
|
|
|
|
|
query_ = self.pre_attention_query.expand(h.shape[0], -1, -1) |
|
|
|
|
|
pre_att = self.attn(query_, h) |
|
|
|
|
|
attn = self.attn(pre_att, pre_att) |
|
|
return attn |
|
|
|