nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
from typing import Optional
import torch
from einops.layers.torch import Rearrange
from torch import Tensor, nn
from . import vb_layers_initialize as init
class AttentionPairBias(nn.Module):
"""Attention pair bias layer."""
def __init__(
self,
c_s: int,
c_z: int,
num_heads: int,
inf: float = 1e6,
initial_norm: bool = True,
) -> None:
"""Initialize the attention pair bias layer.
Parameters
----------
c_s : int
The input sequence dimension.
c_z : int
The input pairwise dimension.
num_heads : int
The number of heads.
inf : float, optional
The inf value, by default 1e6
initial_norm: bool, optional
Whether to apply layer norm to the input, by default True
"""
super().__init__()
assert c_s % num_heads == 0
self.c_s = c_s
self.num_heads = num_heads
self.head_dim = c_s // num_heads
self.inf = inf
self.initial_norm = initial_norm
if self.initial_norm:
self.norm_s = nn.LayerNorm(c_s)
self.proj_q = nn.Linear(c_s, c_s)
self.proj_k = nn.Linear(c_s, c_s, bias=False)
self.proj_v = nn.Linear(c_s, c_s, bias=False)
self.proj_g = nn.Linear(c_s, c_s, bias=False)
self.proj_z = nn.Sequential(
nn.LayerNorm(c_z),
nn.Linear(c_z, num_heads, bias=False),
Rearrange("b ... h -> b h ..."),
)
self.proj_o = nn.Linear(c_s, c_s, bias=False)
init.final_init_(self.proj_o.weight)
def forward(
self,
s: Tensor,
z: Tensor,
mask: Tensor,
k_in: Optional[Tensor] = None,
multiplicity: int = 1,
to_keys=None,
model_cache=None,
) -> Tensor:
"""Forward pass.
Parameters
----------
s : torch.Tensor
The input sequence tensor (B, S, D)
z : torch.Tensor
The input pairwise tensor (B, N, N, D)
mask : torch.Tensor
The pairwise mask tensor (B, N)
multiplicity : int, optional
The diffusion batch size, by default 1
Returns
-------
torch.Tensor
The output sequence tensor.
"""
B = s.shape[0]
# Layer norms
if self.initial_norm:
s = self.norm_s(s)
if to_keys is not None:
k_in = to_keys(s)
mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
else:
if k_in is None:
k_in = s
# Compute projections
q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
# Caching z projection during diffusion roll-out
if model_cache is None or "z" not in model_cache:
z = self.proj_z(z)
if model_cache is not None:
model_cache["z"] = z
else:
z = model_cache["z"]
z = z.repeat_interleave(multiplicity, 0)
g = self.proj_g(s).sigmoid()
with torch.autocast("cuda", enabled=False):
# Compute attention weights
attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
attn = attn / (self.head_dim**0.5) + z.float()
# The pairwise mask tensor (B, N) is broadcasted to (B, 1, 1, N) and (B, H, N, N)
attn = attn + (1 - mask[:, None, None].float()) * -self.inf
attn = attn.softmax(dim=-1)
# Compute output
o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
o = o.reshape(B, -1, self.c_s)
o = self.proj_o(g * o)
return o