graphist / models /acm_gin.py
ogutsevda's picture
Upload 3 files
8cd7b86 verified
"""
(c) Adaptation of the code from https://github.com/SitaoLuan/ACM-GNN
"""
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import OptPairTensor, OptTensor, Size
from torch_geometric.utils import scatter
from .utils import create_activation
class ACM_GIN(MessagePassing):
def __init__(
self,
nn_lowpass: torch.nn.Module,
nn_highpass: torch.nn.Module,
nn_fullpass: torch.nn.Module,
nn_lowpass_proj: torch.nn.Module,
nn_highpass_proj: torch.nn.Module,
nn_fullpass_proj: torch.nn.Module,
nn_mix: torch.nn.Module,
T: float = 3.0,
**kwargs,
):
kwargs.setdefault("aggr", "add")
super().__init__(**kwargs)
self.nn_lowpass = nn_lowpass
self.nn_highpass = nn_highpass
self.nn_fullpass = nn_fullpass
self.nn_lowpass_proj = nn_lowpass_proj
self.nn_highpass_proj = nn_highpass_proj
self.nn_fullpass_proj = nn_fullpass_proj
self.nn_mix = nn_mix
self.sigmoid = torch.nn.Sigmoid()
self.softmax = torch.nn.Softmax(dim=1)
self.T = T
self.reset_parameters()
def reset_parameters(self):
reset(self.nn_lowpass)
reset(self.nn_highpass)
reset(self.nn_fullpass)
reset(self.nn_lowpass_proj)
reset(self.nn_highpass_proj)
reset(self.nn_fullpass_proj)
reset(self.nn_mix)
def forward(
self,
x: Union[Tensor, OptPairTensor],
edge_index: Tensor,
edge_weight: OptTensor = None,
size: Size = None,
) -> Tensor:
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)
deg = scatter(edge_weight, edge_index[1], 0, out.size(0), reduce="sum")
deg_inv = 1.0 / deg
deg_inv.masked_fill_(deg_inv == float("inf"), 0)
out = deg_inv.view(-1, 1) * out
x_r = x[1]
if x_r is not None:
out_lowpass = (x_r + out) / 2.0
out_highpass = (x_r - out) / 2.0
# compute embeddings for each filter
out_lowpass = self.nn_lowpass(out_lowpass)
out_highpass = self.nn_highpass(out_highpass)
out_fullpass = self.nn_fullpass(x_r)
# compute importance weights per filter
alpha_lowpass = self.sigmoid(self.nn_lowpass_proj(out_lowpass))
alpha_highpass = self.sigmoid(self.nn_highpass_proj(out_highpass))
alpha_fullpass = self.sigmoid(self.nn_fullpass_proj(out_fullpass))
alpha_cat = torch.concat([alpha_lowpass, alpha_highpass, alpha_fullpass], dim=1)
alpha_cat = self.softmax(self.nn_mix(alpha_cat / self.T))
out = alpha_cat[:, 0].view(-1, 1) * out_lowpass
out = out + alpha_cat[:, 1].view(-1, 1) * out_highpass
out = out + alpha_cat[:, 2].view(-1, 1) * out_fullpass
return out
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j
def __repr__(self) -> str:
return f"{self.__class__.__name__}(nn={self.nn})"
class ACM_GIN_model(nn.Module):
""" """
def __init__(
self, in_dim, out_dim, num_layers, hidden_dim, batchnorm, activation="relu"
):
super(ACM_GIN_model, self).__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.gnn_batchnorm = batchnorm
self.out_dim = out_dim
self.ACM_convs = nn.ModuleList()
self.nns_lowpass = nn.ModuleList()
self.nns_highpass = nn.ModuleList()
self.nns_fullpass = nn.ModuleList()
self.nns_lowpass_proj = nn.ModuleList()
self.nns_highpass_proj = nn.ModuleList()
self.nns_fullpass_proj = nn.ModuleList()
self.nns_mix = nn.ModuleList()
self.activation = create_activation(activation)
for i in range(self.num_layers):
# projection modules to compute importance weights
for channel_proj_module in [
self.nns_lowpass_proj,
self.nns_highpass_proj,
self.nns_fullpass_proj,
]:
if i == self.num_layers - 1:
channel_proj_module.append(nn.Linear(self.out_dim, 1))
else:
channel_proj_module.append(nn.Linear(self.hidden_dim, 1))
# weights mixing module as attention mechanism
self.nns_mix.append(nn.Linear(3, 3))
# GIN embedding scheme per channel
if i == 0:
local_input_dim = in_dim
else:
local_input_dim = self.hidden_dim
if i == self.num_layers - 1:
local_out_dim = self.out_dim
else:
local_out_dim = self.hidden_dim
for channel_module in [
self.nns_lowpass,
self.nns_highpass,
self.nns_fullpass,
]:
if self.gnn_batchnorm:
sequential = nn.Sequential(
nn.Linear(local_input_dim, self.hidden_dim),
nn.BatchNorm1d(self.hidden_dim),
self.activation,
nn.Linear(self.hidden_dim, local_out_dim),
nn.BatchNorm1d(local_out_dim),
self.activation,
)
else:
sequential = nn.Sequential(
nn.Linear(local_input_dim, self.hidden_dim),
self.activation,
nn.Linear(self.hidden_dim, local_out_dim),
self.activation,
)
channel_module.append(sequential)
self.ACM_convs.append(
ACM_GIN(
nn_lowpass=self.nns_lowpass[i],
nn_highpass=self.nns_highpass[i],
nn_fullpass=self.nns_fullpass[i],
nn_lowpass_proj=self.nns_lowpass_proj[i],
nn_highpass_proj=self.nns_highpass_proj[i],
nn_fullpass_proj=self.nns_fullpass_proj[i],
nn_mix=self.nns_mix[i],
)
)
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Linear):
m.reset_parameters()
elif isinstance(m, nn.BatchNorm1d):
m.reset_parameters()
def forward(self, x, edge_index, edge_attr, return_hidden=False):
outs = []
for i in range(self.num_layers):
x = self.ACM_convs[i](x=x, edge_index=edge_index, edge_weight=edge_attr)
outs.append(x)
if return_hidden:
return x, outs
else:
return x
if __name__ == "__main__":
acm_gin = ACM_GIN_model(46, 46, 2, 256, True)
print(sum(p.numel() for p in acm_gin.parameters() if p.requires_grad))
print("")