File size: 4,250 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import subprocess
import os
from joblib import Parallel, delayed, cpu_count
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
from .PretrainESMIF_model import PretrainESMIF_Model
from torch_scatter import scatter_sum

class MemoESMIF(nn.Module):
    def __init__(self):
        super().__init__()
        self.PretrainESMIF = PretrainESMIF_Model()
        self.memory = {}
        # self.fix_memory = False
    
    # def save_memory(self, path):
    #     params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
    #     torch.save({"params":params,"memory": self.memory}, path)

    # def load_memory(self, path):
    #     data = torch.load(path)
    #     self.load_state_dict(data['params'], strict=False)
    #     self.memory = data['memory']
        
    def initoutput(self, B, maxL, device):
        self.out_embeds = torch.zeros(B, maxL, 512, dtype=torch.float, device=device)
        self.titles = [None for i in range(B)]
        
    def retrivel(self, titles, num_nodes, device, use_memory):
        # retrieval
        unseen = []
        for idx in range(len(titles)):
            name = titles[idx]
            if (name in self.memory) and use_memory:
                memo_embeds = self.memory[name]['embeds'].to(device)
                self.out_embeds[idx, :num_nodes[idx]] = memo_embeds
                self.titles[idx] = name
            else:
                unseen.append(idx)
        return unseen
    
    def rebatch(self, unseen, batch):
        unseen_position = []
        for i in unseen:
            mask = batch['batch_id']==i
            unseen_position.append(batch['position'][mask][:,:3,:])
        return {"position":unseen_position}
    
    def save2memory(self, unseen,outputs, titles, num_nodes):
         # save to memory
        for i in range(len(unseen)):
            name = titles[unseen[i]]
            self.titles[unseen[i]] = name
            num = num_nodes[unseen[i]]
            self.memory[name] = {"embeds":outputs['feat'][i,:num].detach().to('cpu')}
    
    def update(self, unseen, num_nodes, outputs):
        # update
        for idx in range(len(unseen)):
            num = num_nodes[unseen[idx]]
            self.out_embeds[unseen[idx], :num_nodes[unseen[idx]]] = outputs['feat'][idx, :num]
    
    @torch.no_grad()
    def forward(self, batch, use_memory=False):
        # debatch
        # clean_seqs = self.clean_input(batch)
        device = batch['position'].device
        num_nodes = scatter_sum(torch.ones_like(batch['batch_id']), batch['batch_id'], dim=0)
        B, maxL = num_nodes.shape[0], num_nodes.max()
        self.initoutput(B, maxL, device)
        unseen = self.retrivel(batch['title'], num_nodes, device, use_memory)
        
        
        if len(unseen)>0:
            # batch forward
            new_batch = self.rebatch(unseen, batch)
            outputs = self.PretrainESMIF(new_batch['position'])
            
            self.save2memory(unseen,outputs, batch['title'], num_nodes)
            self.update(unseen, num_nodes, outputs)
        
        return {'title':self.titles, 'embeds':self.out_embeds}



if __name__ == '__main__': 
    
    # work_space = '/gaozhangyang/experiments/PiFoldV2/data/mmseq_workspace2'
    # target_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPQTKTYFPHFDLSHGSAQVKGHG", "MVHLTPEEKSAVTALWGKVNVDEVGVEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKV",
    #  "MVLSPADKTNVKAAWGKVGAGGAEALERMFLSFPQKTYYTYFPHFDLSHGSAQVKGHG"]

    # query_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKFPHFDLSHGSAQV", "MVHLTPEEKSAVTALWGKVNVDEVGGGRLLVVYPWTQRFFESFGDLSTPDAV",]
    
    # results = search_seqs(query_seqs, target_seqs, work_space)
    # print(results)
        

    import biotite.sequence as seq
    import biotite.sequence.align as align

    # Create example query and target protein sequences
    query_seq1 = seq.ProteinSequence("MSKXXKAFLNKXXL")
    target_seq1 = seq.ProteinSequence("MSKVKAALNKVLL")
    target_seq2 = seq.ProteinSequence("MSKVKKALNKVLL")
    target_seq3 = seq.ProteinSequence("MSTVAAALKMLLL")

    results = search_seqs_biotite(["MSKXXKAFLNKXXL"], ["MSKVKAALNKVLL", "MSKVKKALNKVLL", "MSTVAAALKMLLL"])

    # Print the alignments
    print("Query alignments:")