| |
|
| |
|
| |
|
| | import math
|
| |
|
| | import torch
|
| | from torch import nn
|
| | import torch.nn.functional as F
|
| | from einops import rearrange
|
| |
|
| |
|
| | class RelativePositionBias(nn.Module):
|
| | def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
| | super().__init__()
|
| | self.scale = scale
|
| | self.causal = causal
|
| | self.num_buckets = num_buckets
|
| | self.max_distance = max_distance
|
| | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
| |
|
| | @staticmethod
|
| | def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
| | ret = 0
|
| | n = -relative_position
|
| | if not causal:
|
| | num_buckets //= 2
|
| | ret += (n < 0).long() * num_buckets
|
| | n = torch.abs(n)
|
| | else:
|
| | n = torch.max(n, torch.zeros_like(n))
|
| |
|
| | max_exact = num_buckets // 2
|
| | is_small = n < max_exact
|
| |
|
| | val_if_large = max_exact + (
|
| | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
| | ).long()
|
| | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
| |
|
| | ret += torch.where(is_small, n, val_if_large)
|
| | return ret
|
| |
|
| | def forward(self, qk_dots):
|
| | i, j, device = *qk_dots.shape[-2:], qk_dots.device
|
| | q_pos = torch.arange(i, dtype=torch.long, device=device)
|
| | k_pos = torch.arange(j, dtype=torch.long, device=device)
|
| | rel_pos = k_pos[None, :] - q_pos[:, None]
|
| | rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
| | max_distance=self.max_distance)
|
| | values = self.relative_attention_bias(rp_bucket)
|
| | bias = rearrange(values, 'i j h -> () h i j')
|
| | return qk_dots + (bias * self.scale)
|
| |
|
| |
|
| | class AttentionQKV(nn.Module):
|
| | def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
|
| | super().__init__()
|
| | self.n_heads = n_heads
|
| | self.head_dim = head_dim
|
| | self.scale = scale if scale is not None else head_dim ** -0.5
|
| | self.flash = flash
|
| | self.dropout_rate = dropout_rate
|
| | self.dropout = nn.Dropout(dropout_rate)
|
| | self.flash_config = self.setup_flash_config() if flash else None
|
| |
|
| | def setup_flash_config(self):
|
| |
|
| | flash_config = {
|
| | 'enable_flash': True,
|
| | 'enable_math': True,
|
| | 'enable_mem_efficient': True
|
| | }
|
| | return flash_config
|
| |
|
| | def forward(self, q, k, v, mask=None):
|
| | q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
|
| | if self.flash:
|
| | out = self.flash_attention(q, k, v, mask=mask)
|
| | else:
|
| | out = self.scaled_dot_product_attention(q, k, v, mask=mask)
|
| |
|
| | return self.combine_heads(out)
|
| |
|
| | def scaled_dot_product_attention(self, q, k, v, mask=None):
|
| | sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
|
| | if mask is not None:
|
| | sim = sim.masked_fill(mask == 0, float('-inf'))
|
| | attn = torch.softmax(sim, dim=-1)
|
| | attn = self.dropout(attn)
|
| | return torch.einsum("bhts,bhls->bhlt", attn, v)
|
| |
|
| | def flash_attention(self, q, k, v, mask=None):
|
| | config = self.flash_config if self.flash_config else {}
|
| | with torch.backends.cuda.sdp_kernel(**config):
|
| | out = F.scaled_dot_product_attention(
|
| | q, k, v,
|
| | attn_mask=mask,
|
| | dropout_p=self.dropout_rate if self.training else 0.
|
| | )
|
| | return out
|
| |
|
| | def split_heads(self, x):
|
| | bs, length, _ = x.shape
|
| | x = x.view(bs, length, self.n_heads, self.head_dim)
|
| | return x.permute(0, 2, 1, 3)
|
| |
|
| | def combine_heads(self, x):
|
| | bs, _, length, _ = x.shape
|
| | x = x.permute(0, 2, 1, 3).contiguous()
|
| | return x.view(bs, length, -1)
|
| |
|
| |
|
| | class AttentionBlock2(nn.Module):
|
| | """
|
| | An attention block that allows spatial positions to attend to each other,
|
| | using AttentionQKV and separate linear transformations for Q, K, and V.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | channels,
|
| | num_heads=1,
|
| | num_head_channels=-1,
|
| | relative_pos_embeddings=False,
|
| | flash_attention=True,
|
| | dropout_rate=0.2,
|
| | scale=None
|
| | ):
|
| | super().__init__()
|
| | self.channels = channels
|
| |
|
| | if num_head_channels == -1:
|
| | self.num_heads = num_heads
|
| | else:
|
| | assert (
|
| | channels % num_head_channels == 0
|
| | ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| | self.num_heads = channels // num_head_channels
|
| |
|
| | self.norm = nn.LayerNorm(channels)
|
| |
|
| |
|
| | self.to_q = nn.Linear(channels, channels)
|
| | self.to_k = nn.Linear(channels, channels)
|
| | self.to_v = nn.Linear(channels, channels)
|
| |
|
| | self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
|
| |
|
| | self.proj_out = nn.Linear(channels, channels)
|
| |
|
| | if relative_pos_embeddings:
|
| | self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
| | else:
|
| | self.relative_pos_embeddings = None
|
| |
|
| | def forward(self, x1, x2, mask=None):
|
| | b1, c1, *spatial1 = x1.shape
|
| | b2, c2, *spatial2 = x2.shape
|
| |
|
| | x1_norm = self.norm(x1)
|
| | x2_norm = self.norm(x2)
|
| |
|
| | q = self.to_q(x1_norm)
|
| | k = self.to_k(x2_norm)
|
| | v = self.to_v(x2_norm)
|
| |
|
| | h = self.attention(q, k, v, mask=mask)
|
| | h = self.proj_out(h)
|
| |
|
| | return (x1 + h).reshape(b1, c1, *spatial1)
|
| |
|
| |
|
| | class Perceiver(nn.Module):
|
| | """Inspired by https://arxiv.org/abs/2103.03206"""
|
| | def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
|
| | """
|
| | Initialize the perceiver module.
|
| |
|
| | :param pre_attention_query_token: Number of query tokens for pre-attention
|
| | :param pre_attention_query_size: Size of each query token
|
| | :param embedding_dim: Dimension of the embedding space
|
| | :param num_attn_heads: Number of attention heads
|
| | """
|
| | super().__init__()
|
| |
|
| |
|
| | self.pre_attention_query = torch.nn.Parameter(
|
| | torch.empty(1, pre_attention_query_token, pre_attention_query_size)
|
| | )
|
| |
|
| |
|
| | query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
|
| |
|
| |
|
| | self.pre_attention_query.data.uniform_(-query_variance, query_variance)
|
| |
|
| |
|
| | self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
|
| |
|
| | def forward(self, h):
|
| | """
|
| | Forward pass of the perceiver module.
|
| | :param h: Input tensor
|
| | :return: Output after applying attention mechanisms
|
| | """
|
| |
|
| | query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
|
| |
|
| | pre_att = self.attn(query_, h)
|
| |
|
| | attn = self.attn(pre_att, pre_att)
|
| | return attn
|
| |
|