import torch import torch.nn as nn from functools import partial def create_activation(name): if name == "relu": return nn.ReLU() elif name == "gelu": return nn.GELU() elif name == "prelu": return nn.PReLU() elif name is None: return nn.Identity() elif name == "elu": return nn.ELU() else: raise NotImplementedError(f"{name} is not implemented.") def create_norm(name): if name == "layernorm": return nn.LayerNorm elif name == "batchnorm": return nn.BatchNorm1d elif name == "graphnorm": return partial(NormLayer, norm_type="groupnorm") else: return nn.Identity class NormLayer(nn.Module): def __init__(self, hidden_dim, norm_type): super().__init__() if norm_type == "batchnorm": self.norm = nn.BatchNorm1d(hidden_dim) elif norm_type == "layernorm": self.norm = nn.LayerNorm(hidden_dim) elif norm_type == "graphnorm": self.norm = norm_type self.weight = nn.Parameter(torch.ones(hidden_dim)) self.bias = nn.Parameter(torch.zeros(hidden_dim)) self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) else: raise NotImplementedError def forward(self, graph, x): tensor = x if self.norm is not None and type(self.norm) != str: return self.norm(tensor) elif self.norm is None: return tensor batch_list = graph.batch_num_nodes batch_size = len(batch_list) batch_list = torch.Tensor(batch_list).long().to(tensor.device) batch_index = ( torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) ) batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as( tensor ) mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) mean = mean.scatter_add_(0, batch_index, tensor) mean = (mean.T / batch_list).T mean = mean.repeat_interleave(batch_list, dim=0) sub = tensor - mean * self.mean_scale std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) std = std.scatter_add_(0, batch_index, sub.pow(2)) std = ((std.T / batch_list).T + 1e-6).sqrt() std = std.repeat_interleave(batch_list, dim=0) return self.weight * sub / std + self.bias