import time import torch import torch.nn as nn from .MemoryTuning import MemoTuning import copy from .MemoryESM import MemoESM from .MemoryPiFold import MemoPiFold_model from .MemoryESMIF import MemoESMIF import torch from torch_scatter import scatter_sum def beam_search(post, k): """Beam Search Decoder Parameters: post(Tensor) – the posterior of network. k(int) – beam size of decoder. Outputs: indices(Tensor) – a beam of index sequence. log_prob(Tensor) – a beam of log likelihood of sequence. Shape: post: (batch_size, seq_length, vocab_size). indices: (batch_size, beam_size, seq_length). log_prob: (batch_size, beam_size). Examples: >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1) >>> indices, log_prob = beam_search_decoder(post, 3) """ batch_size, seq_length, token_size = post.shape log_post = post.log() log_prob, indices = log_post[:, 0, :].topk(k, sorted=True) indices = indices.unsqueeze(-1) for i in range(1, seq_length): log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1) # [batch, k, 33] log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True) index = index%token_size indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1) return indices, log_prob class Design_Model(nn.Module): def __init__(self, args, temporature, msa_n, tunning_layers_n, tunning_layers_dim, input_design_dim, input_esm_dim, tunning_dropout, design_model, LM_model, ESMIF_model, param_path=None): super(Design_Model, self).__init__() self.args = args self.temporature = temporature self.msa_n = msa_n self.design_model = design_model self.LM_model = LM_model self.ESMIF_model = ESMIF_model # self.GNNTuning = GNNTuning_Model(num_encoder_layers=tunning_layers_n, hidden_dim=tunning_layers_dim, input_design_dim=input_design_dim, input_esm_dim=input_esm_dim, dropout = tunning_dropout) self.GNNTuning = MemoTuning(args, tunning_layers_n, tunning_layers_dim, input_design_dim, input_esm_dim, tunning_dropout, tokenizer=self.LM_model.tokenizer) # self.Predictor = nn.Linear(tunning_layers_dim, 21) self.conf_max = 0 self.patience = 0 self.best_params = None self.confidence = [] if param_path is not None: self.GNNTuning.load_state_dict(torch.load(param_path)) def get_MSA(self, pretrain_design): pretrain_gnn_msa = pretrain_design B, N = pretrain_gnn_msa['confs'].shape probs, pred_ids, confs, attention_mask, titles = [], [], [], [], [] for m in range(self.msa_n): titles.append(pretrain_design['title']) probs.append(torch.softmax(pretrain_gnn_msa['probs']/self.temporature, dim=-1)) # pred_ids.append( msa_pred_ids[:,m,:]) pred_ids.append(torch.multinomial(probs[-1].reshape(-1,33), 1).reshape(B,N)) confs.append(probs[-1].reshape(-1,33)[torch.arange(pred_ids[-1].reshape(-1).shape[0]).cuda(),pred_ids[-1].reshape(-1)].reshape(B,N)) attention_mask.append(pretrain_gnn_msa['attention_mask']) pretrain_esm_msa = {} pretrain_esm_msa['title'] = sum(titles,[]) pretrain_esm_msa['probs'] = torch.cat(probs, dim=0) pretrain_esm_msa['pred_ids'] = torch.cat(pred_ids, dim=0) pretrain_esm_msa['confs'] = torch.cat(confs, dim=0) pretrain_esm_msa['attention_mask'] = torch.cat(attention_mask, dim=0) return pretrain_esm_msa def forward(self, batch, design_memory=True, LM_memory=True, Struct_memory=True, Tuning_memory=True, ESMIF_memory=True): ''' MemoPiFold: batch_id,titile, E_idx, h_V, h_E MemoESM: pred_ids, attention_mask, confs Tunning: pretrain_design - pred_ids, confs, embeds pretrain_esm_msa - pred_ids, confs, embeds h_E, E_idx, batch_id ''' with torch.no_grad(): pretrain_design = self.design_model(batch, design_memory) if self.args.use_LM: # language model forward pretrain_msa = self.get_MSA(pretrain_design) pretrain_esm_msa = self.LM_model(pretrain_msa, LM_memory) B, N = pretrain_design['confs'].shape pretrain_esm_msa['embeds'] = pretrain_esm_msa['embeds'].reshape(self.msa_n, B, N, -1) pretrain_esm_msa['pred_ids'] = pretrain_esm_msa['pred_ids'].reshape(self.msa_n, B, N) pretrain_esm_msa['confs'] = pretrain_esm_msa['confs'].reshape(self.msa_n, B, N) pretrain_esm_msa['attention_mask'] = pretrain_esm_msa['attention_mask'].reshape(self.msa_n, B, N) if self.args.use_gearnet: # structure model forward pretrain_msa = self.get_MSA(pretrain_design) protein_seqs_msa = self.LM_model.tokenizer.decode(pretrain_msa['pred_ids'][pretrain_msa['attention_mask']], clean_up_tokenization_spaces=False).split(" ") protein_coords_msa = batch['position'][:,1,:].repeat((self.msa_n,1)) num_nodes = pretrain_msa['attention_mask'].sum(dim=1) msa_id = torch.arange(self.msa_n, device=num_nodes.device).repeat_interleave(pretrain_design['attention_mask'].shape[0]) pretrain_struct_msa = self.Struct_model(protein_seqs_msa, protein_coords_msa, num_nodes, msa_id, pretrain_msa['title'], Struct_memory) if self.args.use_esmif: # esmif model forward esm_feat = self.ESMIF_model(batch, ESMIF_memory) new_batch = {} new_batch['title'] = pretrain_design['title'] new_batch['pretrain_design'] = pretrain_design new_batch['h_E'] = batch['h_E'] new_batch['E_idx'] = batch['E_idx'] new_batch['batch_id'] = batch['batch_id'] new_batch['attention_mask'] = pretrain_design['attention_mask'] if self.args.use_LM: new_batch['pretrain_esm_msa'] = pretrain_esm_msa if self.args.use_gearnet: new_batch['pretrain_struct'] = pretrain_struct_msa if self.args.use_esmif: new_batch['esm_feat'] = esm_feat results = self.GNNTuning(new_batch, Tuning_memory) avg_confs = (results['attention_mask']*results['confs']).sum(dim=1)/results['attention_mask'].sum(dim=1) self.confidence.append(avg_confs) return results class KWDesign_model(nn.Module): def __init__(self, args): super(KWDesign_model, self).__init__() self.args = args input_design_dim, input_esm_dim = args.input_design_dim, args.input_esm_dim tunning_layers_dim = args.tunning_layers_dim self.memo_pifold = MemoPiFold_model(args) self.memo_esmif = MemoESMIF() # if args.load_memory: # memory = torch.load(args.memory_path) # self.memo_pifold = memory['memo_pifold'] # self.memo_esmif = memory['memo_esmif'] for i in range(1, self.args.recycle_n+1): if i==1: self.register_module(f"Design{i}", Design_Model(args, args.temporature, args.msa_n, args.tunning_layers_n, args.tunning_layers_dim, input_design_dim, input_esm_dim, args.tunning_dropout, self.memo_pifold, MemoESM(args), self.memo_esmif)) else: self.register_module(f"Design{i}", Design_Model(args, args.temporature, args.msa_n, args.tunning_layers_n, args.tunning_layers_dim, tunning_layers_dim, input_esm_dim, args.tunning_dropout, self.get_submodule(f"Design{i-1}"), MemoESM(args), self.memo_esmif)) def update(self, batch, node_nums, conf, results, log_probs_mat, threshold, current_batch_id): fix_mask = conf>threshold log_probs_mat[current_batch_id[fix_mask]] = results['log_probs'][fix_mask] current_batch_id = current_batch_id[conf<=threshold] batch_id_old = batch['batch_id'] batch_id_old2new = torch.zeros_like(batch_id_old)-1 batch_id_old2new[current_batch_id] = torch.arange(current_batch_id.shape[0], device=conf.device) node_mask = (batch_id_old.view(-1,1) == current_batch_id).any(dim=1) edge_mask = node_mask[batch['E_idx'][0]] shift_old = torch.cat([torch.zeros(1, device=node_nums.device),node_nums.cumsum(dim=0)]).long() shift_new = torch.cat([torch.zeros(1, device=node_nums.device),node_nums[current_batch_id].cumsum(dim=0)]).long() edge_batch_id = batch_id_old[batch['E_idx'][0]] E_idx = (batch['E_idx'] - shift_old[edge_batch_id] + shift_new[batch_id_old2new[edge_batch_id]])[:,edge_mask] new_batch = {"title": [batch['title'][int(idx)] for idx in current_batch_id], "h_V": batch['h_V'][node_mask], "h_E": batch['h_E'][edge_mask], "E_idx": E_idx, "batch_id": batch_id_old2new[batch_id_old[node_mask]], "alphabet": batch["alphabet"], "S": batch["S"], "position": batch["position"]} return new_batch, log_probs_mat, current_batch_id def forward(self, batch): mask_select_feat = lambda x, mask_attend: torch.masked_select(x, mask_attend.bool().unsqueeze(-1)).reshape(-1,x.shape[-1]) log_probs_list, confs_list = [], [] for i in range(1, self.args.recycle_n+1): module = self.get_submodule(f"Design{i}") if i< self.args.recycle_n: results = module(batch, Tuning_memory=True) else: results = module(batch, Tuning_memory=False) log_probs = mask_select_feat(results['log_probs'], results['attention_mask']) log_probs_list.append(log_probs) confs = mask_select_feat(results['confs'][:,:,None], results['attention_mask']) confs_list.append(confs) max_conf_idx = torch.cat(confs_list, dim=1).argmax(dim=1) log_probs_mat = torch.stack(log_probs_list) log_probs = log_probs_mat[max_conf_idx, torch.arange(max_conf_idx.shape[0], device=max_conf_idx.device)] outputs = {f"log_probs{i+1}": log_probs_list[i] for i in range(len(log_probs_list))} outputs["log_probs"]=log_probs return outputs