Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from .PretrainPiFold_model import PretrainPiFold_Model | |
| from torch_scatter import scatter_sum | |
| import torch.nn.functional as F | |
| class MemoPiFold_model(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.PretrainPiFold = PretrainPiFold_Model(args) | |
| self.memory = {} | |
| def save_memory(self, path): | |
| params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key} | |
| torch.save({"params":params,"memory": self.memory}, path) | |
| def load_memory(self, path): | |
| data = torch.load(path) | |
| self.load_state_dict(data['params'], strict=False) | |
| self.memory = data['memory'] | |
| def initoutput(self, B, max_L, nums, device): | |
| self.confs = torch.ones(B, max_L, device=device) | |
| self.embeds = torch.ones(B, max_L, 128, device=device) | |
| self.probs = torch.ones(B, max_L, 33, device=device) | |
| self.attention_mask = torch.ones_like(self.confs)==0 | |
| self.titles = [None for i in range(B)] | |
| for id, num in enumerate(nums): | |
| self.attention_mask[id, :num] = True | |
| self.edge_feats = [] | |
| def retrivel(self, batch, nums, batch_uid, device, use_memory): | |
| # retrieval | |
| unseen = [] | |
| for idx, name in enumerate(batch['title']): | |
| if (name in self.memory) and use_memory: | |
| try: | |
| self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device) | |
| except: | |
| self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device) | |
| self.embeds[batch_uid[idx],:nums[idx]] = self.memory[name]['embed'].to(device) | |
| self.probs[batch_uid[idx],:nums[idx]] = self.memory[name]['prob'].to(device) | |
| self.edge_feats.append((batch_uid[idx], self.memory[name]['h_E'].to(device))) | |
| self.titles[batch_uid[idx]] = name | |
| else: | |
| unseen.append(idx) | |
| return unseen | |
| def rebatch(self, unseen, batch_uid, batch_id, batch, shift, nums, device): | |
| h_V2, h_E2, E_idx2, batch_id2 = [], [], [], [] | |
| shift2 = [0] | |
| idx=0 | |
| for id in batch_uid: | |
| if id not in unseen: | |
| continue | |
| node_mask = batch_id == id | |
| edge_mask = batch_id[batch['E_idx'][0]] == id | |
| h_V2.append(batch['h_V'][node_mask]) | |
| h_E2.append(batch['h_E'][edge_mask]) | |
| new_E_idx = batch['E_idx'][:,edge_mask] | |
| new_E_idx = new_E_idx- shift[batch_id[new_E_idx[0]]]+shift2[-1] | |
| E_idx2.append(new_E_idx) | |
| new_batch_id = torch.ones(node_mask.sum().long(), device=device)*idx | |
| batch_id2.append(new_batch_id) | |
| shift2.append(shift2[-1]+nums[id]) | |
| idx+=1 | |
| h_V2 = torch.cat(h_V2) | |
| h_E2 = torch.cat(h_E2) | |
| E_idx2 = torch.cat(E_idx2, dim=-1) | |
| batch_id2 = torch.cat(batch_id2).long() | |
| return {"h_V":h_V2, 'h_E':h_E2, 'E_idx':E_idx2, 'batch_id':batch_id2} | |
| def update_save2memory(self, unseen, batch_id2, E_idx2, batch, pretrain_gnn, max_L): | |
| for id in batch_id2.unique(): | |
| node_mask = batch_id2 == id | |
| edge_mask = batch_id2[E_idx2[0]] == id | |
| title = batch['title'][unseen[int(id)]] | |
| conf = pretrain_gnn['confs'][id] | |
| conf = F.pad(conf, (0, max_L-len(conf))) | |
| embed = pretrain_gnn['embeds'][id] | |
| embed = F.pad(embed, (0,0,0,max_L-len(embed))) | |
| prob = pretrain_gnn['probs'][id] | |
| prob = F.pad(prob, (0,0,0,max_L-len(prob))) | |
| self.edge_feats.append((unseen[int(id)], pretrain_gnn['h_E'][edge_mask])) | |
| self.confs[unseen[int(id)]] = conf | |
| self.embeds[unseen[int(id)]] = embed | |
| self.probs[unseen[int(id)]] = prob | |
| self.titles[unseen[int(id)]] = title | |
| attn_mask = self.attention_mask[unseen[int(id)]] | |
| # save to memory | |
| self.memory[title] = {'conf': conf[attn_mask].detach().to('cpu'), | |
| 'embed': embed[attn_mask].detach().to('cpu'), | |
| 'prob': prob[attn_mask].detach().to('cpu'), | |
| 'h_E':pretrain_gnn['h_E'][edge_mask].detach().to('cpu')} | |
| def forward(self, batch, use_memory=False): | |
| batch_id = batch['batch_id'] | |
| batch_uid = batch_id.unique() | |
| device = batch_id.device | |
| nums = scatter_sum(torch.ones_like(batch_id), batch_id) | |
| shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(nums, dim=0)]).long() | |
| max_L, B = nums.max(), batch_uid.shape[0] | |
| self.initoutput(B, max_L, nums, device) | |
| unseen = self.retrivel(batch, nums, batch_uid, device, use_memory) | |
| # organize data | |
| if len(unseen)>0: | |
| # rebatch | |
| new_batch = self.rebatch(unseen, batch_uid, batch_id, batch, shift, nums, device) | |
| # forward pass | |
| pretrain_gnn = self.PretrainPiFold(new_batch) | |
| self.update_save2memory(unseen, pretrain_gnn['batch_id'], pretrain_gnn['E_idx'], batch, pretrain_gnn, max_L) | |
| self.edge_feats = sorted(self.edge_feats, key=lambda x: x[0]) | |
| self.edge_feats = torch.cat([one[1] for one in self.edge_feats]) | |
| pred_ids = self.probs.argmax(dim=-1)*self.attention_mask + (~self.attention_mask)*1 | |
| return {'title': self.titles, | |
| 'pred_ids': pred_ids, | |
| 'confs': self.confs, | |
| 'embeds': self.embeds, | |
| 'probs': self.probs, | |
| 'attention_mask': self.attention_mask, | |
| 'h_E':self.edge_feats, | |
| 'E_idx': batch['E_idx'], | |
| 'batch_id': batch['batch_id']} | |
| def _get_features(self, S, score, X, mask, chain_mask, chain_encoding): | |
| return self.PretrainPiFold._get_features(S, score, X, mask, chain_mask, chain_encoding) |