| from typing import Optional |
|
|
| import torch |
| from einops.layers.torch import Rearrange |
| from torch import Tensor, nn |
|
|
| from . import vb_layers_initialize as init |
|
|
|
|
| class AttentionPairBias(nn.Module): |
| """Attention pair bias layer.""" |
|
|
| def __init__( |
| self, |
| c_s: int, |
| c_z: int, |
| num_heads: int, |
| inf: float = 1e6, |
| initial_norm: bool = True, |
| ) -> None: |
| """Initialize the attention pair bias layer. |
| |
| Parameters |
| ---------- |
| c_s : int |
| The input sequence dimension. |
| c_z : int |
| The input pairwise dimension. |
| num_heads : int |
| The number of heads. |
| inf : float, optional |
| The inf value, by default 1e6 |
| initial_norm: bool, optional |
| Whether to apply layer norm to the input, by default True |
| |
| """ |
| super().__init__() |
|
|
| assert c_s % num_heads == 0 |
|
|
| self.c_s = c_s |
| self.num_heads = num_heads |
| self.head_dim = c_s // num_heads |
| self.inf = inf |
|
|
| self.initial_norm = initial_norm |
| if self.initial_norm: |
| self.norm_s = nn.LayerNorm(c_s) |
|
|
| self.proj_q = nn.Linear(c_s, c_s) |
| self.proj_k = nn.Linear(c_s, c_s, bias=False) |
| self.proj_v = nn.Linear(c_s, c_s, bias=False) |
| self.proj_g = nn.Linear(c_s, c_s, bias=False) |
|
|
| self.proj_z = nn.Sequential( |
| nn.LayerNorm(c_z), |
| nn.Linear(c_z, num_heads, bias=False), |
| Rearrange("b ... h -> b h ..."), |
| ) |
|
|
| self.proj_o = nn.Linear(c_s, c_s, bias=False) |
| init.final_init_(self.proj_o.weight) |
|
|
| def forward( |
| self, |
| s: Tensor, |
| z: Tensor, |
| mask: Tensor, |
| k_in: Optional[Tensor] = None, |
| multiplicity: int = 1, |
| to_keys=None, |
| model_cache=None, |
| ) -> Tensor: |
| """Forward pass. |
| |
| Parameters |
| ---------- |
| s : torch.Tensor |
| The input sequence tensor (B, S, D) |
| z : torch.Tensor |
| The input pairwise tensor (B, N, N, D) |
| mask : torch.Tensor |
| The pairwise mask tensor (B, N) |
| multiplicity : int, optional |
| The diffusion batch size, by default 1 |
| |
| Returns |
| ------- |
| torch.Tensor |
| The output sequence tensor. |
| |
| """ |
| B = s.shape[0] |
|
|
| |
| if self.initial_norm: |
| s = self.norm_s(s) |
|
|
| if to_keys is not None: |
| k_in = to_keys(s) |
| mask = to_keys(mask.unsqueeze(-1)).squeeze(-1) |
| else: |
| if k_in is None: |
| k_in = s |
|
|
| |
| q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim) |
| k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim) |
| v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim) |
|
|
| |
| if model_cache is None or "z" not in model_cache: |
| z = self.proj_z(z) |
|
|
| if model_cache is not None: |
| model_cache["z"] = z |
| else: |
| z = model_cache["z"] |
| z = z.repeat_interleave(multiplicity, 0) |
|
|
| g = self.proj_g(s).sigmoid() |
|
|
| with torch.autocast("cuda", enabled=False): |
| |
| attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float()) |
| attn = attn / (self.head_dim**0.5) + z.float() |
| |
| attn = attn + (1 - mask[:, None, None].float()) * -self.inf |
| attn = attn.softmax(dim=-1) |
|
|
| |
| o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype) |
| o = o.reshape(B, -1, self.c_s) |
| o = self.proj_o(g * o) |
|
|
| return o |
|
|