CGSCORE / examples /graph /test_prbcd.py
Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.utils import to_undirected
import torch_geometric.transforms as T
import argparse
import torch
import deeprobust.graph.utils as utils
from deeprobust.graph.global_attack import PRBCD
parser = argparse.ArgumentParser()
parser.add_argument('--ptb_rate', type=float, default=0.1, help='perturbation rate.')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = PygNodePropPredDataset(name='ogbn-arxiv')
dataset.transform = T.NormalizeFeatures()
data = dataset[0]
if not hasattr(data, 'train_mask'):
utils.add_mask(data, dataset)
data.edge_index = to_undirected(data.edge_index, data.num_nodes)
agent = PRBCD(data, device=device)
edge_index, edge_weight = agent.attack(ptb_rate=args.ptb_rate)