Spaces:
Running
on
Zero
Running
on
Zero
| import subprocess | |
| import os | |
| from joblib import Parallel, delayed, cpu_count | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from .PretrainESMIF_model import PretrainESMIF_Model | |
| from torch_scatter import scatter_sum | |
| class MemoESMIF(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.PretrainESMIF = PretrainESMIF_Model() | |
| self.memory = {} | |
| # self.fix_memory = False | |
| # def save_memory(self, path): | |
| # params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key} | |
| # torch.save({"params":params,"memory": self.memory}, path) | |
| # def load_memory(self, path): | |
| # data = torch.load(path) | |
| # self.load_state_dict(data['params'], strict=False) | |
| # self.memory = data['memory'] | |
| def initoutput(self, B, maxL, device): | |
| self.out_embeds = torch.zeros(B, maxL, 512, dtype=torch.float, device=device) | |
| self.titles = [None for i in range(B)] | |
| def retrivel(self, titles, num_nodes, device, use_memory): | |
| # retrieval | |
| unseen = [] | |
| for idx in range(len(titles)): | |
| name = titles[idx] | |
| if (name in self.memory) and use_memory: | |
| memo_embeds = self.memory[name]['embeds'].to(device) | |
| self.out_embeds[idx, :num_nodes[idx]] = memo_embeds | |
| self.titles[idx] = name | |
| else: | |
| unseen.append(idx) | |
| return unseen | |
| def rebatch(self, unseen, batch): | |
| unseen_position = [] | |
| for i in unseen: | |
| mask = batch['batch_id']==i | |
| unseen_position.append(batch['position'][mask][:,:3,:]) | |
| return {"position":unseen_position} | |
| def save2memory(self, unseen,outputs, titles, num_nodes): | |
| # save to memory | |
| for i in range(len(unseen)): | |
| name = titles[unseen[i]] | |
| self.titles[unseen[i]] = name | |
| num = num_nodes[unseen[i]] | |
| self.memory[name] = {"embeds":outputs['feat'][i,:num].detach().to('cpu')} | |
| def update(self, unseen, num_nodes, outputs): | |
| # update | |
| for idx in range(len(unseen)): | |
| num = num_nodes[unseen[idx]] | |
| self.out_embeds[unseen[idx], :num_nodes[unseen[idx]]] = outputs['feat'][idx, :num] | |
| def forward(self, batch, use_memory=False): | |
| # debatch | |
| # clean_seqs = self.clean_input(batch) | |
| device = batch['position'].device | |
| num_nodes = scatter_sum(torch.ones_like(batch['batch_id']), batch['batch_id'], dim=0) | |
| B, maxL = num_nodes.shape[0], num_nodes.max() | |
| self.initoutput(B, maxL, device) | |
| unseen = self.retrivel(batch['title'], num_nodes, device, use_memory) | |
| if len(unseen)>0: | |
| # batch forward | |
| new_batch = self.rebatch(unseen, batch) | |
| outputs = self.PretrainESMIF(new_batch['position']) | |
| self.save2memory(unseen,outputs, batch['title'], num_nodes) | |
| self.update(unseen, num_nodes, outputs) | |
| return {'title':self.titles, 'embeds':self.out_embeds} | |
| if __name__ == '__main__': | |
| # work_space = '/gaozhangyang/experiments/PiFoldV2/data/mmseq_workspace2' | |
| # target_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPQTKTYFPHFDLSHGSAQVKGHG", "MVHLTPEEKSAVTALWGKVNVDEVGVEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKV", | |
| # "MVLSPADKTNVKAAWGKVGAGGAEALERMFLSFPQKTYYTYFPHFDLSHGSAQVKGHG"] | |
| # query_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKFPHFDLSHGSAQV", "MVHLTPEEKSAVTALWGKVNVDEVGGGRLLVVYPWTQRFFESFGDLSTPDAV",] | |
| # results = search_seqs(query_seqs, target_seqs, work_space) | |
| # print(results) | |
| import biotite.sequence as seq | |
| import biotite.sequence.align as align | |
| # Create example query and target protein sequences | |
| query_seq1 = seq.ProteinSequence("MSKXXKAFLNKXXL") | |
| target_seq1 = seq.ProteinSequence("MSKVKAALNKVLL") | |
| target_seq2 = seq.ProteinSequence("MSKVKKALNKVLL") | |
| target_seq3 = seq.ProteinSequence("MSTVAAALKMLLL") | |
| results = search_seqs_biotite(["MSKXXKAFLNKXXL"], ["MSKVKAALNKVLL", "MSKVKKALNKVLL", "MSTVAAALKMLLL"]) | |
| # Print the alignments | |
| print("Query alignments:") | |