File size: 6,183 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
from .PretrainPiFold_model import PretrainPiFold_Model
from torch_scatter import scatter_sum
import torch.nn.functional as F

class MemoPiFold_model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.PretrainPiFold = PretrainPiFold_Model(args)
        self.memory = {}
    
    def save_memory(self, path):
        params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
        torch.save({"params":params,"memory": self.memory}, path)

    def load_memory(self, path):
        data = torch.load(path)
        self.load_state_dict(data['params'], strict=False)
        self.memory = data['memory']
        
    def initoutput(self, B, max_L, nums, device):
        self.confs = torch.ones(B, max_L, device=device)
        self.embeds = torch.ones(B, max_L, 128, device=device)
        self.probs = torch.ones(B, max_L, 33, device=device)
        self.attention_mask = torch.ones_like(self.confs)==0
        self.titles = [None for i in range(B)]
        for id, num in enumerate(nums):
            self.attention_mask[id, :num] = True
        self.edge_feats = []
        
    def retrivel(self, batch, nums, batch_uid, device, use_memory):
        # retrieval
        unseen = []
        
        for idx, name in enumerate(batch['title']):
            if (name in self.memory) and use_memory:
                try:
                    self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device)
                except:
                    self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device)
                self.embeds[batch_uid[idx],:nums[idx]] = self.memory[name]['embed'].to(device)
                self.probs[batch_uid[idx],:nums[idx]] = self.memory[name]['prob'].to(device)
                self.edge_feats.append((batch_uid[idx], self.memory[name]['h_E'].to(device)))
                self.titles[batch_uid[idx]] = name
            else:
                unseen.append(idx)
        return unseen

    def rebatch(self, unseen, batch_uid, batch_id, batch, shift, nums, device):
        h_V2, h_E2, E_idx2, batch_id2 = [], [], [], []
        shift2 = [0]
        idx=0
        for id in batch_uid:
            if id not in unseen:
                continue
            node_mask = batch_id == id
            edge_mask = batch_id[batch['E_idx'][0]] == id
            h_V2.append(batch['h_V'][node_mask])
            h_E2.append(batch['h_E'][edge_mask])
            new_E_idx = batch['E_idx'][:,edge_mask] 
            new_E_idx = new_E_idx- shift[batch_id[new_E_idx[0]]]+shift2[-1]
            E_idx2.append(new_E_idx)
            new_batch_id = torch.ones(node_mask.sum().long(), device=device)*idx
            batch_id2.append(new_batch_id)
            shift2.append(shift2[-1]+nums[id])
            idx+=1
        
        h_V2 = torch.cat(h_V2)
        h_E2 = torch.cat(h_E2)
        E_idx2 = torch.cat(E_idx2, dim=-1)
        batch_id2 = torch.cat(batch_id2).long()
        return {"h_V":h_V2, 'h_E':h_E2, 'E_idx':E_idx2, 'batch_id':batch_id2}
    
    def update_save2memory(self, unseen, batch_id2, E_idx2, batch, pretrain_gnn, max_L):
        for id in batch_id2.unique():
            node_mask = batch_id2 == id
            edge_mask = batch_id2[E_idx2[0]] == id
            title = batch['title'][unseen[int(id)]]
            conf = pretrain_gnn['confs'][id]
            conf = F.pad(conf, (0, max_L-len(conf)))
            embed = pretrain_gnn['embeds'][id]
            embed = F.pad(embed, (0,0,0,max_L-len(embed)))
            prob = pretrain_gnn['probs'][id]
            prob = F.pad(prob, (0,0,0,max_L-len(prob)))
            self.edge_feats.append((unseen[int(id)], pretrain_gnn['h_E'][edge_mask]))
            
            self.confs[unseen[int(id)]] = conf
            self.embeds[unseen[int(id)]] = embed
            self.probs[unseen[int(id)]] = prob
            self.titles[unseen[int(id)]] = title
            
            attn_mask = self.attention_mask[unseen[int(id)]]
            
            # save to memory
            self.memory[title] = {'conf': conf[attn_mask].detach().to('cpu'), 
                                'embed': embed[attn_mask].detach().to('cpu'), 
                                'prob': prob[attn_mask].detach().to('cpu'),
                                'h_E':pretrain_gnn['h_E'][edge_mask].detach().to('cpu')}
    
    @torch.no_grad()
    def forward(self, batch, use_memory=False):
        batch_id = batch['batch_id']
        batch_uid = batch_id.unique()
        device = batch_id.device
        
        nums = scatter_sum(torch.ones_like(batch_id), batch_id)
        shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(nums, dim=0)]).long()
        max_L, B = nums.max(), batch_uid.shape[0]
        
        self.initoutput(B, max_L, nums, device)
        unseen = self.retrivel(batch, nums, batch_uid, device, use_memory)
        
        # organize data
        if len(unseen)>0:
            # rebatch
            new_batch = self.rebatch(unseen, batch_uid, batch_id, batch, shift, nums, device)
            
            # forward pass 
            pretrain_gnn = self.PretrainPiFold(new_batch)
            
            self.update_save2memory(unseen, pretrain_gnn['batch_id'], pretrain_gnn['E_idx'], batch, pretrain_gnn, max_L)
            
        
        self.edge_feats = sorted(self.edge_feats, key=lambda x: x[0])
        self.edge_feats = torch.cat([one[1] for one in self.edge_feats])
        
        pred_ids = self.probs.argmax(dim=-1)*self.attention_mask + (~self.attention_mask)*1
        
        return {'title': self.titles,
                'pred_ids': pred_ids,
                'confs': self.confs, 
                'embeds': self.embeds, 
                'probs': self.probs, 
                'attention_mask': self.attention_mask,
                'h_E':self.edge_feats,
                'E_idx': batch['E_idx'],
                'batch_id': batch['batch_id']}
            
            
    
    def _get_features(self, S, score, X, mask, chain_mask, chain_encoding):
        return self.PretrainPiFold._get_features(S, score, X, mask, chain_mask, chain_encoding)