"""multiple transformaiton and multiple propagation""" import torch.nn as nn import torch.nn.functional as F import math import torch import torch.optim as optim from torch.nn.parameter import Parameter from torch.nn.modules.module import Module from deeprobust.graph import utils from copy import deepcopy from sklearn.metrics import f1_score from torch.nn import init import torch_sparse class SGC(nn.Module): def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, ntrans=2, with_bias=True, with_bn=False, device=None): """nlayers indicates the number of propagations""" super(SGC, self).__init__() assert device is not None, "Please specify 'device'!" self.device = device self.nfeat = nfeat self.nclass = nclass self.layers = nn.ModuleList([]) if ntrans == 1: self.layers.append(MyLinear(nfeat, nclass)) else: self.layers.append(MyLinear(nfeat, nhid)) if with_bn: self.bns = torch.nn.ModuleList() self.bns.append(nn.BatchNorm1d(nhid)) for i in range(ntrans-2): if with_bn: self.bns.append(nn.BatchNorm1d(nhid)) self.layers.append(MyLinear(nhid, nhid)) self.layers.append(MyLinear(nhid, nclass)) self.nlayers = nlayers self.dropout = dropout self.lr = lr self.with_bn = with_bn self.with_bias = with_bias self.weight_decay = weight_decay self.output = None self.best_model = None self.best_output = None self.adj_norm = None self.features = None self.multi_label = None def forward(self, x, adj): for ix, layer in enumerate(self.layers): x = layer(x) if ix != len(self.layers) - 1: x = self.bns[ix](x) if self.with_bn else x x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) for i in range(self.nlayers): x = torch.spmm(adj, x) if self.multi_label: return torch.sigmoid(x) else: return F.log_softmax(x, dim=1) def forward_sampler(self, x, adjs): for ix, layer in enumerate(self.layers): x = layer(x) if ix != len(self.layers) - 1: x = self.bns[ix](x) if self.with_bn else x x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) for ix, (adj, _, size) in enumerate(adjs): # x_target = x[: size[1]] # x = self.layers[ix]((x, x_target), edge_index) # adj = adj.to(self.device) x = torch_sparse.matmul(adj, x) if self.multi_label: return torch.sigmoid(x) else: return F.log_softmax(x, dim=1) def forward_sampler_syn(self, x, adjs): for ix, layer in enumerate(self.layers): x = layer(x) if ix != len(self.layers) - 1: x = self.bns[ix](x) if self.with_bn else x x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) for ix, (adj) in enumerate(adjs): if type(adj) == torch.Tensor: x = adj @ x else: x = torch_sparse.matmul(adj, x) if self.multi_label: return torch.sigmoid(x) else: return F.log_softmax(x, dim=1) def initialize(self): """Initialize parameters of GCN. """ for layer in self.layers: layer.reset_parameters() if self.with_bn: for bn in self.bns: bn.reset_parameters() def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): '''data: full data class''' if initialize: self.initialize() # features, adj, labels = data.feat_train, data.adj_train, data.labels_train if type(adj) is not torch.Tensor: features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) else: features = features.to(self.device) adj = adj.to(self.device) labels = labels.to(self.device) if normalize: if utils.is_sparse_tensor(adj): adj_norm = utils.normalize_adj_tensor(adj, sparse=True) else: adj_norm = utils.normalize_adj_tensor(adj) else: adj_norm = adj if 'feat_norm' in kwargs and kwargs['feat_norm']: from utils import row_normalize_tensor features = row_normalize_tensor(features-features.min()) self.adj_norm = adj_norm self.features = features if len(labels.shape) > 1: self.multi_label = True self.loss = torch.nn.BCELoss() else: self.multi_label = False self.loss = F.nll_loss labels = labels.float() if self.multi_label else labels self.labels = labels if noval: self._train_with_val(labels, data, train_iters, verbose, adj_val=True) else: self._train_with_val(labels, data, train_iters, verbose) def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): if adj_val: feat_full, adj_full = data.feat_val, data.adj_val else: feat_full, adj_full = data.feat_full, data.adj_full feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) labels_val = torch.LongTensor(data.labels_val).to(self.device) if verbose: print('=== training gcn model ===') optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_acc_val = 0 for i in range(train_iters): if i == train_iters // 2: lr = self.lr*0.1 optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) self.train() optimizer.zero_grad() output = self.forward(self.features, self.adj_norm) loss_train = self.loss(output, labels) loss_train.backward() optimizer.step() if verbose and i % 100 == 0: print('Epoch {}, training loss: {}'.format(i, loss_train.item())) with torch.no_grad(): self.eval() output = self.forward(feat_full, adj_full_norm) if adj_val: loss_val = F.nll_loss(output, labels_val) acc_val = utils.accuracy(output, labels_val) else: loss_val = F.nll_loss(output[data.idx_val], labels_val) acc_val = utils.accuracy(output[data.idx_val], labels_val) if acc_val > best_acc_val: best_acc_val = acc_val self.output = output weights = deepcopy(self.state_dict()) if verbose: print('=== picking the best model according to the performance on validation ===') self.load_state_dict(weights) def test(self, idx_test): """Evaluate GCN performance on test set. Parameters ---------- idx_test : node testing indices """ self.eval() output = self.predict() # output = self.output loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item())) return acc_test.item() @torch.no_grad() def predict(self, features=None, adj=None): """By default, the inputs should be unnormalized adjacency Parameters ---------- features : node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions. adj : adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions. Returns ------- torch.FloatTensor output (log probabilities) of GCN """ self.eval() if features is None and adj is None: return self.forward(self.features, self.adj_norm) else: if type(adj) is not torch.Tensor: features, adj = utils.to_tensor(features, adj, device=self.device) self.features = features if utils.is_sparse_tensor(adj): self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) else: self.adj_norm = utils.normalize_adj_tensor(adj) return self.forward(self.features, self.adj_norm) @torch.no_grad() def predict_unnorm(self, features=None, adj=None): self.eval() if features is None and adj is None: return self.forward(self.features, self.adj_norm) else: if type(adj) is not torch.Tensor: features, adj = utils.to_tensor(features, adj, device=self.device) self.features = features self.adj_norm = adj return self.forward(self.features, self.adj_norm) class MyLinear(Module): """Simple Linear layer, modified from https://github.com/tkipf/pygcn """ def __init__(self, in_features, out_features, with_bias=True): super(MyLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if with_bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): # stdv = 1. / math.sqrt(self.weight.size(1)) stdv = 1. / math.sqrt(self.weight.T.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input): if input.data.is_sparse: support = torch.spmm(input, self.weight) else: support = torch.mm(input, self.weight) output = support if self.bias is not None: return output + self.bias else: return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'