| """ |
| Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation |
| """ |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import torch |
| import torch.optim as optim |
| from torch.nn.parameter import Parameter |
| from torch.nn.modules.module import Module |
| from deeprobust.graph import utils |
| from copy import deepcopy |
| from torch_geometric.nn import ChebConv |
|
|
| class ChebNet(nn.Module): |
| """ 2 Layer ChebNet based on pytorch geometric. |
| |
| Parameters |
| ---------- |
| nfeat : int |
| size of input feature dimension |
| nhid : int |
| number of hidden units |
| nclass : int |
| size of output dimension |
| num_hops: int |
| number of hops in ChebConv |
| dropout : float |
| dropout rate for ChebNet |
| lr : float |
| learning rate for ChebNet |
| weight_decay : float |
| weight decay coefficient (l2 normalization) for GCN. |
| When `with_relu` is True, `weight_decay` will be set to 0. |
| with_bias: bool |
| whether to include bias term in ChebNet weights. |
| device: str |
| 'cpu' or 'cuda'. |
| |
| Examples |
| -------- |
| We can first load dataset and then train ChebNet. |
| |
| >>> from deeprobust.graph.data import Dataset |
| >>> from deeprobust.graph.defense import ChebNet |
| >>> data = Dataset(root='/tmp/', name='cora') |
| >>> adj, features, labels = data.adj, data.features, data.labels |
| >>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test |
| >>> cheby = ChebNet(nfeat=features.shape[1], |
| nhid=16, num_hops=3, |
| nclass=labels.max().item() + 1, |
| dropout=0.5, device='cpu') |
| >>> cheby = cheby.to('cpu') |
| >>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset |
| >>> cheby.fit(pyg_data, patience=10, verbose=True) # train with earlystopping |
| """ |
|
|
| def __init__(self, nfeat, nhid, nclass, num_hops=3, dropout=0.5, lr=0.01, |
| weight_decay=5e-4, with_bias=True, device=None): |
|
|
| super(ChebNet, self).__init__() |
|
|
| assert device is not None, "Please specify 'device'!" |
| self.device = device |
|
|
| self.conv1 = ChebConv( |
| nfeat, |
| nhid, |
| K=num_hops, |
| bias=with_bias) |
|
|
| self.conv2 = ChebConv( |
| nhid, |
| nclass, |
| K=num_hops, |
| bias=with_bias) |
|
|
| self.dropout = dropout |
| self.weight_decay = weight_decay |
| self.lr = lr |
| self.output = None |
| self.best_model = None |
| self.best_output = None |
|
|
| def forward(self, data): |
| x, edge_index = data.x, data.edge_index |
| x = F.relu(self.conv1(x, edge_index)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = self.conv2(x, edge_index) |
| return F.log_softmax(x, dim=1) |
|
|
| def initialize(self): |
| """Initialize parameters of ChebNet. |
| """ |
| self.conv1.reset_parameters() |
| self.conv2.reset_parameters() |
|
|
| def fit(self, pyg_data, train_iters=200, initialize=True, verbose=False, patience=500, **kwargs): |
| """Train the ChebNet model, when idx_val is not None, pick the best model |
| according to the validation loss. |
| |
| Parameters |
| ---------- |
| pyg_data : |
| pytorch geometric dataset object |
| train_iters : int |
| number of training epochs |
| initialize : bool |
| whether to initialize parameters before training |
| verbose : bool |
| whether to show verbose logs |
| patience : int |
| patience for early stopping, only valid when `idx_val` is given |
| """ |
|
|
| self.device = self.conv1.weight.device |
| if initialize: |
| self.initialize() |
|
|
| self.data = pyg_data[0].to(self.device) |
| |
| self.train_with_early_stopping(train_iters, patience, verbose) |
|
|
| def train_with_early_stopping(self, train_iters, patience, verbose): |
| """early stopping based on the validation loss |
| """ |
| if verbose: |
| print('=== training ChebNet model ===') |
| optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) |
|
|
| labels = self.data.y |
| train_mask, val_mask = self.data.train_mask, self.data.val_mask |
|
|
| early_stopping = patience |
| best_loss_val = 100 |
|
|
| for i in range(train_iters): |
| self.train() |
| optimizer.zero_grad() |
| output = self.forward(self.data) |
|
|
| loss_train = F.nll_loss(output[train_mask], labels[train_mask]) |
| loss_train.backward() |
| optimizer.step() |
|
|
| if verbose and i % 10 == 0: |
| print('Epoch {}, training loss: {}'.format(i, loss_train.item())) |
|
|
| self.eval() |
| output = self.forward(self.data) |
| loss_val = F.nll_loss(output[val_mask], labels[val_mask]) |
|
|
| if best_loss_val > loss_val: |
| best_loss_val = loss_val |
| self.output = output |
| weights = deepcopy(self.state_dict()) |
| patience = early_stopping |
| else: |
| patience -= 1 |
| if i > early_stopping and patience <= 0: |
| break |
|
|
| if verbose: |
| print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) ) |
| self.load_state_dict(weights) |
|
|
| def test(self): |
| """Evaluate ChebNet performance on test set. |
| |
| Parameters |
| ---------- |
| idx_test : |
| node testing indices |
| """ |
| self.eval() |
| test_mask = self.data.test_mask |
| labels = self.data.y |
| output = self.forward(self.data) |
| |
| loss_test = F.nll_loss(output[test_mask], labels[test_mask]) |
| acc_test = utils.accuracy(output[test_mask], labels[test_mask]) |
| print("Test set results:", |
| "loss= {:.4f}".format(loss_test.item()), |
| "accuracy= {:.4f}".format(acc_test.item())) |
| return acc_test.item() |
|
|
| def predict(self): |
| """ |
| Returns |
| ------- |
| torch.FloatTensor |
| output (log probabilities) of ChebNet |
| """ |
|
|
| self.eval() |
| return self.forward(self.data) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| from deeprobust.graph.data import Dataset, Dpr2Pyg |
| |
| data = Dataset(root='/tmp/', name='cora') |
| adj, features, labels = data.adj, data.features, data.labels |
| idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test |
| cheby = ChebNet(nfeat=features.shape[1], |
| nhid=16, |
| nclass=labels.max().item() + 1, |
| dropout=0.5, device='cpu') |
| cheby = cheby.to('cpu') |
| pyg_data = Dpr2Pyg(data) |
| cheby.fit(pyg_data, verbose=True) |
| cheby.test() |
| print(cheby.predict()) |
|
|
|
|