File size: 2,448 Bytes
8cd7b86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import torch
import torch.nn as nn
from functools import partial
def create_activation(name):
if name == "relu":
return nn.ReLU()
elif name == "gelu":
return nn.GELU()
elif name == "prelu":
return nn.PReLU()
elif name is None:
return nn.Identity()
elif name == "elu":
return nn.ELU()
else:
raise NotImplementedError(f"{name} is not implemented.")
def create_norm(name):
if name == "layernorm":
return nn.LayerNorm
elif name == "batchnorm":
return nn.BatchNorm1d
elif name == "graphnorm":
return partial(NormLayer, norm_type="groupnorm")
else:
return nn.Identity
class NormLayer(nn.Module):
def __init__(self, hidden_dim, norm_type):
super().__init__()
if norm_type == "batchnorm":
self.norm = nn.BatchNorm1d(hidden_dim)
elif norm_type == "layernorm":
self.norm = nn.LayerNorm(hidden_dim)
elif norm_type == "graphnorm":
self.norm = norm_type
self.weight = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
else:
raise NotImplementedError
def forward(self, graph, x):
tensor = x
if self.norm is not None and type(self.norm) != str:
return self.norm(tensor)
elif self.norm is None:
return tensor
batch_list = graph.batch_num_nodes
batch_size = len(batch_list)
batch_list = torch.Tensor(batch_list).long().to(tensor.device)
batch_index = (
torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
)
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(
tensor
)
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
mean = mean.scatter_add_(0, batch_index, tensor)
mean = (mean.T / batch_list).T
mean = mean.repeat_interleave(batch_list, dim=0)
sub = tensor - mean * self.mean_scale
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
std = std.scatter_add_(0, batch_index, sub.pow(2))
std = ((std.T / batch_list).T + 1e-6).sqrt()
std = std.repeat_interleave(batch_list, dim=0)
return self.weight * sub / std + self.bias
|