|
|
import scipy.sparse as sp |
|
|
import numpy as np |
|
|
import sys |
|
|
import json |
|
|
import os |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
from torch_geometric.data import InMemoryDataset, Data |
|
|
import torch |
|
|
from itertools import repeat |
|
|
from torch_geometric.data import NeighborSampler |
|
|
|
|
|
class DataGraphSAINT: |
|
|
'''datasets used in GraphSAINT paper''' |
|
|
|
|
|
def __init__(self, dataset, **kwargs): |
|
|
dataset_str='data/'+dataset+'/' |
|
|
adj_full = sp.load_npz(dataset_str+'adj_full.npz') |
|
|
self.nnodes = adj_full.shape[0] |
|
|
if dataset == 'ogbn-arxiv': |
|
|
adj_full = adj_full + adj_full.T |
|
|
adj_full[adj_full > 1] = 1 |
|
|
|
|
|
role = json.load(open(dataset_str+'role.json','r')) |
|
|
idx_train = role['tr'] |
|
|
idx_test = role['te'] |
|
|
idx_val = role['va'] |
|
|
|
|
|
if 'label_rate' in kwargs: |
|
|
label_rate = kwargs['label_rate'] |
|
|
if label_rate < 1: |
|
|
idx_train = idx_train[:int(label_rate*len(idx_train))] |
|
|
|
|
|
self.adj_train = adj_full[np.ix_(idx_train, idx_train)] |
|
|
self.adj_val = adj_full[np.ix_(idx_val, idx_val)] |
|
|
self.adj_test = adj_full[np.ix_(idx_test, idx_test)] |
|
|
|
|
|
feat = np.load(dataset_str+'feats.npy') |
|
|
|
|
|
feat_train = feat[idx_train] |
|
|
scaler = StandardScaler() |
|
|
scaler.fit(feat_train) |
|
|
feat = scaler.transform(feat) |
|
|
|
|
|
self.feat_train = feat[idx_train] |
|
|
self.feat_val = feat[idx_val] |
|
|
self.feat_test = feat[idx_test] |
|
|
|
|
|
class_map = json.load(open(dataset_str + 'class_map.json','r')) |
|
|
labels = self.process_labels(class_map) |
|
|
|
|
|
self.labels_train = labels[idx_train] |
|
|
self.labels_val = labels[idx_val] |
|
|
self.labels_test = labels[idx_test] |
|
|
|
|
|
self.data_full = GraphData(adj_full, feat, labels, idx_train, idx_val, idx_test) |
|
|
self.class_dict = None |
|
|
self.class_dict2 = None |
|
|
|
|
|
self.adj_full = adj_full |
|
|
self.feat_full = feat |
|
|
self.labels_full = labels |
|
|
self.idx_train = np.array(idx_train) |
|
|
self.idx_val = np.array(idx_val) |
|
|
self.idx_test = np.array(idx_test) |
|
|
self.samplers = None |
|
|
|
|
|
def process_labels(self, class_map): |
|
|
""" |
|
|
setup vertex property map for output classests |
|
|
""" |
|
|
num_vertices = self.nnodes |
|
|
if isinstance(list(class_map.values())[0], list): |
|
|
num_classes = len(list(class_map.values())[0]) |
|
|
self.nclass = num_classes |
|
|
class_arr = np.zeros((num_vertices, num_classes)) |
|
|
for k,v in class_map.items(): |
|
|
class_arr[int(k)] = v |
|
|
else: |
|
|
class_arr = np.zeros(num_vertices, dtype=np.int) |
|
|
for k, v in class_map.items(): |
|
|
class_arr[int(k)] = v |
|
|
class_arr = class_arr - class_arr.min() |
|
|
self.nclass = max(class_arr) + 1 |
|
|
return class_arr |
|
|
|
|
|
def retrieve_class(self, c, num=256): |
|
|
if self.class_dict is None: |
|
|
self.class_dict = {} |
|
|
for i in range(self.nclass): |
|
|
self.class_dict['class_%s'%i] = (self.labels_train == i) |
|
|
idx = np.arange(len(self.labels_train)) |
|
|
idx = idx[self.class_dict['class_%s'%c]] |
|
|
return np.random.permutation(idx)[:num] |
|
|
|
|
|
def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None): |
|
|
if args.nlayers == 1: |
|
|
sizes = [30] |
|
|
if args.nlayers == 2: |
|
|
if args.dataset in ['reddit', 'flickr']: |
|
|
if args.option == 0: |
|
|
sizes = [15, 8] |
|
|
if args.option == 1: |
|
|
sizes = [20, 10] |
|
|
if args.option == 2: |
|
|
sizes = [25, 10] |
|
|
else: |
|
|
sizes = [10, 5] |
|
|
|
|
|
if self.class_dict2 is None: |
|
|
print(sizes) |
|
|
self.class_dict2 = {} |
|
|
for i in range(self.nclass): |
|
|
if transductive: |
|
|
idx_train = np.array(self.idx_train) |
|
|
idx = idx_train[self.labels_train == i] |
|
|
else: |
|
|
idx = np.arange(len(self.labels_train))[self.labels_train==i] |
|
|
self.class_dict2[i] = idx |
|
|
|
|
|
if self.samplers is None: |
|
|
self.samplers = [] |
|
|
for i in range(self.nclass): |
|
|
node_idx = torch.LongTensor(self.class_dict2[i]) |
|
|
if len(node_idx) == 0: |
|
|
continue |
|
|
|
|
|
self.samplers.append(NeighborSampler(adj, |
|
|
node_idx=node_idx, |
|
|
sizes=sizes, batch_size=num, |
|
|
num_workers=8, return_e_id=False, |
|
|
num_nodes=adj.size(0), |
|
|
shuffle=True)) |
|
|
batch = np.random.permutation(self.class_dict2[c])[:num] |
|
|
out = self.samplers[c].sample(batch) |
|
|
return out |
|
|
|
|
|
|
|
|
class GraphData: |
|
|
|
|
|
def __init__(self, adj, features, labels, idx_train, idx_val, idx_test): |
|
|
self.adj = adj |
|
|
self.features = features |
|
|
self.labels = labels |
|
|
self.idx_train = idx_train |
|
|
self.idx_val = idx_val |
|
|
self.idx_test = idx_test |
|
|
|
|
|
|
|
|
class Data2Pyg: |
|
|
|
|
|
def __init__(self, data, device='cuda', transform=None, **kwargs): |
|
|
self.data_train = Dpr2Pyg(data.data_train, transform=transform)[0].to(device) |
|
|
self.data_val = Dpr2Pyg(data.data_val, transform=transform)[0].to(device) |
|
|
self.data_test = Dpr2Pyg(data.data_test, transform=transform)[0].to(device) |
|
|
self.nclass = data.nclass |
|
|
self.nfeat = data.nfeat |
|
|
self.class_dict = None |
|
|
|
|
|
def retrieve_class(self, c, num=256): |
|
|
if self.class_dict is None: |
|
|
self.class_dict = {} |
|
|
for i in range(self.nclass): |
|
|
self.class_dict['class_%s'%i] = (self.data_train.y == i).cpu().numpy() |
|
|
idx = np.arange(len(self.data_train.y)) |
|
|
idx = idx[self.class_dict['class_%s'%c]] |
|
|
return np.random.permutation(idx)[:num] |
|
|
|
|
|
|
|
|
class Dpr2Pyg(InMemoryDataset): |
|
|
|
|
|
def __init__(self, dpr_data, transform=None, **kwargs): |
|
|
root = 'data/' |
|
|
self.dpr_data = dpr_data |
|
|
super(Dpr2Pyg, self).__init__(root, transform) |
|
|
pyg_data = self.process() |
|
|
self.data, self.slices = self.collate([pyg_data]) |
|
|
self.transform = transform |
|
|
|
|
|
def process(self): |
|
|
dpr_data = self.dpr_data |
|
|
edge_index = torch.LongTensor(dpr_data.adj.nonzero()) |
|
|
|
|
|
if sp.issparse(dpr_data.features): |
|
|
x = torch.FloatTensor(dpr_data.features.todense()).float() |
|
|
else: |
|
|
x = torch.FloatTensor(dpr_data.features).float() |
|
|
y = torch.LongTensor(dpr_data.labels) |
|
|
data = Data(x=x, edge_index=edge_index, y=y) |
|
|
data.train_mask = None |
|
|
data.val_mask = None |
|
|
data.test_mask = None |
|
|
return data |
|
|
|
|
|
|
|
|
def get(self, idx): |
|
|
data = self.data.__class__() |
|
|
|
|
|
if hasattr(self.data, '__num_nodes__'): |
|
|
data.num_nodes = self.data.__num_nodes__[idx] |
|
|
|
|
|
for key in self.data.keys: |
|
|
item, slices = self.data[key], self.slices[key] |
|
|
s = list(repeat(slice(None), item.dim())) |
|
|
s[self.data.__cat_dim__(key, item)] = slice(slices[idx], |
|
|
slices[idx + 1]) |
|
|
data[key] = item[s] |
|
|
return data |
|
|
|
|
|
@property |
|
|
def raw_file_names(self): |
|
|
return ['some_file_1', 'some_file_2', ...] |
|
|
|
|
|
@property |
|
|
def processed_file_names(self): |
|
|
return ['data.pt'] |
|
|
|
|
|
def _download(self): |
|
|
pass |
|
|
|
|
|
|
|
|
|