| | import torch |
| | import torch.nn as nn |
| | from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling |
| | from rfdiffusion.Track_module import IterativeSimulator |
| | from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork |
| | from opt_einsum import contract as einsum |
| |
|
| | class RoseTTAFoldModule(nn.Module): |
| | def __init__(self, |
| | n_extra_block, |
| | n_main_block, |
| | n_ref_block, |
| | d_msa, |
| | d_msa_full, |
| | d_pair, |
| | d_templ, |
| | n_head_msa, |
| | n_head_pair, |
| | n_head_templ, |
| | d_hidden, |
| | d_hidden_templ, |
| | p_drop, |
| | d_t1d, |
| | d_t2d, |
| | T, |
| | use_motif_timestep, |
| | freeze_track_motif, |
| | SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, |
| | SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, |
| | input_seq_onehot=False, |
| | ): |
| |
|
| | super(RoseTTAFoldModule, self).__init__() |
| |
|
| | self.freeze_track_motif = freeze_track_motif |
| |
|
| | |
| | d_state = SE3_param_topk['l0_out_features'] |
| | self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, |
| | p_drop=p_drop, input_seq_onehot=input_seq_onehot) |
| | self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=25, |
| | p_drop=p_drop, input_seq_onehot=input_seq_onehot) |
| | self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, |
| | n_head=n_head_templ, |
| | d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d) |
| |
|
| |
|
| | |
| | self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state) |
| | |
| | self.simulator = IterativeSimulator(n_extra_block=n_extra_block, |
| | n_main_block=n_main_block, |
| | n_ref_block=n_ref_block, |
| | d_msa=d_msa, d_msa_full=d_msa_full, |
| | d_pair=d_pair, d_hidden=d_hidden, |
| | n_head_msa=n_head_msa, |
| | n_head_pair=n_head_pair, |
| | SE3_param_full=SE3_param_full, |
| | SE3_param_topk=SE3_param_topk, |
| | p_drop=p_drop) |
| | |
| | self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop) |
| | self.aa_pred = MaskedTokenNetwork(d_msa) |
| | self.lddt_pred = LDDTNetwork(d_state) |
| | |
| | self.exp_pred = ExpResolvedNetwork(d_msa, d_state) |
| |
|
| | def forward(self, msa_latent, msa_full, seq, xyz, idx, t, |
| | t1d=None, t2d=None, xyz_t=None, alpha_t=None, |
| | msa_prev=None, pair_prev=None, state_prev=None, |
| | return_raw=False, return_full=False, return_infer=False, |
| | use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None): |
| |
|
| | B, N, L = msa_latent.shape[:3] |
| | |
| | msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx) |
| | msa_full = self.full_emb(msa_full, seq, idx) |
| |
|
| | |
| | if msa_prev == None: |
| | msa_prev = torch.zeros_like(msa_latent[:,0]) |
| | pair_prev = torch.zeros_like(pair) |
| | state_prev = torch.zeros_like(state) |
| | msa_recycle, pair_recycle, state_recycle = self.recycle(seq, msa_prev, pair_prev, xyz, state_prev) |
| | msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1) |
| | pair = pair + pair_recycle |
| | state = state + state_recycle |
| |
|
| |
|
| | |
| | if hasattr(self, 'timestep_embedder'): |
| | assert t is not None |
| | time_emb = self.timestep_embedder(L,t,motif_mask) |
| | n_tmpl = t1d.shape[1] |
| | t1d = torch.cat([t1d, time_emb[None,None,...].repeat(1,n_tmpl,1,1)], dim=-1) |
| |
|
| | |
| | pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint) |
| | |
| | |
| | is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool() |
| | msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3], |
| | state, idx, use_checkpoint=use_checkpoint, |
| | motif_mask=is_frozen_residue) |
| | |
| | if return_raw: |
| | |
| | xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2) |
| | return msa[:,0], pair, xyz, state, alpha_s[-1] |
| |
|
| | |
| | logits_aa = self.aa_pred(msa) |
| | |
| | |
| | lddt = self.lddt_pred(state) |
| |
|
| | if return_infer: |
| | |
| | xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2) |
| | |
| | |
| | nbin = lddt.shape[1] |
| | bin_step = 1.0 / nbin |
| | lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=lddt.dtype, device=lddt.device) |
| | pred_lddt = nn.Softmax(dim=1)(lddt) |
| | pred_lddt = torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) |
| |
|
| | return msa[:,0], pair, xyz, state, alpha_s[-1], logits_aa.permute(0,2,1), pred_lddt |
| |
|
| | |
| | |
| | logits = self.c6d_pred(pair) |
| | |
| | |
| | logits_exp = self.exp_pred(msa[:,0], state) |
| | |
| | |
| | xyz = einsum('rbnij,bnaj->rbnai', R, xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T.unsqueeze(-2) |
| |
|
| | return logits, logits_aa, logits_exp, xyz, alpha_s, lddt |
| |
|