| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| def create_activation(name): | |
| if name == "relu": | |
| return nn.ReLU() | |
| elif name == "gelu": | |
| return nn.GELU() | |
| elif name == "prelu": | |
| return nn.PReLU() | |
| elif name is None: | |
| return nn.Identity() | |
| elif name == "elu": | |
| return nn.ELU() | |
| else: | |
| raise NotImplementedError(f"{name} is not implemented.") | |
| def create_norm(name): | |
| if name == "layernorm": | |
| return nn.LayerNorm | |
| elif name == "batchnorm": | |
| return nn.BatchNorm1d | |
| elif name == "graphnorm": | |
| return partial(NormLayer, norm_type="groupnorm") | |
| else: | |
| return nn.Identity | |
| class NormLayer(nn.Module): | |
| def __init__(self, hidden_dim, norm_type): | |
| super().__init__() | |
| if norm_type == "batchnorm": | |
| self.norm = nn.BatchNorm1d(hidden_dim) | |
| elif norm_type == "layernorm": | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| elif norm_type == "graphnorm": | |
| self.norm = norm_type | |
| self.weight = nn.Parameter(torch.ones(hidden_dim)) | |
| self.bias = nn.Parameter(torch.zeros(hidden_dim)) | |
| self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) | |
| else: | |
| raise NotImplementedError | |
| def forward(self, graph, x): | |
| tensor = x | |
| if self.norm is not None and type(self.norm) != str: | |
| return self.norm(tensor) | |
| elif self.norm is None: | |
| return tensor | |
| batch_list = graph.batch_num_nodes | |
| batch_size = len(batch_list) | |
| batch_list = torch.Tensor(batch_list).long().to(tensor.device) | |
| batch_index = ( | |
| torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) | |
| ) | |
| batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as( | |
| tensor | |
| ) | |
| mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) | |
| mean = mean.scatter_add_(0, batch_index, tensor) | |
| mean = (mean.T / batch_list).T | |
| mean = mean.repeat_interleave(batch_list, dim=0) | |
| sub = tensor - mean * self.mean_scale | |
| std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) | |
| std = std.scatter_add_(0, batch_index, sub.pow(2)) | |
| std = ((std.T / batch_list).T + 1e-6).sqrt() | |
| std = std.repeat_interleave(batch_list, dim=0) | |
| return self.weight * sub / std + self.bias | |