| | from torch_scatter import scatter |
| | from torch import nn |
| | from src.data import NAG |
| | from src.nn import MLP, BatchNorm |
| |
|
| |
|
| | __all__ = ['NodeMLP'] |
| |
|
| |
|
| | class NodeMLP(nn.Module): |
| | """Simple MLP on the handcrafted features of the level-i in a NAG. |
| | This is used as a baseline to test how expressive handcrafted |
| | features are. |
| | """ |
| |
|
| | def __init__( |
| | self, dims, level=0, activation=nn.LeakyReLU(), norm=BatchNorm, |
| | drop=None, norm_mode='graph'): |
| |
|
| | super().__init__() |
| |
|
| | self.level = level |
| | self.mlp = MLP(dims, activation=activation, norm=norm, drop=drop) |
| | self.norm_mode = norm_mode |
| |
|
| | @property |
| | def out_dim(self): |
| | return self.mlp.out_dim |
| |
|
| | def forward(self, nag): |
| | assert isinstance(nag, NAG) |
| | assert nag.num_levels > self.level |
| |
|
| | |
| | norm_index = nag[self.i_level].norm_index(mode=self.norm_mode) |
| | x = self.mlp(nag[self.level].x, batch=norm_index) |
| |
|
| | |
| | if self.level == 1: |
| | return x |
| |
|
| | |
| | if self.level == 0: |
| | return scatter( |
| | x, nag[0].super_index, dim=0, dim_size=nag[1].num_nodes, |
| | reduce='max') |
| |
|
| | |
| | |
| | super_index = nag.get_super_index(self.level, low=1) |
| | return x[super_index] |
| |
|