chultquist0's picture
charlie (#3)
73b8b1b verified
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
file_path = os.getcwd()
sys.path.append(file_path)
import root_gnn_base.dataset as datasets
from root_gnn_base import utils
import gc
def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
layers = []
layers.append(nn.Linear(in_size, out_size))
layers.append(activation())
layers.append(nn.Dropout(dropout))
return layers
def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
layers = []
if n_layers > 1:
layers += Make_SLP(in_size, hid_size, activation, dropout)
for i in range(n_layers-2):
layers += Make_SLP(hid_size, hid_size, activation, dropout)
layers += Make_SLP(hid_size, out_size, activation, dropout)
else:
layers += Make_SLP(in_size, out_size, activation, dropout)
layers.append(torch.nn.LayerNorm(out_size))
return nn.Sequential(*layers)
class MLP(nn.Module):
def __init__(self, in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0, **kwargs):
super().__init__()
print(f'Unused args while creating MLP: {kwargs}')
self.layers = Make_MLP(in_size, hid_size, hid_size, n_layers-1, activation, dropout)
self.linear = nn.Linear(hid_size, out_size)
def forward(self, x):
return self.linear(self.layers(x))
def broadcast_global_to_nodes(g, globals):
boundaries = g.batch_num_nodes()
return torch.repeat_interleave(globals, boundaries, dim=0)
def broadcast_global_to_edges(g, globals):
boundaries = g.batch_num_edges()
return torch.repeat_interleave(globals, boundaries, dim=0)
def copy_v(edges):
return {'m_v': edges.dst['h']}
def partial_reset(model : nn.Module):
in_size = len(model.classify.weight[0])
out_size = len(model.classify.weight)
device = next(model.classify.parameters()).device
torch.manual_seed(2)
model.classify = nn.Linear(in_size, out_size)
model.classify.to(device)
print(model.classify.weight)
def print_model(model: nn.Module):
print(model)
def print_mlp(layer):
for l in layer.children():
if isinstance(l, nn.Linear):
print(l.state_dict())
else:
print(l)
# This function returns a model with the whole GNN completely reset
def full_reset(model : nn.Module):
mlp_list = [model.node_encoder, model.edge_encoder, model.global_encoder,
model.node_update, model.edge_update, model.global_update,
model.global_decoder]
for mlp in mlp_list:
for layer in mlp.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
partial_reset(model)
class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size, n_layers, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.layers = nn.ModuleList()
# two-layer GCN
self.layers.extend(
[nn.Linear(in_size, hid_size),] +
[nn.Linear(hid_size, hid_size) for i in range(n_layers)] +
[dglnn.GraphConv(hid_size, hid_size) for i in range(n_layers)] +
[nn.Linear(hid_size, hid_size) for i in range(n_layers)]
)
self.classify = nn.Linear(hid_size, out_size)
#self.dropout = nn.Dropout(0.05)
def forward(self, g):
h = g.ndata['features']
for i, layer in enumerate(self.layers):
if i >= self.n_layers + 1 and i < self.n_layers * 2 + 1:
h = layer(g, h)
else:
h = layer(h)
h = F.relu(h)
with g.local_scope():
g.ndata['h'] = h
# Calculate graph representation by average readout.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
class GCN_global(nn.Module):
def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
#encoder
self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
#GCN
self.node_update = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.conv = dglnn.GraphConv(hid_size, hid_size)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def forward(self, g):
h = self.node_encoder(g.ndata['features'])
h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
for i in range(self.n_layers):
h = self.node_update(h)
h = self.conv(g, h)
g.ndata['h'] = h
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class GCN_global_2way(nn.Module):
def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
#encoder
self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
#GCN
self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.conv = dglnn.GraphConv(hid_size, hid_size)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def forward(self, g):
h = self.node_encoder(g.ndata['features'])
h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
for i in range(self.n_layers):
h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
h = self.conv(g, h)
g.ndata['h'] = h
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Edge_Network(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, 64)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
if "w" in g.ndata:
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
else:
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
def representation(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, 64)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
if "w" in g.ndata:
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
else:
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
before_global_decoder = h_global
after_global_decoder = self.global_decoder(before_global_decoder)
after_classify = self.classify(after_global_decoder)
return before_global_decoder, after_global_decoder, after_classify
def __str__(self):
layer_names = ["node_encoder", "edge_encoder", "global_encoder",
"node_update", "edge_update", "global_update", "global_decoder"]
layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
self.node_update, self.edge_update, self.global_update, self.global_decoder]
for i in range(len(layers)):
print(layer_names[i])
for layer in layers[i].children():
if isinstance(layer, nn.Linear):
print(layer.state_dict())
print("classify")
print(self.classify.weight)
return ""
class Transferred_Learning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
self.global_decoder = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def forward(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return self.classify(self.global_decoder(h_global))
class Transferred_Learning_Graph(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, additional_proc_steps=1, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
self.additional_proc_steps = additional_proc_steps
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def forward(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
for j in range(self.additional_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Transferred_Learning_Parallel(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
class Transferred_Learning_Sequential(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#encoder
self.mlp = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
global_features = self.mlp(pretrained_global)
return self.classify(global_features)
class Transferred_Learning_Message_Passing(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#encoder
self.mlp = Make_MLP(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'], hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
message_passing = None
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
if (message_passing is None):
message_passing = h_global.clone()
else:
message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
h_global = self.TL_global_decoder(h_global)
return message_passing
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
#print(f"message_passing layers have size = {pretrained_global.shape}")
#print(pretrained_global)
global_features = self.mlp(pretrained_global)
return self.classify(global_features)
class Transferred_Learning_Message_Passing_Parallel(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
self.classify = nn.Linear(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'] + hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
message_passing = None
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
if (message_passing is None):
message_passing = h_global.clone()
else:
message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
h_global = self.TL_global_decoder(h_global)
return message_passing
def forward(self, g, global_feats):
pretrained_message = self.Pretrained_Output(g.clone())
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(torch.cat((pretrained_message, h_global), dim = 1))
class Transferred_Learning_Finetuning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=False, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
print(f"Freeze Pretraining = {frozen_pretraining}")
if (frozen_pretraining):
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
for param in self.pretrained_model[7]:
param.requires_grad = True
torch.manual_seed(2)
self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
h_global = self.Pretrained_Output(g.clone())
return self.classify(h_global)
def representation(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
before_global_decoder = h_global
after_global_decoder = self.TL_global_decoder(before_global_decoder)
after_classify = self.classify(after_global_decoder)
return before_global_decoder, after_global_decoder, after_classify
def __str__(self):
layer_names = ["node_encoder", "edge_encoder", "global_encoder",
"node_update", "edge_update", "global_update", "global_decoder"]
layers = [self.pretrained_model[1], self.pretrained_model[2], self.pretrained_model[3],
self.pretrained_model[4], self.pretrained_model[5], self.pretrained_model[6],
self.pretrained_model[7]]
for i in range(len(layers)):
print(layer_names[i])
for layer in layers[i].children():
if isinstance(layer, nn.Linear):
print(layer.state_dict())
print("classify")
print(self.classify.weight)
return ""
class Transferred_Learning_Parallel_Finetuning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, learning_rate=0.0001, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.learning_rate = learning_rate
self.parallel_params = []
self.finetuning_params = []
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
self.finetuning_params.append(self.pretrained_model)
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
self.parallel_params.append(self.node_encoder)
self.parallel_params.append(self.edge_encoder)
self.parallel_params.append(self.global_encoder)
self.parallel_params.append(self.node_update)
self.parallel_params.append(self.edge_update)
self.parallel_params.append(self.global_update)
self.parallel_params.append(self.global_decoder)
self.parallel_params.append(self.classify)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
def parameters(self, recurse: bool = True):
params = []
for model_section in self.parallel_params:
if (type(self.learning_rate) == dict and self.learning_rate["trainable_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["trainable_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
for model_section in self.finetuning_params:
if (type(self.learning_rate) == dict and self.learning_rate["finetuning_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["finetuning_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class Attention(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
self.hid_size = hid_size
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
g.ndata['h'] = h
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
h_original_shape = h.shape
num_graphs = len(dgl.unbatch(g))
num_nodes = g.batch_num_nodes()[0].item()
padding_mask = g.ndata['padding_mask'] > 0
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
h = g.ndata['h']
query = self.queries(h)
key = self.keys(h)
value = self.values(h)
query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
h = torch.reshape(h, h_original_shape)
h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
g.ndata['h'] = h
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Attention_Edge_Network(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
h = g.ndata['h']
h_original_shape = h.shape
num_graphs = len(dgl.unbatch(g))
num_nodes = g.batch_num_nodes()[0].item()
padding_mask = g.ndata['padding_mask'].bool()
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
for i in range(self.n_proc_steps):
h = g.ndata['h']
query = self.queries(h)
key = self.keys(h)
value = self.values(h)
query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
h = torch.reshape(h, h_original_shape)
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h', 'w'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Attention_Unbatched(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, 1, dropout=dropout)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
unbatched_g = dgl.unbatch(g)
for graph in unbatched_g:
h = graph.ndata['h']
h, _ = self.multihead_attn(self.queries(h), self.keys(h), self.values(h))
graph.ndata['h'] = h
g = dgl.batch(unbatched_g)
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Transferred_Learning_Attention(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, num_heads, dropout=0, learning_rate=0.0001, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
self.hid_size = hid_size
gl_size = sample_global.shape[1] if self.has_global else 1
self.learning_rate = learning_rate
self.pretraining_params = []
self.attention_params = []
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
self.pretraining_params.append(self.pretrained_model[1])
self.pretraining_params.append(self.pretrained_model[3])
self.pretraining_params.append(self.pretrained_model[7])
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
self.attention_params.append(self.multihead_attn)
self.attention_params.append(self.queries)
self.attention_params.append(self.keys)
self.attention_params.append(self.values)
self.attention_params.append(self.classify)
self.attention_params.append(self.node_update)
self.attention_params.append(self.global_update)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def forward(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
g.ndata['h'] = h
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
h_original_shape = h.shape
num_graphs = len(dgl.unbatch(g))
num_nodes = g.batch_num_nodes()[0].item()
padding_mask = g.ndata['padding_mask'] > 0
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
h = g.ndata['h']
query = self.queries(h)
key = self.keys(h)
value = self.values(h)
query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
h = torch.reshape(h, h_original_shape)
h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
g.ndata['h'] = h
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
h_global = self.TL_global_decoder(h_global)
return self.classify(h_global)
def parameters(self, recurse: bool = True):
params = []
for model_section in self.pretraining_params:
if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
for model_section in self.attention_params:
if (type(self.learning_rate) == dict and self.learning_rate["attention_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["attention_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class Multimodel_Transferred_Learning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.learning_rate = learning_rate
input_size = 0
self.pretraining_params = []
self.model_params = []
self.pretrained_models = []
for model, path in zip(pretraining_model, pretraining_path):
input_size += model['args']['hid_size']
model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(path)['model_state_dict']
new_state_dict = {}
for k, v in checkpoint.items():
new_key = k.replace('module.', '')
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
pretrained_layers = list(model.children())
pretrained_layers = pretrained_layers[:-1]
model = nn.Sequential(*pretrained_layers)
# Freeze Weights
print(f"Freeze Pretraining = {frozen_pretraining}")
if (frozen_pretraining):
for param in model.parameters():
param.requires_grad = False # Freeze all layers
self.pretraining_params.append(model)
self.pretrained_models.append(model)
print(f"len(pretrained_models) = {len(self.pretrained_models)}")
print(f"input size = {input_size}")
self.final_mlp = Make_MLP(input_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
self.model_params.append(self.final_mlp)
self.model_params.append(self.classify)
def TL_node_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][1]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][1]:
x = layer(x)
return x
def TL_edge_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][2]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][2]:
x = layer(x)
return x
def TL_global_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][3]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][3]:
x = layer(x)
return x
def TL_node_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][4]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][4]:
x = layer(x)
return x
def TL_edge_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][5]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][5]:
x = layer(x)
return x
def TL_global_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][6]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][6]:
x = layer(x)
return x
def TL_global_decoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][7]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][7]:
x = layer(x)
return x
def Pretrained_Output(self, g, model_idx):
h = self.TL_node_encoder(g.ndata['features'], model_idx)
e = self.TL_edge_encoder(g.edata['features'], model_idx)
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats, model_idx)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
# h_global = self.TL_global_decoder(h_global, model_idx)
return h_global
def forward(self, g, global_feats):
h_global = []
for i in range(len(self.pretrained_models)):
h_global.append(self.Pretrained_Output(g.clone(), i))
h_global = torch.concatenate(h_global, dim=1)
return self.classify(self.final_mlp(h_global))
def to(self, device):
for i in range(len(self.pretrained_models)):
self.pretrained_models[i].to(device)
self.classify.to(device)
self.final_mlp.to(device)
return self
def parameters(self, recurse: bool = True):
params = []
for model_section in self.pretraining_params:
if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.00001})
for model_section in self.model_params:
if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class MultiModel(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.learning_rate = learning_rate
input_size = 0
self.model_params = []
self.pretraining_params = []
self.pretrained_models = []
for model, path in zip(pretraining_model, pretraining_path):
input_size += model['args']['hid_size']
model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(path)['model_state_dict']
new_state_dict = {}
for k, v in checkpoint.items():
new_key = k.replace('module.', '')
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
pretrained_layers = list(model.children())
pretrained_layers = pretrained_layers[:-1]
model = nn.Sequential(*pretrained_layers)
# Freeze Weights
print(f"Freeze Pretraining = {frozen_pretraining}")
if (frozen_pretraining):
for param in model.parameters():
param.requires_grad = False # Freeze all layers
self.pretraining_params.append(model)
self.pretrained_models.append(model)
print(f"len(pretrained_models) = {len(self.pretrained_models)}")
print(f"input size = {input_size}")
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.final_mlp = Make_MLP(input_size + hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
self.model_params.append(self.final_mlp)
self.model_params.append(self.classify)
def TL_node_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][1]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][1]:
x = layer(x)
return x
def TL_edge_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][2]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][2]:
x = layer(x)
return x
def TL_global_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][3]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][3]:
x = layer(x)
return x
def TL_node_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][4]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][4]:
x = layer(x)
return x
def TL_edge_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][5]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][5]:
x = layer(x)
return x
def TL_global_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][6]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][6]:
x = layer(x)
return x
def TL_global_decoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][7]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][7]:
x = layer(x)
return x
def Pretrained_Output(self, g, model_idx):
h = self.TL_node_encoder(g.ndata['features'], model_idx)
e = self.TL_edge_encoder(g.edata['features'], model_idx)
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats, model_idx)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
# h_global = self.TL_global_decoder(h_global, model_idx)
return h_global
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = [h_global]
for i in range(len(self.pretrained_models)):
h_global.append(self.Pretrained_Output(g.clone(), i))
h_global = torch.concatenate(h_global, dim=1)
return self.classify(self.final_mlp(h_global))
def to(self, device):
for i in range(len(self.pretrained_models)):
self.pretrained_models[i].to(device)
self.classify.to(device)
self.final_mlp.to(device)
self.node_encoder.to(device)
self.edge_encoder.to(device)
self.global_encoder.to(device)
self.node_update.to(device)
self.edge_update.to(device)
self.global_update.to(device)
return self
def parameters(self, recurse: bool = True):
params = []
for i, model_section in enumerate(self.pretraining_params):
if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
print(f"Pretraining LR = {self.learning_rate['pretraining_lr'][i]}")
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"][i]})
else:
print(f"Pretraining LR = 0.00001")
params.append({'params': model_section.parameters(), 'lr': 0.00001})
for model_section in self.model_params:
if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
print(f"Model LR = {self.learning_rate['model_lr']}")
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
else:
print(f"Model LR = 0.0001")
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class Clustering(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.hid_size = hid_size
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, out_size, n_layers, dropout=dropout)
def model_forward(self, g, global_feats, features = 'features'):
h = self.node_encoder(g.ndata[features])
e = self.edge_encoder(g.edata[features])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata[features] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
if "w" in g.ndata:
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
else:
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
h_global = self.model_forward(g, global_feats, 'features')
h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
return torch.cat((h_global, h_global_augmented), dim=1)
def representation(self, g, global_feats):
h_global = self.model_forward(g, global_feats, 'features')
h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
return h_global, h_global_augmented, torch.cat((h_global, h_global_augmented), dim=1)
def __str__(self):
layer_names = ["node_encoder", "edge_encoder", "global_encoder",
"node_update", "edge_update", "global_update", "global_decoder"]
layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
self.node_update, self.edge_update, self.global_update, self.global_decoder]
for i in range(len(layers)):
print(layer_names[i])
for layer in layers[i].children():
if isinstance(layer, nn.Linear):
print(layer.state_dict())
print("classify")
print(self.classify.weight)
return ""