graphist / models /utils.py
ogutsevda's picture
Upload 3 files
8cd7b86 verified
raw
history blame
2.45 kB
import torch
import torch.nn as nn
from functools import partial
def create_activation(name):
if name == "relu":
return nn.ReLU()
elif name == "gelu":
return nn.GELU()
elif name == "prelu":
return nn.PReLU()
elif name is None:
return nn.Identity()
elif name == "elu":
return nn.ELU()
else:
raise NotImplementedError(f"{name} is not implemented.")
def create_norm(name):
if name == "layernorm":
return nn.LayerNorm
elif name == "batchnorm":
return nn.BatchNorm1d
elif name == "graphnorm":
return partial(NormLayer, norm_type="groupnorm")
else:
return nn.Identity
class NormLayer(nn.Module):
def __init__(self, hidden_dim, norm_type):
super().__init__()
if norm_type == "batchnorm":
self.norm = nn.BatchNorm1d(hidden_dim)
elif norm_type == "layernorm":
self.norm = nn.LayerNorm(hidden_dim)
elif norm_type == "graphnorm":
self.norm = norm_type
self.weight = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
else:
raise NotImplementedError
def forward(self, graph, x):
tensor = x
if self.norm is not None and type(self.norm) != str:
return self.norm(tensor)
elif self.norm is None:
return tensor
batch_list = graph.batch_num_nodes
batch_size = len(batch_list)
batch_list = torch.Tensor(batch_list).long().to(tensor.device)
batch_index = (
torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
)
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(
tensor
)
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
mean = mean.scatter_add_(0, batch_index, tensor)
mean = (mean.T / batch_list).T
mean = mean.repeat_interleave(batch_list, dim=0)
sub = tensor - mean * self.mean_scale
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
std = std.scatter_add_(0, batch_index, sub.pow(2))
std = ((std.T / batch_list).T + 1e-6).sqrt()
std = std.repeat_interleave(batch_list, dim=0)
return self.weight * sub / std + self.bias