| '''one transformation with multiple propagation''' |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import torch |
| import torch.optim as optim |
| from torch.nn.parameter import Parameter |
| from torch.nn.modules.module import Module |
| from deeprobust.graph import utils |
| from copy import deepcopy |
| from sklearn.metrics import f1_score |
| from torch.nn import init |
| import torch_sparse |
|
|
| class GraphConvolution(Module): |
| """Simple GCN layer, similar to https://github.com/tkipf/pygcn |
| """ |
|
|
| def __init__(self, in_features, out_features, with_bias=True): |
| super(GraphConvolution, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.weight = Parameter(torch.FloatTensor(in_features, out_features)) |
| self.bias = Parameter(torch.FloatTensor(out_features)) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| |
| stdv = 1. / math.sqrt(self.weight.T.size(1)) |
| self.weight.data.uniform_(-stdv, stdv) |
| if self.bias is not None: |
| self.bias.data.uniform_(-stdv, stdv) |
|
|
| def forward(self, input, adj): |
| """ Graph Convolutional Layer forward function |
| """ |
| if input.data.is_sparse: |
| support = torch.spmm(input, self.weight) |
| else: |
| support = torch.mm(input, self.weight) |
| if isinstance(adj, torch_sparse.SparseTensor): |
| output = torch_sparse.matmul(adj, support) |
| else: |
| output = torch.spmm(adj, support) |
| if self.bias is not None: |
| return output + self.bias |
| else: |
| return output |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + ' (' \ |
| + str(self.in_features) + ' -> ' \ |
| + str(self.out_features) + ')' |
|
|
|
|
| class SGC(nn.Module): |
|
|
| def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, |
| with_relu=True, with_bias=True, with_bn=False, device=None): |
|
|
| super(SGC, self).__init__() |
|
|
| assert device is not None, "Please specify 'device'!" |
| self.device = device |
| self.nfeat = nfeat |
| self.nclass = nclass |
|
|
| self.conv = GraphConvolution(nfeat, nclass, with_bias=with_bias) |
|
|
| self.nlayers = nlayers |
| self.dropout = dropout |
| self.lr = lr |
| if not with_relu: |
| self.weight_decay = 0 |
| else: |
| self.weight_decay = weight_decay |
| self.with_relu = with_relu |
| if with_bn: |
| print('Warning: SGC does not have bn!!!') |
| self.with_bn = False |
| self.with_bias = with_bias |
| self.output = None |
| self.best_model = None |
| self.best_output = None |
| self.adj_norm = None |
| self.features = None |
| self.multi_label = None |
|
|
| def forward(self, x, adj): |
| weight = self.conv.weight |
| bias = self.conv.bias |
| x = torch.mm(x, weight) |
| for i in range(self.nlayers): |
| x = torch.spmm(adj, x) |
| x = x + bias |
| if self.multi_label: |
| return torch.sigmoid(x) |
| else: |
| return F.log_softmax(x, dim=1) |
|
|
| def forward_sampler(self, x, adjs): |
| weight = self.conv.weight |
| bias = self.conv.bias |
| x = torch.mm(x, weight) |
| for ix, (adj, _, size) in enumerate(adjs): |
| x = torch_sparse.matmul(adj, x) |
| x = x + bias |
| if self.multi_label: |
| return torch.sigmoid(x) |
| else: |
| return F.log_softmax(x, dim=1) |
|
|
| def forward_sampler_syn(self, x, adjs): |
| weight = self.conv.weight |
| bias = self.conv.bias |
| x = torch.mm(x, weight) |
| for ix, (adj) in enumerate(adjs): |
| if type(adj) == torch.Tensor: |
| x = adj @ x |
| else: |
| x = torch_sparse.matmul(adj, x) |
| x = x + bias |
| if self.multi_label: |
| return torch.sigmoid(x) |
| else: |
| return F.log_softmax(x, dim=1) |
|
|
| def initialize(self): |
| """Initialize parameters of GCN. |
| """ |
| self.conv.reset_parameters() |
| if self.with_bn: |
| for bn in self.bns: |
| bn.reset_parameters() |
|
|
| def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): |
| '''data: full data class''' |
| if initialize: |
| self.initialize() |
|
|
| |
| if type(adj) is not torch.Tensor: |
| features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) |
| else: |
| features = features.to(self.device) |
| adj = adj.to(self.device) |
| labels = labels.to(self.device) |
|
|
| if normalize: |
| if utils.is_sparse_tensor(adj): |
| adj_norm = utils.normalize_adj_tensor(adj, sparse=True) |
| else: |
| adj_norm = utils.normalize_adj_tensor(adj) |
| else: |
| adj_norm = adj |
|
|
| if 'feat_norm' in kwargs and kwargs['feat_norm']: |
| from utils import row_normalize_tensor |
| features = row_normalize_tensor(features-features.min()) |
|
|
| self.adj_norm = adj_norm |
| self.features = features |
|
|
| if len(labels.shape) > 1: |
| self.multi_label = True |
| self.loss = torch.nn.BCELoss() |
| else: |
| self.multi_label = False |
| self.loss = F.nll_loss |
|
|
| labels = labels.float() if self.multi_label else labels |
| self.labels = labels |
|
|
| if noval: |
| self._train_with_val(labels, data, train_iters, verbose, adj_val=True) |
| else: |
| self._train_with_val(labels, data, train_iters, verbose) |
|
|
| def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): |
| if adj_val: |
| feat_full, adj_full = data.feat_val, data.adj_val |
| else: |
| feat_full, adj_full = data.feat_full, data.adj_full |
|
|
| feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) |
| adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) |
| labels_val = torch.LongTensor(data.labels_val).to(self.device) |
|
|
| if verbose: |
| print('=== training gcn model ===') |
| optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) |
|
|
| best_acc_val = 0 |
|
|
| for i in range(train_iters): |
| if i == train_iters // 2: |
| lr = self.lr*0.1 |
| optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) |
|
|
| self.train() |
| optimizer.zero_grad() |
| output = self.forward(self.features, self.adj_norm) |
| loss_train = self.loss(output, labels) |
| loss_train.backward() |
| optimizer.step() |
|
|
| if verbose and i % 100 == 0: |
| print('Epoch {}, training loss: {}'.format(i, loss_train.item())) |
|
|
| with torch.no_grad(): |
| self.eval() |
| output = self.forward(feat_full, adj_full_norm) |
| if adj_val: |
| loss_val = F.nll_loss(output, labels_val) |
| acc_val = utils.accuracy(output, labels_val) |
| else: |
| loss_val = F.nll_loss(output[data.idx_val], labels_val) |
| acc_val = utils.accuracy(output[data.idx_val], labels_val) |
|
|
| if acc_val > best_acc_val: |
| best_acc_val = acc_val |
| self.output = output |
| weights = deepcopy(self.state_dict()) |
|
|
| if verbose: |
| print('=== picking the best model according to the performance on validation ===') |
| self.load_state_dict(weights) |
|
|
|
|
| def test(self, idx_test): |
| """Evaluate GCN performance on test set. |
| Parameters |
| ---------- |
| idx_test : |
| node testing indices |
| """ |
| self.eval() |
| output = self.predict() |
| |
| loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) |
| acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) |
| print("Test set results:", |
| "loss= {:.4f}".format(loss_test.item()), |
| "accuracy= {:.4f}".format(acc_test.item())) |
| return acc_test.item() |
|
|
|
|
| @torch.no_grad() |
| def predict(self, features=None, adj=None): |
| """By default, the inputs should be unnormalized adjacency |
| Parameters |
| ---------- |
| features : |
| node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions. |
| adj : |
| adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions. |
| Returns |
| ------- |
| torch.FloatTensor |
| output (log probabilities) of GCN |
| """ |
|
|
| self.eval() |
| if features is None and adj is None: |
| return self.forward(self.features, self.adj_norm) |
| else: |
| if type(adj) is not torch.Tensor: |
| features, adj = utils.to_tensor(features, adj, device=self.device) |
|
|
| self.features = features |
| if utils.is_sparse_tensor(adj): |
| self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) |
| else: |
| self.adj_norm = utils.normalize_adj_tensor(adj) |
| return self.forward(self.features, self.adj_norm) |
|
|
| @torch.no_grad() |
| def predict_unnorm(self, features=None, adj=None): |
| self.eval() |
| if features is None and adj is None: |
| return self.forward(self.features, self.adj_norm) |
| else: |
| if type(adj) is not torch.Tensor: |
| features, adj = utils.to_tensor(features, adj, device=self.device) |
|
|
| self.features = features |
| self.adj_norm = adj |
| return self.forward(self.features, self.adj_norm) |
|
|
|
|
|
|