English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch_scatter import scatter
from torch import nn
from src.data import NAG
from src.nn import MLP, BatchNorm
__all__ = ['NodeMLP']
class NodeMLP(nn.Module):
"""Simple MLP on the handcrafted features of the level-i in a NAG.
This is used as a baseline to test how expressive handcrafted
features are.
"""
def __init__(
self, dims, level=0, activation=nn.LeakyReLU(), norm=BatchNorm,
drop=None, norm_mode='graph'):
super().__init__()
self.level = level
self.mlp = MLP(dims, activation=activation, norm=norm, drop=drop)
self.norm_mode = norm_mode
@property
def out_dim(self):
return self.mlp.out_dim
def forward(self, nag):
assert isinstance(nag, NAG)
assert nag.num_levels > self.level
# Compute node features from the handcrafted features
norm_index = nag[self.i_level].norm_index(mode=self.norm_mode)
x = self.mlp(nag[self.level].x, batch=norm_index)
# If node level is 1, output level-1 features
if self.level == 1:
return x
# If node level is 0, max-pool to produce level-1 features
if self.level == 0:
return scatter(
x, nag[0].super_index, dim=0, dim_size=nag[1].num_nodes,
reduce='max')
# If node level is larger than 1, distribute parent features to
# level-1 nodes
super_index = nag.get_super_index(self.level, low=1)
return x[super_index]