chhultqu
Fixing bug where prep_data needs to be called twice. Allowing for different hidden layer sizes
89e9564
raw
history blame
87.9 kB
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
file_path = os.getcwd()
sys.path.append(file_path)
import root_gnn_base.dataset as datasets
from root_gnn_base import utils
import gc
def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
layers = []
layers.append(nn.Linear(in_size, out_size))
layers.append(activation())
layers.append(nn.Dropout(dropout))
return layers
def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
layers = []
if n_layers > 1:
layers += Make_SLP(in_size, hid_size, activation, dropout)
for i in range(n_layers-2):
layers += Make_SLP(hid_size, hid_size, activation, dropout)
layers += Make_SLP(hid_size, out_size, activation, dropout)
else:
layers += Make_SLP(in_size, out_size, activation, dropout)
layers.append(torch.nn.LayerNorm(out_size))
return nn.Sequential(*layers)
class MLP(nn.Module):
def __init__(self, in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0, **kwargs):
super().__init__()
print(f'Unused args while creating MLP: {kwargs}')
self.layers = Make_MLP(in_size, hid_size, hid_size, n_layers-1, activation, dropout)
self.linear = nn.Linear(hid_size, out_size)
def forward(self, x):
return self.linear(self.layers(x))
def broadcast_global_to_nodes(g, globals):
boundaries = g.batch_num_nodes()
return torch.repeat_interleave(globals, boundaries, dim=0)
def broadcast_global_to_edges(g, globals):
boundaries = g.batch_num_edges()
return torch.repeat_interleave(globals, boundaries, dim=0)
def copy_v(edges):
return {'m_v': edges.dst['h']}
def partial_reset(model : nn.Module):
in_size = len(model.classify.weight[0])
out_size = len(model.classify.weight)
device = next(model.classify.parameters()).device
torch.manual_seed(2)
model.classify = nn.Linear(in_size, out_size)
model.classify.to(device)
print(model.classify.weight)
def print_model(model: nn.Module):
print(model)
def print_mlp(layer):
for l in layer.children():
if isinstance(l, nn.Linear):
print(l.state_dict())
else:
print(l)
# This function returns a model with the whole GNN completely reset
def full_reset(model : nn.Module):
mlp_list = [model.node_encoder, model.edge_encoder, model.global_encoder,
model.node_update, model.edge_update, model.global_update,
model.global_decoder]
for mlp in mlp_list:
for layer in mlp.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
partial_reset(model)
class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size, n_layers, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.layers = nn.ModuleList()
# two-layer GCN
self.layers.extend(
[nn.Linear(in_size, hid_size),] +
[nn.Linear(hid_size, hid_size) for i in range(n_layers)] +
[dglnn.GraphConv(hid_size, hid_size) for i in range(n_layers)] +
[nn.Linear(hid_size, hid_size) for i in range(n_layers)]
)
self.classify = nn.Linear(hid_size, out_size)
#self.dropout = nn.Dropout(0.05)
def forward(self, g):
h = g.ndata['features']
for i, layer in enumerate(self.layers):
if i >= self.n_layers + 1 and i < self.n_layers * 2 + 1:
h = layer(g, h)
else:
h = layer(h)
h = F.relu(h)
with g.local_scope():
g.ndata['h'] = h
# Calculate graph representation by average readout.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
class GCN_global(nn.Module):
def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
#encoder
self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
#GCN
self.node_update = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.conv = dglnn.GraphConv(hid_size, hid_size)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def forward(self, g):
h = self.node_encoder(g.ndata['features'])
h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
for i in range(self.n_layers):
h = self.node_update(h)
h = self.conv(g, h)
g.ndata['h'] = h
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class GCN_global_2way(nn.Module):
def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
#encoder
self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout)
#GCN
self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.conv = dglnn.GraphConv(hid_size, hid_size)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def forward(self, g):
h = self.node_encoder(g.ndata['features'])
h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float))
for i in range(self.n_layers):
h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
h = self.conv(g, h)
g.ndata['h'] = h
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Edge_Network(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, 64)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
if "w" in g.ndata:
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
else:
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
def representation(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, 64)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
if "w" in g.ndata:
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
else:
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
before_global_decoder = h_global
after_global_decoder = self.global_decoder(before_global_decoder)
after_classify = self.classify(after_global_decoder)
return before_global_decoder, after_global_decoder, after_classify
def __str__(self):
layer_names = ["node_encoder", "edge_encoder", "global_encoder",
"node_update", "edge_update", "global_update", "global_decoder"]
layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
self.node_update, self.edge_update, self.global_update, self.global_decoder]
for i in range(len(layers)):
print(layer_names[i])
for layer in layers[i].children():
if isinstance(layer, nn.Linear):
print(layer.state_dict())
print("classify")
print(self.classify.weight)
return ""
class Transferred_Learning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
self.global_decoder = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def forward(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return self.classify(self.global_decoder(h_global))
class Transferred_Learning_Graph(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, additional_proc_steps=1, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
self.additional_proc_steps = additional_proc_steps
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def forward(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
for j in range(self.additional_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Transferred_Learning_Parallel(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
class Transferred_Learning_Sequential(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#encoder
self.mlp = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
global_features = self.mlp(pretrained_global)
return self.classify(global_features)
class Transferred_Learning_Message_Passing(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
#encoder
self.mlp = Make_MLP(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'], hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
message_passing = None
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
if (message_passing is None):
message_passing = h_global.clone()
else:
message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
h_global = self.TL_global_decoder(h_global)
return message_passing
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
#print(f"message_passing layers have size = {pretrained_global.shape}")
#print(pretrained_global)
global_features = self.mlp(pretrained_global)
return self.classify(global_features)
class Transferred_Learning_Message_Passing_Parallel(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
# Freeze Weights
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
self.classify = nn.Linear(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'] + hid_size, out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
message_passing = None
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
if (message_passing is None):
message_passing = h_global.clone()
else:
message_passing = torch.cat((message_passing, h_global.clone()), dim=1)
h_global = self.TL_global_decoder(h_global)
return message_passing
def forward(self, g, global_feats):
pretrained_message = self.Pretrained_Output(g.clone())
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(torch.cat((pretrained_message, h_global), dim = 1))
class Transferred_Learning_Finetuning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=False, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
print(f"Freeze Pretraining = {frozen_pretraining}")
if (frozen_pretraining):
for param in self.pretrained_model.parameters():
param.requires_grad = False # Freeze all layers
for param in self.pretrained_model[7]:
param.requires_grad = True
torch.manual_seed(2)
self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
h_global = self.Pretrained_Output(g.clone())
return self.classify(h_global)
def representation(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
before_global_decoder = h_global
after_global_decoder = self.TL_global_decoder(before_global_decoder)
after_classify = self.classify(after_global_decoder)
return before_global_decoder, after_global_decoder, after_classify
def __str__(self):
layer_names = ["node_encoder", "edge_encoder", "global_encoder",
"node_update", "edge_update", "global_update", "global_decoder"]
layers = [self.pretrained_model[1], self.pretrained_model[2], self.pretrained_model[3],
self.pretrained_model[4], self.pretrained_model[5], self.pretrained_model[6],
self.pretrained_model[7]]
for i in range(len(layers)):
print(layer_names[i])
for layer in layers[i].children():
if isinstance(layer, nn.Linear):
print(layer.state_dict())
print("classify")
print(self.classify.weight)
return ""
class Transferred_Learning_Parallel_Finetuning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, learning_rate=0.0001, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.learning_rate = learning_rate
self.parallel_params = []
self.finetuning_params = []
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
self.finetuning_params.append(self.pretrained_model)
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size)
self.parallel_params.append(self.node_encoder)
self.parallel_params.append(self.edge_encoder)
self.parallel_params.append(self.global_encoder)
self.parallel_params.append(self.node_update)
self.parallel_params.append(self.edge_update)
self.parallel_params.append(self.global_update)
self.parallel_params.append(self.global_decoder)
self.parallel_params.append(self.classify)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_edge_encoder(self, x):
for layer in self.pretrained_model[2]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_node_update(self, x):
for layer in self.pretrained_model[4]:
x = layer(x)
return x
def TL_edge_update(self, x):
for layer in self.pretrained_model[5]:
x = layer(x)
return x
def TL_global_update(self, x):
for layer in self.pretrained_model[6]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def Pretrained_Output(self, g):
h = self.TL_node_encoder(g.ndata['features'])
e = self.TL_edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.TL_global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
pretrained_global = self.Pretrained_Output(g.clone())
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(torch.cat((pretrained_global, h_global), dim = 1))
def parameters(self, recurse: bool = True):
params = []
for model_section in self.parallel_params:
if (type(self.learning_rate) == dict and self.learning_rate["trainable_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["trainable_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
for model_section in self.finetuning_params:
if (type(self.learning_rate) == dict and self.learning_rate["finetuning_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["finetuning_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class Attention(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
self.hid_size = hid_size
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
g.ndata['h'] = h
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
h_original_shape = h.shape
num_graphs = len(dgl.unbatch(g))
num_nodes = g.batch_num_nodes()[0].item()
padding_mask = g.ndata['padding_mask'] > 0
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
h = g.ndata['h']
query = self.queries(h)
key = self.keys(h)
value = self.values(h)
query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
h = torch.reshape(h, h_original_shape)
h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
g.ndata['h'] = h
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Attention_Edge_Network(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
h = g.ndata['h']
h_original_shape = h.shape
num_graphs = len(dgl.unbatch(g))
num_nodes = g.batch_num_nodes()[0].item()
padding_mask = g.ndata['padding_mask'].bool()
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
for i in range(self.n_proc_steps):
h = g.ndata['h']
query = self.queries(h)
key = self.keys(h)
value = self.values(h)
query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
h = torch.reshape(h, h_original_shape)
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h', 'w'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Attention_Unbatched(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, 1, dropout=dropout)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
unbatched_g = dgl.unbatch(g)
for graph in unbatched_g:
h = graph.ndata['h']
h, _ = self.multihead_attn(self.queries(h), self.keys(h), self.values(h))
graph.ndata['h'] = h
g = dgl.batch(unbatched_g)
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return self.classify(h_global)
class Transferred_Learning_Attention(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, num_heads, dropout=0, learning_rate=0.0001, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
self.hid_size = hid_size
gl_size = sample_global.shape[1] if self.has_global else 1
self.learning_rate = learning_rate
self.pretraining_params = []
self.attention_params = []
self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(pretraining_path)
self.pretrained_model.load_state_dict(checkpoint['model_state_dict'])
pretrained_layers = list(self.pretrained_model.children())
pretrained_layers = pretrained_layers[:-1]
self.pretrained_model = nn.Sequential(*pretrained_layers)
self.pretraining_params.append(self.pretrained_model[1])
self.pretraining_params.append(self.pretrained_model[3])
self.pretraining_params.append(self.pretrained_model[7])
#attention
self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True)
self.queries = nn.Linear(hid_size, hid_size)
self.keys = nn.Linear(hid_size, hid_size)
self.values = nn.Linear(hid_size, hid_size)
self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size)
self.attention_params.append(self.multihead_attn)
self.attention_params.append(self.queries)
self.attention_params.append(self.keys)
self.attention_params.append(self.values)
self.attention_params.append(self.classify)
self.attention_params.append(self.node_update)
self.attention_params.append(self.global_update)
def TL_node_encoder(self, x):
for layer in self.pretrained_model[1]:
x = layer(x)
return x
def TL_global_encoder(self, x):
for layer in self.pretrained_model[3]:
x = layer(x)
return x
def TL_global_decoder(self, x):
for layer in self.pretrained_model[7]:
x = layer(x)
return x
def forward(self, g, global_feats):
h = self.TL_node_encoder(g.ndata['features'])
g.ndata['h'] = h
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats)
h_original_shape = h.shape
num_graphs = len(dgl.unbatch(g))
num_nodes = g.batch_num_nodes()[0].item()
padding_mask = g.ndata['padding_mask'] > 0
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
h = g.ndata['h']
query = self.queries(h)
key = self.keys(h)
value = self.values(h)
query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1]))
key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1]))
value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1]))
h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask)
h = torch.reshape(h, h_original_shape)
h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1))
g.ndata['h'] = h
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1))
h_global = self.TL_global_decoder(h_global)
return self.classify(h_global)
def parameters(self, recurse: bool = True):
params = []
for model_section in self.pretraining_params:
if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
for model_section in self.attention_params:
if (type(self.learning_rate) == dict and self.learning_rate["attention_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["attention_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class Multimodel_Transferred_Learning(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.learning_rate = learning_rate
input_size = 0
self.pretraining_params = []
self.model_params = []
self.pretrained_models = []
for model, path in zip(pretraining_model, pretraining_path):
input_size += model['args']['hid_size']
model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(path)['model_state_dict']
new_state_dict = {}
for k, v in checkpoint.items():
new_key = k.replace('module.', '')
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
pretrained_layers = list(model.children())
pretrained_layers = pretrained_layers[:-1]
model = nn.Sequential(*pretrained_layers)
# Freeze Weights
print(f"Freeze Pretraining = {frozen_pretraining}")
if (frozen_pretraining):
for param in model.parameters():
param.requires_grad = False # Freeze all layers
self.pretraining_params.append(model)
self.pretrained_models.append(model)
print(f"len(pretrained_models) = {len(self.pretrained_models)}")
print(f"input size = {input_size}")
self.final_mlp = Make_MLP(input_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
self.model_params.append(self.final_mlp)
self.model_params.append(self.classify)
def TL_node_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][1]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][1]:
x = layer(x)
return x
def TL_edge_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][2]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][2]:
x = layer(x)
return x
def TL_global_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][3]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][3]:
x = layer(x)
return x
def TL_node_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][4]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][4]:
x = layer(x)
return x
def TL_edge_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][5]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][5]:
x = layer(x)
return x
def TL_global_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][6]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][6]:
x = layer(x)
return x
def TL_global_decoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][7]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][7]:
x = layer(x)
return x
def Pretrained_Output(self, g, model_idx):
h = self.TL_node_encoder(g.ndata['features'], model_idx)
e = self.TL_edge_encoder(g.edata['features'], model_idx)
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats, model_idx)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
# h_global = self.TL_global_decoder(h_global, model_idx)
return h_global
def forward(self, g, global_feats):
h_global = []
for i in range(len(self.pretrained_models)):
h_global.append(self.Pretrained_Output(g.clone(), i))
h_global = torch.concatenate(h_global, dim=1)
return self.classify(self.final_mlp(h_global))
def to(self, device):
for i in range(len(self.pretrained_models)):
self.pretrained_models[i].to(device)
self.classify.to(device)
self.final_mlp.to(device)
return self
def parameters(self, recurse: bool = True):
params = []
for model_section in self.pretraining_params:
if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.00001})
for model_section in self.model_params:
if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
else:
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class MultiModel(nn.Module):
def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
self.learning_rate = learning_rate
input_size = 0
self.model_params = []
self.pretraining_params = []
self.pretrained_models = []
for model, path in zip(pretraining_model, pretraining_path):
input_size += model['args']['hid_size']
model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global})
checkpoint = torch.load(path)['model_state_dict']
new_state_dict = {}
for k, v in checkpoint.items():
new_key = k.replace('module.', '')
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
pretrained_layers = list(model.children())
pretrained_layers = pretrained_layers[:-1]
model = nn.Sequential(*pretrained_layers)
# Freeze Weights
print(f"Freeze Pretraining = {frozen_pretraining}")
if (frozen_pretraining):
for param in model.parameters():
param.requires_grad = False # Freeze all layers
self.pretraining_params.append(model)
self.pretrained_models.append(model)
print(f"len(pretrained_models) = {len(self.pretrained_models)}")
print(f"input size = {input_size}")
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.final_mlp = Make_MLP(input_size + hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.classify = nn.Linear(hid_size, out_size)
self.model_params.append(self.final_mlp)
self.model_params.append(self.classify)
def TL_node_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][1]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][1]:
x = layer(x)
return x
def TL_edge_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][2]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][2]:
x = layer(x)
return x
def TL_global_encoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][3]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][3]:
x = layer(x)
return x
def TL_node_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][4]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][4]:
x = layer(x)
return x
def TL_edge_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][5]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][5]:
x = layer(x)
return x
def TL_global_update(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][6]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][6]:
x = layer(x)
return x
def TL_global_decoder(self, x, model_idx):
try:
for layer in self.pretrained_models[model_idx][7]:
x = layer(x)
return x
except (NotImplementedError, IndexError):
for layer in self.pretrained_models[model_idx][1][7]:
x = layer(x)
return x
def Pretrained_Output(self, g, model_idx):
h = self.TL_node_encoder(g.ndata['features'], model_idx)
e = self.TL_edge_encoder(g.edata['features'], model_idx)
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.TL_global_encoder(global_feats, model_idx)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx)
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx)
h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx)
# h_global = self.TL_global_decoder(h_global, model_idx)
return h_global
def forward(self, g, global_feats):
h = self.node_encoder(g.ndata['features'])
e = self.edge_encoder(g.edata['features'])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = [h_global]
for i in range(len(self.pretrained_models)):
h_global.append(self.Pretrained_Output(g.clone(), i))
h_global = torch.concatenate(h_global, dim=1)
return self.classify(self.final_mlp(h_global))
def to(self, device):
for i in range(len(self.pretrained_models)):
self.pretrained_models[i].to(device)
self.classify.to(device)
self.final_mlp.to(device)
self.node_encoder.to(device)
self.edge_encoder.to(device)
self.global_encoder.to(device)
self.node_update.to(device)
self.edge_update.to(device)
self.global_update.to(device)
return self
def parameters(self, recurse: bool = True):
params = []
for i, model_section in enumerate(self.pretraining_params):
if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]):
print(f"Pretraining LR = {self.learning_rate['pretraining_lr'][i]}")
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"][i]})
else:
print(f"Pretraining LR = 0.00001")
params.append({'params': model_section.parameters(), 'lr': 0.00001})
for model_section in self.model_params:
if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]):
print(f"Model LR = {self.learning_rate['model_lr']}")
params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]})
else:
print(f"Model LR = 0.0001")
params.append({'params': model_section.parameters(), 'lr': 0.0001})
return params
class Clustering(nn.Module):
def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs):
super().__init__()
print(f'Unused args while creating GCN: {kwargs}')
self.n_layers = n_layers
self.n_proc_steps = n_proc_steps
self.layers = nn.ModuleList()
self.hid_size = hid_size
if (len(sample_global) == 0):
self.has_global = False
else:
self.has_global = sample_global.shape[1] != 0
gl_size = sample_global.shape[1] if self.has_global else 1
#encoder
self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout)
self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout)
#GNN
self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout)
#decoder
self.global_decoder = Make_MLP(hid_size, hid_size, out_size, n_layers, dropout=dropout)
def model_forward(self, g, global_feats, features = 'features'):
h = self.node_encoder(g.ndata[features])
e = self.edge_encoder(g.edata[features])
g.ndata['h'] = h
g.edata['e'] = e
if not self.has_global:
global_feats = g.batch_num_nodes()[:, None].to(torch.float)
batch_num_nodes = None
sum_weights = None
if "w" in g.ndata:
batch_indices = g.batch_num_nodes()
# Find non-zero rows (non-padded nodes)
non_padded_nodes_mask = torch.any(g.ndata[features] != 0, dim=1)
# Split the mask according to the batch indices
batch_num_nodes = []
start_idx = 0
for num_nodes in batch_indices:
end_idx = start_idx + num_nodes
non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item()
batch_num_nodes.append(non_padded_count)
start_idx = end_idx
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
global_feats = batch_num_nodes[:, None].to(torch.float)
h_global = self.global_encoder(global_feats)
for i in range(self.n_proc_steps):
g.apply_edges(dgl.function.copy_u('h', 'm_u'))
g.apply_edges(copy_v)
g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1))
g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1))
if "w" in g.ndata:
mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights
h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1))
else:
h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1))
h_global = self.global_decoder(h_global)
return h_global
def forward(self, g, global_feats):
h_global = self.model_forward(g, global_feats, 'features')
h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
return torch.cat((h_global, h_global_augmented), dim=1)
def representation(self, g, global_feats):
h_global = self.model_forward(g, global_feats, 'features')
h_global_augmented = self.model_forward(g, global_feats, 'augmented_features')
return h_global, h_global_augmented, torch.cat((h_global, h_global_augmented), dim=1)
def __str__(self):
layer_names = ["node_encoder", "edge_encoder", "global_encoder",
"node_update", "edge_update", "global_update", "global_decoder"]
layers = [self.node_encoder, self.edge_encoder, self.global_encoder,
self.node_update, self.edge_update, self.global_update, self.global_decoder]
for i in range(len(layers)):
print(layer_names[i])
for layer in layers[i].children():
if isinstance(layer, nn.Linear):
print(layer.state_dict())
print("classify")
print(self.classify.weight)
return ""