| | """ |
| | (c) Adaptation of the code from https://github.com/SitaoLuan/ACM-GNN |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from torch import Tensor |
| | from typing import Union |
| | from torch_geometric.nn.conv import MessagePassing |
| | from torch_geometric.nn.inits import reset |
| | from torch_geometric.typing import OptPairTensor, OptTensor, Size |
| | from torch_geometric.utils import scatter |
| |
|
| | from .utils import create_activation |
| |
|
| |
|
| | class ACM_GIN(MessagePassing): |
| | def __init__( |
| | self, |
| | nn_lowpass: torch.nn.Module, |
| | nn_highpass: torch.nn.Module, |
| | nn_fullpass: torch.nn.Module, |
| | nn_lowpass_proj: torch.nn.Module, |
| | nn_highpass_proj: torch.nn.Module, |
| | nn_fullpass_proj: torch.nn.Module, |
| | nn_mix: torch.nn.Module, |
| | T: float = 3.0, |
| | **kwargs, |
| | ): |
| | kwargs.setdefault("aggr", "add") |
| | super().__init__(**kwargs) |
| | self.nn_lowpass = nn_lowpass |
| | self.nn_highpass = nn_highpass |
| | self.nn_fullpass = nn_fullpass |
| | self.nn_lowpass_proj = nn_lowpass_proj |
| | self.nn_highpass_proj = nn_highpass_proj |
| | self.nn_fullpass_proj = nn_fullpass_proj |
| | self.nn_mix = nn_mix |
| | self.sigmoid = torch.nn.Sigmoid() |
| | self.softmax = torch.nn.Softmax(dim=1) |
| | self.T = T |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | reset(self.nn_lowpass) |
| | reset(self.nn_highpass) |
| | reset(self.nn_fullpass) |
| | reset(self.nn_lowpass_proj) |
| | reset(self.nn_highpass_proj) |
| | reset(self.nn_fullpass_proj) |
| | reset(self.nn_mix) |
| |
|
| | def forward( |
| | self, |
| | x: Union[Tensor, OptPairTensor], |
| | edge_index: Tensor, |
| | edge_weight: OptTensor = None, |
| | size: Size = None, |
| | ) -> Tensor: |
| |
|
| | if isinstance(x, Tensor): |
| | x: OptPairTensor = (x, x) |
| |
|
| | |
| | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) |
| |
|
| | deg = scatter(edge_weight, edge_index[1], 0, out.size(0), reduce="sum") |
| | deg_inv = 1.0 / deg |
| | deg_inv.masked_fill_(deg_inv == float("inf"), 0) |
| | out = deg_inv.view(-1, 1) * out |
| |
|
| | x_r = x[1] |
| | if x_r is not None: |
| | out_lowpass = (x_r + out) / 2.0 |
| | out_highpass = (x_r - out) / 2.0 |
| |
|
| | |
| | out_lowpass = self.nn_lowpass(out_lowpass) |
| | out_highpass = self.nn_highpass(out_highpass) |
| | out_fullpass = self.nn_fullpass(x_r) |
| | |
| | alpha_lowpass = self.sigmoid(self.nn_lowpass_proj(out_lowpass)) |
| | alpha_highpass = self.sigmoid(self.nn_highpass_proj(out_highpass)) |
| | alpha_fullpass = self.sigmoid(self.nn_fullpass_proj(out_fullpass)) |
| | alpha_cat = torch.concat([alpha_lowpass, alpha_highpass, alpha_fullpass], dim=1) |
| | alpha_cat = self.softmax(self.nn_mix(alpha_cat / self.T)) |
| |
|
| | out = alpha_cat[:, 0].view(-1, 1) * out_lowpass |
| | out = out + alpha_cat[:, 1].view(-1, 1) * out_highpass |
| | out = out + alpha_cat[:, 2].view(-1, 1) * out_fullpass |
| |
|
| | return out |
| |
|
| | def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: |
| | return edge_weight.view(-1, 1) * x_j |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(nn={self.nn})" |
| |
|
| |
|
| | class ACM_GIN_model(nn.Module): |
| | """ """ |
| |
|
| | def __init__( |
| | self, in_dim, out_dim, num_layers, hidden_dim, batchnorm, activation="relu" |
| | ): |
| | super(ACM_GIN_model, self).__init__() |
| | self.num_layers = num_layers |
| | self.hidden_dim = hidden_dim |
| | self.gnn_batchnorm = batchnorm |
| | self.out_dim = out_dim |
| |
|
| | self.ACM_convs = nn.ModuleList() |
| | self.nns_lowpass = nn.ModuleList() |
| | self.nns_highpass = nn.ModuleList() |
| | self.nns_fullpass = nn.ModuleList() |
| | self.nns_lowpass_proj = nn.ModuleList() |
| | self.nns_highpass_proj = nn.ModuleList() |
| | self.nns_fullpass_proj = nn.ModuleList() |
| | self.nns_mix = nn.ModuleList() |
| |
|
| | self.activation = create_activation(activation) |
| |
|
| | for i in range(self.num_layers): |
| | |
| | for channel_proj_module in [ |
| | self.nns_lowpass_proj, |
| | self.nns_highpass_proj, |
| | self.nns_fullpass_proj, |
| | ]: |
| | if i == self.num_layers - 1: |
| | channel_proj_module.append(nn.Linear(self.out_dim, 1)) |
| | else: |
| | channel_proj_module.append(nn.Linear(self.hidden_dim, 1)) |
| | |
| | self.nns_mix.append(nn.Linear(3, 3)) |
| |
|
| | |
| | if i == 0: |
| | local_input_dim = in_dim |
| | else: |
| | local_input_dim = self.hidden_dim |
| |
|
| | if i == self.num_layers - 1: |
| | local_out_dim = self.out_dim |
| | else: |
| | local_out_dim = self.hidden_dim |
| |
|
| | for channel_module in [ |
| | self.nns_lowpass, |
| | self.nns_highpass, |
| | self.nns_fullpass, |
| | ]: |
| | if self.gnn_batchnorm: |
| | sequential = nn.Sequential( |
| | nn.Linear(local_input_dim, self.hidden_dim), |
| | nn.BatchNorm1d(self.hidden_dim), |
| | self.activation, |
| | nn.Linear(self.hidden_dim, local_out_dim), |
| | nn.BatchNorm1d(local_out_dim), |
| | self.activation, |
| | ) |
| | else: |
| | sequential = nn.Sequential( |
| | nn.Linear(local_input_dim, self.hidden_dim), |
| | self.activation, |
| | nn.Linear(self.hidden_dim, local_out_dim), |
| | self.activation, |
| | ) |
| |
|
| | channel_module.append(sequential) |
| |
|
| | self.ACM_convs.append( |
| | ACM_GIN( |
| | nn_lowpass=self.nns_lowpass[i], |
| | nn_highpass=self.nns_highpass[i], |
| | nn_fullpass=self.nns_fullpass[i], |
| | nn_lowpass_proj=self.nns_lowpass_proj[i], |
| | nn_highpass_proj=self.nns_highpass_proj[i], |
| | nn_fullpass_proj=self.nns_fullpass_proj[i], |
| | nn_mix=self.nns_mix[i], |
| | ) |
| | ) |
| |
|
| | def reset_parameters(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Linear): |
| | m.reset_parameters() |
| | elif isinstance(m, nn.BatchNorm1d): |
| | m.reset_parameters() |
| |
|
| | def forward(self, x, edge_index, edge_attr, return_hidden=False): |
| | outs = [] |
| | for i in range(self.num_layers): |
| | x = self.ACM_convs[i](x=x, edge_index=edge_index, edge_weight=edge_attr) |
| | outs.append(x) |
| | if return_hidden: |
| | return x, outs |
| | else: |
| | return x |
| |
|
| |
|
| | if __name__ == "__main__": |
| | acm_gin = ACM_GIN_model(46, 46, 2, 256, True) |
| | print(sum(p.numel() for p in acm_gin.parameters() if p.requires_grad)) |
| | print("") |
| |
|