Spaces:
Running
on
Zero
Running
on
Zero
| import torch.nn as nn | |
| from src.modules.pifold_module import * | |
| from torch_scatter import scatter_softmax, scatter_log_softmax | |
| def positional_encoding(x): | |
| batch_size, seq_len, hidden_size = x.size() | |
| pos = torch.arange(0, seq_len).float().unsqueeze(1).repeat(1, hidden_size // 2) | |
| div = torch.exp(torch.arange(0, hidden_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_size)) | |
| sin = torch.sin(pos * div) | |
| cos = torch.cos(pos * div) | |
| pos_encoding = torch.cat([sin, cos], dim=-1).unsqueeze(0).repeat(batch_size, 1, 1) | |
| return pos_encoding | |
| class MSAAttention(nn.Module): | |
| def __init__(self, hidden_dim) -> None: | |
| super().__init__() | |
| self.MSA_Q = nn.Linear(hidden_dim, hidden_dim) | |
| self.MSA_K = nn.Linear(hidden_dim, hidden_dim) | |
| self.MSA_V = nn.Linear(hidden_dim, hidden_dim) | |
| def forward(self, inputs_embeds): | |
| pos_enc = positional_encoding(inputs_embeds) | |
| inputs_embeds = inputs_embeds + pos_enc | |
| query = self.MSA_Q(inputs_embeds) # shape: [batch, N, 128] | |
| key = self.MSA_K(inputs_embeds) # shape: [batch, N, 128] | |
| value = self.MSA_V(inputs_embeds) # shape: [batch, N, 128] | |
| attn_scores = torch.bmm(query, key.transpose(1, 2)) | |
| attn_weights = nn.functional.softmax(attn_scores, dim=2) | |
| attn_output = torch.bmm(attn_weights, value) | |
| return attn_output | |
| class GNNTuning_Model(nn.Module): | |
| def __init__(self, args, num_encoder_layers, hidden_dim, input_design_dim, input_esm_dim, input_struct_dim=3072, input_esmif_dim=512, dropout=0.1): | |
| super(GNNTuning_Model, self).__init__() | |
| self.args = args | |
| encoder_layers = [] | |
| for i in range(num_encoder_layers): | |
| encoder_layers.append( | |
| GeneralGNN(hidden_dim, hidden_dim*2, dropout=dropout, node_net = "AttMLP", edge_net = "EdgeMLP", node_context = 1, edge_context = 0), | |
| ) | |
| self.encoder_layers = nn.Sequential(*encoder_layers) | |
| from transformers import AutoTokenizer | |
| from transformers.models.esm.modeling_esm import EsmModel, EsmEmbeddings | |
| from transformers.models.esm.configuration_esm import EsmConfig | |
| self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/") | |
| config = EsmConfig(attention_probs_dropout_prob=0, | |
| hidden_size=hidden_dim, | |
| intermediate_size=1280, | |
| mask_token_id=32, | |
| num_attention_heads=12, | |
| num_hidden_layers=3, | |
| pad_token_id=1, | |
| position_embedding_type="rotary", | |
| token_dropout=False, | |
| vocab_size=33 | |
| ) | |
| self.DesignEmbed = EsmEmbeddings(config) | |
| self.ESMEmbed = EsmEmbeddings(config) | |
| self.EdgeEmbed = nn.Sequential(nn.Linear(416+16+16, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,hidden_dim)) | |
| self.DesignConf = nn.Sequential(nn.Linear(1, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 128), | |
| nn.ReLU(), | |
| nn.Linear(128,1), | |
| nn.Sigmoid()) | |
| self.ESMConf = nn.Sequential(nn.Linear(1, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 128), | |
| nn.ReLU(), | |
| nn.Linear(128,1)) | |
| self.DesignProj = nn.Sequential(nn.Linear(input_design_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,hidden_dim)) | |
| self.ESMProj = nn.Sequential(nn.Linear(input_esm_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,hidden_dim)) | |
| self.StructProj = nn.Sequential(nn.Linear(input_struct_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,hidden_dim)) | |
| self.ESMIFProj = nn.Sequential(nn.Linear(input_esmif_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,hidden_dim)) | |
| self.ReadOut = nn.Linear(hidden_dim,33) | |
| # self.TimeEmbed = nn.Embedding(20, hidden_dim) | |
| # self.ProbEmbed = nn.Sequential(nn.Linear(33, 512), | |
| # nn.ReLU(), | |
| # nn.Linear(512, hidden_dim), | |
| # nn.ReLU(), | |
| # nn.Linear(hidden_dim,hidden_dim)) | |
| self.MLP1 = nn.Sequential(nn.Linear(1, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,1), | |
| nn.Sigmoid()) | |
| self.MLP2 = nn.Sequential(nn.Linear(1, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim,1), | |
| nn.Sigmoid()) | |
| # def embed_gnn(self, pretrain_gnn, mask_select_id, mask_select_feat): | |
| # gnn_embed = self.DesignEmbed(mask_select_id(pretrain_gnn['pred_ids'])).squeeze() | |
| # gnn_conf = self.DesignConf(mask_select_id(pretrain_gnn['confs'])) | |
| # gnn_proj = self.DesignProj(mask_select_feat(pretrain_gnn['embeds'])) | |
| # if self.args.use_confembed: | |
| # return gnn_embed*F.sigmoid(gnn_conf) + gnn_proj | |
| # else: | |
| # return gnn_embed + gnn_proj | |
| # def embed_esm(self, pretrain_esm, mask_select_id, mask_select_feat): | |
| # esm_embed = self.ESMEmbed(mask_select_id(pretrain_esm['pred_ids'])).squeeze() | |
| # esm_conf = self.ESMConf(mask_select_id(pretrain_esm['confs'])) | |
| # esm_proj = self.ESMProj(mask_select_feat(pretrain_esm['embeds'])) | |
| # if self.args.use_confembed: | |
| # return esm_embed*F.sigmoid(esm_conf) + esm_proj | |
| # else: | |
| # return esm_embed + esm_proj | |
| # def embed_struct(self, pretrain_struct, mask_select_feat): | |
| # struct_proj = self.StructProj(mask_select_feat(pretrain_struct['embeds'])) | |
| # return struct_proj | |
| # def embed_esmif(self, pretrain_esmif, mask_select_feat): | |
| # struct_proj = self.ESMIFProj(mask_select_feat(pretrain_esmif['embeds'])) | |
| # return struct_proj | |
| def fuse(self, mask_select_feat, mask_select_id, gnn_embed=None, esm_embed=None, gearnet_embed=None, esmif_embed=None, gnn_pred_id=None, esm_pred_id=None, confidence=None, confidence_esm=None): | |
| gnn, esm, gearnet, esmif, conf = 0, 0, 0, 0, 1.0 | |
| if gnn_embed is not None: | |
| gnn = self.DesignProj(mask_select_feat(gnn_embed)) | |
| gnn += self.DesignEmbed(mask_select_id(gnn_pred_id)).squeeze() | |
| if esm_embed is not None: | |
| esm = self.ESMProj(mask_select_feat(esm_embed)) | |
| esm += self.ESMEmbed(mask_select_id(esm_pred_id)).squeeze() | |
| if gearnet_embed is not None: | |
| gearnet = self.StructProj(mask_select_feat(gearnet_embed)) | |
| if esmif_embed is not None: | |
| esmif = self.ESMIFProj(mask_select_feat(esmif_embed)) | |
| if conf is not None: | |
| conf = self.DesignConf(mask_select_id(confidence)) | |
| esm_conf = self.ESMConf(mask_select_id(confidence_esm)) | |
| return (gnn*conf+esm*esm_conf+gearnet+esmif) | |
| def forward(self, batch): | |
| pretrain_design, h_E_raw, E_idx, mask_attend, batch_id = batch['pretrain_design'], batch['h_E'], batch['E_idx'], batch['attention_mask'], batch['batch_id'] | |
| if self.args.use_LM: | |
| pretrain_esm_msa = batch['pretrain_esm_msa'] | |
| if self.args.use_gearnet: | |
| pretrain_struct = batch['pretrain_struct'] | |
| if self.args.use_esmif: | |
| pretrain_esmif = batch['pretrain_esmif'] | |
| mask_select_id = lambda x: torch.masked_select(x, mask_attend.bool()).reshape(-1,1) | |
| mask_select_feat = lambda x: torch.masked_select(x, mask_attend.bool().unsqueeze(-1)).reshape(-1,x.shape[-1]) | |
| inputs_embeds = 0 | |
| for i in range(self.args.msa_n): | |
| gnn_embed = pretrain_design['embeds'] | |
| esm_embed = pretrain_esm_msa['embeds'][i] if self.args.use_LM else None | |
| gearnet_embed = pretrain_struct['embeds'][i] if self.args.use_gearnet else None | |
| esmif_embed = pretrain_esmif['embeds'] if self.args.use_esmif else None | |
| confidence = pretrain_design['confs'] | |
| confidence_esm = pretrain_esm_msa['confs'][i] | |
| inputs_embeds += self.fuse(mask_select_feat, mask_select_id, gnn_embed, esm_embed, gearnet_embed, esmif_embed, pretrain_design['pred_ids'], pretrain_esm_msa['pred_ids'][i], confidence, confidence_esm) | |
| h_V = inputs_embeds | |
| h_E = self.EdgeEmbed(h_E_raw) | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, batch_id) | |
| logits = self.ReadOut(h_V) | |
| # confidence update | |
| old_confs = mask_select_id(pretrain_design['confs']) | |
| confs = torch.softmax(logits, dim=-1).max(dim=-1)[0][:,None] | |
| h_V = h_V*self.MLP1(confs-old_confs) + inputs_embeds*self.MLP2(old_confs-confs) | |
| logits = self.ReadOut(h_V) | |
| B, N = pretrain_design['confs'].shape | |
| vocab_size = logits.shape[-1] | |
| new_logits = torch.zeros(B,N,vocab_size, device=logits.device).reshape(B*N, vocab_size) | |
| new_logits = new_logits.masked_scatter_(mask_attend.bool().view(-1,1), logits) | |
| new_logits = new_logits.reshape(B,N,vocab_size) | |
| log_probs = torch.log_softmax(new_logits, dim=-1) | |
| device = logits.device | |
| seqs, confs, embeds, probs2 = self.to_matrix(h_V, logits, batch_id) | |
| ret = {"pred_ids":seqs['input_ids'].to(device), | |
| "confs":confs, | |
| "embeds":embeds, | |
| "probs":probs2, | |
| "attention_mask":seqs['attention_mask'].to(device), | |
| "h_E":h_E_raw, | |
| "E_idx":E_idx, | |
| "batch_id":batch_id, | |
| "log_probs":log_probs} | |
| return ret | |
| def to_matrix(self, h_V, logits, batch_id): | |
| probs = F.softmax(logits, dim=-1) | |
| conf, pred_id = probs.max(dim=-1) | |
| maxL = 0 | |
| for b in batch_id.unique(): | |
| mask = batch_id==b | |
| L = mask.sum() | |
| if L>maxL: | |
| maxL=L | |
| confs = [] | |
| seqs = [] | |
| embeds = [] | |
| probs2 = [] | |
| for b in batch_id.unique(): | |
| mask = batch_id==b | |
| # elements = [alphabet[int(id)] for id in pred_id[mask]] | |
| elements = self.tokenizer.decode(pred_id[mask]).split(" ") | |
| seqs.append(elements) | |
| confs.append(conf[mask]) | |
| embeds.append(h_V[mask]) | |
| probs2.append(probs[mask]) | |
| seqs = self.tokenizer(["".join(one) for one in seqs], padding=True, truncation=True, return_tensors='pt', add_special_tokens=False) | |
| confs = torch.stack([F.pad(one, (0, maxL-len(one))) for one in confs]) | |
| embeds = torch.stack([F.pad(one, (0,0, 0, maxL-len(one))) for one in embeds]) | |
| probs2 = torch.stack([F.pad(one, (0,0, 0, maxL-len(one)), value=1/33) for one in probs2]) | |
| return seqs, confs, embeds, probs2 | |