Commit
·
e551dda
0
Parent(s):
init
Browse files- HeteroVG_MNIST.py +211 -0
- README.md +3 -0
- condainstall.txt +6 -0
- dataset.py +213 -0
- eval.py +261 -0
- model_new.py +224 -0
- train_new.py +397 -0
- util.py +102 -0
HeteroVG_MNIST.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import random
|
| 9 |
+
import pickle as pkl
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from scipy.spatial import distance_matrix
|
| 13 |
+
import torch_geometric
|
| 14 |
+
from torch_geometric.data import HeteroData
|
| 15 |
+
from torch_geometric.nn import to_hetero
|
| 16 |
+
# from shapely.geometry import Point, Polygon
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def cross_product(p1, p2, p3):
|
| 20 |
+
return (p2[0] - p1[0]) * (p3[1] - p1[1]) - (p3[0] - p1[0]) * (p2[1] - p1[1])
|
| 21 |
+
|
| 22 |
+
def colinear(p1, p2, p3):
|
| 23 |
+
if (p1[1]-p2[1])*(p2[0]-p3[0]) == (p1[0]-p2[0])*(p2[1]-p3[1]) and p3[0]>min(p1[0],p2[0]) and p3[0]<max(p1[0],p2[0]): return True
|
| 24 |
+
if (p1[1]-p2[1])*(p2[0]-p3[0]) == (p1[0]-p2[0])*(p2[1]-p3[1]) and p3[1]>min(p1[1],p2[1]) and p3[1]<max(p1[1],p2[1]): return True
|
| 25 |
+
|
| 26 |
+
def is_intersected(p1, p2, p3, p4):
|
| 27 |
+
if colinear(p1, p2, p3) or colinear(p1, p2, p4): return True
|
| 28 |
+
if cross_product(p1, p2, p3) * cross_product(p1, p2, p4) < 0 and cross_product(p3, p4, p1) * cross_product(p3, p4, p2) < 0: return True
|
| 29 |
+
else: return False
|
| 30 |
+
|
| 31 |
+
#Pos of single number
|
| 32 |
+
def read_singe(df, i):
|
| 33 |
+
p_i = df[0][i]
|
| 34 |
+
np_i = list(p_i)
|
| 35 |
+
rflag = 0
|
| 36 |
+
for x in range(len(p_i)-4):
|
| 37 |
+
if p_i[x] == ')': rflag+=1
|
| 38 |
+
if p_i[x:x+4] == '), (' and p_i[x-1]!=')': np_i.insert(x+rflag+1,' 0.0 0.0')
|
| 39 |
+
elif p_i[x:x+4] == '), (' and p_i[x-1]==')': np_i.insert(x+rflag+1,' 1.0 1.0')
|
| 40 |
+
p_i = ''.join(np_i)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
pos = np.empty((1,2))
|
| 44 |
+
pi_nums = re.findall(r"\d+\.?\d*",p_i)
|
| 45 |
+
j=0
|
| 46 |
+
while j < len(pi_nums)-1:
|
| 47 |
+
if j == 0:
|
| 48 |
+
pos[0][0] = float(pi_nums[j])
|
| 49 |
+
pos[0][1] = float(pi_nums[j+1])
|
| 50 |
+
j+=2
|
| 51 |
+
continue
|
| 52 |
+
pos = np.append(pos,[[float(pi_nums[j]),0]],0)
|
| 53 |
+
pos[j//2][1] = float(pi_nums[j+1])
|
| 54 |
+
j+=2
|
| 55 |
+
return pos
|
| 56 |
+
|
| 57 |
+
def Visi_Edge(pos_join, flag):
|
| 58 |
+
inside_edge_index = [[],[]]
|
| 59 |
+
apart_edge_index = [[],[]]
|
| 60 |
+
|
| 61 |
+
vg_point = []
|
| 62 |
+
for i in range(len(pos_join)): vg_point.append((pos_join[i][0], pos_join[i][1]))
|
| 63 |
+
|
| 64 |
+
hole_p = np.where(flag==1)[0]
|
| 65 |
+
if len(hole_p) != 0:
|
| 66 |
+
last_id = 0
|
| 67 |
+
for m in range(len(flag)):
|
| 68 |
+
if flag[m] == 2 or flag[m] == 3:
|
| 69 |
+
if sum(flag[last_id:m]) == 0:
|
| 70 |
+
last_id = m+1
|
| 71 |
+
continue
|
| 72 |
+
poly_i = vg_point[last_id:m+1]
|
| 73 |
+
pos_i = np.arange(last_id, m+1)
|
| 74 |
+
last_id = m+1
|
| 75 |
+
for i in range(len(poly_i)):
|
| 76 |
+
if flag[pos_i[i]] == 1:
|
| 77 |
+
for j in range(i, len(flag)):
|
| 78 |
+
if flag[j]==1 or flag[j] == 2 or flag[j] == 3:
|
| 79 |
+
hole_i = poly_i[i:j+1]
|
| 80 |
+
pos_hole = np.arange(i, j+1)
|
| 81 |
+
for p1 in hole_i:
|
| 82 |
+
for p2 in poly_i:
|
| 83 |
+
if p2 not in hole_i:
|
| 84 |
+
inter_count = 0
|
| 85 |
+
for d in range(len(poly_i)-1):
|
| 86 |
+
p3, p4 = poly_i[d], poly_i[d+1]
|
| 87 |
+
if is_intersected(p1, p2, p3, p4): inter_count+=1
|
| 88 |
+
if inter_count==0:
|
| 89 |
+
head, tail = pos_i[poly_i.index(p1)], pos_i[poly_i.index(p2)]
|
| 90 |
+
inside_edge_index[0].append(head), inside_edge_index[1].append(tail)
|
| 91 |
+
|
| 92 |
+
for i in range(len(vg_point)):
|
| 93 |
+
p1 = vg_point[i]
|
| 94 |
+
p1_id = np.count_nonzero(flag[0:i] == 2) + np.count_nonzero(flag[0:i] == 3)
|
| 95 |
+
for j in range(len(vg_point)):
|
| 96 |
+
p2 = vg_point[j]
|
| 97 |
+
if p1 == p2: continue
|
| 98 |
+
p2_id = np.count_nonzero(flag[0:j] == 2) + np.count_nonzero(flag[0:j] == 3)
|
| 99 |
+
inter_count = 0
|
| 100 |
+
for m in range(len(flag-1)):
|
| 101 |
+
if flag[m]!=1 and flag[m]!=2 and flag[m]!=3: p3, p4 = vg_point[m], vg_point[m+1]
|
| 102 |
+
if is_intersected(p1, p2, p3, p4): inter_count+=1
|
| 103 |
+
if inter_count==0:
|
| 104 |
+
head, tail = vg_point.index(p1), vg_point.index(p2)
|
| 105 |
+
cc = np.count_nonzero(flag[min(head, tail):max(head, tail)] == 2) + np.count_nonzero(flag[min(head, tail): max(head, tail)] == 3)
|
| 106 |
+
if p1_id!=p2_id and cc!=0: apart_edge_index[0].append(head), apart_edge_index[1].append(tail)
|
| 107 |
+
#print(i)
|
| 108 |
+
|
| 109 |
+
ninside_edge_index = [[],[]]
|
| 110 |
+
napart_edge_index = [[],[]]
|
| 111 |
+
exteriors = [[],[]]
|
| 112 |
+
|
| 113 |
+
if len(hole_p)!=0:
|
| 114 |
+
for i in range(len(flag)):
|
| 115 |
+
link_i = [pos_join[inside_edge_index[1][j]] for j in range(len(inside_edge_index[1])) if inside_edge_index[0][j]==i]
|
| 116 |
+
if len(link_i)==0: continue
|
| 117 |
+
ninside_edge_index[0].append(i)
|
| 118 |
+
dis_matrix = distance_matrix([pos_join[i]], link_i)
|
| 119 |
+
node_i = (link_i[np.argmin(dis_matrix[0])][0], link_i[np.argmin(dis_matrix[0])][1])
|
| 120 |
+
ninside_edge_index[1].append(vg_point.index(node_i))
|
| 121 |
+
|
| 122 |
+
for i in range(len(vg_point)-1):
|
| 123 |
+
if flag[i]!=1 and flag[i]!=2 and flag[i]!=3 : exteriors[0].append(i), exteriors[1].append(i+1)
|
| 124 |
+
|
| 125 |
+
for i in range(len(flag)):
|
| 126 |
+
link_i = [pos_join[apart_edge_index[1][j]] for j in range(len(apart_edge_index[1])) if apart_edge_index[0][j]==i]
|
| 127 |
+
if len(link_i)==0: continue
|
| 128 |
+
napart_edge_index[0].append(i)
|
| 129 |
+
dis_matrix = distance_matrix([pos_join[i]], link_i)
|
| 130 |
+
node_i = (link_i[np.argmin(dis_matrix[0])][0], link_i[np.argmin(dis_matrix[0])][1])
|
| 131 |
+
napart_edge_index[1].append(vg_point.index(node_i))
|
| 132 |
+
|
| 133 |
+
inside_edge_index, apart_edge_index = ninside_edge_index, napart_edge_index
|
| 134 |
+
|
| 135 |
+
return inside_edge_index, apart_edge_index, exteriors
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def HeteroEdge(pos,k):
|
| 139 |
+
pos_join = np.delete(pos, np.where(np.sum(pos, 1)==0)[0], axis=0)
|
| 140 |
+
pos_join = np.delete(pos_join, np.where(np.sum(pos_join, 1)==2)[0], axis=0)
|
| 141 |
+
pos_join = np.delete(pos_join, np.where(np.sum(pos_join, 1)==4)[0], axis=0)
|
| 142 |
+
flag = np.zeros(len(pos_join))
|
| 143 |
+
pos = np.delete(pos, 0, axis=0)
|
| 144 |
+
count, id = 0, 0
|
| 145 |
+
while count<k:
|
| 146 |
+
for i in range(len(pos)):
|
| 147 |
+
if pos[i][0]==0:
|
| 148 |
+
flag[i-1]=1
|
| 149 |
+
pos = np.delete(pos, i, axis=0)
|
| 150 |
+
break
|
| 151 |
+
elif pos[i][0]==1:
|
| 152 |
+
flag[i-1]=2
|
| 153 |
+
pos = np.delete(pos, i, axis=0)
|
| 154 |
+
break
|
| 155 |
+
elif pos[i][0]==2:
|
| 156 |
+
flag[i-1]=3
|
| 157 |
+
pos = np.delete(pos, i, axis=0)
|
| 158 |
+
pos_join[id:i, 0]+=count
|
| 159 |
+
count+=1
|
| 160 |
+
id = i
|
| 161 |
+
break
|
| 162 |
+
pos_join = pos_join
|
| 163 |
+
inside_edge_index, apart_edge_index, exteriors = Visi_Edge(pos_join, flag)
|
| 164 |
+
|
| 165 |
+
return pos_join, inside_edge_index, apart_edge_index, exteriors
|
| 166 |
+
|
| 167 |
+
#build heterovg of k-digit from MNIST
|
| 168 |
+
def NNIST_HeteroVG(df, label_df, k):
|
| 169 |
+
pos = [[0,0]]
|
| 170 |
+
label = ''
|
| 171 |
+
for i in np.random.randint(0, len(df), k):
|
| 172 |
+
while True:
|
| 173 |
+
if len(pos) == 1 and label_df[0][i] == 0: i = random.randint(0, len(df))
|
| 174 |
+
else: break
|
| 175 |
+
pos = np.append(pos, read_singe(df, i), 0)
|
| 176 |
+
pos = np.append(pos, [[2,2]], 0)
|
| 177 |
+
label = label+'%d'%(label_df[0][i])
|
| 178 |
+
|
| 179 |
+
label = int(label)
|
| 180 |
+
pos_join, inside, apart, exteriors = HeteroEdge(pos,k)
|
| 181 |
+
|
| 182 |
+
data = HeteroData()
|
| 183 |
+
|
| 184 |
+
data['vertices'].x = torch.zeros((len(pos_join), 1), dtype=torch.float)
|
| 185 |
+
data.y = torch.tensor(label, dtype=torch.int)
|
| 186 |
+
data.pos = torch.tensor(pos_join, dtype=torch.float)
|
| 187 |
+
|
| 188 |
+
data['vertices', 'inside', 'vertices'].edge_index = torch.tensor([inside[0]+inside[1]+exteriors[0],inside[1]+inside[0]+exteriors[1]], dtype=torch.long)
|
| 189 |
+
data['vertices', 'apart', 'vertices'].edge_index = torch.tensor([apart[0]+apart[1],apart[1]+apart[0]], dtype=torch.long)
|
| 190 |
+
data['vertices', 'inside', 'vertices'].edge_attr = torch.zeros((len(data['vertices', 'inside', 'vertices'].edge_index[0]),1), dtype=torch.float)
|
| 191 |
+
data['vertices', 'apart', 'vertices'].edge_attr = torch.zeros((len(data['vertices', 'apart', 'vertices'].edge_index[0]),1), dtype=torch.float)
|
| 192 |
+
|
| 193 |
+
return data
|
| 194 |
+
|
| 195 |
+
mnist_filename = '/content/drive/MyDrive/MINST_Polygons/polyMNIST/mnist_polygon_test.json'
|
| 196 |
+
label_filename = '/content/drive/MyDrive/MINST_Polygons/polyMNIST/mnist_label_test.json'
|
| 197 |
+
df = pd.read_json(mnist_filename)
|
| 198 |
+
label_df = pd.read_json(label_filename)
|
| 199 |
+
|
| 200 |
+
K = 2 # number of digits
|
| 201 |
+
N = 10 # number of generated graphs
|
| 202 |
+
multi_mnist_dataset = []
|
| 203 |
+
for k in range(2, K+1):
|
| 204 |
+
for i in tqdm(range(N)):
|
| 205 |
+
data = NNIST_HeteroVG(df, label_df, k=k)
|
| 206 |
+
multi_mnist_dataset.append(data)
|
| 207 |
+
|
| 208 |
+
if not os.path.exists('/content/drive/MyDrive/MINST_Polygons/multi_mnist'):
|
| 209 |
+
os.makedirs('/content/drive/MyDrive/MINST_Polygons/multi_mnist')
|
| 210 |
+
with open('/content/drive/MyDrive/MINST_Polygons/multi_mnist/multi_mnist.pkl','wb') as file:
|
| 211 |
+
pkl.dump(multi_mnist_dataset, file)
|
README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# KDD24 PolygonGNN: Representation Learning for Polygonal Geometries with Heterogeneous Visibility Graph
|
| 2 |
+
data is on [dropbox](https://www.dropbox.com/scl/fo/f7dir04pldz36n6m47m30/ABxnZk8Qyf16k0Yo75WqXpY?rlkey=f3lhgyv7um323ngpa2bmueimq&st=e4wg0uec&dl=0)
|
| 3 |
+
|
condainstall.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
conda create -n graph python=3.8
|
| 2 |
+
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia
|
| 3 |
+
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
|
| 4 |
+
conda install pyg -c pyg
|
| 5 |
+
conda install -c conda-forge pytorch_sparse
|
| 6 |
+
conda install matplotlib numpy ipykernel pandas tensorboard
|
dataset.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.utils.data as data
|
| 2 |
+
import os
|
| 3 |
+
import os.path
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import sys
|
| 8 |
+
import pickle
|
| 9 |
+
import time
|
| 10 |
+
import torchvision.datasets as datasets
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
| 14 |
+
from torchvision.datasets import VisionDataset
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
from datetime import date, timedelta,datetime
|
| 17 |
+
import random
|
| 18 |
+
import pickle as pkl
|
| 19 |
+
import string
|
| 20 |
+
|
| 21 |
+
valid_chars = 'EFHILOTUYZ'
|
| 22 |
+
|
| 23 |
+
alphabetic_labels = [char1 + char2 for char1 in valid_chars for char2 in valid_chars]
|
| 24 |
+
alphabetic_labels.sort()
|
| 25 |
+
label_mapping = {label: idx for idx, label in enumerate(alphabetic_labels)} # to number
|
| 26 |
+
reverse_label_mapping = {v: k for k, v in label_mapping.items()} # to alphabetic
|
| 27 |
+
|
| 28 |
+
single_alphabetic_labels=[char1 for char1 in valid_chars]
|
| 29 |
+
single_alphabetic_labels.sort()
|
| 30 |
+
single_label_mapping = {label: idx for idx, label in enumerate(single_alphabetic_labels)}
|
| 31 |
+
single_reverse_label_mapping = {v: k for k, v in single_label_mapping.items()}
|
| 32 |
+
|
| 33 |
+
def get_mnist_dataset(data_dir='data/multi_mnist.pkl',Seed=0,test_ratio=0.2):
|
| 34 |
+
|
| 35 |
+
random.seed(Seed)
|
| 36 |
+
torch.manual_seed(Seed)
|
| 37 |
+
np.random.seed(Seed)
|
| 38 |
+
|
| 39 |
+
with open(data_dir, 'rb') as f:
|
| 40 |
+
dataset = pkl.load(f)
|
| 41 |
+
for entry in dataset:
|
| 42 |
+
entry.y -= 10
|
| 43 |
+
|
| 44 |
+
np.random.shuffle(dataset)
|
| 45 |
+
val_test_split = int(np.around( test_ratio * len(dataset) ))
|
| 46 |
+
train_val_split = int(len(dataset)-2*val_test_split)
|
| 47 |
+
train_ds = dataset[:train_val_split]
|
| 48 |
+
val_ds = dataset[train_val_split:train_val_split+val_test_split]
|
| 49 |
+
test_ds = dataset[train_val_split+val_test_split:]
|
| 50 |
+
|
| 51 |
+
print(data_dir)
|
| 52 |
+
print('Train: ' +str(len(train_ds)))
|
| 53 |
+
print('Val : ' +str(len(val_ds)))
|
| 54 |
+
print('Test : ' +str(len(test_ds)))
|
| 55 |
+
|
| 56 |
+
return train_ds,val_ds,test_ds
|
| 57 |
+
|
| 58 |
+
def get_building_dataset(data_dir='data/building_with_index.pkl',Seed=0,test_ratio=0.2):
|
| 59 |
+
|
| 60 |
+
random.seed(Seed)
|
| 61 |
+
torch.manual_seed(Seed)
|
| 62 |
+
np.random.seed(Seed)
|
| 63 |
+
|
| 64 |
+
with open(data_dir, 'rb') as f:
|
| 65 |
+
dataset = pkl.load(f)
|
| 66 |
+
for entry in dataset:
|
| 67 |
+
entry.y = label_mapping[entry.y]
|
| 68 |
+
|
| 69 |
+
np.random.shuffle(dataset)
|
| 70 |
+
val_test_split = int(np.around( test_ratio * len(dataset) ))
|
| 71 |
+
train_val_split = int(len(dataset)-2*val_test_split)
|
| 72 |
+
train_ds = dataset[:train_val_split]
|
| 73 |
+
val_ds = dataset[train_val_split:train_val_split+val_test_split]
|
| 74 |
+
test_ds = dataset[train_val_split+val_test_split:]
|
| 75 |
+
|
| 76 |
+
print(data_dir)
|
| 77 |
+
print('Train: ' +str(len(train_ds)))
|
| 78 |
+
print('Val : ' +str(len(val_ds)))
|
| 79 |
+
print('Test : ' +str(len(test_ds)))
|
| 80 |
+
|
| 81 |
+
return train_ds,val_ds,test_ds
|
| 82 |
+
|
| 83 |
+
def get_mbuilding_dataset(data_dir='data/mp_building.pkl',Seed=0,test_ratio=0.2):
|
| 84 |
+
|
| 85 |
+
random.seed(Seed)
|
| 86 |
+
torch.manual_seed(Seed)
|
| 87 |
+
np.random.seed(Seed)
|
| 88 |
+
|
| 89 |
+
with open(data_dir, 'rb') as f:
|
| 90 |
+
dataset = pkl.load(f)
|
| 91 |
+
for entry in dataset:
|
| 92 |
+
entry.y = label_mapping[entry.y]
|
| 93 |
+
|
| 94 |
+
np.random.shuffle(dataset)
|
| 95 |
+
val_test_split = int(np.around( test_ratio * len(dataset) ))
|
| 96 |
+
train_val_split = int(len(dataset)-2*val_test_split)
|
| 97 |
+
train_ds = dataset[:train_val_split]
|
| 98 |
+
val_ds = dataset[train_val_split:train_val_split+val_test_split]
|
| 99 |
+
test_ds = dataset[train_val_split+val_test_split:]
|
| 100 |
+
|
| 101 |
+
print(data_dir)
|
| 102 |
+
print('Train: ' +str(len(train_ds)))
|
| 103 |
+
print('Val : ' +str(len(val_ds)))
|
| 104 |
+
print('Test : ' +str(len(test_ds)))
|
| 105 |
+
|
| 106 |
+
return train_ds,val_ds,test_ds
|
| 107 |
+
|
| 108 |
+
def get_sbuilding_dataset(data_dir='data/single_building.pkl',Seed=0,test_ratio=0.2):
|
| 109 |
+
|
| 110 |
+
random.seed(Seed)
|
| 111 |
+
torch.manual_seed(Seed)
|
| 112 |
+
np.random.seed(Seed)
|
| 113 |
+
|
| 114 |
+
with open(data_dir, 'rb') as f:
|
| 115 |
+
dataset = pkl.load(f)
|
| 116 |
+
for entry in dataset:
|
| 117 |
+
entry.y = single_label_mapping[entry.y]
|
| 118 |
+
|
| 119 |
+
np.random.shuffle(dataset)
|
| 120 |
+
val_test_split = int(np.around( test_ratio * len(dataset) ))
|
| 121 |
+
train_val_split = int(len(dataset)-2*val_test_split)
|
| 122 |
+
train_ds = dataset[:train_val_split]
|
| 123 |
+
val_ds = dataset[train_val_split:train_val_split+val_test_split]
|
| 124 |
+
test_ds = dataset[train_val_split+val_test_split:]
|
| 125 |
+
|
| 126 |
+
print(data_dir)
|
| 127 |
+
print('Train: ' +str(len(train_ds)))
|
| 128 |
+
print('Val : ' +str(len(val_ds)))
|
| 129 |
+
print('Test : ' +str(len(test_ds)))
|
| 130 |
+
|
| 131 |
+
return train_ds,val_ds,test_ds
|
| 132 |
+
|
| 133 |
+
def get_smnist_dataset(data_dir='data/single_mnist.pkl',Seed=0,test_ratio=0.2):
|
| 134 |
+
|
| 135 |
+
random.seed(Seed)
|
| 136 |
+
torch.manual_seed(Seed)
|
| 137 |
+
np.random.seed(Seed)
|
| 138 |
+
|
| 139 |
+
with open(data_dir, 'rb') as f:
|
| 140 |
+
dataset = pkl.load(f)
|
| 141 |
+
|
| 142 |
+
np.random.shuffle(dataset)
|
| 143 |
+
val_test_split = int(np.around( test_ratio * len(dataset) ))
|
| 144 |
+
train_val_split = int(len(dataset)-2*val_test_split)
|
| 145 |
+
train_ds = dataset[:train_val_split]
|
| 146 |
+
val_ds = dataset[train_val_split:train_val_split+val_test_split]
|
| 147 |
+
test_ds = dataset[train_val_split+val_test_split:]
|
| 148 |
+
|
| 149 |
+
print(data_dir)
|
| 150 |
+
print('Train: ' +str(len(train_ds)))
|
| 151 |
+
print('Val : ' +str(len(val_ds)))
|
| 152 |
+
print('Test : ' +str(len(test_ds)))
|
| 153 |
+
|
| 154 |
+
return train_ds,val_ds,test_ds
|
| 155 |
+
|
| 156 |
+
def get_dbp_dataset(data_dir='data/triple_building.pkl',Seed=0,test_ratio=0.2):
|
| 157 |
+
|
| 158 |
+
random.seed(Seed)
|
| 159 |
+
torch.manual_seed(Seed)
|
| 160 |
+
np.random.seed(Seed)
|
| 161 |
+
|
| 162 |
+
with open(data_dir, 'rb') as f:
|
| 163 |
+
dataset = pkl.load(f)
|
| 164 |
+
for entry in dataset:
|
| 165 |
+
entry.y = 1 if entry.y>=1 else 0
|
| 166 |
+
|
| 167 |
+
np.random.shuffle(dataset)
|
| 168 |
+
val_test_split = int(np.around( test_ratio * len(dataset) ))
|
| 169 |
+
train_val_split = int(len(dataset)-2*val_test_split)
|
| 170 |
+
train_ds = dataset[:train_val_split]
|
| 171 |
+
val_ds = dataset[train_val_split:train_val_split+val_test_split]
|
| 172 |
+
test_ds = dataset[train_val_split+val_test_split:]
|
| 173 |
+
|
| 174 |
+
print(data_dir)
|
| 175 |
+
print('Train: ' +str(len(train_ds)))
|
| 176 |
+
print('Val : ' +str(len(val_ds)))
|
| 177 |
+
print('Test : ' +str(len(test_ds)))
|
| 178 |
+
|
| 179 |
+
return train_ds,val_ds,test_ds
|
| 180 |
+
|
| 181 |
+
def affine_transform_to_range(ds, target_range=(-1, 1)):
|
| 182 |
+
# Find the extent (min and max) of coordinates in both x and y directions
|
| 183 |
+
for item in ds:
|
| 184 |
+
min_x = torch.min(item.pos[:,0])
|
| 185 |
+
min_y = torch.min(item.pos[:,1])
|
| 186 |
+
|
| 187 |
+
max_x = torch.max(item.pos[:,0])
|
| 188 |
+
max_y = torch.max(item.pos[:,1])
|
| 189 |
+
|
| 190 |
+
scale_x = (target_range[1] - target_range[0]) / (max_x - min_x)
|
| 191 |
+
scale_y = (target_range[1] - target_range[0]) / (max_y - min_y)
|
| 192 |
+
translate_x = target_range[0] - min_x * scale_x
|
| 193 |
+
translate_y = target_range[0] - min_y * scale_y
|
| 194 |
+
|
| 195 |
+
# Apply the affine transformation to
|
| 196 |
+
item.pos[:,0] = item.pos[:,0] * scale_x + translate_x
|
| 197 |
+
item.pos[:,1] = item.pos[:,1] * scale_y + translate_y
|
| 198 |
+
return ds
|
| 199 |
+
|
| 200 |
+
class CustomDataset(Dataset):
|
| 201 |
+
def __init__(self, data_list):
|
| 202 |
+
super(CustomDataset, self).__init__()
|
| 203 |
+
self.data_list = data_list
|
| 204 |
+
|
| 205 |
+
def len(self):
|
| 206 |
+
return len(self.data_list)
|
| 207 |
+
|
| 208 |
+
def get(self, idx):
|
| 209 |
+
return self.data_list[idx]
|
| 210 |
+
|
| 211 |
+
if __name__ == '__main__':
|
| 212 |
+
a,b,c=get_mnist_dataset()
|
| 213 |
+
print("")
|
eval.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import time
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import scipy
|
| 10 |
+
|
| 11 |
+
from matplotlib import cm
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import json
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.nn.functional import softmax
|
| 16 |
+
|
| 17 |
+
torch.autograd.set_detect_anomaly(True)
|
| 18 |
+
import pickle
|
| 19 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 20 |
+
import dataset,util
|
| 21 |
+
from model_new import Smodel
|
| 22 |
+
import model_new
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torchvision.transforms as transforms
|
| 27 |
+
import torchvision.datasets
|
| 28 |
+
import torchvision.models
|
| 29 |
+
import math
|
| 30 |
+
import shutil
|
| 31 |
+
import time
|
| 32 |
+
from datetime import date, timedelta,datetime
|
| 33 |
+
import torch_geometric
|
| 34 |
+
from torch_geometric.data import Data, DataLoader
|
| 35 |
+
from torch_geometric.nn import MessagePassing
|
| 36 |
+
from torch_geometric.utils import add_self_loops
|
| 37 |
+
from torch_geometric.nn import GIN,GATConv,MLP
|
| 38 |
+
from torch_geometric.nn.pool import global_mean_pool,global_add_pool
|
| 39 |
+
import csv
|
| 40 |
+
|
| 41 |
+
blue = lambda x: '\033[94m' + x + '\033[0m'
|
| 42 |
+
red = lambda x: '\033[31m' + x + '\033[0m'
|
| 43 |
+
green = lambda x: '\033[32m' + x + '\033[0m'
|
| 44 |
+
yellow = lambda x: '\033[33m' + x + '\033[0m'
|
| 45 |
+
greenline = lambda x: '\033[42m' + x + '\033[0m'
|
| 46 |
+
yellowline = lambda x: '\033[43m' + x + '\033[0m'
|
| 47 |
+
|
| 48 |
+
def get_args():
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument('--model',default="our", type=str)
|
| 51 |
+
parser.add_argument('--train_batch', default=64, type=int)
|
| 52 |
+
parser.add_argument('--test_batch', default=128, type=int)
|
| 53 |
+
parser.add_argument('--share', type=str, default="0")
|
| 54 |
+
parser.add_argument('--edge_rep', type=str, default="True")
|
| 55 |
+
parser.add_argument('--batchnorm', type=str, default="True")
|
| 56 |
+
parser.add_argument('--extent_norm', type=str, default="T")
|
| 57 |
+
parser.add_argument('--spanning_tree', type=str, default="F")
|
| 58 |
+
|
| 59 |
+
parser.add_argument('--loss_coef', default=0.1, type=float)
|
| 60 |
+
parser.add_argument('--h_ch', default=512, type=int)
|
| 61 |
+
parser.add_argument('--localdepth', type=int, default=1)
|
| 62 |
+
parser.add_argument('--num_interactions', type=int, default=4)
|
| 63 |
+
parser.add_argument('--finaldepth', type=int, default=4)
|
| 64 |
+
parser.add_argument('--classifier_depth', type=int, default=4)
|
| 65 |
+
parser.add_argument('--dropout', type=float, default=0.0)
|
| 66 |
+
|
| 67 |
+
parser.add_argument('--dataset', type=str, default='mnist')
|
| 68 |
+
parser.add_argument('--log', type=str, default="True")
|
| 69 |
+
parser.add_argument('--test_per_round', type=int, default=10)
|
| 70 |
+
parser.add_argument('--patience', type=int, default=30) #scheduler
|
| 71 |
+
parser.add_argument('--nepoch', type=int, default=201)
|
| 72 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
| 73 |
+
parser.add_argument('--manualSeed', type=str, default="False")
|
| 74 |
+
parser.add_argument('--man_seed', type=int, default=12345)
|
| 75 |
+
|
| 76 |
+
parser.add_argument("--targetfiles", nargs='+', type=str, default=["Dec11-14:44:32.pth","Nov13-14:30:48.pth"])
|
| 77 |
+
args = parser.parse_args()
|
| 78 |
+
args.log=True if args.log=="True" else False
|
| 79 |
+
args.edge_rep=True if args.edge_rep=="True" else False
|
| 80 |
+
args.batchnorm=True if args.batchnorm=="True" else False
|
| 81 |
+
args.save_dir=os.path.join('./save/',args.dataset)
|
| 82 |
+
args.manualSeed=True if args.manualSeed=="True" else False
|
| 83 |
+
return args
|
| 84 |
+
|
| 85 |
+
args = get_args()
|
| 86 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 87 |
+
criterion=nn.CrossEntropyLoss()
|
| 88 |
+
|
| 89 |
+
def forward_HGT(args,data,model,mlpmodel):
|
| 90 |
+
data = data.to(device)
|
| 91 |
+
x,batch=data.pos, data['vertices'].batch
|
| 92 |
+
data["vertices"]['x']=data.pos
|
| 93 |
+
label=data.y.long().view(-1)
|
| 94 |
+
|
| 95 |
+
output=model(data.x_dict, data.edge_index_dict)
|
| 96 |
+
if args.dataset in ["dbp"]:
|
| 97 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 98 |
+
else:
|
| 99 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 100 |
+
graph_embeddings.clamp_(max=1e6)
|
| 101 |
+
|
| 102 |
+
output=mlpmodel(graph_embeddings)
|
| 103 |
+
# log_probs = F.log_softmax(output, dim=1)
|
| 104 |
+
|
| 105 |
+
loss = criterion(output, label)
|
| 106 |
+
return loss,output,label, graph_embeddings
|
| 107 |
+
|
| 108 |
+
def forward(args,data,model,mlpmodel):
|
| 109 |
+
data = data.to(device)
|
| 110 |
+
edge_index1=data['vertices', 'inside', 'vertices']['edge_index']
|
| 111 |
+
edge_index2=data['vertices', 'apart', 'vertices']['edge_index']
|
| 112 |
+
combined_edge_index=torch.cat([data['vertices', 'inside', 'vertices']['edge_index'],data['vertices', 'apart', 'vertices']['edge_index']],1)
|
| 113 |
+
|
| 114 |
+
if args.spanning_tree == 'True':
|
| 115 |
+
edge_weight=torch.rand(combined_edge_index.shape[1]) + 1
|
| 116 |
+
combined_edge_index = util.build_spanning_tree_edge(combined_edge_index, edge_weight,num_nodes=num_nodes,)
|
| 117 |
+
|
| 118 |
+
num_edge_inside=edge_index1.shape[1]
|
| 119 |
+
x,batch=data.pos, data['vertices'].batch
|
| 120 |
+
label=data.y.long().view(-1)
|
| 121 |
+
"""
|
| 122 |
+
triplets are not the same for graphs when training
|
| 123 |
+
"""
|
| 124 |
+
num_nodes=x.shape[0]
|
| 125 |
+
edge_index_2rd, num_triplets_real, edx_jk, edx_ij = util.triplets(combined_edge_index, num_nodes)
|
| 126 |
+
|
| 127 |
+
input_feature=torch.zeros([x.shape[0],args.h_ch],device=device)
|
| 128 |
+
output=model(input_feature,x,[edge_index1,edge_index2], edge_index_2rd,edx_jk, edx_ij,batch,num_edge_inside,args.edge_rep)
|
| 129 |
+
output=torch.cat(output,dim=1)
|
| 130 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 131 |
+
graph_embeddings.clamp_(max=1e6)
|
| 132 |
+
|
| 133 |
+
output=mlpmodel(graph_embeddings)
|
| 134 |
+
# log_probs = F.log_softmax(output, dim=1)
|
| 135 |
+
|
| 136 |
+
loss = criterion(output, label)
|
| 137 |
+
return loss,output,label,graph_embeddings
|
| 138 |
+
def test(args,loader,model,mlpmodel,writer,reverse_mapping ):
|
| 139 |
+
y_hat, y_true,y_hat_logit = [], [], [],
|
| 140 |
+
embeddings=[]
|
| 141 |
+
|
| 142 |
+
loss_total, pred_num = 0, 0
|
| 143 |
+
model.eval()
|
| 144 |
+
mlpmodel.eval()
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for data in loader:
|
| 147 |
+
if args.model=="our":
|
| 148 |
+
loss,output,label,embedding =forward(args,data,model,mlpmodel)
|
| 149 |
+
elif args.model in ["HGT","HAN"]:
|
| 150 |
+
loss,output,label,embedding =forward_HGT(args,data,model,mlpmodel)
|
| 151 |
+
_, pred = output.topk(1, dim=1, largest=True, sorted=True)
|
| 152 |
+
pred,label,output=pred.cpu(),label.cpu(),output.cpu()
|
| 153 |
+
y_hat += list(pred.detach().numpy().reshape(-1))
|
| 154 |
+
y_true += list(label.detach().numpy().reshape(-1))
|
| 155 |
+
y_hat_logit+=list(output.detach().numpy())
|
| 156 |
+
embeddings.append(embedding)
|
| 157 |
+
|
| 158 |
+
pred_num += len(label.reshape(-1, 1))
|
| 159 |
+
loss_total += loss.detach() * len(label.reshape(-1, 1))
|
| 160 |
+
|
| 161 |
+
y_true_str=[reverse_mapping(item) for item in y_true]
|
| 162 |
+
writer.add_embedding(torch.cat(embeddings,dim=0).detach().cpu(),metadata=y_true_str,tag="numbers")
|
| 163 |
+
writer.close()
|
| 164 |
+
return loss_total/pred_num,y_hat, y_true, y_hat_logit
|
| 165 |
+
|
| 166 |
+
def main(args,train_Loader,val_Loader,test_Loader):
|
| 167 |
+
donefiles=os.listdir(os.path.join(args.save_dir,args.model,'model'))
|
| 168 |
+
tensorboard_dir=os.path.join(args.save_dir,args.model,'log')
|
| 169 |
+
if args.dataset in ["mnist","mnist_sparse"]:
|
| 170 |
+
reverse_mapping=lambda x: x + 10
|
| 171 |
+
# list(map(lambda x: x - 10, []))
|
| 172 |
+
elif args.dataset in ["building","mbuilding"]:
|
| 173 |
+
reverse_mapping=lambda x: dataset.reverse_label_mapping[x]
|
| 174 |
+
elif args.dataset in ["sbuilding"]:
|
| 175 |
+
reverse_mapping=lambda x: dataset.single_reverse_label_mapping[x]
|
| 176 |
+
elif args.dataset in ["dbp","smnist"]:
|
| 177 |
+
reverse_mapping=lambda x: x
|
| 178 |
+
for file in donefiles:
|
| 179 |
+
if file not in args.targetfiles:
|
| 180 |
+
continue
|
| 181 |
+
else:
|
| 182 |
+
print(file)
|
| 183 |
+
saved_dict=torch.load(os.path.join(args.save_dir,args.model,'model',file))
|
| 184 |
+
if saved_dict['args'].dataset in ["mnist","mnist_sparse"]:
|
| 185 |
+
x_out=90
|
| 186 |
+
elif saved_dict['args'].dataset in ["building","mbuilding"]:
|
| 187 |
+
x_out=100
|
| 188 |
+
elif saved_dict['args'].dataset in ["sbuilding","smnist"]:
|
| 189 |
+
x_out=10
|
| 190 |
+
elif saved_dict['args'].dataset in ["dbp"]:
|
| 191 |
+
x_out=2
|
| 192 |
+
if saved_dict['args'].model=="our":
|
| 193 |
+
model=Smodel(h_channel=saved_dict['args'].h_ch,input_featuresize=saved_dict['args'].h_ch,\
|
| 194 |
+
localdepth=saved_dict['args'].localdepth,num_interactions=saved_dict['args'].num_interactions,finaldepth=saved_dict['args'].finaldepth,share=saved_dict['args'].share,batchnorm=saved_dict['args'].batchnorm)
|
| 195 |
+
mlpmodel=MLP(in_channels=saved_dict['args'].h_ch*saved_dict['args'].num_interactions, hidden_channels=saved_dict['args'].h_ch,out_channels=x_out, num_layers=saved_dict['args'].classifier_depth)
|
| 196 |
+
elif saved_dict['args'].model=="HGT":
|
| 197 |
+
model=model_new.HGT(hidden_channels=saved_dict['args'].h_ch, out_channels=saved_dict['args'].h_ch, num_heads=2, num_layers=saved_dict['args'].num_interactions)
|
| 198 |
+
mlpmodel=MLP(in_channels=saved_dict['args'].h_ch, hidden_channels=saved_dict['args'].h_ch,out_channels=x_out, num_layers=saved_dict['args'].classifier_depth,dropout=saved_dict['args'].dropout)
|
| 199 |
+
elif saved_dict['args'].model=="HAN":
|
| 200 |
+
model=model_new.HAN(hidden_channels=saved_dict['args'].h_ch, out_channels=saved_dict['args'].h_ch, num_heads=2, num_layers=saved_dict['args'].num_interactions)
|
| 201 |
+
mlpmodel=MLP(in_channels=saved_dict['args'].h_ch, hidden_channels=saved_dict['args'].h_ch,out_channels=x_out, num_layers=saved_dict['args'].classifier_depth,dropout=saved_dict['args'].dropout)
|
| 202 |
+
model.to(device), mlpmodel.to(device)
|
| 203 |
+
try:
|
| 204 |
+
model.load_state_dict(saved_dict['model'],strict=True)
|
| 205 |
+
mlpmodel.load_state_dict(saved_dict['mlpmodel'],strict=True)
|
| 206 |
+
except OSError:
|
| 207 |
+
print('loadfail: ',file)
|
| 208 |
+
pass
|
| 209 |
+
print(saved_dict['args'])
|
| 210 |
+
|
| 211 |
+
writer = SummaryWriter(os.path.join(tensorboard_dir,file+"_embedding"))
|
| 212 |
+
test_loss, yhat_test, ytrue_test, yhatlogit_test = test(saved_dict['args'],test_Loader,model,mlpmodel,writer,reverse_mapping)
|
| 213 |
+
|
| 214 |
+
pred_dir=os.path.join(tensorboard_dir,file+"_test_record")
|
| 215 |
+
to_save_dict={'labels':ytrue_test,'yhat':yhat_test,'yhat_logit':yhatlogit_test}
|
| 216 |
+
torch.save(to_save_dict, pred_dir)
|
| 217 |
+
|
| 218 |
+
test_acc=util.calculate(yhat_test,ytrue_test,yhatlogit_test)
|
| 219 |
+
util.print_1(0,'Test', {"loss":test_loss,"acc":test_acc},color=blue)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == '__main__':
|
| 223 |
+
Seed = 0
|
| 224 |
+
test_ratio=0.2
|
| 225 |
+
print("data splitting Random Seed: ", Seed)
|
| 226 |
+
if args.dataset in ["mnist"]:
|
| 227 |
+
args.data_dir='data/multi_mnist_with_index.pkl'
|
| 228 |
+
train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 229 |
+
elif args.dataset in ["mnist_sparse"]:
|
| 230 |
+
args.data_dir='data/multi_mnist_sparse.pkl'
|
| 231 |
+
train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 232 |
+
elif args.dataset in ["building"]:
|
| 233 |
+
args.data_dir='data/building_with_index.pkl'
|
| 234 |
+
train_ds,val_ds,test_ds=dataset.get_mbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 235 |
+
elif args.dataset in ["mbuilding"]:
|
| 236 |
+
args.data_dir='data/mp_building.pkl'
|
| 237 |
+
train_ds,val_ds,test_ds=dataset.get_building_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 238 |
+
elif args.dataset in ["sbuilding"]:
|
| 239 |
+
args.data_dir='data/single_building.pkl'
|
| 240 |
+
train_ds,val_ds,test_ds=dataset.get_sbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 241 |
+
elif args.dataset in ["smnist"]:
|
| 242 |
+
args.data_dir='data/single_mnist.pkl'
|
| 243 |
+
train_ds,val_ds,test_ds=dataset.get_smnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 244 |
+
elif args.dataset in ['dbp']:
|
| 245 |
+
args.data_dir='data/triple_building_600.pkl'
|
| 246 |
+
train_ds,val_ds,test_ds=dataset.get_dbp_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 247 |
+
|
| 248 |
+
if args.extent_norm=="T":
|
| 249 |
+
train_ds= dataset.affine_transform_to_range(train_ds,target_range=(-1, 1))
|
| 250 |
+
val_ds= dataset.affine_transform_to_range(val_ds,target_range=(-1, 1))
|
| 251 |
+
test_ds= dataset.affine_transform_to_range(test_ds,target_range=(-1, 1))
|
| 252 |
+
train_loader = torch_geometric.loader.DataLoader(train_ds,batch_size=args.train_batch, shuffle=False,pin_memory=True)
|
| 253 |
+
val_loader = torch_geometric.loader.DataLoader(val_ds, batch_size=args.test_batch, shuffle=False, pin_memory=True)
|
| 254 |
+
test_loader = torch_geometric.loader.DataLoader(test_ds,batch_size=args.test_batch, shuffle=False,pin_memory=True)
|
| 255 |
+
|
| 256 |
+
Seed=random.randint(1, 10000)
|
| 257 |
+
print("Random Seed: ", Seed)
|
| 258 |
+
random.seed(Seed)
|
| 259 |
+
torch.manual_seed(Seed)
|
| 260 |
+
np.random.seed(Seed)
|
| 261 |
+
main(args,train_loader,val_loader,test_loader)
|
model_new.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from math import pi as PI
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.nn.parallel
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torch_geometric.transforms as T
|
| 11 |
+
from torch.nn import ModuleList, Parameter
|
| 12 |
+
from torch_geometric.nn import HANConv, HEATConv, HGTConv, Linear
|
| 13 |
+
from torch_geometric.nn.conv import MessagePassing
|
| 14 |
+
from torch_geometric.nn.dense.linear import Linear
|
| 15 |
+
# from dataset import
|
| 16 |
+
from torch_geometric.nn.inits import glorot, zeros
|
| 17 |
+
from torch_geometric.utils import softmax
|
| 18 |
+
from torch_scatter import scatter
|
| 19 |
+
|
| 20 |
+
from util import get_angle, get_theta, triplets
|
| 21 |
+
|
| 22 |
+
class Smodel(nn.Module):
|
| 23 |
+
def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,share='0',batchnorm="True"):
|
| 24 |
+
super(Smodel,self).__init__()
|
| 25 |
+
self.training=True
|
| 26 |
+
self.h_channel = h_channel
|
| 27 |
+
self.input_featuresize=input_featuresize
|
| 28 |
+
self.localdepth = localdepth
|
| 29 |
+
self.num_interactions=num_interactions
|
| 30 |
+
self.finaldepth=finaldepth
|
| 31 |
+
self.batchnorm = batchnorm
|
| 32 |
+
self.activation=nn.ReLU()
|
| 33 |
+
self.att = Parameter(torch.ones(4),requires_grad=True)
|
| 34 |
+
|
| 35 |
+
num_gaussians=(1,1,1)
|
| 36 |
+
self.mlp_geo = ModuleList()
|
| 37 |
+
for i in range(self.localdepth):
|
| 38 |
+
if i == 0:
|
| 39 |
+
self.mlp_geo.append(Linear(sum(num_gaussians), h_channel))
|
| 40 |
+
else:
|
| 41 |
+
self.mlp_geo.append(Linear(h_channel, h_channel))
|
| 42 |
+
if self.batchnorm == "True":
|
| 43 |
+
self.mlp_geo.append(nn.BatchNorm1d(h_channel))
|
| 44 |
+
self.mlp_geo.append(self.activation)
|
| 45 |
+
|
| 46 |
+
self.mlp_geo_backup = ModuleList()
|
| 47 |
+
for i in range(self.localdepth):
|
| 48 |
+
if i == 0:
|
| 49 |
+
self.mlp_geo_backup.append(Linear(4, h_channel))
|
| 50 |
+
else:
|
| 51 |
+
self.mlp_geo_backup.append(Linear(h_channel, h_channel))
|
| 52 |
+
if self.batchnorm == "True":
|
| 53 |
+
self.mlp_geo_backup.append(nn.BatchNorm1d(h_channel))
|
| 54 |
+
self.mlp_geo_backup.append(self.activation)
|
| 55 |
+
self.translinear=Linear(input_featuresize+1, self.h_channel)
|
| 56 |
+
self.interactions= ModuleList()
|
| 57 |
+
for i in range(self.num_interactions):
|
| 58 |
+
block = SPNN(
|
| 59 |
+
in_ch=self.input_featuresize,
|
| 60 |
+
hidden_channels=self.h_channel,
|
| 61 |
+
activation=self.activation,
|
| 62 |
+
finaldepth=self.finaldepth,
|
| 63 |
+
batchnorm=self.batchnorm,
|
| 64 |
+
num_input_geofeature=self.h_channel
|
| 65 |
+
)
|
| 66 |
+
self.interactions.append(block)
|
| 67 |
+
self.reset_parameters()
|
| 68 |
+
def reset_parameters(self):
|
| 69 |
+
for lin in self.mlp_geo:
|
| 70 |
+
if isinstance(lin, Linear):
|
| 71 |
+
torch.nn.init.xavier_uniform_(lin.weight)
|
| 72 |
+
lin.bias.data.fill_(0)
|
| 73 |
+
for i in (self.interactions):
|
| 74 |
+
i.reset_parameters()
|
| 75 |
+
|
| 76 |
+
def single_forward(self, input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep):
|
| 77 |
+
if edge_rep:
|
| 78 |
+
i, j, k = edge_index_2rd
|
| 79 |
+
edge_index1,edge_index2= edge_index
|
| 80 |
+
edge_index_all=torch.cat([edge_index1,edge_index2],1)
|
| 81 |
+
distance_ij=(coords[j] - coords[i]).norm(p=2, dim=1)
|
| 82 |
+
distance_jk=(coords[j] - coords[k]).norm(p=2, dim=1)
|
| 83 |
+
theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j])
|
| 84 |
+
geo_encoding_1st=distance_ij[:,None]
|
| 85 |
+
geo_encoding=torch.cat([geo_encoding_1st,distance_jk[:,None],theta_ijk[:,None]],dim=-1)
|
| 86 |
+
else:
|
| 87 |
+
coords_j = coords[edge_index[0]]
|
| 88 |
+
coords_i = coords[edge_index[1]]
|
| 89 |
+
geo_encoding=torch.cat([coords_j,coords_i],dim=-1)
|
| 90 |
+
if edge_rep:
|
| 91 |
+
for lin in self.mlp_geo:
|
| 92 |
+
geo_encoding=lin(geo_encoding)
|
| 93 |
+
else:
|
| 94 |
+
for lin in self.mlp_geo_backup:
|
| 95 |
+
geo_encoding=lin(geo_encoding)
|
| 96 |
+
geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype)
|
| 97 |
+
node_feature= input_feature
|
| 98 |
+
node_feature_list=[]
|
| 99 |
+
for interaction in self.interactions:
|
| 100 |
+
node_feature = interaction(node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,self.att)
|
| 101 |
+
node_feature_list.append(node_feature)
|
| 102 |
+
return node_feature_list
|
| 103 |
+
def forward(self, input_feature, coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep):
|
| 104 |
+
output=self.single_forward(input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep)
|
| 105 |
+
return output
|
| 106 |
+
|
| 107 |
+
class SPNN(torch.nn.Module):
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
in_ch,
|
| 111 |
+
hidden_channels,
|
| 112 |
+
activation=torch.nn.ReLU(),
|
| 113 |
+
finaldepth=3,
|
| 114 |
+
batchnorm="True",
|
| 115 |
+
num_input_geofeature=13
|
| 116 |
+
):
|
| 117 |
+
super(SPNN, self).__init__()
|
| 118 |
+
self.activation = activation
|
| 119 |
+
self.finaldepth = finaldepth
|
| 120 |
+
self.batchnorm = batchnorm
|
| 121 |
+
self.num_input_geofeature=num_input_geofeature
|
| 122 |
+
|
| 123 |
+
self.WMLP_list = ModuleList()
|
| 124 |
+
for _ in range(4):
|
| 125 |
+
WMLP = ModuleList()
|
| 126 |
+
for i in range(self.finaldepth + 1):
|
| 127 |
+
if i == 0:
|
| 128 |
+
WMLP.append(Linear(hidden_channels*3+num_input_geofeature, hidden_channels))
|
| 129 |
+
else:
|
| 130 |
+
WMLP.append(Linear(hidden_channels, hidden_channels))
|
| 131 |
+
if self.batchnorm == "True":
|
| 132 |
+
WMLP.append(nn.BatchNorm1d(hidden_channels))
|
| 133 |
+
WMLP.append(self.activation)
|
| 134 |
+
self.WMLP_list.append(WMLP)
|
| 135 |
+
self.reset_parameters()
|
| 136 |
+
|
| 137 |
+
def reset_parameters(self):
|
| 138 |
+
for mlp in self.WMLP_list:
|
| 139 |
+
for lin in mlp:
|
| 140 |
+
if isinstance(lin, Linear):
|
| 141 |
+
torch.nn.init.xavier_uniform_(lin.weight)
|
| 142 |
+
lin.bias.data.fill_(0)
|
| 143 |
+
def forward(self, node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,att):
|
| 144 |
+
i,j,k = edge_index_2rd
|
| 145 |
+
if node_feature is None:
|
| 146 |
+
concatenated_vector = geo_encoding
|
| 147 |
+
else:
|
| 148 |
+
node_attr_0st = node_feature[i]
|
| 149 |
+
node_attr_1st = node_feature[j]
|
| 150 |
+
node_attr_2 = node_feature[k]
|
| 151 |
+
concatenated_vector = torch.cat(
|
| 152 |
+
[
|
| 153 |
+
node_attr_0st,
|
| 154 |
+
node_attr_1st,node_attr_2,
|
| 155 |
+
geo_encoding,
|
| 156 |
+
],
|
| 157 |
+
dim=-1,
|
| 158 |
+
)
|
| 159 |
+
x_i = concatenated_vector
|
| 160 |
+
|
| 161 |
+
edge1_edge1_mask = (edx_ij < num_edge_inside) & (edx_jk < num_edge_inside)
|
| 162 |
+
edge1_edge2_mask = (edx_ij < num_edge_inside) & (edx_jk >= num_edge_inside)
|
| 163 |
+
edge2_edge1_mask = (edx_ij >= num_edge_inside) & (edx_jk < num_edge_inside)
|
| 164 |
+
edge2_edge2_mask = (edx_ij >= num_edge_inside) & (edx_jk >= num_edge_inside)
|
| 165 |
+
masks=[edge1_edge1_mask,edge1_edge2_mask,edge2_edge1_mask,edge2_edge2_mask]
|
| 166 |
+
|
| 167 |
+
x_output=torch.zeros(x_i.shape[0],self.WMLP_list[0][0].weight.shape[0],device=x_i.device)
|
| 168 |
+
for index in range(4):
|
| 169 |
+
WMLP=self.WMLP_list[index]
|
| 170 |
+
x=x_i[masks[index]]
|
| 171 |
+
for lin in WMLP:
|
| 172 |
+
x=lin(x)
|
| 173 |
+
x = F.leaky_relu(x)*att[index]
|
| 174 |
+
x_output[masks[index]]+=x
|
| 175 |
+
|
| 176 |
+
out_feature = scatter(x_output, i, dim=0, reduce='add')
|
| 177 |
+
return out_feature
|
| 178 |
+
|
| 179 |
+
class HGT(torch.nn.Module):
|
| 180 |
+
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
|
| 181 |
+
super().__init__()
|
| 182 |
+
|
| 183 |
+
self.lin_dict = torch.nn.ModuleDict()
|
| 184 |
+
for node_type in ["vertices"]:
|
| 185 |
+
self.lin_dict[node_type] = Linear(-1, hidden_channels)
|
| 186 |
+
|
| 187 |
+
self.convs = torch.nn.ModuleList()
|
| 188 |
+
for _ in range(num_layers):
|
| 189 |
+
conv = HGTConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]),
|
| 190 |
+
num_heads, group='sum')
|
| 191 |
+
self.convs.append(conv)
|
| 192 |
+
|
| 193 |
+
self.lin = Linear(hidden_channels, out_channels)
|
| 194 |
+
|
| 195 |
+
def forward(self, x_dict, edge_index_dict):
|
| 196 |
+
for node_type, x in x_dict.items():
|
| 197 |
+
x_dict[node_type]=self.lin_dict[node_type](x).relu_()
|
| 198 |
+
|
| 199 |
+
for conv in self.convs:
|
| 200 |
+
x_dict = conv(x_dict, edge_index_dict)
|
| 201 |
+
return self.lin(x_dict['vertices'])
|
| 202 |
+
class HAN(torch.nn.Module):
|
| 203 |
+
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
self.lin_dict = torch.nn.ModuleDict()
|
| 207 |
+
for node_type in ["vertices"]:
|
| 208 |
+
self.lin_dict[node_type] = Linear(-1, hidden_channels)
|
| 209 |
+
|
| 210 |
+
self.convs = torch.nn.ModuleList()
|
| 211 |
+
for _ in range(num_layers):
|
| 212 |
+
conv = HANConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]),
|
| 213 |
+
num_heads)
|
| 214 |
+
self.convs.append(conv)
|
| 215 |
+
|
| 216 |
+
self.lin = Linear(hidden_channels, out_channels)
|
| 217 |
+
|
| 218 |
+
def forward(self, x_dict, edge_index_dict):
|
| 219 |
+
for node_type, x in x_dict.items():
|
| 220 |
+
x_dict[node_type]=self.lin_dict[node_type](x).relu_()
|
| 221 |
+
|
| 222 |
+
for conv in self.convs:
|
| 223 |
+
x_dict = conv(x_dict, edge_index_dict)
|
| 224 |
+
return self.lin(x_dict['vertices'])
|
train_new.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import scipy
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
from matplotlib import cm
|
| 16 |
+
from sklearn.metrics import (auc, explained_variance_score, f1_score,
|
| 17 |
+
mean_absolute_error, mean_squared_error,
|
| 18 |
+
precision_score, r2_score, recall_score,
|
| 19 |
+
roc_auc_score, roc_curve)
|
| 20 |
+
from torch.nn.functional import softmax
|
| 21 |
+
from torch_geometric.utils import subgraph
|
| 22 |
+
|
| 23 |
+
torch.autograd.set_detect_anomaly(True)
|
| 24 |
+
import math
|
| 25 |
+
import pickle
|
| 26 |
+
import time
|
| 27 |
+
from datetime import date, datetime, timedelta
|
| 28 |
+
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch_geometric
|
| 31 |
+
import torchvision.datasets
|
| 32 |
+
import torchvision.models
|
| 33 |
+
import torchvision.transforms as transforms
|
| 34 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 35 |
+
from torch_geometric.nn import GIN, MLP, GATConv
|
| 36 |
+
from torch_geometric.nn.pool import global_add_pool, global_mean_pool
|
| 37 |
+
from torch_geometric.utils import add_self_loops
|
| 38 |
+
|
| 39 |
+
import dataset
|
| 40 |
+
import model_new
|
| 41 |
+
import util
|
| 42 |
+
from dataset import label_mapping, reverse_label_mapping
|
| 43 |
+
from model_new import Smodel
|
| 44 |
+
|
| 45 |
+
blue = lambda x: '\033[94m' + x + '\033[0m'
|
| 46 |
+
red = lambda x: '\033[31m' + x + '\033[0m'
|
| 47 |
+
green = lambda x: '\033[32m' + x + '\033[0m'
|
| 48 |
+
yellow = lambda x: '\033[33m' + x + '\033[0m'
|
| 49 |
+
greenline = lambda x: '\033[42m' + x + '\033[0m'
|
| 50 |
+
yellowline = lambda x: '\033[43m' + x + '\033[0m'
|
| 51 |
+
|
| 52 |
+
def get_args():
|
| 53 |
+
parser = argparse.ArgumentParser()
|
| 54 |
+
parser.add_argument('--model',default="our", type=str)
|
| 55 |
+
parser.add_argument('--train_batch', default=64, type=int)
|
| 56 |
+
parser.add_argument('--test_batch', default=128, type=int)
|
| 57 |
+
parser.add_argument('--share', type=str, default="0")
|
| 58 |
+
parser.add_argument('--edge_rep', type=str, default="True")
|
| 59 |
+
parser.add_argument('--batchnorm', type=str, default="True")
|
| 60 |
+
parser.add_argument('--extent_norm', type=str, default="T")
|
| 61 |
+
parser.add_argument('--spanning_tree', type=str, default="T")
|
| 62 |
+
|
| 63 |
+
parser.add_argument('--loss_coef', default=0.1, type=float)
|
| 64 |
+
parser.add_argument('--h_ch', default=512, type=int)
|
| 65 |
+
parser.add_argument('--localdepth', type=int, default=1)
|
| 66 |
+
parser.add_argument('--num_interactions', type=int, default=4)
|
| 67 |
+
parser.add_argument('--finaldepth', type=int, default=4)
|
| 68 |
+
parser.add_argument('--classifier_depth', type=int, default=4)
|
| 69 |
+
parser.add_argument('--dropout', type=float, default=0.0)
|
| 70 |
+
|
| 71 |
+
parser.add_argument('--dataset', type=str, default='mnist')
|
| 72 |
+
parser.add_argument('--log', type=str, default="True")
|
| 73 |
+
parser.add_argument('--test_per_round', type=int, default=10)
|
| 74 |
+
parser.add_argument('--patience', type=int, default=30) #scheduler
|
| 75 |
+
parser.add_argument('--nepoch', type=int, default=301)
|
| 76 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
| 77 |
+
parser.add_argument('--manualSeed', type=str, default="False")
|
| 78 |
+
parser.add_argument('--man_seed', type=int, default=12345)
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
args.log=True if args.log=="True" else False
|
| 81 |
+
args.edge_rep=True if args.edge_rep=="True" else False
|
| 82 |
+
args.batchnorm=True if args.batchnorm=="True" else False
|
| 83 |
+
args.save_dir=os.path.join('./save/',args.dataset,args.model)
|
| 84 |
+
args.manualSeed=True if args.manualSeed=="True" else False
|
| 85 |
+
return args
|
| 86 |
+
|
| 87 |
+
args = get_args()
|
| 88 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 89 |
+
criterion=nn.CrossEntropyLoss()
|
| 90 |
+
if args.dataset in ["mnist"]:
|
| 91 |
+
x_out=90
|
| 92 |
+
args.data_dir='data/multi_mnist_with_index.pkl'
|
| 93 |
+
elif args.dataset in ["mnist_sparse"]:
|
| 94 |
+
x_out=90
|
| 95 |
+
args.data_dir='data/multi_mnist_sparse.pkl'
|
| 96 |
+
elif args.dataset in ["building"]:
|
| 97 |
+
x_out=100
|
| 98 |
+
args.data_dir='data/building_with_index.pkl'
|
| 99 |
+
elif args.dataset in ["mbuilding"]:
|
| 100 |
+
x_out=100
|
| 101 |
+
args.data_dir='data/mp_building.pkl'
|
| 102 |
+
elif args.dataset in ["sbuilding"]:
|
| 103 |
+
x_out=10
|
| 104 |
+
args.data_dir='data/single_building.pkl'
|
| 105 |
+
elif args.dataset in ["smnist"]:
|
| 106 |
+
x_out=10
|
| 107 |
+
args.data_dir='data/single_mnist.pkl'
|
| 108 |
+
elif args.dataset in ["dbp"]:
|
| 109 |
+
x_out=2
|
| 110 |
+
args.data_dir='data/triple_building_600.pkl'
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if args.model=="our":
|
| 114 |
+
model=Smodel(h_channel=args.h_ch,input_featuresize=args.h_ch,\
|
| 115 |
+
localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share,batchnorm=args.batchnorm)
|
| 116 |
+
mlpmodel=MLP(in_channels=args.h_ch*args.num_interactions, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
|
| 117 |
+
|
| 118 |
+
elif args.model=="HGT":
|
| 119 |
+
model=model_new.HGT(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions)
|
| 120 |
+
mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
|
| 121 |
+
elif args.model=="HAN":
|
| 122 |
+
model=model_new.HAN(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions)
|
| 123 |
+
mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
|
| 124 |
+
|
| 125 |
+
model.to(device), mlpmodel.to(device)
|
| 126 |
+
opt_list=list(model.parameters())+list(mlpmodel.parameters())
|
| 127 |
+
|
| 128 |
+
optimizer = torch.optim.Adam( opt_list, lr=args.lr)
|
| 129 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8)
|
| 130 |
+
|
| 131 |
+
def contrastive_loss(embeddings,labels,margin):
|
| 132 |
+
|
| 133 |
+
positive_mask = labels.view(-1, 1) == labels.view(1, -1)
|
| 134 |
+
negative_mask = ~positive_mask
|
| 135 |
+
|
| 136 |
+
# Calculate the number of positive and negative pairs
|
| 137 |
+
num_positive_pairs = positive_mask.sum() - labels.shape[0]
|
| 138 |
+
num_negative_pairs = negative_mask.sum()
|
| 139 |
+
|
| 140 |
+
# If there are no negative pairs, return a placeholder loss
|
| 141 |
+
if num_negative_pairs==0 or num_positive_pairs== 0:
|
| 142 |
+
print("all pos or neg")
|
| 143 |
+
return torch.tensor(0, dtype=torch.float)
|
| 144 |
+
# Calculate the pairwise Euclidean distances between embeddings
|
| 145 |
+
distances = torch.cdist(embeddings, embeddings)/np.sqrt(embeddings.shape[1])
|
| 146 |
+
|
| 147 |
+
if num_positive_pairs>num_negative_pairs:
|
| 148 |
+
# Sample an equal number of + pairs
|
| 149 |
+
positive_indices = torch.nonzero(positive_mask)
|
| 150 |
+
random_positive_indices = torch.randperm(len(positive_indices))[:num_negative_pairs]
|
| 151 |
+
selected_positive_indices = positive_indices[random_positive_indices]
|
| 152 |
+
|
| 153 |
+
# Select corresponding negative pairs
|
| 154 |
+
negative_mask.fill_diagonal_(False)
|
| 155 |
+
negative_distances = distances[negative_mask].view(-1, 1)
|
| 156 |
+
positive_distances = distances[selected_positive_indices[:,0],selected_positive_indices[:,1]].view(-1, 1)
|
| 157 |
+
else: # case for most datasets
|
| 158 |
+
# Sample an equal number of - pairs
|
| 159 |
+
negative_indices = torch.nonzero(negative_mask)
|
| 160 |
+
random_negative_indices = torch.randperm(len(negative_indices))[:num_positive_pairs]
|
| 161 |
+
selected_negative_indices = negative_indices[random_negative_indices]
|
| 162 |
+
|
| 163 |
+
# Select corresponding positive pairs
|
| 164 |
+
positive_mask.fill_diagonal_(False)
|
| 165 |
+
positive_distances = distances[positive_mask].view(-1, 1)
|
| 166 |
+
negative_distances = distances[selected_negative_indices[:,0],selected_negative_indices[:,1]].view(-1, 1)
|
| 167 |
+
|
| 168 |
+
# Calculate the loss for positive and negative pairs
|
| 169 |
+
loss = (positive_distances - negative_distances + margin).clamp(min=0).mean()
|
| 170 |
+
return loss
|
| 171 |
+
|
| 172 |
+
def forward_HGT(data,model,mlpmodel):
|
| 173 |
+
data = data.to(device)
|
| 174 |
+
x,batch=data.pos, data['vertices'].batch
|
| 175 |
+
data["vertices"]['x']=data.pos
|
| 176 |
+
label=data.y.long().view(-1)
|
| 177 |
+
|
| 178 |
+
optimizer.zero_grad()
|
| 179 |
+
|
| 180 |
+
output=model(data.x_dict, data.edge_index_dict)
|
| 181 |
+
if args.dataset in ["dbp"]:
|
| 182 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 183 |
+
else:
|
| 184 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 185 |
+
graph_embeddings.clamp_(max=1e6)
|
| 186 |
+
c_loss=contrastive_loss(graph_embeddings,label,margin=1)
|
| 187 |
+
output=mlpmodel(graph_embeddings)
|
| 188 |
+
# log_probs = F.log_softmax(output, dim=1)
|
| 189 |
+
|
| 190 |
+
loss = criterion(output, label)
|
| 191 |
+
loss+=c_loss*args.loss_coef
|
| 192 |
+
return loss,c_loss*args.loss_coef,output,label
|
| 193 |
+
|
| 194 |
+
def forward(data,model,mlpmodel):
|
| 195 |
+
data = data.to(device)
|
| 196 |
+
edge_index1=data['vertices', 'inside', 'vertices']['edge_index']
|
| 197 |
+
edge_index2=data['vertices', 'apart', 'vertices']['edge_index']
|
| 198 |
+
combined_edge_index=torch.cat([data['vertices', 'inside', 'vertices']['edge_index'],data['vertices', 'apart', 'vertices']['edge_index']],1)
|
| 199 |
+
num_edge_inside=edge_index1.shape[1]
|
| 200 |
+
|
| 201 |
+
if args.spanning_tree == 'T':
|
| 202 |
+
edge_weight=torch.rand(combined_edge_index.shape[1]) + 1
|
| 203 |
+
undirected_spanning_edge = util.build_spanning_tree_edge(combined_edge_index, edge_weight,num_nodes=data.pos.shape[0])
|
| 204 |
+
|
| 205 |
+
edge_set_1 = set(map(tuple, edge_index2.t().tolist()))
|
| 206 |
+
edge_set_2 = set(map(tuple, undirected_spanning_edge.t().tolist()))
|
| 207 |
+
|
| 208 |
+
common_edges = edge_set_1.intersection(edge_set_2)
|
| 209 |
+
common_edges_tensor = torch.tensor(list(common_edges), dtype=torch.long).t().to(device)
|
| 210 |
+
spanning_edge=torch.cat([edge_index1,common_edges_tensor],1)
|
| 211 |
+
combined_edge_index=spanning_edge
|
| 212 |
+
x,batch=data.pos, data['vertices'].batch
|
| 213 |
+
label=data.y.long().view(-1)
|
| 214 |
+
|
| 215 |
+
num_nodes=x.shape[0]
|
| 216 |
+
edge_index_2rd, num_triplets_real, edx_jk, edx_ij = util.triplets(combined_edge_index, num_nodes)
|
| 217 |
+
optimizer.zero_grad()
|
| 218 |
+
input_feature=torch.zeros([x.shape[0],args.h_ch],device=device)
|
| 219 |
+
output=model(input_feature,x,[edge_index1,edge_index2], edge_index_2rd,edx_jk, edx_ij,batch,num_edge_inside,args.edge_rep)
|
| 220 |
+
output=torch.cat(output,dim=1)
|
| 221 |
+
if args.dataset in ["dbp"]:
|
| 222 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 223 |
+
else:
|
| 224 |
+
graph_embeddings=global_add_pool(output,batch)
|
| 225 |
+
graph_embeddings.clamp_(max=1e6)
|
| 226 |
+
c_loss=contrastive_loss(graph_embeddings,label,margin=1)
|
| 227 |
+
output=mlpmodel(graph_embeddings)
|
| 228 |
+
|
| 229 |
+
loss = criterion(output, label)
|
| 230 |
+
loss+=c_loss*args.loss_coef
|
| 231 |
+
return loss,c_loss*args.loss_coef,output,label
|
| 232 |
+
def train(train_Loader,model,mlpmodel ):
|
| 233 |
+
epochloss=0
|
| 234 |
+
epochcloss=0
|
| 235 |
+
y_hat, y_true,y_hat_logit = [], [], []
|
| 236 |
+
optimizer.zero_grad()
|
| 237 |
+
model.train()
|
| 238 |
+
mlpmodel.train()
|
| 239 |
+
for i,data in enumerate(train_Loader):
|
| 240 |
+
if args.model=="our":
|
| 241 |
+
loss,c_loss,output,label =forward(data,model,mlpmodel)
|
| 242 |
+
elif args.model in ["HGT","HAN"]:
|
| 243 |
+
loss,c_loss,output,label =forward_HGT(data,model,mlpmodel)
|
| 244 |
+
|
| 245 |
+
loss.backward()
|
| 246 |
+
optimizer.step()
|
| 247 |
+
epochloss+=loss.detach().cpu()
|
| 248 |
+
epochcloss+=c_loss.detach().cpu()
|
| 249 |
+
|
| 250 |
+
_, pred = output.topk(1, dim=1, largest=True, sorted=True)
|
| 251 |
+
pred,label,output=pred.cpu(),label.cpu(),output.cpu()
|
| 252 |
+
y_hat += list(pred.detach().numpy().reshape(-1))
|
| 253 |
+
y_true += list(label.detach().numpy().reshape(-1))
|
| 254 |
+
y_hat_logit+=list(output.detach().numpy())
|
| 255 |
+
return epochloss.item()/len(train_Loader),epochcloss.item()/len(train_Loader),y_hat, y_true,y_hat_logit
|
| 256 |
+
|
| 257 |
+
def test(loader,model,mlpmodel ):
|
| 258 |
+
y_hat, y_true,y_hat_logit = [], [], []
|
| 259 |
+
loss_total, pred_num = 0, 0
|
| 260 |
+
model.eval()
|
| 261 |
+
mlpmodel.eval()
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
for data in loader:
|
| 264 |
+
if args.model=="our":
|
| 265 |
+
loss,c_loss,output,label =forward(data,model,mlpmodel)
|
| 266 |
+
elif args.model in ["HGT","HAN"]:
|
| 267 |
+
loss,c_loss,output,label =forward_HGT(data,model,mlpmodel)
|
| 268 |
+
|
| 269 |
+
_, pred = output.topk(1, dim=1, largest=True, sorted=True)
|
| 270 |
+
pred,label,output=pred.cpu(),label.cpu(),output.cpu()
|
| 271 |
+
y_hat += list(pred.detach().numpy().reshape(-1))
|
| 272 |
+
y_true += list(label.detach().numpy().reshape(-1))
|
| 273 |
+
y_hat_logit+=list(output.detach().numpy())
|
| 274 |
+
|
| 275 |
+
pred_num += len(label.reshape(-1, 1))
|
| 276 |
+
loss_total += loss.detach() * len(label.reshape(-1, 1))
|
| 277 |
+
return loss_total/pred_num,y_hat, y_true, y_hat_logit
|
| 278 |
+
def main(args,train_Loader,val_Loader,test_Loader):
|
| 279 |
+
best_val_trigger = -1
|
| 280 |
+
old_lr=1e3
|
| 281 |
+
suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"),
|
| 282 |
+
datetime.now().strftime("%d"),
|
| 283 |
+
datetime.now().strftime("%H"),
|
| 284 |
+
datetime.now().strftime("%M"),
|
| 285 |
+
datetime.now().strftime("%S"))
|
| 286 |
+
if args.log: writer = SummaryWriter(os.path.join(tensorboard_dir,suffix))
|
| 287 |
+
|
| 288 |
+
for epoch in range(args.nepoch):
|
| 289 |
+
train_loss,train_closs,y_hat, y_true,y_hat_logit=train(train_Loader,model,mlpmodel )
|
| 290 |
+
|
| 291 |
+
train_acc=util.calculate(y_hat,y_true,y_hat_logit)
|
| 292 |
+
try:util.record({"loss":train_loss,"closs":train_closs,"acc":train_acc},epoch,writer,"Train")
|
| 293 |
+
except: pass
|
| 294 |
+
util.print_1(epoch,'Train',{"loss":train_loss,"closs":train_closs,"acc":train_acc})
|
| 295 |
+
if epoch % args.test_per_round == 0:
|
| 296 |
+
val_loss, yhat_val, ytrue_val, yhatlogit_val = test(val_Loader,model,mlpmodel )
|
| 297 |
+
test_loss, yhat_test, ytrue_test, yhatlogit_test = test(test_Loader,model,mlpmodel )
|
| 298 |
+
val_acc=util.calculate(yhat_val,ytrue_val,yhatlogit_val)
|
| 299 |
+
try:util.record({"loss":val_loss,"acc":val_acc},epoch,writer,"Val")
|
| 300 |
+
except: pass
|
| 301 |
+
util.print_1(epoch,'Val',{"loss":val_loss,"acc":val_acc},color=blue)
|
| 302 |
+
test_acc=util.calculate(yhat_test,ytrue_test,yhatlogit_test)
|
| 303 |
+
try:util.record({"loss":test_loss,"acc":test_acc},epoch,writer,"Test")
|
| 304 |
+
except: pass
|
| 305 |
+
util.print_1(epoch,'Test',{"loss":test_loss,"acc":test_acc},color=blue)
|
| 306 |
+
val_trigger=val_acc
|
| 307 |
+
if val_trigger > best_val_trigger:
|
| 308 |
+
best_val_trigger = val_trigger
|
| 309 |
+
best_model = copy.deepcopy(model)
|
| 310 |
+
best_mlpmodel=copy.deepcopy(mlpmodel)
|
| 311 |
+
best_info=[epoch,val_trigger]
|
| 312 |
+
"""
|
| 313 |
+
update lr when epoch≥30
|
| 314 |
+
"""
|
| 315 |
+
if epoch >= 30:
|
| 316 |
+
lr = scheduler.optimizer.param_groups[0]['lr']
|
| 317 |
+
if old_lr!=lr:
|
| 318 |
+
print(red('lr'), epoch, (lr), sep=', ')
|
| 319 |
+
old_lr=lr
|
| 320 |
+
scheduler.step(val_trigger)
|
| 321 |
+
"""
|
| 322 |
+
use best model to get best model result
|
| 323 |
+
"""
|
| 324 |
+
val_loss, yhat_val, ytrue_val, yhat_logit_val = test(val_Loader,best_model,best_mlpmodel)
|
| 325 |
+
test_loss, yhat_test, ytrue_test, yhat_logit_test= test(test_Loader,best_model,best_mlpmodel)
|
| 326 |
+
|
| 327 |
+
val_acc=util.calculate(yhat_val,ytrue_val,yhat_logit_val)
|
| 328 |
+
util.print_1(best_info[0],'BestVal',{"loss":val_loss,"acc":val_acc},color=blue)
|
| 329 |
+
test_acc=util.calculate(yhat_test,ytrue_test,yhat_logit_test)
|
| 330 |
+
util.print_1(best_info[0],'BestTest',{"loss":test_loss,"acc":test_acc},color=blue)
|
| 331 |
+
if args.model=="our":print(best_model.att)
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
save training info and best result
|
| 335 |
+
"""
|
| 336 |
+
result_file=os.path.join(info_dir, suffix)
|
| 337 |
+
with open(result_file, 'w') as f:
|
| 338 |
+
print("Random Seed: ", Seed,file=f)
|
| 339 |
+
print(f"acc val : {val_acc:.3f}, Test : {test_acc:.3f}", file=f)
|
| 340 |
+
print(f"Best info: {best_info}", file=f)
|
| 341 |
+
for i in [[a,getattr(args, a)] for a in args.__dict__]:
|
| 342 |
+
print(i,sep='\n',file=f)
|
| 343 |
+
to_save_dict={'model':best_model.state_dict(),'mlpmodel':best_mlpmodel.state_dict(),'args':args,'labels':ytrue_test,'yhat':yhat_test,'yhat_logit':yhat_logit_test}
|
| 344 |
+
torch.save(to_save_dict, os.path.join(model_dir,suffix+'.pth') )
|
| 345 |
+
print("done")
|
| 346 |
+
|
| 347 |
+
if __name__ == '__main__':
|
| 348 |
+
"""
|
| 349 |
+
build dir
|
| 350 |
+
"""
|
| 351 |
+
if not os.path.exists(args.save_dir):
|
| 352 |
+
os.makedirs(args.save_dir,exist_ok=True)
|
| 353 |
+
tensorboard_dir=os.path.join(args.save_dir,'log')
|
| 354 |
+
if not os.path.exists(tensorboard_dir):
|
| 355 |
+
os.makedirs(tensorboard_dir,exist_ok=True)
|
| 356 |
+
model_dir=os.path.join(args.save_dir,'model')
|
| 357 |
+
if not os.path.exists(model_dir):
|
| 358 |
+
os.makedirs(model_dir,exist_ok=True)
|
| 359 |
+
info_dir=os.path.join(args.save_dir,'info')
|
| 360 |
+
if not os.path.exists(info_dir):
|
| 361 |
+
os.makedirs(info_dir,exist_ok=True)
|
| 362 |
+
|
| 363 |
+
Seed = 0
|
| 364 |
+
test_ratio=0.2
|
| 365 |
+
print("data splitting Random Seed: ", Seed)
|
| 366 |
+
if args.dataset in ['mnist',"mnist_sparse"]:
|
| 367 |
+
train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 368 |
+
elif args.dataset in ['building']:
|
| 369 |
+
train_ds,val_ds,test_ds=dataset.get_building_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 370 |
+
elif args.dataset in ['mbuilding']:
|
| 371 |
+
train_ds,val_ds,test_ds=dataset.get_mbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 372 |
+
elif args.dataset in ['sbuilding']:
|
| 373 |
+
train_ds,val_ds,test_ds=dataset.get_sbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 374 |
+
elif args.dataset in ['smnist']:
|
| 375 |
+
train_ds,val_ds,test_ds=dataset.get_smnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 376 |
+
elif args.dataset in ['dbp']:
|
| 377 |
+
train_ds,val_ds,test_ds=dataset.get_dbp_dataset(args.data_dir,Seed,test_ratio=test_ratio)
|
| 378 |
+
if args.extent_norm=="T":
|
| 379 |
+
train_ds= dataset.affine_transform_to_range(train_ds,target_range=(-1, 1))
|
| 380 |
+
val_ds= dataset.affine_transform_to_range(val_ds,target_range=(-1, 1))
|
| 381 |
+
test_ds= dataset.affine_transform_to_range(test_ds,target_range=(-1, 1))
|
| 382 |
+
|
| 383 |
+
train_loader = torch_geometric.loader.DataLoader(train_ds,batch_size=args.train_batch, shuffle=False,pin_memory=True,drop_last=True)
|
| 384 |
+
val_loader = torch_geometric.loader.DataLoader(val_ds, batch_size=args.test_batch, shuffle=False, pin_memory=True)
|
| 385 |
+
test_loader = torch_geometric.loader.DataLoader(test_ds,batch_size=args.test_batch, shuffle=False,pin_memory=True)
|
| 386 |
+
"""
|
| 387 |
+
set model seed
|
| 388 |
+
"""
|
| 389 |
+
Seed = args.man_seed if args.manualSeed else random.randint(1, 10000)
|
| 390 |
+
Seed=3407
|
| 391 |
+
print("Random Seed: ", Seed)
|
| 392 |
+
print(args)
|
| 393 |
+
random.seed(Seed)
|
| 394 |
+
torch.manual_seed(Seed)
|
| 395 |
+
np.random.seed(Seed)
|
| 396 |
+
main(args,train_loader,val_loader,test_loader)
|
| 397 |
+
|
util.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from matplotlib import cm
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import scipy
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision
|
| 10 |
+
|
| 11 |
+
from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc,confusion_matrix
|
| 12 |
+
from sklearn.feature_selection import r_regression
|
| 13 |
+
|
| 14 |
+
from torch_sparse import SparseTensor
|
| 15 |
+
from scipy.sparse import csr_matrix
|
| 16 |
+
from scipy.sparse.csgraph import minimum_spanning_tree
|
| 17 |
+
from math import pi as PI
|
| 18 |
+
|
| 19 |
+
def scipy_spanning_tree(edge_index, edge_weight,num_nodes ):
|
| 20 |
+
row, col = edge_index.cpu()
|
| 21 |
+
edge_weight=edge_weight.cpu()
|
| 22 |
+
cgraph = csr_matrix((edge_weight, (row, col)), shape=(num_nodes, num_nodes))
|
| 23 |
+
Tcsr = minimum_spanning_tree(cgraph)
|
| 24 |
+
tree_row, tree_col = Tcsr.nonzero()
|
| 25 |
+
spanning_edges = np.stack([tree_row,tree_col],0)
|
| 26 |
+
return spanning_edges
|
| 27 |
+
|
| 28 |
+
def build_spanning_tree_edge(edge_index,edge_weight, num_nodes):
|
| 29 |
+
spanning_edges = scipy_spanning_tree(edge_index, edge_weight,num_nodes,)
|
| 30 |
+
|
| 31 |
+
spanning_edges = torch.tensor(spanning_edges, dtype=torch.long, device=edge_index.device)
|
| 32 |
+
spanning_edges_undirected = torch.cat([spanning_edges,torch.stack([spanning_edges[1],spanning_edges[0]])],1)
|
| 33 |
+
return spanning_edges_undirected
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def record(values,epoch,writer,phase="Train"):
|
| 39 |
+
""" tfboard write """
|
| 40 |
+
for key,value in values.items():
|
| 41 |
+
writer.add_scalar(key+"/"+phase,value,epoch)
|
| 42 |
+
def calculate(y_hat,y_true,y_hat_logit):
|
| 43 |
+
""" calculate five metrics using y_hat, y_true, y_hat_logit """
|
| 44 |
+
train_acc=(np.array(y_hat) == np.array(y_true)).sum()/len(y_true)
|
| 45 |
+
# recall=recall_score(y_true, y_hat,zero_division=0,average='micro')
|
| 46 |
+
# precision=precision_score(y_true, y_hat,zero_division=0,average='micro')
|
| 47 |
+
# Fscore=f1_score(y_true, y_hat,zero_division=0,average='micro')
|
| 48 |
+
# roc=roc_auc_score(y_true, scipy.special.softmax(np.array(y_hat_logit),axis=1)[:,1],average='micro',multi_class='ovr')
|
| 49 |
+
# one_hot_encoded_labels = np.zeros((len(y_true), 100))
|
| 50 |
+
# one_hot_encoded_labels[np.arange(len(y_true)), y_true] = 1
|
| 51 |
+
# roc=roc_auc_score(one_hot_encoded_labels, scipy.special.softmax(np.array(y_hat_logit),axis=1),average='micro',multi_class='ovr')
|
| 52 |
+
return train_acc
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def print_1(epoch,phase,values,color=None):
|
| 56 |
+
""" print epoch info"""
|
| 57 |
+
if color is not None:
|
| 58 |
+
print(color( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
|
| 59 |
+
else:
|
| 60 |
+
print(( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
|
| 61 |
+
|
| 62 |
+
def get_angle(v1, v2):
|
| 63 |
+
if v1.shape[1]==2:
|
| 64 |
+
v1=F.pad(v1, (0, 1),value=0)
|
| 65 |
+
if v2.shape[1]==2:
|
| 66 |
+
v2= F.pad(v2, (0, 1),value=0)
|
| 67 |
+
return torch.atan2( torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
|
| 68 |
+
def get_theta(v1, v2):
|
| 69 |
+
# v1 is starting line, right-hand rule to v2, if thumb is up, +, else -
|
| 70 |
+
angle=get_angle(v1, v2)
|
| 71 |
+
if v1.shape[1]==2:
|
| 72 |
+
v1=F.pad(v1, (0, 1),value=0)
|
| 73 |
+
if v2.shape[1]==2:
|
| 74 |
+
v2= F.pad(v2, (0, 1),value=0)
|
| 75 |
+
v = torch.cross(v1, v2, dim=1)[...,2]
|
| 76 |
+
flag = torch.sign((v))
|
| 77 |
+
flag[flag==0]=-1
|
| 78 |
+
return angle*flag
|
| 79 |
+
|
| 80 |
+
def triplets(edge_index, num_nodes):
|
| 81 |
+
row, col = edge_index
|
| 82 |
+
|
| 83 |
+
value = torch.arange(row.size(0), device=row.device)
|
| 84 |
+
adj_t = SparseTensor(row=row, col=col, value=value,
|
| 85 |
+
sparse_sizes=(num_nodes, num_nodes))
|
| 86 |
+
adj_t_col = adj_t[:,row]
|
| 87 |
+
num_triplets = adj_t_col.set_value(None).sum(dim=0).to(torch.long)
|
| 88 |
+
|
| 89 |
+
idx_j = row.repeat_interleave(num_triplets)
|
| 90 |
+
idx_i = col.repeat_interleave(num_triplets)
|
| 91 |
+
edx_2nd = value.repeat_interleave(num_triplets)
|
| 92 |
+
idx_k = adj_t_col.t().storage.col()
|
| 93 |
+
edx_1st = adj_t_col.t().storage.value()
|
| 94 |
+
mask1 = (idx_i == idx_k) & (idx_j != idx_i) # Remove go back triplets.
|
| 95 |
+
mask2 = (idx_i == idx_j) & (idx_j != idx_k) # Remove repeat self loop triplets
|
| 96 |
+
mask3 = (idx_j == idx_k) & (idx_i != idx_k) # Remove self-loop neighbors
|
| 97 |
+
mask = ~(mask1 | mask2 | mask3)
|
| 98 |
+
idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
|
| 99 |
+
|
| 100 |
+
num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]
|
| 101 |
+
|
| 102 |
+
return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd
|