Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # from torch_scatter import scatter_sum, scatter_softmax, scatter_mean | |
| from torch_geometric.utils import scatter | |
| #################################### node modules ############################### | |
| class NeighborAttention(nn.Module): | |
| def __init__(self, num_hidden, num_in, num_heads=4, edge_drop=0.0, output_mlp=True): | |
| super(NeighborAttention, self).__init__() | |
| self.num_heads = num_heads | |
| self.num_hidden = num_hidden | |
| self.edge_drop = edge_drop | |
| self.output_mlp = output_mlp | |
| self.W_V = nn.Sequential(nn.Linear(num_in, num_hidden), | |
| nn.GELU(), | |
| nn.Linear(num_hidden, num_hidden), | |
| nn.GELU(), | |
| nn.Linear(num_hidden, num_hidden) | |
| ) | |
| self.Bias = nn.Sequential( | |
| nn.Linear(num_hidden*3, num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_heads) | |
| ) | |
| self.W_O = nn.Linear(num_hidden, num_hidden, bias=False) | |
| def forward(self, h_V, h_E, center_id, batch_id, dst_idx=None): | |
| N = h_V.shape[0] | |
| E = h_E.shape[0] | |
| n_heads = self.num_heads | |
| d = int(self.num_hidden / n_heads) | |
| w = self.Bias(torch.cat([h_V[center_id], h_E],dim=-1)).view(E, n_heads, 1) | |
| attend_logits = w/np.sqrt(d) | |
| V = self.W_V(h_E).view(-1, n_heads, d) | |
| attend = scatter.scatter_softmax(attend_logits, index=center_id, dim=0) | |
| h_V = scatter.scatter_sum(attend*V, center_id, dim=0).view([-1, self.num_hidden]) | |
| if self.output_mlp: | |
| h_V_update = self.W_O(h_V) | |
| else: | |
| h_V_update = h_V | |
| return h_V_update | |
| #################################### edge modules ############################### | |
| class EdgeMLP(nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): | |
| super(EdgeMLP, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.dropout = nn.Dropout(dropout) | |
| self.norm = nn.BatchNorm1d(num_hidden) | |
| self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W12 = nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W13 = nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.act = torch.nn.GELU() | |
| def forward(self, h_V, h_E, edge_idx, batch_id): | |
| src_idx = edge_idx[0] | |
| dst_idx = edge_idx[1] | |
| h_EV = torch.cat([h_V[src_idx], h_E, h_V[dst_idx]], dim=-1) | |
| h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV))))) | |
| h_E = self.norm(h_E + self.dropout(h_message)) | |
| return h_E | |
| #################################### context modules ############################### | |
| class Context(nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30, node_context = False, edge_context = False): | |
| super(Context, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.node_context = node_context | |
| self.edge_context = edge_context | |
| self.V_MLP = nn.Sequential( | |
| nn.Linear(num_hidden, num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| ) | |
| self.V_MLP_g = nn.Sequential( | |
| nn.Linear(num_hidden, num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.Sigmoid() | |
| ) | |
| self.E_MLP = nn.Sequential( | |
| nn.Linear(num_hidden, num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden) | |
| ) | |
| self.E_MLP_g = nn.Sequential( | |
| nn.Linear(num_hidden, num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden,num_hidden), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, h_V, h_E, edge_idx, batch_id): | |
| if self.node_context: | |
| c_V = scatter.scatter_mean(h_V, batch_id, dim=0) | |
| h_V = h_V * self.V_MLP_g(c_V[batch_id]) | |
| if self.edge_context: | |
| c_V = scatter.scatter_mean(h_V, batch_id, dim=0) | |
| h_E = h_E * self.E_MLP_g(c_V[batch_id[edge_idx[0]]]) | |
| return h_V, h_E | |
| class GeneralGNN(nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30, node_context = 0, edge_context = 0): | |
| super(GeneralGNN, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.dropout = nn.Dropout(dropout) | |
| self.norm = nn.ModuleList([nn.BatchNorm1d(num_hidden) for _ in range(3)]) | |
| self.attention = NeighborAttention(num_hidden, num_in, num_heads=4) | |
| self.edge_update = EdgeMLP(num_hidden, num_in, num_heads=4) | |
| self.context = Context(num_hidden, num_in, num_heads=4, node_context=node_context, edge_context=edge_context) | |
| self.dense = nn.Sequential( | |
| nn.Linear(num_hidden, num_hidden*4), | |
| nn.ReLU(), | |
| nn.Linear(num_hidden*4, num_hidden) | |
| ) | |
| self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W12 = nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W13 = nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.act = torch.nn.GELU() | |
| def forward(self, h_V, h_E, edge_idx, batch_id): | |
| src_idx = edge_idx[0] | |
| dst_idx = edge_idx[1] | |
| dh = self.attention(h_V, torch.cat([h_E, h_V[dst_idx]], dim=-1), src_idx, batch_id, dst_idx) | |
| h_V = self.norm[0](h_V + self.dropout(dh)) | |
| dh = self.dense(h_V) | |
| h_V = self.norm[1](h_V + self.dropout(dh)) | |
| h_E = self.edge_update( h_V, h_E, edge_idx, batch_id ) | |
| h_V, h_E = self.context(h_V, h_E, edge_idx, batch_id) | |
| return h_V, h_E | |
| class StructureEncoder(nn.Module): | |
| def __init__(self, hidden_dim, num_encoder_layers=3, dropout=0, updating_edges=0, node_context = False, edge_context = False): | |
| super(StructureEncoder, self).__init__() | |
| encoder_layers = [] | |
| self.updating_edges = updating_edges | |
| module = GeneralGNN | |
| for i in range(num_encoder_layers): | |
| encoder_layers.append( | |
| module(hidden_dim, hidden_dim*2, dropout=dropout, node_context = node_context, edge_context = edge_context), | |
| ) | |
| self.encoder_layers = nn.Sequential(*encoder_layers) | |
| def forward(self, h_V, h_P, P_idx, batch_id): | |
| for layer in self.encoder_layers: | |
| if self.updating_edges == 0: | |
| h_V = layer(h_V, torch.cat([h_P, h_V[P_idx[1]]], dim=1), P_idx, batch_id) | |
| else: | |
| h_V, h_P = layer(h_V, h_P, P_idx, batch_id) | |
| return h_V, h_P | |
| class MLPDecoder(nn.Module): | |
| def __init__(self, hidden_dim, vocab=20): | |
| super().__init__() | |
| self.readout = nn.Linear(hidden_dim, vocab) | |
| def forward(self, h_V): | |
| logits = self.readout(h_V) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| return log_probs, logits |