from typing import Optional import torch from einops.layers.torch import Rearrange from torch import Tensor, nn from . import vb_layers_initialize as init class AttentionPairBias(nn.Module): """Attention pair bias layer.""" def __init__( self, c_s: int, c_z: int, num_heads: int, inf: float = 1e6, initial_norm: bool = True, ) -> None: """Initialize the attention pair bias layer. Parameters ---------- c_s : int The input sequence dimension. c_z : int The input pairwise dimension. num_heads : int The number of heads. inf : float, optional The inf value, by default 1e6 initial_norm: bool, optional Whether to apply layer norm to the input, by default True """ super().__init__() assert c_s % num_heads == 0 self.c_s = c_s self.num_heads = num_heads self.head_dim = c_s // num_heads self.inf = inf self.initial_norm = initial_norm if self.initial_norm: self.norm_s = nn.LayerNorm(c_s) self.proj_q = nn.Linear(c_s, c_s) self.proj_k = nn.Linear(c_s, c_s, bias=False) self.proj_v = nn.Linear(c_s, c_s, bias=False) self.proj_g = nn.Linear(c_s, c_s, bias=False) self.proj_z = nn.Sequential( nn.LayerNorm(c_z), nn.Linear(c_z, num_heads, bias=False), Rearrange("b ... h -> b h ..."), ) self.proj_o = nn.Linear(c_s, c_s, bias=False) init.final_init_(self.proj_o.weight) def forward( self, s: Tensor, z: Tensor, mask: Tensor, k_in: Optional[Tensor] = None, multiplicity: int = 1, to_keys=None, model_cache=None, ) -> Tensor: """Forward pass. Parameters ---------- s : torch.Tensor The input sequence tensor (B, S, D) z : torch.Tensor The input pairwise tensor (B, N, N, D) mask : torch.Tensor The pairwise mask tensor (B, N) multiplicity : int, optional The diffusion batch size, by default 1 Returns ------- torch.Tensor The output sequence tensor. """ B = s.shape[0] # Layer norms if self.initial_norm: s = self.norm_s(s) if to_keys is not None: k_in = to_keys(s) mask = to_keys(mask.unsqueeze(-1)).squeeze(-1) else: if k_in is None: k_in = s # Compute projections q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim) k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim) v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim) # Caching z projection during diffusion roll-out if model_cache is None or "z" not in model_cache: z = self.proj_z(z) if model_cache is not None: model_cache["z"] = z else: z = model_cache["z"] z = z.repeat_interleave(multiplicity, 0) g = self.proj_g(s).sigmoid() with torch.autocast("cuda", enabled=False): # Compute attention weights attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float()) attn = attn / (self.head_dim**0.5) + z.float() # The pairwise mask tensor (B, N) is broadcasted to (B, 1, 1, N) and (B, H, N, N) attn = attn + (1 - mask[:, None, None].float()) * -self.inf attn = attn.softmax(dim=-1) # Compute output o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype) o = o.reshape(B, -1, self.c_s) o = self.proj_o(g * o) return o