import torch import torch.nn as nn from .Tuning import GNNTuning_Model class MemoTuning(nn.Module): def __init__(self, args, tunning_layers_n, tunning_layers_dim, input_design_dim, input_esm_dim, tunning_dropout, tokenizer, fix_memory=False): super().__init__() self.args = args self.tunning_layers_dim = tunning_layers_dim self.GNNTuning = GNNTuning_Model(args, num_encoder_layers=tunning_layers_n, hidden_dim=tunning_layers_dim, input_design_dim=input_design_dim, input_esm_dim=input_esm_dim, dropout = tunning_dropout) self.tokenizer = tokenizer self.memory = {} def save_param_memory(self, path): torch.save({"params":self.state_dict(),"memory": self.memory}, path) def load_param_memory(self, path): data = torch.load(path) self.load_state_dict(data['params']) self.memory = data['memory'] def get_seqs(self, pred_ids_raw, attention_mask): query_seqs = [] for pred_ids, mask in zip(pred_ids_raw, attention_mask): seq = self.tokenizer.decode(pred_ids[mask], clean_up_tokenization_spaces=False) seq = "".join(seq.split(" ")) query_seqs.append(seq) return query_seqs def initoutput(self, pretrain_design, B, max_L, device): # initialize output self.out_pred_ids = torch.zeros_like(pretrain_design['pred_ids']) self.out_confs = torch.zeros_like(pretrain_design['confs']) self.out_embeds = torch.zeros(B, max_L, self.tunning_layers_dim, device = device) self.out_attention_mask = torch.zeros_like(pretrain_design['attention_mask']) self.out_probs = torch.zeros_like(pretrain_design['probs']) self.out_log_probs = torch.zeros_like(pretrain_design['probs']) self.titles = [None for i in range(B)] def retrivel(self, keys, num_nodes,device, use_memory): unseen = [] for idx in range(len(keys)): key = keys[idx] if (key in self.memory) and use_memory: self.out_pred_ids[idx, :num_nodes[idx]] = self.memory[key]['pred_ids'].to(device) self.out_confs[idx, :num_nodes[idx]] = self.memory[key]['confs'].to(device) self.out_embeds[idx, :num_nodes[idx]] = self.memory[key]['embeds'].to(device) self.out_attention_mask[idx, :num_nodes[idx]] = self.memory[key]['attention_mask'].to(device) self.out_probs[idx, :num_nodes[idx]] = self.memory[key]['probs'].to(device) self.out_log_probs[idx, :num_nodes[idx]] = self.memory[key]['log_probs'].to(device) self.titles[idx] = key else: unseen.append(idx) return unseen def rebatch(self,unseen, batch_id_raw, E_idx_raw, h_E_raw, shift, num_nodes, pretrain_design, pretrain_esm_msa, pretrain_struct, pretrain_esmif, device): unseen_design_pred_ids = [] unseen_design_confs = [] unseen_design_embeds = [] unseen_design_attention_mask = [] unseen_esm_pred_ids = [] unseen_esm_confs = [] unseen_esm_embeds = [] unseen_esm_attention_mask = [] unseen_struct_embeds = [] unseen_esmif_embeds = [] h_E = [] E_idx = [] batch_id = [] new_shift = 0 for bid, i in enumerate(unseen): edge_mask = batch_id_raw[E_idx_raw[0]] == i h_E.append(h_E_raw[edge_mask]) E_idx.append(E_idx_raw[:,edge_mask]-shift[i]+new_shift) batch_id.append(torch.ones(num_nodes[i], device=device).long()*bid) new_shift += num_nodes[i] unseen_design_pred_ids.append(pretrain_design['pred_ids'][i]) unseen_design_confs.append(pretrain_design['confs'][i]) unseen_design_embeds.append(pretrain_design['embeds'][i]) unseen_design_attention_mask.append(pretrain_design['attention_mask'][i]) if self.args.use_LM: unseen_esm_pred_ids.append(pretrain_esm_msa['pred_ids'][:,i]) unseen_esm_confs.append(pretrain_esm_msa['confs'][:,i]) unseen_esm_embeds.append(pretrain_esm_msa['embeds'][:,i]) unseen_esm_attention_mask.append(pretrain_esm_msa['attention_mask'][:,i]) if self.args.use_gearnet: unseen_struct_embeds.append(pretrain_struct['embeds'][:,i]) if self.args.use_esmif: unseen_esmif_embeds.append(pretrain_esmif['embeds'][i]) unseen_design_pred_ids = torch.stack(unseen_design_pred_ids) unseen_design_confs = torch.stack(unseen_design_confs) unseen_design_embeds = torch.stack(unseen_design_embeds) unseen_design_attention_mask = torch.stack(unseen_design_attention_mask) if self.args.use_LM: unseen_esm_pred_ids = torch.stack(unseen_esm_pred_ids, dim=1) unseen_esm_confs = torch.stack(unseen_esm_confs, dim=1) unseen_esm_embeds = torch.stack(unseen_esm_embeds, dim=1) unseen_esm_attention_mask = torch.stack(unseen_esm_attention_mask, dim=1) if self.args.use_gearnet: unseen_struct_embeds = torch.stack(unseen_struct_embeds, dim=1) if self.args.use_esmif: unseen_esmif_embeds = torch.stack(unseen_esmif_embeds, dim=0) unseen_batch = {"pretrain_design": {"pred_ids": unseen_design_pred_ids, "confs":unseen_design_confs, "embeds": unseen_design_embeds, "attention_mask":unseen_design_attention_mask}, "h_E": torch.cat(h_E), "E_idx": torch.cat(E_idx, dim=1), "batch_id": torch.cat(batch_id), "attention_mask":unseen_design_attention_mask } if self.args.use_LM: unseen_batch["pretrain_esm_msa"]={"pred_ids": unseen_esm_pred_ids, "confs":unseen_esm_confs, "embeds": unseen_esm_embeds, "attention_mask":unseen_esm_attention_mask} if self.args.use_gearnet: unseen_batch["pretrain_struct"] = { "embeds":unseen_struct_embeds} if self.args.use_esmif: unseen_batch["pretrain_esmif"] = {"embeds":unseen_esmif_embeds} return unseen_batch def save2memory(self,keys,unseen,num_nodes, unseen_results): # save to memory for i in range(len(unseen)): key = keys[unseen[i]] num = num_nodes[unseen[i]] self.memory[key] = {"pred_ids":unseen_results['pred_ids'][i][:num].detach().to('cpu'), "confs":unseen_results['confs'][i][:num].detach().to('cpu'), "embeds":unseen_results['embeds'][i][:num].detach().to('cpu'), "probs":unseen_results['probs'][i][:num].detach().to('cpu'), "log_probs":unseen_results['log_probs'][i][:num].detach().to('cpu'), "attention_mask":unseen_results['attention_mask'][i][:num].detach().to('cpu')} def update(self, unseen, num_nodes, unseen_results, keys): # update for i in range(len(unseen)): num = num_nodes[unseen[i]] self.out_pred_ids[unseen[i], :num] = unseen_results['pred_ids'][i][:num] self.out_confs[unseen[i], :num] = unseen_results['confs'][i][:num] self.out_embeds[unseen[i], :num] = unseen_results['embeds'][i][:num] self.out_probs[unseen[i], :num] = unseen_results['probs'][i][:num] self.out_log_probs[unseen[i], :num] = unseen_results['log_probs'][i][:num] self.titles[unseen[i]] = keys[unseen[i]] def forward(self, batch, use_memory=False): self.use_memory = use_memory pretrain_design, h_E_raw, E_idx_raw, mask_attend, batch_id_raw = batch['pretrain_design'] ,batch['h_E'], batch['E_idx'], batch['attention_mask'], batch['batch_id'] device = h_E_raw.device pretrain_esm_msa = None if self.args.use_LM: pretrain_esm_msa = batch['pretrain_esm_msa'] pretrain_struct = None if self.args.use_gearnet: pretrain_struct = batch['pretrain_struct'] pretrain_esmif = None if self.args.use_esmif: pretrain_esmif = batch['esm_feat'] num_nodes = batch['attention_mask'].sum(dim=-1) shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(num_nodes, dim=0)]).long() B, max_L = num_nodes.shape[0], num_nodes.max() self.initoutput(pretrain_design, B, max_L, device) # keys = list(zip(design_seqs, *lm_seqs)) keys = batch['title'] unseen = self.retrivel(keys, num_nodes,device, use_memory) if len(unseen)>0: unseen_batch = self.rebatch(unseen, batch_id_raw, E_idx_raw, h_E_raw, shift, num_nodes, pretrain_design, pretrain_esm_msa, pretrain_struct, pretrain_esmif, device) unseen_results = self.GNNTuning(unseen_batch) self.save2memory(keys,unseen,num_nodes, unseen_results) self.update(unseen, num_nodes, unseen_results, keys) return {'title':self.titles,'pred_ids':self.out_pred_ids, 'confs':self.out_confs, 'embeds':self.out_embeds, 'probs':self.out_probs, "log_probs":self.out_log_probs, 'attention_mask':pretrain_design['attention_mask']}