clique / GCond /utils_graphsaint.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
import scipy.sparse as sp
import numpy as np
import sys
import json
import os
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import InMemoryDataset, Data
import torch
from itertools import repeat
from torch_geometric.data import NeighborSampler
class DataGraphSAINT:
'''datasets used in GraphSAINT paper'''
def __init__(self, dataset, **kwargs):
dataset_str='data/'+dataset+'/'
adj_full = sp.load_npz(dataset_str+'adj_full.npz')
self.nnodes = adj_full.shape[0]
if dataset == 'ogbn-arxiv':
adj_full = adj_full + adj_full.T
adj_full[adj_full > 1] = 1
role = json.load(open(dataset_str+'role.json','r'))
idx_train = role['tr']
idx_test = role['te']
idx_val = role['va']
if 'label_rate' in kwargs:
label_rate = kwargs['label_rate']
if label_rate < 1:
idx_train = idx_train[:int(label_rate*len(idx_train))]
self.adj_train = adj_full[np.ix_(idx_train, idx_train)]
self.adj_val = adj_full[np.ix_(idx_val, idx_val)]
self.adj_test = adj_full[np.ix_(idx_test, idx_test)]
feat = np.load(dataset_str+'feats.npy')
# ---- normalize feat ----
feat_train = feat[idx_train]
scaler = StandardScaler()
scaler.fit(feat_train)
feat = scaler.transform(feat)
self.feat_train = feat[idx_train]
self.feat_val = feat[idx_val]
self.feat_test = feat[idx_test]
class_map = json.load(open(dataset_str + 'class_map.json','r'))
labels = self.process_labels(class_map)
self.labels_train = labels[idx_train]
self.labels_val = labels[idx_val]
self.labels_test = labels[idx_test]
self.data_full = GraphData(adj_full, feat, labels, idx_train, idx_val, idx_test)
self.class_dict = None
self.class_dict2 = None
self.adj_full = adj_full
self.feat_full = feat
self.labels_full = labels
self.idx_train = np.array(idx_train)
self.idx_val = np.array(idx_val)
self.idx_test = np.array(idx_test)
self.samplers = None
def process_labels(self, class_map):
"""
setup vertex property map for output classests
"""
num_vertices = self.nnodes
if isinstance(list(class_map.values())[0], list):
num_classes = len(list(class_map.values())[0])
self.nclass = num_classes
class_arr = np.zeros((num_vertices, num_classes))
for k,v in class_map.items():
class_arr[int(k)] = v
else:
class_arr = np.zeros(num_vertices, dtype=np.int)
for k, v in class_map.items():
class_arr[int(k)] = v
class_arr = class_arr - class_arr.min()
self.nclass = max(class_arr) + 1
return class_arr
def retrieve_class(self, c, num=256):
if self.class_dict is None:
self.class_dict = {}
for i in range(self.nclass):
self.class_dict['class_%s'%i] = (self.labels_train == i)
idx = np.arange(len(self.labels_train))
idx = idx[self.class_dict['class_%s'%c]]
return np.random.permutation(idx)[:num]
def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None):
if args.nlayers == 1:
sizes = [30]
if args.nlayers == 2:
if args.dataset in ['reddit', 'flickr']:
if args.option == 0:
sizes = [15, 8]
if args.option == 1:
sizes = [20, 10]
if args.option == 2:
sizes = [25, 10]
else:
sizes = [10, 5]
if self.class_dict2 is None:
print(sizes)
self.class_dict2 = {}
for i in range(self.nclass):
if transductive:
idx_train = np.array(self.idx_train)
idx = idx_train[self.labels_train == i]
else:
idx = np.arange(len(self.labels_train))[self.labels_train==i]
self.class_dict2[i] = idx
if self.samplers is None:
self.samplers = []
for i in range(self.nclass):
node_idx = torch.LongTensor(self.class_dict2[i])
if len(node_idx) == 0:
continue
self.samplers.append(NeighborSampler(adj,
node_idx=node_idx,
sizes=sizes, batch_size=num,
num_workers=8, return_e_id=False,
num_nodes=adj.size(0),
shuffle=True))
batch = np.random.permutation(self.class_dict2[c])[:num]
out = self.samplers[c].sample(batch)
return out
class GraphData:
def __init__(self, adj, features, labels, idx_train, idx_val, idx_test):
self.adj = adj
self.features = features
self.labels = labels
self.idx_train = idx_train
self.idx_val = idx_val
self.idx_test = idx_test
class Data2Pyg:
def __init__(self, data, device='cuda', transform=None, **kwargs):
self.data_train = Dpr2Pyg(data.data_train, transform=transform)[0].to(device)
self.data_val = Dpr2Pyg(data.data_val, transform=transform)[0].to(device)
self.data_test = Dpr2Pyg(data.data_test, transform=transform)[0].to(device)
self.nclass = data.nclass
self.nfeat = data.nfeat
self.class_dict = None
def retrieve_class(self, c, num=256):
if self.class_dict is None:
self.class_dict = {}
for i in range(self.nclass):
self.class_dict['class_%s'%i] = (self.data_train.y == i).cpu().numpy()
idx = np.arange(len(self.data_train.y))
idx = idx[self.class_dict['class_%s'%c]]
return np.random.permutation(idx)[:num]
class Dpr2Pyg(InMemoryDataset):
def __init__(self, dpr_data, transform=None, **kwargs):
root = 'data/' # dummy root; does not mean anything
self.dpr_data = dpr_data
super(Dpr2Pyg, self).__init__(root, transform)
pyg_data = self.process()
self.data, self.slices = self.collate([pyg_data])
self.transform = transform
def process(self):
dpr_data = self.dpr_data
edge_index = torch.LongTensor(dpr_data.adj.nonzero())
# by default, the features in pyg data is dense
if sp.issparse(dpr_data.features):
x = torch.FloatTensor(dpr_data.features.todense()).float()
else:
x = torch.FloatTensor(dpr_data.features).float()
y = torch.LongTensor(dpr_data.labels)
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = None
data.val_mask = None
data.test_mask = None
return data
def get(self, idx):
data = self.data.__class__()
if hasattr(self.data, '__num_nodes__'):
data.num_nodes = self.data.__num_nodes__[idx]
for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
s = list(repeat(slice(None), item.dim()))
s[self.data.__cat_dim__(key, item)] = slice(slices[idx],
slices[idx + 1])
data[key] = item[s]
return data
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def _download(self):
pass