from ogb.nodeproppred import PygNodePropPredDataset from torch_geometric.utils import to_undirected import torch_geometric.transforms as T import argparse import torch import deeprobust.graph.utils as utils from deeprobust.graph.global_attack import PRBCD parser = argparse.ArgumentParser() parser.add_argument('--ptb_rate', type=float, default=0.1, help='perturbation rate.') args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = PygNodePropPredDataset(name='ogbn-arxiv') dataset.transform = T.NormalizeFeatures() data = dataset[0] if not hasattr(data, 'train_mask'): utils.add_mask(data, dataset) data.edge_index = to_undirected(data.edge_index, data.num_nodes) agent = PRBCD(data, device=device) edge_index, edge_weight = agent.attack(ptb_rate=args.ptb_rate)