clique / GCond /utils.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
import os.path as osp
import numpy as np
import scipy.sparse as sp
import torch
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from deeprobust.graph.data import Dataset
from deeprobust.graph.utils import get_train_val_test
from torch_geometric.utils import train_test_split_edges
from sklearn.model_selection import train_test_split
from sklearn import metrics
import numpy as np
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from deeprobust.graph.utils import *
from torch_geometric.data import NeighborSampler
from torch_geometric.utils import add_remaining_self_loops, to_undirected
from torch_geometric.datasets import Planetoid
def get_dataset(name, normalize_features=False, transform=None, if_dpr=True):
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', name)
if name in ['cora', 'citeseer', 'pubmed']:
dataset = Planetoid(path, name)
elif name in ['ogbn-arxiv']:
dataset = PygNodePropPredDataset(name='ogbn-arxiv')
else:
raise NotImplementedError
if transform is not None and normalize_features:
dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
elif normalize_features:
dataset.transform = T.NormalizeFeatures()
elif transform is not None:
dataset.transform = transform
dpr_data = Pyg2Dpr(dataset)
if name in ['ogbn-arxiv']:
# the features are different from the features provided by GraphSAINT
# normalize features, following graphsaint
feat, idx_train = dpr_data.features, dpr_data.idx_train
feat_train = feat[idx_train]
scaler = StandardScaler()
scaler.fit(feat_train)
feat = scaler.transform(feat)
dpr_data.features = feat
return dpr_data
class Pyg2Dpr(Dataset):
def __init__(self, pyg_data, **kwargs):
try:
splits = pyg_data.get_idx_split()
except:
pass
dataset_name = pyg_data.name
pyg_data = pyg_data[0]
n = pyg_data.num_nodes
if dataset_name == 'ogbn-arxiv': # symmetrization
pyg_data.edge_index = to_undirected(pyg_data.edge_index, pyg_data.num_nodes)
self.adj = sp.csr_matrix((np.ones(pyg_data.edge_index.shape[1]),
(pyg_data.edge_index[0], pyg_data.edge_index[1])), shape=(n, n))
self.features = pyg_data.x.numpy()
self.labels = pyg_data.y.numpy()
if len(self.labels.shape) == 2 and self.labels.shape[1] == 1:
self.labels = self.labels.reshape(-1) # ogb-arxiv needs to reshape
if hasattr(pyg_data, 'train_mask'):
# for fixed split
self.idx_train = mask_to_index(pyg_data.train_mask, n)
self.idx_val = mask_to_index(pyg_data.val_mask, n)
self.idx_test = mask_to_index(pyg_data.test_mask, n)
self.name = 'Pyg2Dpr'
else:
try:
# for ogb
self.idx_train = splits['train']
self.idx_val = splits['valid']
self.idx_test = splits['test']
self.name = 'Pyg2Dpr'
except:
# for other datasets
self.idx_train, self.idx_val, self.idx_test = get_train_val_test(
nnodes=n, val_size=0.1, test_size=0.8, stratify=self.labels)
def mask_to_index(index, size):
all_idx = np.arange(size)
return all_idx[index]
def index_to_mask(index, size):
mask = torch.zeros((size, ), dtype=torch.bool)
mask[index] = 1
return mask
class Transd2Ind:
# transductive setting to inductive setting
def __init__(self, dpr_data, keep_ratio):
idx_train, idx_val, idx_test = dpr_data.idx_train, dpr_data.idx_val, dpr_data.idx_test
adj, features, labels = dpr_data.adj, dpr_data.features, dpr_data.labels
self.nclass = labels.max()+1
self.adj_full, self.feat_full, self.labels_full = adj, features, labels
self.idx_train = np.array(idx_train)
self.idx_val = np.array(idx_val)
self.idx_test = np.array(idx_test)
if keep_ratio < 1:
idx_train, _ = train_test_split(idx_train,
random_state=None,
train_size=keep_ratio,
test_size=1-keep_ratio,
stratify=labels[idx_train])
self.adj_train = adj[np.ix_(idx_train, idx_train)]
self.adj_val = adj[np.ix_(idx_val, idx_val)]
self.adj_test = adj[np.ix_(idx_test, idx_test)]
print('size of adj_train:', self.adj_train.shape)
print('#edges in adj_train:', self.adj_train.sum())
self.labels_train = labels[idx_train]
self.labels_val = labels[idx_val]
self.labels_test = labels[idx_test]
self.feat_train = features[idx_train]
self.feat_val = features[idx_val]
self.feat_test = features[idx_test]
self.class_dict = None
self.samplers = None
self.class_dict2 = None
def retrieve_class(self, c, num=256):
if self.class_dict is None:
self.class_dict = {}
for i in range(self.nclass):
self.class_dict['class_%s'%i] = (self.labels_train == i)
idx = np.arange(len(self.labels_train))
idx = idx[self.class_dict['class_%s'%c]]
return np.random.permutation(idx)[:num]
def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None):
if self.class_dict2 is None:
self.class_dict2 = {}
for i in range(self.nclass):
if transductive:
idx = self.idx_train[self.labels_train == i]
else:
idx = np.arange(len(self.labels_train))[self.labels_train==i]
self.class_dict2[i] = idx
if args.nlayers == 1:
sizes = [15]
if args.nlayers == 2:
sizes = [10, 5]
# sizes = [-1, -1]
if args.nlayers == 3:
sizes = [15, 10, 5]
if args.nlayers == 4:
sizes = [15, 10, 5, 5]
if args.nlayers == 5:
sizes = [15, 10, 5, 5, 5]
if self.samplers is None:
self.samplers = []
for i in range(self.nclass):
node_idx = torch.LongTensor(self.class_dict2[i])
self.samplers.append(NeighborSampler(adj,
node_idx=node_idx,
sizes=sizes, batch_size=num,
num_workers=12, return_e_id=False,
num_nodes=adj.size(0),
shuffle=True))
batch = np.random.permutation(self.class_dict2[c])[:num]
out = self.samplers[c].sample(batch)
return out
def retrieve_class_multi_sampler(self, c, adj, transductive, num=256, args=None):
if self.class_dict2 is None:
self.class_dict2 = {}
for i in range(self.nclass):
if transductive:
idx = self.idx_train[self.labels_train == i]
else:
idx = np.arange(len(self.labels_train))[self.labels_train==i]
self.class_dict2[i] = idx
if self.samplers is None:
self.samplers = []
for l in range(2):
layer_samplers = []
sizes = [15] if l == 0 else [10, 5]
for i in range(self.nclass):
node_idx = torch.LongTensor(self.class_dict2[i])
layer_samplers.append(NeighborSampler(adj,
node_idx=node_idx,
sizes=sizes, batch_size=num,
num_workers=12, return_e_id=False,
num_nodes=adj.size(0),
shuffle=True))
self.samplers.append(layer_samplers)
batch = np.random.permutation(self.class_dict2[c])[:num]
out = self.samplers[args.nlayers-1][c].sample(batch)
return out
def match_loss(gw_syn, gw_real, args, device):
dis = torch.tensor(0.0).to(device)
if args.dis_metric == 'ours':
for ig in range(len(gw_real)):
gwr = gw_real[ig]
gws = gw_syn[ig]
dis += distance_wb(gwr, gws)
elif args.dis_metric == 'mse':
gw_real_vec = []
gw_syn_vec = []
for ig in range(len(gw_real)):
gw_real_vec.append(gw_real[ig].reshape((-1)))
gw_syn_vec.append(gw_syn[ig].reshape((-1)))
gw_real_vec = torch.cat(gw_real_vec, dim=0)
gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
dis = torch.sum((gw_syn_vec - gw_real_vec)**2)
elif args.dis_metric == 'cos':
gw_real_vec = []
gw_syn_vec = []
for ig in range(len(gw_real)):
gw_real_vec.append(gw_real[ig].reshape((-1)))
gw_syn_vec.append(gw_syn[ig].reshape((-1)))
gw_real_vec = torch.cat(gw_real_vec, dim=0)
gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)
else:
exit('DC error: unknown distance function')
return dis
def distance_wb(gwr, gws):
shape = gwr.shape
# TODO: output node!!!!
if len(gwr.shape) == 2:
gwr = gwr.T
gws = gws.T
if len(shape) == 4: # conv, out*in*h*w
gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
elif len(shape) == 3: # layernorm, C*h*w
gwr = gwr.reshape(shape[0], shape[1] * shape[2])
gws = gws.reshape(shape[0], shape[1] * shape[2])
elif len(shape) == 2: # linear, out*in
tmp = 'do nothing'
elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
gwr = gwr.reshape(1, shape[0])
gws = gws.reshape(1, shape[0])
return 0
dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
dis = dis_weight
return dis
def calc_f1(y_true, y_pred,is_sigmoid):
if not is_sigmoid:
y_pred = np.argmax(y_pred, axis=1)
else:
y_pred[y_pred > 0.5] = 1
y_pred[y_pred <= 0.5] = 0
return metrics.f1_score(y_true, y_pred, average="micro"), metrics.f1_score(y_true, y_pred, average="macro")
def evaluate(output, labels, args):
data_graphsaint = ['yelp', 'ppi', 'ppi-large', 'flickr', 'reddit', 'amazon']
if args.dataset in data_graphsaint:
labels = labels.cpu().numpy()
output = output.cpu().numpy()
if len(labels.shape) > 1:
micro, macro = calc_f1(labels, output, is_sigmoid=True)
else:
micro, macro = calc_f1(labels, output, is_sigmoid=False)
print("Test set results:", "F1-micro= {:.4f}".format(micro),
"F1-macro= {:.4f}".format(macro))
else:
loss_test = F.nll_loss(output, labels)
acc_test = accuracy(output, labels)
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))
return
from torchvision import datasets, transforms
def get_mnist(data_path):
channel = 1
im_size = (28, 28)
num_classes = 10
mean = [0.1307]
std = [0.3081]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
class_names = [str(c) for c in range(num_classes)]
labels = []
feat = []
for x, y in dst_train:
feat.append(x.view(1, -1))
labels.append(y)
feat = torch.cat(feat, axis=0).numpy()
from utils_graphsaint import GraphData
adj = sp.eye(len(feat))
idx = np.arange(len(feat))
dpr_data = GraphData(adj-adj, feat, labels, idx, idx, idx)
from deeprobust.graph.data import Dpr2Pyg
return Dpr2Pyg(dpr_data)
def regularization(adj, x, eig_real=None):
# fLf
loss = 0
# loss += torch.norm(adj, p=1)
loss += feature_smoothing(adj, x)
return loss
def maxdegree(adj):
n = adj.shape[0]
return F.relu(max(adj.sum(1))/n - 0.5)
def sparsity2(adj):
n = adj.shape[0]
loss_degree = - torch.log(adj.sum(1)).sum() / n
loss_fro = torch.norm(adj) / n
return 0 * loss_degree + loss_fro
def sparsity(adj):
n = adj.shape[0]
thresh = n * n * 0.01
return F.relu(adj.sum()-thresh)
# return F.relu(adj.sum()-thresh) / n**2
def feature_smoothing(adj, X):
adj = (adj.t() + adj)/2
rowsum = adj.sum(1)
r_inv = rowsum.flatten()
D = torch.diag(r_inv)
L = D - adj
r_inv = r_inv + 1e-8
r_inv = r_inv.pow(-1/2).flatten()
r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag(r_inv)
# L = r_mat_inv @ L
L = r_mat_inv @ L @ r_mat_inv
XLXT = torch.matmul(torch.matmul(X.t(), L), X)
loss_smooth_feat = torch.trace(XLXT)
# loss_smooth_feat = loss_smooth_feat / (adj.shape[0]**2)
return loss_smooth_feat
def row_normalize_tensor(mx):
rowsum = mx.sum(1)
r_inv = rowsum.pow(-1).flatten()
# r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag(r_inv)
mx = r_mat_inv @ mx
return mx