| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| | import torch |
| | import torch.optim as optim |
| | from torch.nn.parameter import Parameter |
| | from torch.nn.modules.module import Module |
| | from itertools import product |
| | import numpy as np |
| |
|
| | class PGE(nn.Module): |
| |
|
| | def __init__(self, nfeat, nnodes, nhid=128, nlayers=3, device=None, args=None): |
| | super(PGE, self).__init__() |
| | if args.dataset in ['ogbn-arxiv', 'arxiv', 'flickr']: |
| | nhid = 256 |
| | if args.dataset in ['reddit']: |
| | nhid = 256 |
| | if args.reduction_rate==0.01: |
| | nhid = 128 |
| | nlayers = 3 |
| | |
| |
|
| | self.layers = nn.ModuleList([]) |
| | self.layers.append(nn.Linear(nfeat*2, nhid)) |
| | self.bns = torch.nn.ModuleList() |
| | self.bns.append(nn.BatchNorm1d(nhid)) |
| | for i in range(nlayers-2): |
| | self.layers.append(nn.Linear(nhid, nhid)) |
| | self.bns.append(nn.BatchNorm1d(nhid)) |
| | self.layers.append(nn.Linear(nhid, 1)) |
| |
|
| | edge_index = np.array(list(product(range(nnodes), range(nnodes)))) |
| | self.edge_index = edge_index.T |
| | self.nnodes = nnodes |
| | self.device = device |
| | self.reset_parameters() |
| | self.cnt = 0 |
| | self.args = args |
| | self.nnodes = nnodes |
| |
|
| | def forward(self, x, inference=False): |
| | if self.args.dataset == 'reddit' and self.args.reduction_rate >= 0.01: |
| | edge_index = self.edge_index |
| | n_part = 5 |
| | splits = np.array_split(np.arange(edge_index.shape[1]), n_part) |
| | edge_embed = [] |
| | for idx in splits: |
| | tmp_edge_embed = torch.cat([x[edge_index[0][idx]], |
| | x[edge_index[1][idx]]], axis=1) |
| | for ix, layer in enumerate(self.layers): |
| | tmp_edge_embed = layer(tmp_edge_embed) |
| | if ix != len(self.layers) - 1: |
| | tmp_edge_embed = self.bns[ix](tmp_edge_embed) |
| | tmp_edge_embed = F.relu(tmp_edge_embed) |
| | edge_embed.append(tmp_edge_embed) |
| | edge_embed = torch.cat(edge_embed) |
| | else: |
| | edge_index = self.edge_index |
| | edge_embed = torch.cat([x[edge_index[0]], |
| | x[edge_index[1]]], axis=1) |
| | for ix, layer in enumerate(self.layers): |
| | edge_embed = layer(edge_embed) |
| | if ix != len(self.layers) - 1: |
| | edge_embed = self.bns[ix](edge_embed) |
| | edge_embed = F.relu(edge_embed) |
| |
|
| | adj = edge_embed.reshape(self.nnodes, self.nnodes) |
| |
|
| | adj = (adj + adj.T)/2 |
| | adj = torch.sigmoid(adj) |
| | adj = adj - torch.diag(torch.diag(adj, 0)) |
| | return adj |
| |
|
| | @torch.no_grad() |
| | def inference(self, x): |
| | |
| | adj_syn = self.forward(x, inference=True) |
| | return adj_syn |
| |
|
| | def reset_parameters(self): |
| | def weight_reset(m): |
| | if isinstance(m, nn.Linear): |
| | m.reset_parameters() |
| | if isinstance(m, nn.BatchNorm1d): |
| | m.reset_parameters() |
| | self.apply(weight_reset) |
| |
|
| |
|