File size: 2,776 Bytes
fb9bb41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from deeprobust.graph.defense import GCN
import pickle
import os.path as osp
from deeprobust.graph.data import Dataset
from deeprobust.graph.utils import preprocess
import os

def load_victim_model(data, model_name='gcn', device='cpu', file_path=None):
    """load_victim_model.

    Parameters
    ----------
    data : deeprobust.graph.Dataset
        graph data
    model_name : str
        victime model name, e.g. ('gcn', 'deepwalk') But currently it only
        supports gcn as victim model.
    device : str
        'cpu' or 'cuda'
    file_path :
        if given, the victim model will be loaded from this path.
    """

    assert model_name == 'gcn', 'Currently only support gcn as victim model...'
    if file_path is None:
        # file_path = f'results/saved_models/{data.name}/{model_name}_checkpoint'
        file_path = 'results/saved_models/{0}/{1}_checkpoint'.format(data.name, model_name)
    else:
        file_path = osp.join(file_path, '{}_checkpoint'.format(model_name))

    # Setup victim model
    if osp.exists(file_path):
        victim_model = GCN(nfeat=data.features.shape[1], nclass=data.labels.max().item()+1,
                    nhid=16, dropout=0.5, weight_decay=5e-4, device=device)

        victim_model.load_state_dict(torch.load(file_path, map_location=device))
        victim_model.to(device)
        victim_model.eval()
        return victim_model

    victim_model = train_victim_model(data=data, model_name=model_name,
                                        device=device, file_path=osp.dirname(file_path))
    return victim_model

def train_victim_model(data, model_name='gcn', file_path=None, device='cpu'):
    """Train the victim model (target classifer) and save the model
    Note that the attacker can only do black query to this model.
    """

    if file_path is None:
        file_path = 'results/saved_models/%s/' % data.name

    adj, features, labels = data.adj, data.features, data.labels
    idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
    nfeat = features.shape[1]
    adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)

    # Setup victim model
    victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
                    nhid=16, dropout=0.5, weight_decay=5e-4, device=device)

    adj = adj.to(device)
    features = features.to(device)
    labels = labels.to(device)
    victim_model = victim_model.to(device)
    victim_model.fit(features, adj, labels, idx_train, idx_val)

    # save the model
    if not osp.exists(file_path):
        os.system('mkdir -p %s' % file_path)
    torch.save(victim_model.state_dict(), osp.join(file_path, model_name + '_checkpoint'))
    victim_model.eval()
    return victim_model