Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from UltraFlow import layers, losses | |
| class IGN_basic(nn.Module): | |
| def __init__(self,config): | |
| super(IGN_basic, self).__init__() | |
| self.config = config | |
| self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
| self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
| self.graph_conv = layers.ModifiedAttentiveFPGNNV2(config.model.lig_node_dim, config.model.lig_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) | |
| if config.model.jk == 'concat': | |
| self.noncov_graph = layers.DTIConvGraph3Layer_IGN_basic(config.model.hidden_dim * config.model.num_layers + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| else: | |
| self.noncov_graph = layers.DTIConvGraph3Layer_IGN_basic(config.model.hidden_dim + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
| self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| else: | |
| self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
| self.softmax = nn.Softmax(dim=1) | |
| if self.pretrain_use_assay_description: | |
| print(f'use assay descrption type: {config.data.assay_des_type}') | |
| if self.pretrain_assay_mlp_share: | |
| self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| else: | |
| self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| def forward(self, batch): | |
| bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch | |
| node_feats_lig = self.graph_conv(bg_lig) | |
| node_feats_prot = self.graph_conv(bg_prot) | |
| bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
| bond_feats_inter = self.noncov_graph(bg_inter) | |
| graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
| if self.pretrain_use_assay_description: | |
| if self.pretrain_assay_mlp_share: | |
| ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
| affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
| else: | |
| regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
| affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
| ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
| else: | |
| affinity_pred = self.FC(graph_embedding) | |
| ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
| return affinity_pred, graph_embedding, ranking_assay_embedding | |
| def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
| inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
| lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
| lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
| inter_start = lig_start + prot_start | |
| for i in range(lig_num.shape[0]): | |
| inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
| inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
| return inter_feature | |
| class IGN(nn.Module): | |
| def __init__(self,config): | |
| super(IGN, self).__init__() | |
| self.config = config | |
| self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
| self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
| self.ligand_conv = layers.ModifiedAttentiveFPGNNV2(config.model.lig_node_dim, config.model.lig_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) | |
| self.protein_conv = layers.ModifiedAttentiveFPGNNV2(config.model.pro_node_dim, config.model.pro_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) | |
| if config.model.jk == 'concat': | |
| self.noncov_graph = layers.DTIConvGraph3Layer(config.model.hidden_dim * (config.model.num_layers + config.model.num_layers) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| else: | |
| self.noncov_graph = layers.DTIConvGraph3Layer(config.model.hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
| self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| else: | |
| self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
| self.softmax = nn.Softmax(dim=1) | |
| if self.pretrain_use_assay_description: | |
| print(f'use assay descrption type: {config.data.assay_des_type}') | |
| if self.pretrain_assay_mlp_share: | |
| self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| else: | |
| self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| def forward(self, batch): | |
| bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch | |
| node_feats_lig = self.ligand_conv(bg_lig) | |
| node_feats_prot = self.protein_conv(bg_prot) | |
| bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
| bond_feats_inter = self.noncov_graph(bg_inter) | |
| graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
| if self.pretrain_use_assay_description: | |
| if self.pretrain_assay_mlp_share: | |
| ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
| affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
| else: | |
| regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
| affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
| ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
| else: | |
| affinity_pred = self.FC(graph_embedding) | |
| ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
| return affinity_pred, graph_embedding, ranking_assay_embedding | |
| def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
| inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
| lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
| lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
| inter_start = lig_start + prot_start | |
| for i in range(lig_num.shape[0]): | |
| inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
| inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
| return inter_feature | |
| class GNNs(nn.Module): | |
| def __init__(self, nLigNode, nLigEdge, nLayer, nHid, JK, GNN): | |
| super(GNNs, self).__init__() | |
| if GNN == 'GCN': | |
| self.Encoder = layers.GCN(nLigNode, hidden_feats=[nHid] * nLayer) | |
| elif GNN == 'GAT': | |
| self.Encoder = layers.GAT(nLigNode, hidden_feats=[nHid] * nLayer) | |
| elif GNN == 'GIN': | |
| self.Encoder = layers.GIN(nLigNode, nHid, nLayer, num_mlp_layers=2, dropout=0.1, learn_eps=False, | |
| neighbor_pooling_type='sum', JK=JK) | |
| elif GNN == 'EGNN': | |
| self.Encoder = layers.EGNN(nLigNode, nLigEdge, nHid, nLayer, dropout=0.1, JK=JK) | |
| elif GNN == 'AttentiveFP': | |
| self.Encoder = layers.ModifiedAttentiveFPGNNV2(nLigNode, nLigEdge, nLayer, nHid, 0.1, JK) | |
| def forward(self, Graph, Perturb=None): | |
| Node_Rep = self.Encoder(Graph, Perturb) | |
| return Node_Rep | |
| class Affinity_GNNs(nn.Module): | |
| def __init__(self, config): | |
| super(Affinity_GNNs, self).__init__() | |
| lig_node_dim = config.model.lig_node_dim | |
| lig_edge_dim = config.model.lig_edge_dim | |
| pro_node_dim = config.model.pro_node_dim | |
| pro_edge_dim = config.model.pro_edge_dim | |
| layer_num = config.model.num_layers | |
| hidden_dim = config.model.hidden_dim | |
| jk = config.model.jk | |
| GNN = config.model.GNN_type | |
| self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
| self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
| self.lig_encoder = GNNs(lig_node_dim, lig_edge_dim, layer_num, hidden_dim, jk, GNN) | |
| self.pro_encoder = GNNs(pro_node_dim, pro_edge_dim, layer_num, hidden_dim, jk, GNN) | |
| if config.model.jk == 'concat': | |
| self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * (layer_num + layer_num) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| else: | |
| self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
| if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
| self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| else: | |
| self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| self.softmax = nn.Softmax(dim=1) | |
| if self.pretrain_use_assay_description: | |
| print(f'use assay descrption type: {config.data.assay_des_type}') | |
| if self.pretrain_assay_mlp_share: | |
| self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| else: | |
| self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| def forward(self, batch): | |
| bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch | |
| node_feats_lig = self.lig_encoder(bg_lig) | |
| node_feats_prot = self.pro_encoder(bg_prot) | |
| bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
| bond_feats_inter = self.noncov_graph(bg_inter) | |
| graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
| if self.pretrain_use_assay_description: | |
| if self.pretrain_assay_mlp_share: | |
| ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
| affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
| else: | |
| regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
| affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
| ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
| else: | |
| affinity_pred = self.FC(graph_embedding) | |
| ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
| return affinity_pred, graph_embedding, ranking_assay_embedding | |
| def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
| inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
| lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
| lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
| inter_start = lig_start + prot_start | |
| for i in range(lig_num.shape[0]): | |
| inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
| inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
| return inter_feature | |
| class affinity_head(nn.Module): | |
| def __init__(self, config): | |
| super(affinity_head, self).__init__() | |
| self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
| self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
| if self.pretrain_use_assay_description: | |
| print(f'use assay descrption type: {config.data.assay_des_type}') | |
| if self.pretrain_assay_mlp_share: | |
| self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| else: | |
| self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
| self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| else: | |
| self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| def forward(self, graph_embedding, ass_des): | |
| if self.pretrain_use_assay_description: | |
| if self.pretrain_assay_mlp_share: | |
| ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
| affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
| else: | |
| regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
| affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
| ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
| else: | |
| affinity_pred = self.FC(graph_embedding) | |
| ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
| return affinity_pred | |
| class ASRP_head(nn.Module): | |
| def __init__(self, config): | |
| super(ASRP_head, self).__init__() | |
| self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
| self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
| self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
| if self.pretrain_use_assay_description: | |
| print(f'use assay descrption type: {config.data.assay_des_type}') | |
| if self.pretrain_assay_mlp_share: | |
| self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| else: | |
| self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
| config.model.dropout, config.model.inter_out_dim * 2) | |
| if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
| self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| else: | |
| self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
| self.regression_loss_fn = nn.MSELoss(reduce=False) | |
| self.ranking_loss_fn = losses.pairwise_BCE_loss(config) | |
| self.pairwise_two_tower_regression_loss = config.train.pairwise_two_tower_regression_loss | |
| if self.pairwise_two_tower_regression_loss: | |
| print('use two tower regression loss') | |
| def forward(self, bg_inter, bond_feats_inter, ass_des, labels, select_flag): | |
| graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
| if self.pretrain_use_assay_description: | |
| if self.pretrain_assay_mlp_share: | |
| ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
| affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
| else: | |
| regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
| affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
| ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
| else: | |
| affinity_pred = self.FC(graph_embedding) | |
| ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
| y_pred_num = len(affinity_pred) | |
| assert y_pred_num % 2 == 0 | |
| if self.pairwise_two_tower_regression_loss: | |
| regression_loss = self.regression_loss_fn(affinity_pred, labels) # | |
| labels_select = labels[select_flag] | |
| affinity_pred_select = affinity_pred[select_flag] | |
| regression_loss_select = regression_loss[select_flag].sum() | |
| else: | |
| regression_loss = self.regression_loss_fn(affinity_pred[:y_pred_num // 2], labels[:y_pred_num // 2]) # | |
| labels_select = labels[:y_pred_num // 2][select_flag[:y_pred_num // 2]] | |
| affinity_pred_select = affinity_pred[:y_pred_num // 2][select_flag[:y_pred_num // 2]] | |
| regression_loss_select = regression_loss[select_flag[:y_pred_num // 2]].sum() | |
| ranking_loss, relation, relation_pred = self.ranking_loss_fn(graph_embedding, labels, ranking_assay_embedding) # | |
| ranking_loss_select = ranking_loss[select_flag[:y_pred_num // 2]].sum() | |
| relation_select = relation[select_flag[:y_pred_num // 2]] | |
| relation_pred_selcet = relation_pred[select_flag[:y_pred_num // 2]] | |
| return regression_loss_select, ranking_loss_select,\ | |
| labels_select, affinity_pred_select,\ | |
| relation_select, relation_pred_selcet | |
| def forward_pointwise(self, bg_inter, bond_feats_inter, ass_des, labels, select_flag): | |
| graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
| affinity_pred = self.FC(graph_embedding) | |
| regression_loss = self.regression_loss_fn(affinity_pred, labels) # | |
| regression_loss_select = regression_loss[select_flag].sum() | |
| labels_select = labels[select_flag] | |
| affinity_pred_select = affinity_pred[select_flag] | |
| return regression_loss_select, labels_select, affinity_pred_select | |
| def evaluate_mtl(self, bg_inter, bond_feats_inter, ass_des, labels): | |
| graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
| if self.pretrain_use_assay_description: | |
| if self.pretrain_assay_mlp_share: | |
| ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
| affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
| else: | |
| regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
| affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
| ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
| else: | |
| affinity_pred = self.FC(graph_embedding) | |
| ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
| n = graph_embedding.shape[0] | |
| pair_a_index, pair_b_index = [], [] | |
| for i in range(n): | |
| pair_a_index.extend([i] * (n - 1)) | |
| pair_b_index.extend([j for j in range(n) if i != j]) | |
| pair_index = pair_a_index + pair_b_index | |
| _, relation, relation_pred = self.ranking_fn(graph_embedding[pair_index], labels[pair_index], ranking_assay_embedding[pair_index]) | |
| return affinity_pred, relation, relation_pred | |
| class Affinity_GNNs_MTL(nn.Module): | |
| def __init__(self, config): | |
| super(Affinity_GNNs_MTL, self).__init__() | |
| lig_node_dim = config.model.lig_node_dim | |
| lig_edge_dim = config.model.lig_edge_dim | |
| pro_node_dim = config.model.pro_node_dim | |
| pro_edge_dim = config.model.pro_edge_dim | |
| layer_num = config.model.num_layers | |
| hidden_dim = config.model.hidden_dim | |
| jk = config.model.jk | |
| GNN = config.model.GNN_type | |
| self.multi_task = config.train.multi_task | |
| self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
| self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
| self.lig_encoder = GNNs(lig_node_dim, lig_edge_dim, layer_num, hidden_dim, jk, GNN) | |
| self.pro_encoder = GNNs(pro_node_dim, pro_edge_dim, layer_num, hidden_dim, jk, GNN) | |
| if config.model.jk == 'concat': | |
| self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * (layer_num + layer_num) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| else: | |
| self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
| self.softmax = nn.Softmax(dim=1) | |
| if self.multi_task == 'IC50KdKi': | |
| self.IC50_ASRP_head = ASRP_head(config) | |
| self.Kd_ASRP_head = ASRP_head(config) | |
| self.Ki_ASRP_head = ASRP_head(config) | |
| elif self.multi_task == 'IC50K': | |
| self.IC50_ASRP_head = ASRP_head(config) | |
| self.K_ASRP_head = ASRP_head(config) | |
| self.config = config | |
| def forward(self, batch, ASRP=True, Perturb=None, Perturb_v=None): | |
| if self.multi_task == 'IC50KdKi': | |
| bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, Kd_f, Ki_f = batch | |
| lig_node_feats_init = bg_lig.ndata['h'] | |
| pro_node_feats_init = bg_prot.ndata['h'] | |
| if Perturb is not None and Perturb_v == 'v_intra': | |
| node_feats_lig = self.lig_encoder(bg_lig, Perturb_v[:bg_lig.number_of_nodes()]) | |
| node_feats_prot = self.pro_encoder(bg_prot, Perturb_v[bg_lig.number_of_nodes():]) | |
| else: | |
| node_feats_lig = self.lig_encoder(bg_lig) | |
| node_feats_prot = self.pro_encoder(bg_prot) | |
| if self.config.train.encoder_ablation == 'interact': | |
| return node_feats_lig, node_feats_prot | |
| elif self.config.train.encoder_ablation == 'ligand': | |
| node_feats_lig = node_feats_lig.zero_() | |
| node_feats_lig[:,:self.config.model.lig_node_dim] = lig_node_feats_init | |
| elif self.config.train.encoder_ablation == 'protein': | |
| node_feats_prot = node_feats_prot.zero_() | |
| node_feats_prot[:,:self.config.model.pro_node_dim] = pro_node_feats_init | |
| bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
| if Perturb is not None and Perturb_v == 'v_inter': | |
| bg_inter.ndata['h'] = bg_inter.ndata['h'] + Perturb | |
| bond_feats_inter = self.noncov_graph(bg_inter) | |
| if ASRP: | |
| return self.multi_head_pred(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f) | |
| else: | |
| return self.multi_head_pointwise(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f) | |
| elif self.multi_task == 'IC50K': | |
| bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, K_f = batch | |
| lig_node_feats_init = bg_lig.ndata['h'] | |
| pro_node_feats_init = bg_prot.ndata['h'] | |
| if Perturb is not None and Perturb_v == 'v_intra': | |
| node_feats_lig = self.lig_encoder(bg_lig, Perturb_v[:bg_lig.number_of_nodes()]) | |
| node_feats_prot = self.pro_encoder(bg_prot, Perturb_v[bg_lig.number_of_nodes():]) | |
| else: | |
| node_feats_lig = self.lig_encoder(bg_lig) | |
| node_feats_prot = self.pro_encoder(bg_prot) | |
| if self.config.train.encoder_ablation == 'interact': | |
| return node_feats_lig, node_feats_prot | |
| elif self.config.train.encoder_ablation == 'ligand': | |
| node_feats_lig = node_feats_lig.zero_() | |
| node_feats_lig[:,:self.config.model.lig_node_dim] = lig_node_feats_init | |
| elif self.config.train.encoder_ablation == 'protein': | |
| node_feats_prot = node_feats_prot.zero_() | |
| node_feats_prot[:,:self.config.model.pro_node_dim] = pro_node_feats_init | |
| bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
| if Perturb is not None and Perturb_v == 'v_inter': | |
| bg_inter.ndata['h'] = bg_inter.ndata['h'] + Perturb | |
| bond_feats_inter = self.noncov_graph(bg_inter) | |
| if ASRP: | |
| return self.multi_head_pred_v2(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f) | |
| else: | |
| return self.multi_head_pointwise_v2(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f) | |
| def multi_head_pointwise(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): | |
| regression_loss_IC50, affinity_IC50, affinity_pred_IC50 = \ | |
| self.IC50_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
| regression_loss_Kd, affinity_Kd, affinity_pred_Kd = \ | |
| self.Kd_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, Kd_f) | |
| regression_loss_Ki, affinity_Ki, affinity_pred_Ki = \ | |
| self.Ki_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, Ki_f) | |
| return (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki),\ | |
| (affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ | |
| (affinity_IC50, affinity_Kd, affinity_Ki) | |
| def multi_head_pointwise_v2(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f): | |
| regression_loss_IC50, affinity_IC50, affinity_pred_IC50 = \ | |
| self.IC50_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
| regression_loss_K, affinity_K, affinity_pred_K = \ | |
| self.K_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, K_f) | |
| return (regression_loss_IC50, regression_loss_K),\ | |
| (affinity_pred_IC50, affinity_pred_K), \ | |
| (affinity_IC50, affinity_K) | |
| def multi_head_pred(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): | |
| regression_loss_IC50, ranking_loss_IC50, \ | |
| affinity_IC50, affinity_pred_IC50, \ | |
| relation_IC50, relation_pred_IC50 = self.IC50_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
| regression_loss_Kd, ranking_loss_Kd, \ | |
| affinity_Kd, affinity_pred_Kd, \ | |
| relation_Kd, relation_pred_Kd = self.Kd_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, Kd_f) | |
| regression_loss_Ki, ranking_loss_Ki, \ | |
| affinity_Ki, affinity_pred_Ki, \ | |
| relation_Ki, relation_pred_Ki = self.Ki_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, Ki_f) | |
| return (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki),\ | |
| (ranking_loss_IC50, ranking_loss_Kd, ranking_loss_Ki), \ | |
| (affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ | |
| (relation_pred_IC50, relation_pred_Kd, relation_pred_Ki), \ | |
| (affinity_IC50, affinity_Kd, affinity_Ki), \ | |
| (relation_IC50, relation_Kd, relation_Kd) | |
| def multi_head_pred_v2(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f): | |
| regression_loss_IC50, ranking_loss_IC50, \ | |
| affinity_IC50, affinity_pred_IC50, \ | |
| relation_IC50, relation_pred_IC50 = self.IC50_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
| regression_loss_K, ranking_loss_K, \ | |
| affinity_K, affinity_pred_K, \ | |
| relation_K, relation_pred_K = self.K_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, K_f) | |
| return (regression_loss_IC50, regression_loss_K),\ | |
| (ranking_loss_IC50, ranking_loss_K), \ | |
| (affinity_pred_IC50, affinity_pred_K), \ | |
| (relation_pred_IC50, relation_pred_K), \ | |
| (affinity_IC50, affinity_K), \ | |
| (relation_IC50, relation_K) | |
| def multi_head_evaluate(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): | |
| if sum(IC50_f): | |
| assert sum(Kd_f) == 0 and sum(Ki_f) == 0 | |
| return self.IC50_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) | |
| elif sum(Kd_f): | |
| assert sum(IC50_f) == 0 and sum(Ki_f) == 0 | |
| return self.Kd_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) | |
| elif sum(Ki_f): | |
| assert sum(IC50_f) == 0 and sum(Kd_f) == 0 | |
| return self.Kd_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) | |
| def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
| inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
| lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
| lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
| inter_start = lig_start + prot_start | |
| for i in range(lig_num.shape[0]): | |
| inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
| inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
| return inter_feature | |
| class interact_ablation(nn.Module): | |
| def __init__(self, config): | |
| super(interact_ablation, self).__init__() | |
| self.IC50_ASRP_head = interact_ablation_head(config) | |
| self.K_ASRP_head = interact_ablation_head(config) | |
| self.config = config | |
| def forward(self, graph_embedding, labels, IC50_f, K_f): | |
| regression_loss_IC50, \ | |
| affinity_IC50, affinity_pred_IC50,= self.IC50_ASRP_head(graph_embedding, labels, IC50_f) | |
| regression_loss_K, \ | |
| affinity_K, affinity_pred_K = self.K_ASRP_head(graph_embedding, labels, K_f) | |
| return (regression_loss_IC50, regression_loss_K),\ | |
| (affinity_pred_IC50, affinity_pred_K), \ | |
| (affinity_IC50, affinity_K), \ | |
| class interact_ablation_head(nn.Module): | |
| def __init__(self, config): | |
| super(interact_ablation_head, self).__init__() | |
| self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, | |
| config.model.out_dim) | |
| self.regression_loss_fn = nn.MSELoss(reduce=False) | |
| def forward(self, graph_embedding, labels, select_flag): | |
| affinity_pred = self.FC(graph_embedding) | |
| regression_loss = self.regression_loss_fn(affinity_pred, labels) # | |
| regression_loss_select = regression_loss[select_flag].sum() | |
| labels_select = labels[select_flag] | |
| affinity_pred_select = affinity_pred[select_flag] | |
| return regression_loss_select, labels_select, affinity_pred_select |