| | |
| |
|
| | from collections import namedtuple |
| | from functools import wraps |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| | from einops.layers.torch import Rearrange |
| | from packaging import version |
| | from torch import einsum, nn |
| |
|
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def once(fn): |
| | called = False |
| |
|
| | @wraps(fn) |
| | def inner(x): |
| | nonlocal called |
| | if called: |
| | return |
| | called = True |
| | return fn(x) |
| |
|
| | return inner |
| |
|
| |
|
| | print_once = once(print) |
| |
|
| |
|
| | |
| | class Attend(nn.Module): |
| | def __init__(self, dropout=0.0, causal=False, use_flash=False): |
| | super().__init__() |
| | self.dropout = dropout |
| | self.attn_dropout = nn.Dropout(dropout) |
| |
|
| | self.causal = causal |
| | self.register_buffer("mask", None, persistent=False) |
| |
|
| | self.use_flash = use_flash |
| | assert not ( |
| | use_flash and version.parse(torch.__version__) < version.parse("2.0.0") |
| | ), "in order to use flash attention, you must be using pytorch 2.0 or above" |
| |
|
| | |
| | self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) |
| | self.cpu_config = self.config(True, True, True) |
| | self.cuda_config = None |
| |
|
| | if not torch.cuda.is_available() or not use_flash: |
| | return |
| |
|
| | device_properties = torch.cuda.get_device_properties(torch.device("cuda")) |
| |
|
| | if device_properties.major == 8 and device_properties.minor == 0: |
| | print_once("A100 GPU detected, using flash attention if input tensor is on cuda") |
| | self.cuda_config = self.config(True, False, False) |
| | else: |
| | print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") |
| | self.cuda_config = self.config(False, True, True) |
| |
|
| | def get_mask(self, n, device): |
| | if exists(self.mask) and self.mask.shape[-1] >= n: |
| | return self.mask[:n, :n] |
| |
|
| | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) |
| | self.register_buffer("mask", mask, persistent=False) |
| | return mask |
| |
|
| | def flash_attn(self, q, k, v, mask=None): |
| | _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda |
| |
|
| | |
| | |
| |
|
| | if k.ndim == 3: |
| | k = rearrange(k, "b ... -> b 1 ...").expand_as(q) |
| |
|
| | if v.ndim == 3: |
| | v = rearrange(v, "b ... -> b 1 ...").expand_as(q) |
| |
|
| | |
| | |
| |
|
| | if exists(mask): |
| | mask = rearrange(mask, "b j -> b 1 1 j") |
| | mask = mask.expand(-1, heads, q_len, -1) |
| |
|
| | |
| |
|
| | config = self.cuda_config if is_cuda else self.cpu_config |
| |
|
| | |
| |
|
| | with torch.backends.cuda.sdp_kernel(**config._asdict()): |
| | out = F.scaled_dot_product_attention( |
| | q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal |
| | ) |
| |
|
| | return out |
| |
|
| | def forward(self, q, k, v, mask=None): |
| | """ |
| | einstein notation |
| | b - batch |
| | h - heads |
| | n, i, j - sequence length (base sequence length, source, target) |
| | d - feature dimension |
| | """ |
| |
|
| | n, device = q.shape[-2], q.device |
| |
|
| | scale = q.shape[-1] ** -0.5 |
| |
|
| | if self.use_flash: |
| | return self.flash_attn(q, k, v, mask=mask) |
| |
|
| | kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" |
| |
|
| | |
| |
|
| | sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale |
| |
|
| | |
| |
|
| | if exists(mask): |
| | mask = rearrange(mask, "b j -> b 1 1 j") |
| | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) |
| |
|
| | |
| |
|
| | if self.causal: |
| | causal_mask = self.get_mask(n, device) |
| | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) |
| |
|
| | |
| |
|
| | attn = sim.softmax(dim=-1) |
| | attn = self.attn_dropout(attn) |
| |
|
| | |
| |
|
| | out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) |
| |
|
| | return out |
| |
|
| |
|
| | def Sequential(*mods): |
| | return nn.Sequential(*filter(exists, mods)) |
| |
|
| |
|
| | def exists(x): |
| | return x is not None |
| |
|
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim, scale=True, dim_cond=None): |
| | super().__init__() |
| | self.cond = exists(dim_cond) |
| | self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None |
| |
|
| | self.scale = dim**0.5 |
| | self.gamma = nn.Parameter(torch.ones(dim)) if scale else None |
| |
|
| | def forward(self, x, cond=None): |
| | gamma = default(self.gamma, 1) |
| | out = F.normalize(x, dim=-1) * self.scale * gamma |
| |
|
| | if not self.cond: |
| | return out |
| |
|
| | assert exists(cond) |
| | gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) |
| | gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) |
| | return out * gamma + beta |
| |
|
| |
|
| | class CausalConv1d(nn.Conv1d): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | (kernel_size,) = self.kernel_size |
| | (dilation,) = self.dilation |
| | (stride,) = self.stride |
| |
|
| | assert stride == 1 |
| | self.causal_padding = dilation * (kernel_size - 1) |
| |
|
| | def forward(self, x): |
| | causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) |
| | return super().forward(causal_padded_x) |
| |
|
| |
|
| | class GEGLU(nn.Module): |
| | def forward(self, x): |
| | x, gate = x.chunk(2, dim=-1) |
| | return F.gelu(gate) * x |
| |
|
| |
|
| | def FeedForward(dim, mult=4, causal_conv=False): |
| | dim_inner = int(dim * mult * 2 / 3) |
| |
|
| | conv = None |
| | if causal_conv: |
| | conv = nn.Sequential( |
| | Rearrange("b n d -> b d n"), |
| | CausalConv1d(dim_inner, dim_inner, 3), |
| | Rearrange("b d n -> b n d"), |
| | ) |
| |
|
| | return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)) |
| |
|
| |
|
| | class PerceiverResampler(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | depth=2, |
| | dim_context=None, |
| | num_latents=32, |
| | dim_head=64, |
| | heads=8, |
| | ff_mult=4, |
| | use_flash_attn=False, |
| | ): |
| | super().__init__() |
| | dim_context = default(dim_context, dim) |
| |
|
| | self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() |
| |
|
| | self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
| | nn.init.normal_(self.latents, std=0.02) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(depth): |
| | self.layers.append( |
| | nn.ModuleList( |
| | [ |
| | Attention( |
| | dim=dim, |
| | dim_head=dim_head, |
| | heads=heads, |
| | use_flash=use_flash_attn, |
| | cross_attn_include_queries=True, |
| | ), |
| | FeedForward(dim=dim, mult=ff_mult), |
| | ] |
| | ) |
| | ) |
| |
|
| | self.norm = RMSNorm(dim) |
| |
|
| | def forward(self, x, mask=None): |
| | batch = x.shape[0] |
| |
|
| | x = self.proj_context(x) |
| |
|
| | latents = repeat(self.latents, "n d -> b n d", b=batch) |
| |
|
| | for attn, ff in self.layers: |
| | latents = attn(latents, x, mask=mask) + latents |
| | latents = ff(latents) + latents |
| |
|
| | return self.norm(latents) |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | *, |
| | dim_context=None, |
| | causal=False, |
| | dim_head=64, |
| | heads=8, |
| | dropout=0.0, |
| | use_flash=False, |
| | cross_attn_include_queries=False, |
| | ): |
| | super().__init__() |
| | self.scale = dim_head**-0.5 |
| | self.heads = heads |
| | self.cross_attn_include_queries = cross_attn_include_queries |
| |
|
| | dim_inner = dim_head * heads |
| | dim_context = default(dim_context, dim) |
| |
|
| | self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) |
| | self.to_q = nn.Linear(dim, dim_inner, bias=False) |
| | self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) |
| | self.to_out = nn.Linear(dim_inner, dim, bias=False) |
| |
|
| | def forward(self, x, context=None, mask=None): |
| | h, has_context = self.heads, exists(context) |
| |
|
| | context = default(context, x) |
| |
|
| | if has_context and self.cross_attn_include_queries: |
| | context = torch.cat((x, context), dim=-2) |
| |
|
| | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) |
| | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) |
| |
|
| | out = self.attend(q, k, v, mask=mask) |
| |
|
| | out = rearrange(out, "b h n d -> b n (h d)") |
| | return self.to_out(out) |
| |
|