|
|
from torch_scatter import scatter |
|
|
from torch import nn |
|
|
from src.data import NAG |
|
|
from src.nn import MLP, BatchNorm |
|
|
|
|
|
|
|
|
__all__ = ['NodeMLP'] |
|
|
|
|
|
|
|
|
class NodeMLP(nn.Module): |
|
|
"""Simple MLP on the handcrafted features of the level-i in a NAG. |
|
|
This is used as a baseline to test how expressive handcrafted |
|
|
features are. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, dims, level=0, activation=nn.LeakyReLU(), norm=BatchNorm, |
|
|
drop=None, norm_mode='graph'): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.level = level |
|
|
self.mlp = MLP(dims, activation=activation, norm=norm, drop=drop) |
|
|
self.norm_mode = norm_mode |
|
|
|
|
|
@property |
|
|
def out_dim(self): |
|
|
return self.mlp.out_dim |
|
|
|
|
|
def forward(self, nag): |
|
|
assert isinstance(nag, NAG) |
|
|
assert nag.num_levels > self.level |
|
|
|
|
|
|
|
|
norm_index = nag[self.i_level].norm_index(mode=self.norm_mode) |
|
|
x = self.mlp(nag[self.level].x, batch=norm_index) |
|
|
|
|
|
|
|
|
if self.level == 1: |
|
|
return x |
|
|
|
|
|
|
|
|
if self.level == 0: |
|
|
return scatter( |
|
|
x, nag[0].super_index, dim=0, dim_size=nag[1].num_nodes, |
|
|
reduce='max') |
|
|
|
|
|
|
|
|
|
|
|
super_index = nag.get_super_index(self.level, low=1) |
|
|
return x[super_index] |
|
|
|