directionality_probe / protify /FastPLMs /boltz_fastplms /vb_modules_encodersv2.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
from functools import partial
from math import pi
import torch
from einops import rearrange
from torch import nn
from torch.nn import Linear, Module, ModuleList
from torch.nn.functional import one_hot
from . import vb_layers_initialize as init
from .vb_layers_transition import Transition
from .vb_modules_transformersv2 import AtomTransformer
from .vb_modules_utils import LinearNoBias
class FourierEmbedding(Module):
"""Algorithm 22."""
def __init__(self, dim):
super().__init__()
self.proj = nn.Linear(1, dim)
torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
self.proj.requires_grad_(False)
def forward(
self,
times, # Float[' b'],
): # -> Float['b d']:
times = rearrange(times, "b -> b 1")
rand_proj = self.proj(times)
return torch.cos(2 * pi * rand_proj)
class RelativePositionEncoder(Module):
"""Algorithm 3."""
def __init__(
self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False
):
super().__init__()
self.r_max = r_max
self.s_max = s_max
self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
self.fix_sym_check = fix_sym_check
self.cyclic_pos_enc = cyclic_pos_enc
def forward(self, feats):
b_same_chain = torch.eq(
feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
)
b_same_residue = torch.eq(
feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
)
b_same_entity = torch.eq(
feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
)
d_residue = (
feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
)
if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0):
period = torch.where(
feats["cyclic_period"] > 0,
feats["cyclic_period"],
torch.zeros_like(feats["cyclic_period"]) + 10000,
)
d_residue = (d_residue - period * torch.round(d_residue / period)).long()
d_residue = torch.clip(
d_residue + self.r_max,
0,
2 * self.r_max,
)
d_residue = torch.where(
b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
)
a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
d_token = torch.clip(
feats["token_index"][:, :, None]
- feats["token_index"][:, None, :]
+ self.r_max,
0,
2 * self.r_max,
)
d_token = torch.where(
b_same_chain & b_same_residue,
d_token,
torch.zeros_like(d_token) + 2 * self.r_max + 1,
)
a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
d_chain = torch.clip(
feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
0,
2 * self.s_max,
)
d_chain = torch.where(
(~b_same_entity) if self.fix_sym_check else b_same_chain,
torch.zeros_like(d_chain) + 2 * self.s_max + 1,
d_chain,
)
# Note: added | (~b_same_entity) based on observation of ProteinX manuscript
a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
p = self.linear_layer(
torch.cat(
[
a_rel_pos.float(),
a_rel_token.float(),
b_same_entity.unsqueeze(-1).float(),
a_rel_chain.float(),
],
dim=-1,
)
)
return p
class SingleConditioning(Module):
"""Algorithm 21."""
def __init__(
self,
sigma_data: float,
token_s: int = 384,
dim_fourier: int = 256,
num_transitions: int = 2,
transition_expansion_factor: int = 2,
eps: float = 1e-20,
disable_times: bool = False,
) -> None:
super().__init__()
self.eps = eps
self.sigma_data = sigma_data
self.disable_times = disable_times
self.norm_single = nn.LayerNorm(2 * token_s)
self.single_embed = nn.Linear(2 * token_s, 2 * token_s)
if not self.disable_times:
self.fourier_embed = FourierEmbedding(dim_fourier)
self.norm_fourier = nn.LayerNorm(dim_fourier)
self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
transitions = ModuleList([])
for _ in range(num_transitions):
transition = Transition(
dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
)
transitions.append(transition)
self.transitions = transitions
def forward(
self,
times, # Float[' b'],
s_trunk, # Float['b n ts'],
s_inputs, # Float['b n ts'],
): # -> Float['b n 2ts']:
s = torch.cat((s_trunk, s_inputs), dim=-1)
s = self.single_embed(self.norm_single(s))
if not self.disable_times:
fourier_embed = self.fourier_embed(
times
) # note: sigma rescaling done in diffusion module
normed_fourier = self.norm_fourier(fourier_embed)
fourier_to_single = self.fourier_to_single(normed_fourier)
s = rearrange(fourier_to_single, "b d -> b 1 d") + s
for transition in self.transitions:
s = transition(s) + s
return s, normed_fourier if not self.disable_times else None
class PairwiseConditioning(Module):
"""Algorithm 21."""
def __init__(
self,
token_z,
dim_token_rel_pos_feats,
num_transitions=2,
transition_expansion_factor=2,
):
super().__init__()
self.dim_pairwise_init_proj = nn.Sequential(
nn.LayerNorm(token_z + dim_token_rel_pos_feats),
LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
)
transitions = ModuleList([])
for _ in range(num_transitions):
transition = Transition(
dim=token_z, hidden=transition_expansion_factor * token_z
)
transitions.append(transition)
self.transitions = transitions
def forward(
self,
z_trunk, # Float['b n n tz'],
token_rel_pos_feats, # Float['b n n 3'],
): # -> Float['b n n tz']:
z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
z = self.dim_pairwise_init_proj(z)
for transition in self.transitions:
z = transition(z) + z
return z
def get_indexing_matrix(K, W, H, device):
assert W % 2 == 0
assert H % (W // 2) == 0
h = H // (W // 2)
assert h % 2 == 0
arange = torch.arange(2 * K, device=device)
index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
min=0, max=h + 1
)
index = index.view(K, 2, 2 * K)[:, 0, :]
onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
return onehot.reshape(2 * K, h * K).float()
def single_to_keys(single, indexing_matrix, W, H):
B, N, D = single.shape
K = N // W
single = single.view(B, 2 * K, W // 2, D)
return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
B, K, H, D
) # j = 2K, i = W//2, k = h * K
class AtomEncoder(Module):
def __init__(
self,
atom_s,
atom_z,
token_s,
token_z,
atoms_per_window_queries,
atoms_per_window_keys,
atom_feature_dim,
structure_prediction=True,
use_no_atom_char=False,
use_atom_backbone_feat=False,
use_residue_feats_atoms=False,
):
super().__init__()
self.embed_atom_features = Linear(atom_feature_dim, atom_s)
self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
self.embed_atompair_mask = LinearNoBias(1, atom_z)
self.atoms_per_window_queries = atoms_per_window_queries
self.atoms_per_window_keys = atoms_per_window_keys
self.use_no_atom_char = use_no_atom_char
self.use_atom_backbone_feat = use_atom_backbone_feat
self.use_residue_feats_atoms = use_residue_feats_atoms
self.structure_prediction = structure_prediction
if structure_prediction:
self.s_to_c_trans = nn.Sequential(
nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
)
init.final_init_(self.s_to_c_trans[1].weight)
self.z_to_p_trans = nn.Sequential(
nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
)
init.final_init_(self.z_to_p_trans[1].weight)
self.c_to_p_trans_k = nn.Sequential(
nn.ReLU(),
LinearNoBias(atom_s, atom_z),
)
init.final_init_(self.c_to_p_trans_k[1].weight)
self.c_to_p_trans_q = nn.Sequential(
nn.ReLU(),
LinearNoBias(atom_s, atom_z),
)
init.final_init_(self.c_to_p_trans_q[1].weight)
self.p_mlp = nn.Sequential(
nn.ReLU(),
LinearNoBias(atom_z, atom_z),
nn.ReLU(),
LinearNoBias(atom_z, atom_z),
nn.ReLU(),
LinearNoBias(atom_z, atom_z),
)
init.final_init_(self.p_mlp[5].weight)
def forward(
self,
feats,
s_trunk=None, # Float['bm n ts'],
z=None, # Float['bm n n tz'],
):
with torch.autocast("cuda", enabled=False):
B, N, _ = feats["ref_pos"].shape
atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
atom_ref_pos = feats["ref_pos"] # Float['b m 3'],
atom_uid = feats["ref_space_uid"] # Long['b m'],
atom_feats = [
atom_ref_pos,
feats["ref_charge"].unsqueeze(-1),
feats["ref_element"],
]
if not self.use_no_atom_char:
atom_feats.append(feats["ref_atom_name_chars"].reshape(B, N, 4 * 64))
if self.use_atom_backbone_feat:
atom_feats.append(feats["atom_backbone_feat"])
if self.use_residue_feats_atoms:
res_feats = torch.cat(
[
feats["res_type"],
feats["modified"].unsqueeze(-1),
one_hot(feats["mol_type"], num_classes=4).float(),
],
dim=-1,
)
atom_to_token = feats["atom_to_token"].float()
atom_res_feats = torch.bmm(atom_to_token, res_feats)
atom_feats.append(atom_res_feats)
atom_feats = torch.cat(atom_feats, dim=-1)
c = self.embed_atom_features(atom_feats)
# note we are already creating the windows to make it more efficient
W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
B, N = c.shape[:2]
K = N // W
keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
to_keys = partial(
single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
)
atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3']
d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1']
d_norm = 1 / (
1 + d_norm
) # AF3 feeds in the reciprocal of the distance norm
atom_mask_queries = atom_mask.view(B, K, W, 1)
atom_mask_keys = (
to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
)
atom_uid_queries = atom_uid.view(B, K, W, 1)
atom_uid_keys = (
to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
)
v = (
(
atom_mask_queries
& atom_mask_keys
& (atom_uid_queries == atom_uid_keys)
)
.float()
.unsqueeze(-1)
) # Bool['b k w h 1']
p = self.embed_atompair_ref_pos(d) * v
p = p + self.embed_atompair_ref_dist(d_norm) * v
p = p + self.embed_atompair_mask(v) * v
q = c
if self.structure_prediction:
# run only in structure model not in initial encoding
atom_to_token = feats["atom_to_token"].float() # Long['b m n'],
s_to_c = self.s_to_c_trans(s_trunk.float())
s_to_c = torch.bmm(atom_to_token, s_to_c)
c = c + s_to_c.to(c)
atom_to_token_queries = atom_to_token.view(
B, K, W, atom_to_token.shape[-1]
)
atom_to_token_keys = to_keys(atom_to_token)
z_to_p = self.z_to_p_trans(z.float())
z_to_p = torch.einsum(
"bijd,bwki,bwlj->bwkld",
z_to_p,
atom_to_token_queries,
atom_to_token_keys,
)
p = p + z_to_p.to(p)
p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
p = p + self.p_mlp(p)
return q, c, p, to_keys
class AtomAttentionEncoder(Module):
def __init__(
self,
atom_s,
token_s,
atoms_per_window_queries,
atoms_per_window_keys,
atom_encoder_depth=3,
atom_encoder_heads=4,
structure_prediction=True,
activation_checkpointing=False,
transformer_post_layer_norm=False,
):
super().__init__()
self.structure_prediction = structure_prediction
if structure_prediction:
self.r_to_q_trans = LinearNoBias(3, atom_s)
init.final_init_(self.r_to_q_trans.weight)
self.atom_encoder = AtomTransformer(
dim=atom_s,
dim_single_cond=atom_s,
attn_window_queries=atoms_per_window_queries,
attn_window_keys=atoms_per_window_keys,
depth=atom_encoder_depth,
heads=atom_encoder_heads,
activation_checkpointing=activation_checkpointing,
post_layer_norm=transformer_post_layer_norm,
)
self.atom_to_token_trans = nn.Sequential(
LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
nn.ReLU(),
)
def forward(
self,
feats,
q,
c,
atom_enc_bias,
to_keys,
r=None, # Float['bm m 3'],
multiplicity=1,
):
B, N, _ = feats["ref_pos"].shape
atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
if self.structure_prediction:
# only here the multiplicity kicks in because we use the different positions r
q = q.repeat_interleave(multiplicity, 0)
r_to_q = self.r_to_q_trans(r)
q = q + r_to_q
c = c.repeat_interleave(multiplicity, 0)
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
q = self.atom_encoder(
q=q,
mask=atom_mask,
c=c,
bias=atom_enc_bias,
multiplicity=multiplicity,
to_keys=to_keys,
)
with torch.autocast("cuda", enabled=False):
q_to_a = self.atom_to_token_trans(q).float()
atom_to_token = feats["atom_to_token"].float()
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
atom_to_token_mean = atom_to_token / (
atom_to_token.sum(dim=1, keepdim=True) + 1e-6
)
a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
a = a.to(q)
return a, q, c, to_keys
class AtomAttentionDecoder(Module):
"""Algorithm 6."""
def __init__(
self,
atom_s,
token_s,
attn_window_queries,
attn_window_keys,
atom_decoder_depth=3,
atom_decoder_heads=4,
activation_checkpointing=False,
transformer_post_layer_norm=False,
):
super().__init__()
self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
init.final_init_(self.a_to_q_trans.weight)
self.atom_decoder = AtomTransformer(
dim=atom_s,
dim_single_cond=atom_s,
attn_window_queries=attn_window_queries,
attn_window_keys=attn_window_keys,
depth=atom_decoder_depth,
heads=atom_decoder_heads,
activation_checkpointing=activation_checkpointing,
post_layer_norm=transformer_post_layer_norm,
)
if transformer_post_layer_norm:
self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3)
init.final_init_(self.atom_feat_to_atom_pos_update.weight)
else:
self.atom_feat_to_atom_pos_update = nn.Sequential(
nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
)
init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
def forward(
self,
a, # Float['bm n 2ts'],
q, # Float['bm m as'],
c, # Float['bm m as'],
atom_dec_bias, # Float['bm m m az'],
feats,
to_keys,
multiplicity=1,
):
with torch.autocast("cuda", enabled=False):
atom_to_token = feats["atom_to_token"].float()
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
a_to_q = self.a_to_q_trans(a.float())
a_to_q = torch.bmm(atom_to_token, a_to_q)
q = q + a_to_q.to(q)
atom_mask = feats["atom_pad_mask"] # Bool['b m'],
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
q = self.atom_decoder(
q=q,
mask=atom_mask,
c=c,
bias=atom_dec_bias,
multiplicity=multiplicity,
to_keys=to_keys,
)
r_update = self.atom_feat_to_atom_pos_update(q)
return r_update