import torch import torch.nn as nn import torch.nn.functional as F import itertools import numpy as np from src.modules.graphtrans_module import gather_nodes, cat_neighbors_nodes from src.modules.proteinmpnn_module import EncLayer, DecLayer, ProteinFeatures class ProteinMPNN_Model(nn.Module): def __init__(self, args, **kwargs): """ Graph labeling network """ super(ProteinMPNN_Model, self).__init__() # Hyperparameters self.node_features = args.hidden self.edge_features = args.hidden self.hidden_dim = args.hidden self.k_neighbors = args.k_neighbors self.augment_eps = args.augment_eps self.vocab = args.vocab self.num_letters = args.num_letters self.num_encoder_layers = args.num_encoder_layers self.num_decoder_layers = args.num_decoder_layers self.dropout = args.dropout self.proteinmpnn_type = args.proteinmpnn_type if args.proteinmpnn_type == 1: self.augment_eps = 0.02 self.init_flex_features = args.init_flex_features self.use_dynamics = args.use_dynamics # Featurization layers self.features = ProteinFeatures(self.node_features, self.edge_features, top_k=self.k_neighbors, augment_eps=self.augment_eps, proteinmpnn_type=self.proteinmpnn_type) self.W_e = nn.Linear(self.edge_features, self.hidden_dim, bias=True) self.W_s = nn.Embedding(self.vocab, self.hidden_dim) # import pdb; pdb.set_trace() #TODO check the path is correctly read from the config self.init_pmpnn_weights = args.use_pmpnn_checkpoint self.pmpnn_init_weights_path = None if not self.init_pmpnn_weights else args.starting_checkpoint_path # Encoder layers self.encoder_layers = nn.ModuleList([ EncLayer(self.hidden_dim, self.hidden_dim*2, dropout=self.dropout, proteinmpnn_type=self.proteinmpnn_type) for _ in range(self.num_encoder_layers) ]) # Decoder layers self.decoder_layers = nn.ModuleList([ DecLayer(self.hidden_dim, self.hidden_dim*3, dropout=self.dropout) for _ in range(self.num_decoder_layers) ]) self.W_out = nn.Linear(self.hidden_dim, self.num_letters, bias=True) for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) self._init_params() # self.gt_flex_cache = {} def _init_params(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def _autoregressive_mask(self, E_idx): N_nodes = E_idx.size(1) ii = torch.arange(N_nodes) ii = ii.view((1, -1, 1)).to(E_idx.device) mask = E_idx - ii < 0 mask = mask.type(torch.float32) return mask def _get_features(self, batch): return batch def forward(self, batch, use_input_decoding_order=False, decoding_order=None): """ Graph-conditioned sequence model """ X, S, score, mask, lengths, chain_M, chain_M_pos, residue_idx, chain_encoding_all = batch['X'], batch['S'], batch['score'], batch['mask'], batch['lengths'], batch['chain_M'], batch['chain_M_pos'], batch['residue_idx'], batch['chain_encoding_all'] # import pdb; pdb.set_trace() randn = torch.randn(chain_M.shape, device=X.device) chain_M = chain_M*chain_M_pos device = X.device # Prepare node and edge embeddings E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) if self.init_flex_features: #and self.use_dynamics: # gt_seq = batch['S'] # anm_input = batch['enm_vals'] gt_flex = batch['gt_flex'] # trail_idcs = torch.argmax((batch['S'] == 0).int(), dim=1) # trail_idcs[trail_idcs == 0] = batch['S'].shape[1] # cache_keys = list(batch['title']) # # Check if all cache_keys are in self.gt_flex_cache # all_keys_in_cache = all(cache_key in self.gt_flex_cache for cache_key in cache_keys) # #TODO: check the keys!!! # if not all_keys_in_cache: # gt_flex = self.flex_model(None, anm_input, trail_idcs, attention_mask=batch['mask'], sampled_pmpnn_sequence=gt_seq, alphabet='pmpnn')['predicted_flex'][:,:-1,0] # for key, val in zip(cache_keys, gt_flex): # self.gt_flex_cache[key] = val # else: # retrieved_gt_flexs = [] # for key in cache_keys: # _gt_flex = self.gt_flex_cache[key] # retrieved_gt_flexs.append(_gt_flex) # gt_flex = torch.cat(retrieved_gt_flexs, dim=0) h_V = gt_flex.unsqueeze(-1).expand(-1, -1, E.shape[-1]).clone() h_V = torch.nan_to_num(h_V, nan=0.0) else: h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for layer in self.encoder_layers: h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) # Concatenate sequence embeddings for autoregressive decoder h_S = self.W_s(S) h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) # Build encoder embeddings h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) if self.proteinmpnn_type == 4: mask_attend = self._autoregressive_mask(E_idx).unsqueeze(-1) else: chain_M = chain_M*mask #update chain_M to include missing regions if not use_input_decoding_order: decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(randn))) # [8, 901] # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] mask_size = E_idx.shape[1] permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() # [8, 901, 901] order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) h_EXV_encoder_fw = mask_fw * h_EXV_encoder for layer in self.decoder_layers: # Masked positions attend to encoder information, unmasked see. h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw h_V = layer(h_V, h_ESV, mask) logits = self.W_out(h_V) log_probs = F.log_softmax(logits, dim=-1) return {'log_probs':log_probs, 'logits':logits} def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, bias_by_res=None): device = X.device # Prepare node and edge embeddings E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for layer in self.encoder_layers: h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) # Decoder uses masked self-attention chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] mask_size = E_idx.shape[1] permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) N_batch, N_nodes = X.size(0), X.size(1) log_probs = torch.zeros((N_batch, N_nodes, 33), device=device) all_probs = torch.zeros((N_batch, N_nodes, 33), device=device, dtype=torch.float32) h_S = torch.zeros_like(h_V, device=device) S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device) h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))] # constant = torch.tensor(omit_AAs_np, device=device) # constant_bias = torch.tensor(bias_AAs_np, device=device) #chain_mask_combined = chain_mask*chain_M_pos omit_AA_mask_flag = omit_AA_mask != None h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) h_EXV_encoder_fw = mask_fw * h_EXV_encoder for t_ in range(N_nodes): t = decoding_order[:,t_] #[B] chain_mask_gathered = torch.gather(chain_mask, 1, t[:,None]) #[B] mask_gathered = torch.gather(mask, 1, t[:,None]) #[B] # bias_by_res_gathered = torch.gather(bias_by_res, 1, t[:,None,None].repeat(1,1,21))[:,0,:] #[B, 21] if (mask_gathered==0).all(): #for padded or missing regions only S_t = torch.gather(S_true, 1, t[:,None]) else: # Hidden layers E_idx_t = torch.gather(E_idx, 1, t[:,None,None].repeat(1,1,E_idx.shape[-1])) h_E_t = torch.gather(h_E, 1, t[:,None,None,None].repeat(1,1,h_E.shape[-2], h_E.shape[-1])) h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, t[:,None,None,None].repeat(1,1,h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1])) mask_t = torch.gather(mask, 1, t[:,None]) for l, layer in enumerate(self.decoder_layers): # Updated relational features for future states h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) h_V_t = torch.gather(h_V_stack[l], 1, t[:,None,None].repeat(1,1,h_V_stack[l].shape[-1])) h_ESV_t = torch.gather(mask_bw, 1, t[:,None,None,None].repeat(1,1,mask_bw.shape[-2], mask_bw.shape[-1])) * h_ESV_decoder_t + h_EXV_encoder_t h_V_stack[l+1].scatter_(1, t[:,None,None].repeat(1,1,h_V.shape[-1]), layer(h_V_t, h_ESV_t, mask_V=mask_t)) # Sampling step h_V_t = torch.gather(h_V_stack[-1], 1, t[:,None,None].repeat(1,1,h_V_stack[-1].shape[-1]))[:,0] logits = self.W_out(h_V_t) / temperature # probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1) probs = F.softmax(logits, dim=-1) if pssm_bias_flag: pssm_coef_gathered = torch.gather(pssm_coef, 1, t[:,None])[:,0] pssm_bias_gathered = torch.gather(pssm_bias, 1, t[:,None,None].repeat(1,1,pssm_bias.shape[-1]))[:,0] probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered if pssm_log_odds_flag: pssm_log_odds_mask_gathered = torch.gather(pssm_log_odds_mask, 1, t[:,None, None].repeat(1,1,pssm_log_odds_mask.shape[-1]))[:,0] #[B, 21] probs_masked = probs*pssm_log_odds_mask_gathered probs_masked += probs * 0.001 probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] if omit_AA_mask_flag: omit_AA_mask_gathered = torch.gather(omit_AA_mask, 1, t[:,None, None].repeat(1,1,omit_AA_mask.shape[-1]))[:,0] #[B, 21] probs_masked = probs*(1.0-omit_AA_mask_gathered) probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] # S_t = torch.multinomial(probs, 1) S_t = probs.argmax(dim=-1, keepdim=True) all_probs.scatter_(1, t[:,None,None].repeat(1,1,33), (chain_mask_gathered[:,:,None,]*probs[:,None,:]).float()) S_true_gathered = torch.gather(S_true, 1, t[:,None]) S_t = (S_t*chain_mask_gathered+S_true_gathered*(1.0-chain_mask_gathered)).long() temp1 = self.W_s(S_t) h_S.scatter_(1, t[:,None,None].repeat(1,1,temp1.shape[-1]), temp1) S.scatter_(1, t[:,None], S_t) output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order} return output_dict def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None): device = X.device # Prepare node and edge embeddings E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for layer in self.encoder_layers: h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) # Decoder uses masked self-attention chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] new_decoding_order = [] for t_dec in list(decoding_order[0,].cpu().data.numpy()): if t_dec not in list(itertools.chain(*new_decoding_order)): list_a = [item for item in tied_pos if t_dec in item] if list_a: new_decoding_order.append(list_a[0]) else: new_decoding_order.append([t_dec]) decoding_order = torch.tensor(list(itertools.chain(*new_decoding_order)), device=device)[None,].repeat(X.shape[0],1) mask_size = E_idx.shape[1] permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) N_batch, N_nodes = X.size(0), X.size(1) log_probs = torch.zeros((N_batch, N_nodes, 21), device=device) all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32) h_S = torch.zeros_like(h_V, device=device) S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device) h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))] constant = torch.tensor(omit_AAs_np, device=device) constant_bias = torch.tensor(bias_AAs_np, device=device) omit_AA_mask_flag = omit_AA_mask != None h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) h_EXV_encoder_fw = mask_fw * h_EXV_encoder for t_list in new_decoding_order: logits = 0.0 logit_list = [] done_flag = False for t in t_list: if (chain_mask[:,t]==0).all(): S_t = S_true[:,t] for t in t_list: h_S[:,t,:] = self.W_s(S_t) S[:,t] = S_t done_flag = True break else: E_idx_t = E_idx[:,t:t+1,:] h_E_t = h_E[:,t:t+1,:,:] h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) h_EXV_encoder_t = h_EXV_encoder_fw[:,t:t+1,:,:] mask_t = mask[:,t:t+1] for l, layer in enumerate(self.decoder_layers): h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) h_V_t = h_V_stack[l][:,t:t+1,:] h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_EXV_encoder_t h_V_stack[l+1][:,t,:] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1) h_V_t = h_V_stack[-1][:,t,:] logit_list.append((self.W_out(h_V_t) / temperature)/len(t_list)) logits += tied_beta[t]*(self.W_out(h_V_t) / temperature)/len(t_list) if done_flag: pass else: bias_by_res_gathered = bias_by_res[:,t,:] #[B, 21] probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1) if pssm_bias_flag: pssm_coef_gathered = pssm_coef[:,t] pssm_bias_gathered = pssm_bias[:,t] probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered if pssm_log_odds_flag: pssm_log_odds_mask_gathered = pssm_log_odds_mask[:,t] probs_masked = probs*pssm_log_odds_mask_gathered probs_masked += probs * 0.001 probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] if omit_AA_mask_flag: omit_AA_mask_gathered = omit_AA_mask[:,t] probs_masked = probs*(1.0-omit_AA_mask_gathered) probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] S_t_repeat = torch.multinomial(probs, 1).squeeze(-1) for t in t_list: h_S[:,t,:] = self.W_s(S_t_repeat) S[:,t] = S_t_repeat all_probs[:,t,:] = probs.float() output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order} return output_dict def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False): """ Graph-conditioned sequence model """ device=X.device # Prepare node and edge embeddings E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) h_V_enc = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for layer in self.encoder_layers: h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend) # Concatenate sequence embeddings for autoregressive decoder h_S = self.W_s(S) h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) # Build encoder embeddings h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx) chain_M = chain_M*mask #update chain_M to include missing regions chain_M_np = chain_M.cpu().numpy() idx_to_loop = np.argwhere(chain_M_np[0,:]==1)[:,0] log_conditional_probs = torch.zeros([X.shape[0], chain_M.shape[1], 21], device=device).float() for idx in idx_to_loop: h_V = torch.clone(h_V_enc) order_mask = torch.zeros(chain_M.shape[1], device=device).float() if backbone_only: order_mask = torch.ones(chain_M.shape[1], device=device).float() order_mask[idx] = 0. else: order_mask = torch.zeros(chain_M.shape[1], device=device).float() order_mask[idx] = 1. decoding_order = torch.argsort((order_mask[None,]+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] mask_size = E_idx.shape[1] permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) h_EXV_encoder_fw = mask_fw * h_EXV_encoder for layer in self.decoder_layers: # Masked positions attend to encoder information, unmasked see. h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw h_V = layer(h_V, h_ESV, mask) logits = self.W_out(h_V) log_probs = F.log_softmax(logits, dim=-1) log_conditional_probs[:,idx,:] = log_probs[:,idx,:] return log_conditional_probs def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all): """ Graph-conditioned sequence model """ device=X.device # Prepare node and edge embeddings E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for layer in self.encoder_layers: h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) # Build encoder embeddings h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx) h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) order_mask_backward = torch.zeros([X.shape[0], X.shape[1], X.shape[1]], device=device) mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) h_EXV_encoder_fw = mask_fw * h_EXV_encoder for layer in self.decoder_layers: h_V = layer(h_V, h_EXV_encoder_fw, mask) logits = self.W_out(h_V) log_probs = F.log_softmax(logits, dim=-1) return log_probs