FlowProt / model /models /proteinflow.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
9.16 kB
"""
ProteinFlow model
model:
node_embed_size: 256
edge_embed_size: 128
symmetric: False
node_features:
c_s: ${model.node_embed_size}
c_pos_emb: 128
c_timestep_emb: 128
embed_diffuse_mask: False
max_num_res: 2000
timestep_int: 1000
edge_features:
single_bias_transition_n: 2
c_s: ${model.node_embed_size}
c_p: ${model.edge_embed_size}
relpos_k: 64
use_rbf: True
num_rbf: 32
feat_dim: 64
num_bins: 22
self_condition: True
ipa:
c_s: ${model.node_embed_size}
c_z: ${model.edge_embed_size}
c_hidden: 128
no_heads: 8
no_qk_points: 8
no_v_points: 12
seq_tfmr_num_heads: 4
seq_tfmr_num_layers: 2
num_blocks: 6
"""
import torch
from torch import nn
from utils import modelUtils as u
from models import ipa_pytorch
NM_TO_ANG_SCALE = 10.0
ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE
class ProteinFlow(nn.Module):
def __init__(self, model_conf):
super(ProteinFlow, self).__init__()
self._model_conf = model_conf
self._ipa_conf = model_conf.ipa
# Convert angstrom to nm
self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * ANG_TO_NM_SCALE)
# Inverse
self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * NM_TO_ANG_SCALE)
self.node_embedder = NodeEmbedder(model_conf.node_features)
self.edge_embedder = EdgeEmbedder(model_conf.edge_features)
# Attention trunk
self.trunk = nn.ModuleDict()
for b in range(self._ipa_conf.num_blocks):
self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
tfmr_in = self._ipa_conf.c_s
tfmr_layer = torch.nn.TransformerEncoderLayer(
d_model=tfmr_in,
nhead=self._ipa_conf.seq_tfmr_num_heads,
dim_feedforward=tfmr_in,
batch_first=True,
dropout=0.0,
norm_first=False
)
self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False
)
self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
tfmr_in, self._ipa_conf.c_s, init='final'
)
self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
c=self._ipa_conf.c_s
)
self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
self._ipa_conf.c_s, use_rot_updates=True
)
if b < self._ipa_conf.num_blocks - 1:
# No edge update
edge_in = self._model_conf.edge_embed_size
self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
node_embed_size=self._ipa_conf.c_s,
edge_embed_in=edge_in,
edge_embed_out=self._model_conf.edge_embed_size,
)
def forward(self, input_features):
# Get features
node_mask = input_features['res_mask']
edge_mask = node_mask[:, None] * node_mask[:, :, None]
continuous_t = input_features['t']
trans_t = input_features['trans_t']
rotmats_t = input_features['rotmats_t']
# Get embeddings
init_node_embed = self.node_embedder(continuous_t, node_mask)
if 'trans_sc' not in input_features:
trans_sc = torch.zeros_like(trans_t)
else:
trans_sc = input_features['trans_sc']
init_edge_embed = self.edge_embedder(
init_node_embed, trans_t, trans_sc, edge_mask
)
curr_rigids = u.create_rigid(rotmats_t, trans_t)
# Send to the trunk
curr_rigids = self.rigids_ang_to_nm(curr_rigids)
init_node_embed = init_node_embed * node_mask[..., None]
node_embed = init_node_embed * node_mask[..., None]
edge_embed = init_edge_embed * edge_mask[..., None]
for b in range(self._ipa_conf.num_blocks):
ipa_embed = self.trunk[f'ipa_{b}'](
node_embed,
edge_embed,
curr_rigids,
node_mask
)
ipa_embed *= node_mask[..., None]
node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
node_embed, src_key_padding_mask=(1 - node_mask).to(torch.bool))
node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
node_embed = self.trunk[f'node_transition_{b}'](node_embed)
node_embed = node_embed * node_mask[..., None]
rigid_update = self.trunk[f'bb_update_{b}'](
node_embed * node_mask[..., None])
curr_rigids = curr_rigids.compose_q_update_vec(
rigid_update, node_mask[..., None])
if b < self._ipa_conf.num_blocks - 1:
edge_embed = self.trunk[f'edge_transition_{b}'](
node_embed, edge_embed)
edge_embed *= edge_mask[..., None]
curr_rigids = self.rigids_nm_to_ang(curr_rigids)
pred_trans = curr_rigids.get_trans()
pred_rotmats = curr_rigids.get_rots().get_rot_mats()
return {
'pred_trans': pred_trans,
'pred_rotmats': pred_rotmats
}
class NodeEmbedder(nn.Module):
"""
node_features:
c_s: ${model.node_embed_size}
c_pos_emb: 128
c_timestep_emb: 128
embed_diffuse_mask: False
max_num_res: 2000
timestep_int: 1000
"""
def __init__(self, module_cfg):
super(NodeEmbedder, self).__init__()
self._cfg = module_cfg
self.c_s = self._cfg.c_s
self.c_pos_emb = self._cfg.c_pos_emb
self.c_timestep_emb = self._cfg.c_timestep_emb
self.linear = nn.Linear(
self._cfg.c_pos_emb + self._cfg.c_timestep_emb, self.c_s
)
def embed_t(self, timesteps, mask):
timestep_emb = u.get_time_embedding(
timesteps[:, 0],
self.c_timestep_emb,
max_positions=2056
)
timestep_emb = timestep_emb[:, None, :].repeat(1, mask.shape[1], 1)
return timestep_emb * mask.unsqueeze(-1)
def forward(self, timesteps, mask):
# b: batch size
b, num_res, device = mask.shape[0], mask.shape[1], mask.device
pos = torch.arange(num_res, dtype=torch.float32).to(device)[None]
pos_emb = u.get_index_embedding(pos, self.c_pos_emb, max_len=2056)
pos_emb = pos_emb.repeat([b, 1, 1])
pos_emb = pos_emb * mask.unsqueeze(-1)
input_features = [pos_emb, self.embed_t(timesteps, mask)]
return self.linear(torch.cat(input_features, dim=-1))
class EdgeEmbedder(nn.Module):
"""
edge_features:
single_bias_transition_n: 2
c_s: ${model.node_embed_size}
c_p: ${model.edge_embed_size}
relpos_k: 64
use_rbf: True
num_rbf: 32
feat_dim: 64
num_bins: 22
self_condition: True
"""
def __init__(self, module_cfg):
super(EdgeEmbedder, self).__init__()
self._cfg = module_cfg
self.c_s = self._cfg.c_s
self.c_p = self._cfg.c_p
self.feat_dim = self._cfg.feat_dim
self.linear_s_p = nn.Linear(self.c_s, self.feat_dim)
self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim)
total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2
self.edge_embedder = nn.Sequential(
nn.Linear(total_edge_feats, self.c_p),
nn.ReLU(),
nn.Linear(self.c_p, self.c_p),
nn.ReLU(),
nn.Linear(self.c_p, self.c_p),
nn.LayerNorm(self.c_p),
)
def embed_relpos(self, pos):
rel_pos = pos[:, :, None] - pos[:, None, :]
pos_emb = u.get_index_embedding(rel_pos, self._cfg.feat_dim, max_len=2056)
return self.linear_relpos(pos_emb)
def _cross_concat(self, feats_1d, num_batch, num_res):
return torch.cat([
torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)),
torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)),
], dim=-1).float().reshape([num_batch, num_res, num_res, -1])
def forward(self, s, t, sc_t, p_mask):
num_batch, num_res, _ = s.shape
p_i = self.linear_s_p(s)
cross_node_feats = self._cross_concat(p_i, num_batch, num_res)
pos = torch.arange(
num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1)
relpos_feats = self.embed_relpos(pos)
dist_feats = u.calc_distogram(
t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins)
sc_feats = u.calc_distogram(
sc_t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins)
all_edge_feats = torch.concat(
[cross_node_feats, relpos_feats, dist_feats, sc_feats], dim=-1)
edge_feats = self.edge_embedder(all_edge_feats)
edge_feats *= p_mask.unsqueeze(-1)
return edge_feats