Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import dgl.function as fn | |
| """ | |
| ResGatedGCN: Residual Gated Graph ConvNets | |
| An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) | |
| https://arxiv.org/pdf/1711.07553v2.pdf | |
| """ | |
| class GatedGCNLayer(nn.Module): | |
| """ | |
| Param: [] | |
| """ | |
| def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False): | |
| super().__init__() | |
| self.in_channels = input_dim | |
| self.out_channels = output_dim | |
| self.dropout = dropout | |
| self.batch_norm = batch_norm | |
| self.residual = residual | |
| if input_dim != output_dim: | |
| self.residual = False | |
| self.A = nn.Linear(input_dim, output_dim, bias=True) | |
| self.B = nn.Linear(input_dim, output_dim, bias=True) | |
| self.C = nn.Linear(input_dim, output_dim, bias=True) | |
| self.D = nn.Linear(input_dim, output_dim, bias=True) | |
| self.E = nn.Linear(input_dim, output_dim, bias=True) | |
| self.bn_node_h = nn.BatchNorm1d(output_dim) | |
| self.bn_node_e = nn.BatchNorm1d(output_dim) | |
| def message_func(self, edges): | |
| Bh_j = edges.src['Bh'] | |
| e_ij = edges.data['Ce'] + edges.src['Dh'] + edges.dst['Eh'] # e_ij = Ce_ij + Dhi + Ehj | |
| edges.data['e'] = e_ij | |
| return {'Bh_j' : Bh_j, 'e_ij' : e_ij} | |
| def reduce_func(self, nodes): | |
| Ah_i = nodes.data['Ah'] | |
| Bh_j = nodes.mailbox['Bh_j'] | |
| e = nodes.mailbox['e_ij'] | |
| sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij) | |
| #h = Ah_i + torch.mean( sigma_ij * Bh_j, dim=1 ) # hi = Ahi + mean_j alpha_ij * Bhj | |
| h = Ah_i + torch.sum( sigma_ij * Bh_j, dim=1 ) / ( torch.sum( sigma_ij, dim=1 ) + 1e-6 ) # hi = Ahi + sum_j eta_ij/sum_j' eta_ij' * Bhj <= dense attention | |
| return {'h' : h} | |
| def forward(self, g, h, e): | |
| h_in = h # for residual connection | |
| e_in = e # for residual connection | |
| g.ndata['h'] = h | |
| g.ndata['Ah'] = self.A(h) | |
| g.ndata['Bh'] = self.B(h) | |
| g.ndata['Dh'] = self.D(h) | |
| g.ndata['Eh'] = self.E(h) | |
| g.edata['e'] = e | |
| g.edata['Ce'] = self.C(e) | |
| g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh')) | |
| g.edata['e'] = g.edata['DEh'] + g.edata['Ce'] | |
| g.edata['sigma'] = torch.sigmoid(g.edata['e']) | |
| g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h')) | |
| g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma')) | |
| g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6) | |
| #g.update_all(self.message_func,self.reduce_func) | |
| h = g.ndata['h'] # result of graph convolution | |
| e = g.edata['e'] # result of graph convolution | |
| if self.batch_norm: | |
| h = self.bn_node_h(h) # batch normalization | |
| e = self.bn_node_e(e) # batch normalization | |
| h = F.relu(h) # non-linear activation | |
| e = F.relu(e) # non-linear activation | |
| if self.residual: | |
| h = h_in + h # residual connection | |
| e = e_in + e # residual connection | |
| h = F.dropout(h, self.dropout, training=self.training) | |
| e = F.dropout(e, self.dropout, training=self.training) | |
| return h, e | |
| def __repr__(self): | |
| return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, | |
| self.in_channels, | |
| self.out_channels) | |
| ############################################################## | |
| # | |
| # Additional layers for edge feature/representation analysis | |
| # | |
| ############################################################## | |
| class GatedGCNLayerEdgeFeatOnly(nn.Module): | |
| """ | |
| Param: [] | |
| """ | |
| def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False): | |
| super().__init__() | |
| self.in_channels = input_dim | |
| self.out_channels = output_dim | |
| self.dropout = dropout | |
| self.batch_norm = batch_norm | |
| self.residual = residual | |
| if input_dim != output_dim: | |
| self.residual = False | |
| self.A = nn.Linear(input_dim, output_dim, bias=True) | |
| self.B = nn.Linear(input_dim, output_dim, bias=True) | |
| self.D = nn.Linear(input_dim, output_dim, bias=True) | |
| self.E = nn.Linear(input_dim, output_dim, bias=True) | |
| self.bn_node_h = nn.BatchNorm1d(output_dim) | |
| def message_func(self, edges): | |
| Bh_j = edges.src['Bh'] | |
| e_ij = edges.src['Dh'] + edges.dst['Eh'] # e_ij = Dhi + Ehj | |
| edges.data['e'] = e_ij | |
| return {'Bh_j' : Bh_j, 'e_ij' : e_ij} | |
| def reduce_func(self, nodes): | |
| Ah_i = nodes.data['Ah'] | |
| Bh_j = nodes.mailbox['Bh_j'] | |
| e = nodes.mailbox['e_ij'] | |
| sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij) | |
| h = Ah_i + torch.sum( sigma_ij * Bh_j, dim=1 ) / ( torch.sum( sigma_ij, dim=1 ) + 1e-6 ) # hi = Ahi + sum_j eta_ij/sum_j' eta_ij' * Bhj <= dense attention | |
| return {'h' : h} | |
| def forward(self, g, h, e): | |
| h_in = h # for residual connection | |
| g.ndata['h'] = h | |
| g.ndata['Ah'] = self.A(h) | |
| g.ndata['Bh'] = self.B(h) | |
| g.ndata['Dh'] = self.D(h) | |
| g.ndata['Eh'] = self.E(h) | |
| #g.update_all(self.message_func,self.reduce_func) | |
| g.apply_edges(fn.u_add_v('Dh', 'Eh', 'e')) | |
| g.edata['sigma'] = torch.sigmoid(g.edata['e']) | |
| g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h')) | |
| g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma')) | |
| g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6) | |
| h = g.ndata['h'] # result of graph convolution | |
| if self.batch_norm: | |
| h = self.bn_node_h(h) # batch normalization | |
| h = F.relu(h) # non-linear activation | |
| if self.residual: | |
| h = h_in + h # residual connection | |
| h = F.dropout(h, self.dropout, training=self.training) | |
| return h, e | |
| def __repr__(self): | |
| return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, | |
| self.in_channels, | |
| self.out_channels) | |
| ############################################################## | |
| class GatedGCNLayerIsotropic(nn.Module): | |
| """ | |
| Param: [] | |
| """ | |
| def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False): | |
| super().__init__() | |
| self.in_channels = input_dim | |
| self.out_channels = output_dim | |
| self.dropout = dropout | |
| self.batch_norm = batch_norm | |
| self.residual = residual | |
| if input_dim != output_dim: | |
| self.residual = False | |
| self.A = nn.Linear(input_dim, output_dim, bias=True) | |
| self.B = nn.Linear(input_dim, output_dim, bias=True) | |
| self.bn_node_h = nn.BatchNorm1d(output_dim) | |
| def message_func(self, edges): | |
| Bh_j = edges.src['Bh'] | |
| return {'Bh_j' : Bh_j} | |
| def reduce_func(self, nodes): | |
| Ah_i = nodes.data['Ah'] | |
| Bh_j = nodes.mailbox['Bh_j'] | |
| h = Ah_i + torch.sum( Bh_j, dim=1 ) # hi = Ahi + sum_j Bhj | |
| return {'h' : h} | |
| def forward(self, g, h, e): | |
| h_in = h # for residual connection | |
| g.ndata['h'] = h | |
| g.ndata['Ah'] = self.A(h) | |
| g.ndata['Bh'] = self.B(h) | |
| #g.update_all(self.message_func,self.reduce_func) | |
| g.update_all(fn.copy_u('Bh', 'm'), fn.sum('m', 'sum_h')) | |
| g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_h'] | |
| h = g.ndata['h'] # result of graph convolution | |
| if self.batch_norm: | |
| h = self.bn_node_h(h) # batch normalization | |
| h = F.relu(h) # non-linear activation | |
| if self.residual: | |
| h = h_in + h # residual connection | |
| h = F.dropout(h, self.dropout, training=self.training) | |
| return h, e | |
| def __repr__(self): | |
| return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, | |
| self.in_channels, | |
| self.out_channels) |