Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool | |
| from gcn_lib.sparse.torch_vertex import GENConv | |
| from gcn_lib.sparse.torch_nn import norm_layer, MLP, MM_AtomEncoder | |
| from model.model_encoder import AtomEncoder, BondEncoder | |
| import logging | |
| class DeeperGCN(torch.nn.Module): | |
| def __init__(self, args, is_prot=False, saliency=False): | |
| super(DeeperGCN, self).__init__() | |
| # Set PM configuration | |
| if is_prot: | |
| self.num_layers = args.num_layers_prot | |
| mlp_layers = args.mlp_layers_prot | |
| hidden_channels = args.hidden_channels_prot | |
| self.msg_norm = args.msg_norm_prot | |
| learn_msg_scale = args.learn_msg_scale_prot | |
| self.conv_encode_edge = args.conv_encode_edge_prot | |
| # Set LM configuration | |
| else: | |
| self.num_layers = args.num_layers | |
| mlp_layers = args.mlp_layers | |
| hidden_channels = args.hidden_channels | |
| self.msg_norm = args.msg_norm | |
| learn_msg_scale = args.learn_msg_scale | |
| self.conv_encode_edge = args.conv_encode_edge | |
| # Set overall model configuration | |
| self.dropout = args.dropout | |
| self.block = args.block | |
| self.add_virtual_node = args.add_virtual_node | |
| self.training = True | |
| self.args = args | |
| num_classes = args.nclasses | |
| conv = args.conv | |
| aggr = args.gcn_aggr | |
| t = args.t | |
| self.learn_t = args.learn_t | |
| p = args.p | |
| self.learn_p = args.learn_p | |
| norm = args.norm | |
| graph_pooling = args.graph_pooling | |
| # Print model parameters | |
| print( | |
| "The number of layers {}".format(self.num_layers), | |
| "Aggr aggregation method {}".format(aggr), | |
| "block: {}".format(self.block), | |
| ) | |
| if self.block == "res+": | |
| print("LN/BN->ReLU->GraphConv->Res") | |
| elif self.block == "res": | |
| print("GraphConv->LN/BN->ReLU->Res") | |
| elif self.block == "dense": | |
| raise NotImplementedError("To be implemented") | |
| elif self.block == "plain": | |
| print("GraphConv->LN/BN->ReLU") | |
| else: | |
| raise Exception("Unknown block Type") | |
| self.gcns = torch.nn.ModuleList() | |
| self.norms = torch.nn.ModuleList() | |
| if self.add_virtual_node: | |
| self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels) | |
| torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) | |
| self.mlp_virtualnode_list = torch.nn.ModuleList() | |
| for layer in range(self.num_layers - 1): | |
| self.mlp_virtualnode_list.append(MLP([hidden_channels] * 3, norm=norm)) | |
| # Set GCN layer configuration | |
| for layer in range(self.num_layers): | |
| if conv == "gen": | |
| gcn = GENConv( | |
| hidden_channels, | |
| hidden_channels, | |
| args, | |
| aggr=aggr, | |
| t=t, | |
| learn_t=self.learn_t, | |
| p=p, | |
| learn_p=self.learn_p, | |
| msg_norm=self.msg_norm, | |
| learn_msg_scale=learn_msg_scale, | |
| encode_edge=self.conv_encode_edge, | |
| bond_encoder=True, | |
| norm=norm, | |
| mlp_layers=mlp_layers, | |
| ) | |
| else: | |
| raise Exception("Unknown Conv Type") | |
| self.gcns.append(gcn) | |
| self.norms.append(norm_layer(norm, hidden_channels)) | |
| # Set embbeding layers | |
| self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) | |
| if saliency: | |
| self.atom_encoder = MM_AtomEncoder(emb_dim=hidden_channels) | |
| else: | |
| self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) | |
| if not self.conv_encode_edge: | |
| self.bond_encoder = BondEncoder(emb_dim=hidden_channels) | |
| # Set type of pooling | |
| if graph_pooling == "sum": | |
| self.pool = global_add_pool | |
| elif graph_pooling == "mean": | |
| self.pool = global_mean_pool | |
| elif graph_pooling == "max": | |
| self.pool = global_max_pool | |
| else: | |
| raise Exception("Unknown Pool Type") | |
| # Set classification layer | |
| self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_classes) | |
| def forward(self, input_batch, dropout=True, embeddings=False): | |
| x = input_batch.x | |
| edge_index = input_batch.edge_index | |
| edge_attr = input_batch.edge_attr | |
| batch = input_batch.batch | |
| h = self.atom_encoder(x) | |
| if self.add_virtual_node: | |
| virtualnode_embedding = self.virtualnode_embedding( | |
| torch.zeros(batch[-1].item() + 1) | |
| .to(edge_index.dtype) | |
| .to(edge_index.device) | |
| ) | |
| h = h + virtualnode_embedding[batch] | |
| if self.conv_encode_edge: | |
| edge_emb = edge_attr | |
| else: | |
| edge_emb = self.bond_encoder(edge_attr) | |
| if self.block == "res+": | |
| h = self.gcns[0](h, edge_index, edge_emb) | |
| for layer in range(1, self.num_layers): | |
| h1 = self.norms[layer - 1](h) | |
| h2 = F.relu(h1) | |
| if dropout: | |
| h2 = F.dropout(h2, p=self.dropout, training=self.training) | |
| if self.add_virtual_node: | |
| virtualnode_embedding_temp = ( | |
| global_add_pool(h2, batch) + virtualnode_embedding | |
| ) | |
| if dropout: | |
| virtualnode_embedding = F.dropout( | |
| self.mlp_virtualnode_list[layer - 1]( | |
| virtualnode_embedding_temp | |
| ), | |
| self.dropout, | |
| training=self.training, | |
| ) | |
| h2 = h2 + virtualnode_embedding[batch] | |
| h = self.gcns[layer](h2, edge_index, edge_emb) + h | |
| h = self.norms[self.num_layers - 1](h) | |
| if dropout: | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| elif self.block == "res": | |
| h = F.relu(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| for layer in range(1, self.num_layers): | |
| h1 = self.gcns[layer](h, edge_index, edge_emb) | |
| h2 = self.norms[layer](h1) | |
| h = F.relu(h2) + h | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| elif self.block == "dense": | |
| raise NotImplementedError("To be implemented") | |
| elif self.block == "plain": | |
| h = F.relu(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| for layer in range(1, self.num_layers): | |
| h1 = self.gcns[layer](h, edge_index, edge_emb) | |
| h2 = self.norms[layer](h1) | |
| if layer != (self.num_layers - 1): | |
| h = F.relu(h2) | |
| else: | |
| h = h2 | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| else: | |
| raise Exception("Unknown block Type") | |
| h_graph = self.pool(h, batch) | |
| if self.args.use_prot or embeddings: | |
| return h_graph | |
| else: | |
| return self.graph_pred_linear(h_graph) | |
| def print_params(self, epoch=None, final=False): | |
| if self.learn_t: | |
| ts = [] | |
| for gcn in self.gcns: | |
| ts.append(gcn.t.item()) | |
| if final: | |
| print("Final t {}".format(ts)) | |
| else: | |
| logging.info("Epoch {}, t {}".format(epoch, ts)) | |
| if self.learn_p: | |
| ps = [] | |
| for gcn in self.gcns: | |
| ps.append(gcn.p.item()) | |
| if final: | |
| print("Final p {}".format(ps)) | |
| else: | |
| logging.info("Epoch {}, p {}".format(epoch, ps)) | |
| if self.msg_norm: | |
| ss = [] | |
| for gcn in self.gcns: | |
| ss.append(gcn.msg_norm.msg_scale.item()) | |
| if final: | |
| print("Final s {}".format(ss)) | |
| else: | |
| logging.info("Epoch {}, s {}".format(epoch, ss)) | |