| import dgl |
| import dgl.nn as dglnn |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import sys |
| import os |
| file_path = os.getcwd() |
| sys.path.append(file_path) |
|
|
| import root_gnn_base.dataset as datasets |
| from root_gnn_base import utils |
|
|
| import gc |
|
|
| def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0): |
| layers = [] |
| layers.append(nn.Linear(in_size, out_size)) |
| layers.append(activation()) |
| layers.append(nn.Dropout(dropout)) |
| return layers |
|
|
| def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0): |
| layers = [] |
| if n_layers > 1: |
| layers += Make_SLP(in_size, hid_size, activation, dropout) |
| for i in range(n_layers-2): |
| layers += Make_SLP(hid_size, hid_size, activation, dropout) |
| layers += Make_SLP(hid_size, out_size, activation, dropout) |
| else: |
| layers += Make_SLP(in_size, out_size, activation, dropout) |
| layers.append(torch.nn.LayerNorm(out_size)) |
| return nn.Sequential(*layers) |
|
|
| class MLP(nn.Module): |
| def __init__(self, in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating MLP: {kwargs}') |
| self.layers = Make_MLP(in_size, hid_size, hid_size, n_layers-1, activation, dropout) |
| self.linear = nn.Linear(hid_size, out_size) |
| |
| def forward(self, x): |
| return self.linear(self.layers(x)) |
|
|
| def broadcast_global_to_nodes(g, globals): |
| boundaries = g.batch_num_nodes() |
| return torch.repeat_interleave(globals, boundaries, dim=0) |
|
|
| def broadcast_global_to_edges(g, globals): |
| boundaries = g.batch_num_edges() |
| return torch.repeat_interleave(globals, boundaries, dim=0) |
|
|
| def copy_v(edges): |
| return {'m_v': edges.dst['h']} |
|
|
| def partial_reset(model : nn.Module): |
| in_size = len(model.classify.weight[0]) |
| out_size = len(model.classify.weight) |
| device = next(model.classify.parameters()).device |
| torch.manual_seed(2) |
| model.classify = nn.Linear(in_size, out_size) |
| model.classify.to(device) |
| print(model.classify.weight) |
|
|
| def print_model(model: nn.Module): |
| print(model) |
|
|
| def print_mlp(layer): |
| for l in layer.children(): |
| if isinstance(l, nn.Linear): |
| print(l.state_dict()) |
| else: |
| print(l) |
|
|
|
|
| |
| def full_reset(model : nn.Module): |
| mlp_list = [model.node_encoder, model.edge_encoder, model.global_encoder, |
| model.node_update, model.edge_update, model.global_update, |
| model.global_decoder] |
| |
| for mlp in mlp_list: |
| for layer in mlp.children(): |
| if hasattr(layer, 'reset_parameters'): |
| layer.reset_parameters() |
| partial_reset(model) |
|
|
|
|
| class GCN(nn.Module): |
| def __init__(self, in_size, hid_size, out_size, n_layers, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.layers = nn.ModuleList() |
|
|
| |
| self.layers.extend( |
| [nn.Linear(in_size, hid_size),] + |
| [nn.Linear(hid_size, hid_size) for i in range(n_layers)] + |
| [dglnn.GraphConv(hid_size, hid_size) for i in range(n_layers)] + |
| [nn.Linear(hid_size, hid_size) for i in range(n_layers)] |
| ) |
| self.classify = nn.Linear(hid_size, out_size) |
| |
|
|
| def forward(self, g): |
| h = g.ndata['features'] |
| for i, layer in enumerate(self.layers): |
| if i >= self.n_layers + 1 and i < self.n_layers * 2 + 1: |
| h = layer(g, h) |
| else: |
| h = layer(h) |
| h = F.relu(h) |
| with g.local_scope(): |
| g.ndata['h'] = h |
| |
| hg = dgl.mean_nodes(g, 'h') |
| return self.classify(hg) |
| |
| class GCN_global(nn.Module): |
| def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
|
|
| |
| self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.conv = dglnn.GraphConv(hid_size, hid_size) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def forward(self, g): |
| h = self.node_encoder(g.ndata['features']) |
| h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float)) |
| for i in range(self.n_layers): |
| h = self.node_update(h) |
| h = self.conv(g, h) |
| g.ndata['h'] = h |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
| |
| class GCN_global_2way(nn.Module): |
| def __init__(self, in_size, hid_size=4, out_size=1, n_layers=1, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
|
|
| |
| self.node_encoder = Make_MLP(in_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(1, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.conv = dglnn.GraphConv(hid_size, hid_size) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def forward(self, g): |
| h = self.node_encoder(g.ndata['features']) |
| h_global = self.global_encoder(g.batch_num_nodes()[:, None].to(torch.float)) |
| for i in range(self.n_layers): |
| h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h = self.conv(g, h) |
| g.ndata['h'] = h |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
|
|
| class Edge_Network(nn.Module): |
| def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| if (len(sample_global) == 0): |
| self.has_global = False |
| else: |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def forward(self, g, global_feats): |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
|
|
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
|
|
| batch_num_nodes = None |
| sum_weights = None |
| if "w" in g.ndata: |
| batch_indices = g.batch_num_nodes() |
| |
| non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1) |
| |
| batch_num_nodes = [] |
| start_idx = 0 |
| for num_nodes in batch_indices: |
| end_idx = start_idx + num_nodes |
| non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item() |
| batch_num_nodes.append(non_padded_count) |
| start_idx = end_idx |
| batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device) |
| sum_weights = batch_num_nodes[:, None].repeat(1, 64) |
| global_feats = batch_num_nodes[:, None].to(torch.float) |
|
|
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| if "w" in g.ndata: |
| mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights |
| h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1)) |
| else: |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
| |
| def representation(self, g, global_feats): |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
|
|
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
|
|
| batch_num_nodes = None |
| sum_weights = None |
| if "w" in g.ndata: |
| batch_indices = g.batch_num_nodes() |
| |
| non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1) |
| |
| batch_num_nodes = [] |
| start_idx = 0 |
| for num_nodes in batch_indices: |
| end_idx = start_idx + num_nodes |
| non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item() |
| batch_num_nodes.append(non_padded_count) |
| start_idx = end_idx |
| batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device) |
| sum_weights = batch_num_nodes[:, None].repeat(1, 64) |
| global_feats = batch_num_nodes[:, None].to(torch.float) |
|
|
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| if "w" in g.ndata: |
| mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights |
| h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1)) |
| else: |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| before_global_decoder = h_global |
| after_global_decoder = self.global_decoder(before_global_decoder) |
| after_classify = self.classify(after_global_decoder) |
| return before_global_decoder, after_global_decoder, after_classify |
| |
| def __str__(self): |
| layer_names = ["node_encoder", "edge_encoder", "global_encoder", |
| "node_update", "edge_update", "global_update", "global_decoder"] |
| |
| layers = [self.node_encoder, self.edge_encoder, self.global_encoder, |
| self.node_update, self.edge_update, self.global_update, self.global_decoder] |
| |
| for i in range(len(layers)): |
| print(layer_names[i]) |
| for layer in layers[i].children(): |
| if isinstance(layer, nn.Linear): |
| print(layer.state_dict()) |
| |
| print("classify") |
| print(self.classify.weight) |
| return "" |
|
|
| class Transferred_Learning(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
|
|
| if (len(sample_global) == 0): |
| self.has_global = False |
| else: |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
|
|
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
|
|
| self.global_decoder = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def forward(self, g, global_feats): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.TL_global_decoder(h_global) |
| return self.classify(self.global_decoder(h_global)) |
|
|
| class Transferred_Learning_Graph(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, additional_proc_steps=1, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
|
|
| if (len(sample_global) == 0): |
| self.has_global = False |
| else: |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
|
|
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| self.additional_proc_steps = additional_proc_steps |
|
|
| |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
|
|
| def forward(self, g, global_feats): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| for j in range(self.additional_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
|
|
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
|
|
| class Transferred_Learning_Parallel(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.TL_global_decoder(h_global) |
| return h_global |
|
|
| def forward(self, g, global_feats): |
| pretrained_global = self.Pretrained_Output(g.clone()) |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
|
|
| return self.classify(torch.cat((pretrained_global, h_global), dim = 1)) |
| |
| class Transferred_Learning_Sequential(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.mlp = Make_MLP(pretraining_model['args']['hid_size'], hid_size, hid_size, n_layers, dropout=dropout) |
| |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.TL_global_decoder(h_global) |
| return h_global |
|
|
| def forward(self, g, global_feats): |
| pretrained_global = self.Pretrained_Output(g.clone()) |
| global_features = self.mlp(pretrained_global) |
| return self.classify(global_features) |
|
|
|
|
| class Transferred_Learning_Message_Passing(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.mlp = Make_MLP(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'], hid_size, hid_size, n_layers, dropout=dropout) |
| |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g): |
| message_passing = None |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| if (message_passing is None): |
| message_passing = h_global.clone() |
| else: |
| message_passing = torch.cat((message_passing, h_global.clone()), dim=1) |
| h_global = self.TL_global_decoder(h_global) |
| return message_passing |
|
|
| def forward(self, g, global_feats): |
| pretrained_global = self.Pretrained_Output(g.clone()) |
| |
| |
| global_features = self.mlp(pretrained_global) |
| return self.classify(global_features) |
| |
| class Transferred_Learning_Message_Passing_Parallel(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
|
|
| self.classify = nn.Linear(pretraining_model['args']['hid_size']*pretraining_model['args']['n_proc_steps'] + hid_size, out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g): |
| message_passing = None |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| if (message_passing is None): |
| message_passing = h_global.clone() |
| else: |
| message_passing = torch.cat((message_passing, h_global.clone()), dim=1) |
| h_global = self.TL_global_decoder(h_global) |
| return message_passing |
|
|
| def forward(self, g, global_feats): |
| pretrained_message = self.Pretrained_Output(g.clone()) |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(torch.cat((pretrained_message, h_global), dim = 1)) |
| |
| class Transferred_Learning_Finetuning(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=False, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
|
|
| if (len(sample_global) == 0): |
| self.has_global = False |
| else: |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
|
|
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| print(f"Freeze Pretraining = {frozen_pretraining}") |
| if (frozen_pretraining): |
| for param in self.pretrained_model.parameters(): |
| param.requires_grad = False |
| for param in self.pretrained_model[7]: |
| param.requires_grad = True |
|
|
| torch.manual_seed(2) |
| self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.TL_global_decoder(h_global) |
| return h_global |
|
|
| def forward(self, g, global_feats): |
| h_global = self.Pretrained_Output(g.clone()) |
| return self.classify(h_global) |
| |
| def representation(self, g, global_feats): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| |
| before_global_decoder = h_global |
| after_global_decoder = self.TL_global_decoder(before_global_decoder) |
| after_classify = self.classify(after_global_decoder) |
| return before_global_decoder, after_global_decoder, after_classify |
|
|
| def __str__(self): |
| layer_names = ["node_encoder", "edge_encoder", "global_encoder", |
| "node_update", "edge_update", "global_update", "global_decoder"] |
| |
| layers = [self.pretrained_model[1], self.pretrained_model[2], self.pretrained_model[3], |
| self.pretrained_model[4], self.pretrained_model[5], self.pretrained_model[6], |
| self.pretrained_model[7]] |
| |
| for i in range(len(layers)): |
| print(layer_names[i]) |
| for layer in layers[i].children(): |
| if isinstance(layer, nn.Linear): |
| print(layer.state_dict()) |
| |
| print("classify") |
| print(self.classify.weight) |
| return "" |
| |
|
|
| class Transferred_Learning_Parallel_Finetuning(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, learning_rate=0.0001, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
|
|
| self.learning_rate = learning_rate |
|
|
| self.parallel_params = [] |
| self.finetuning_params = [] |
|
|
|
|
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| self.finetuning_params.append(self.pretrained_model) |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size + pretraining_model['args']['hid_size'], out_size) |
|
|
| self.parallel_params.append(self.node_encoder) |
| self.parallel_params.append(self.edge_encoder) |
| self.parallel_params.append(self.global_encoder) |
| self.parallel_params.append(self.node_update) |
| self.parallel_params.append(self.edge_update) |
| self.parallel_params.append(self.global_update) |
| self.parallel_params.append(self.global_decoder) |
| self.parallel_params.append(self.classify) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x): |
| for layer in self.pretrained_model[2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x): |
| for layer in self.pretrained_model[4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x): |
| for layer in self.pretrained_model[5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x): |
| for layer in self.pretrained_model[6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g): |
| h = self.TL_node_encoder(g.ndata['features']) |
| e = self.TL_edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.TL_global_decoder(h_global) |
| return h_global |
|
|
| def forward(self, g, global_feats): |
| pretrained_global = self.Pretrained_Output(g.clone()) |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
|
|
| return self.classify(torch.cat((pretrained_global, h_global), dim = 1)) |
| |
| def parameters(self, recurse: bool = True): |
| params = [] |
| for model_section in self.parallel_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["trainable_lr"]): |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["trainable_lr"]}) |
| else: |
| params.append({'params': model_section.parameters(), 'lr': 0.0001}) |
| for model_section in self.finetuning_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["finetuning_lr"]): |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["finetuning_lr"]}) |
| else: |
| params.append({'params': model_section.parameters(), 'lr': 0.0001}) |
| return params |
| |
| class Attention(nn.Module): |
| def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| self.hid_size = hid_size |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
| |
| |
| self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True) |
| self.queries = nn.Linear(hid_size, hid_size) |
| self.keys = nn.Linear(hid_size, hid_size) |
| self.values = nn.Linear(hid_size, hid_size) |
|
|
| def forward(self, g, global_feats): |
| h = self.node_encoder(g.ndata['features']) |
| g.ndata['h'] = h |
|
|
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
|
|
| batch_num_nodes = None |
| sum_weights = None |
| if "w" in g.ndata: |
| batch_indices = g.batch_num_nodes() |
| |
| non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1) |
| |
| batch_num_nodes = [] |
| start_idx = 0 |
| for num_nodes in batch_indices: |
| end_idx = start_idx + num_nodes |
| non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item() |
| batch_num_nodes.append(non_padded_count) |
| start_idx = end_idx |
| batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device) |
| sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size) |
| global_feats = batch_num_nodes[:, None].to(torch.float) |
|
|
| h_global = self.global_encoder(global_feats) |
|
|
| h_original_shape = h.shape |
| num_graphs = len(dgl.unbatch(g)) |
| num_nodes = g.batch_num_nodes()[0].item() |
| padding_mask = g.ndata['padding_mask'] > 0 |
| padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes)) |
|
|
| h = g.ndata['h'] |
| query = self.queries(h) |
| key = self.keys(h) |
| value = self.values(h) |
| query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1])) |
| key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1])) |
| value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1])) |
| h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask) |
| h = torch.reshape(h, h_original_shape) |
| |
| h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| g.ndata['h'] = h |
| mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights |
| h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
|
|
| class Attention_Edge_Network(nn.Module): |
| def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
| |
|
|
| |
| self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True) |
| self.queries = nn.Linear(hid_size, hid_size) |
| self.keys = nn.Linear(hid_size, hid_size) |
| self.values = nn.Linear(hid_size, hid_size) |
|
|
| def forward(self, g, global_feats): |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
|
|
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.global_encoder(global_feats) |
|
|
| h = g.ndata['h'] |
| h_original_shape = h.shape |
| num_graphs = len(dgl.unbatch(g)) |
| num_nodes = g.batch_num_nodes()[0].item() |
| padding_mask = g.ndata['padding_mask'].bool() |
|
|
| padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes)) |
|
|
| for i in range(self.n_proc_steps): |
| |
| h = g.ndata['h'] |
| query = self.queries(h) |
| key = self.keys(h) |
| value = self.values(h) |
| query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1])) |
| key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1])) |
| value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1])) |
| h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask) |
| h = torch.reshape(h, h_original_shape) |
|
|
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h', 'w'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
|
|
| class Attention_Unbatched(nn.Module): |
| def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, num_heads = 1, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
| |
|
|
| |
| self.multihead_attn = nn.MultiheadAttention(hid_size, 1, dropout=dropout) |
| self.queries = nn.Linear(hid_size, hid_size) |
| self.keys = nn.Linear(hid_size, hid_size) |
| self.values = nn.Linear(hid_size, hid_size) |
|
|
|
|
|
|
| def forward(self, g, global_feats): |
|
|
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
|
|
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.global_encoder(global_feats) |
|
|
| for i in range(self.n_proc_steps): |
|
|
| unbatched_g = dgl.unbatch(g) |
| for graph in unbatched_g: |
| h = graph.ndata['h'] |
| h, _ = self.multihead_attn(self.queries(h), self.keys(h), self.values(h)) |
| graph.ndata['h'] = h |
| g = dgl.batch(unbatched_g) |
|
|
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return self.classify(h_global) |
| |
| class Transferred_Learning_Attention(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, num_heads, dropout=0, learning_rate=0.0001, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| self.hid_size = hid_size |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.learning_rate = learning_rate |
|
|
| self.pretraining_params = [] |
| self.attention_params = [] |
| |
| self.pretrained_model = utils.buildFromConfig(pretraining_model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
|
|
| checkpoint = torch.load(pretraining_path) |
| self.pretrained_model.load_state_dict(checkpoint['model_state_dict']) |
| pretrained_layers = list(self.pretrained_model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| self.pretrained_model = nn.Sequential(*pretrained_layers) |
|
|
| self.pretraining_params.append(self.pretrained_model[1]) |
| self.pretraining_params.append(self.pretrained_model[3]) |
| self.pretraining_params.append(self.pretrained_model[7]) |
|
|
| |
| self.multihead_attn = nn.MultiheadAttention(hid_size, num_heads, dropout=dropout, batch_first=True) |
| self.queries = nn.Linear(hid_size, hid_size) |
| self.keys = nn.Linear(hid_size, hid_size) |
| self.values = nn.Linear(hid_size, hid_size) |
|
|
| self.node_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(2*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| |
| self.classify = nn.Linear(pretraining_model['args']['hid_size'], out_size) |
|
|
| self.attention_params.append(self.multihead_attn) |
|
|
| self.attention_params.append(self.queries) |
| self.attention_params.append(self.keys) |
| self.attention_params.append(self.values) |
| self.attention_params.append(self.classify) |
| self.attention_params.append(self.node_update) |
| self.attention_params.append(self.global_update) |
|
|
| def TL_node_encoder(self, x): |
| for layer in self.pretrained_model[1]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x): |
| for layer in self.pretrained_model[3]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x): |
| for layer in self.pretrained_model[7]: |
| x = layer(x) |
| return x |
|
|
| def forward(self, g, global_feats): |
| h = self.TL_node_encoder(g.ndata['features']) |
| g.ndata['h'] = h |
|
|
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
|
|
| batch_num_nodes = None |
| sum_weights = None |
| if "w" in g.ndata: |
| batch_indices = g.batch_num_nodes() |
| |
| non_padded_nodes_mask = torch.any(g.ndata['features'] != 0, dim=1) |
| |
| batch_num_nodes = [] |
| start_idx = 0 |
| for num_nodes in batch_indices: |
| end_idx = start_idx + num_nodes |
| non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item() |
| batch_num_nodes.append(non_padded_count) |
| start_idx = end_idx |
| batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device) |
| sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size) |
| global_feats = batch_num_nodes[:, None].to(torch.float) |
|
|
| h_global = self.TL_global_encoder(global_feats) |
|
|
| h_original_shape = h.shape |
| num_graphs = len(dgl.unbatch(g)) |
| num_nodes = g.batch_num_nodes()[0].item() |
| padding_mask = g.ndata['padding_mask'] > 0 |
| padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes)) |
|
|
| h = g.ndata['h'] |
| query = self.queries(h) |
| key = self.keys(h) |
| value = self.values(h) |
| query = torch.reshape(query, (num_graphs, num_nodes, h_original_shape[1])) |
| key = torch.reshape(key, (num_graphs, num_nodes, h_original_shape[1])) |
| value = torch.reshape(value, (num_graphs, num_nodes, h_original_shape[1])) |
| h, _ = self.multihead_attn(query, key, value, key_padding_mask=padding_mask) |
| h = torch.reshape(h, h_original_shape) |
| |
| h = self.node_update(torch.cat((h, broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| g.ndata['h'] = h |
| mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights |
| h_global = self.global_update(torch.cat((h_global, mean_nodes), dim = 1)) |
| h_global = self.TL_global_decoder(h_global) |
| return self.classify(h_global) |
| |
| def parameters(self, recurse: bool = True): |
| params = [] |
| for model_section in self.pretraining_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]): |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]}) |
| else: |
| params.append({'params': model_section.parameters(), 'lr': 0.0001}) |
| for model_section in self.attention_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["attention_lr"]): |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["attention_lr"]}) |
| else: |
| params.append({'params': model_section.parameters(), 'lr': 0.0001}) |
| return params |
| |
| class Multimodel_Transferred_Learning(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.learning_rate = learning_rate |
| input_size = 0 |
|
|
| self.pretraining_params = [] |
| self.model_params = [] |
| |
| self.pretrained_models = [] |
| for model, path in zip(pretraining_model, pretraining_path): |
| input_size += model['args']['hid_size'] |
| model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
|
|
| checkpoint = torch.load(path)['model_state_dict'] |
| new_state_dict = {} |
| for k, v in checkpoint.items(): |
| new_key = k.replace('module.', '') |
| new_state_dict[new_key] = v |
| model.load_state_dict(new_state_dict) |
| pretrained_layers = list(model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| |
| model = nn.Sequential(*pretrained_layers) |
|
|
| |
| print(f"Freeze Pretraining = {frozen_pretraining}") |
| if (frozen_pretraining): |
| for param in model.parameters(): |
| param.requires_grad = False |
| self.pretraining_params.append(model) |
| self.pretrained_models.append(model) |
|
|
| print(f"len(pretrained_models) = {len(self.pretrained_models)}") |
| print(f"input size = {input_size}") |
|
|
| self.final_mlp = Make_MLP(input_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| self.model_params.append(self.final_mlp) |
| self.model_params.append(self.classify) |
|
|
| def TL_node_encoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][1]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][2]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][3]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][4]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][5]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][6]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][7]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g, model_idx): |
| h = self.TL_node_encoder(g.ndata['features'], model_idx) |
| e = self.TL_edge_encoder(g.edata['features'], model_idx) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats, model_idx) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx) |
| |
| return h_global |
|
|
| def forward(self, g, global_feats): |
| h_global = [] |
| for i in range(len(self.pretrained_models)): |
| h_global.append(self.Pretrained_Output(g.clone(), i)) |
| h_global = torch.concatenate(h_global, dim=1) |
| return self.classify(self.final_mlp(h_global)) |
| |
| def to(self, device): |
| for i in range(len(self.pretrained_models)): |
| self.pretrained_models[i].to(device) |
| self.classify.to(device) |
| self.final_mlp.to(device) |
| return self |
| |
| def parameters(self, recurse: bool = True): |
| params = [] |
| for model_section in self.pretraining_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]): |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"]}) |
| else: |
| params.append({'params': model_section.parameters(), 'lr': 0.00001}) |
| for model_section in self.model_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]): |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]}) |
| else: |
| params.append({'params': model_section.parameters(), 'lr': 0.0001}) |
| return params |
|
|
|
|
| class MultiModel(nn.Module): |
| def __init__(self, pretraining_path, pretraining_model, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, frozen_pretraining=True, learning_rate=None, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| self.learning_rate = learning_rate |
| input_size = 0 |
|
|
| self.model_params = [] |
| self.pretraining_params = [] |
| |
| self.pretrained_models = [] |
| for model, path in zip(pretraining_model, pretraining_path): |
| input_size += model['args']['hid_size'] |
| model = utils.buildFromConfig(model, {'sample_graph': sample_graph, 'sample_global': sample_global}) |
|
|
| checkpoint = torch.load(path)['model_state_dict'] |
| new_state_dict = {} |
| for k, v in checkpoint.items(): |
| new_key = k.replace('module.', '') |
| new_state_dict[new_key] = v |
| model.load_state_dict(new_state_dict) |
| pretrained_layers = list(model.children()) |
| pretrained_layers = pretrained_layers[:-1] |
| |
| model = nn.Sequential(*pretrained_layers) |
|
|
| |
| print(f"Freeze Pretraining = {frozen_pretraining}") |
| if (frozen_pretraining): |
| for param in model.parameters(): |
| param.requires_grad = False |
| self.pretraining_params.append(model) |
| self.pretrained_models.append(model) |
|
|
| print(f"len(pretrained_models) = {len(self.pretrained_models)}") |
| print(f"input size = {input_size}") |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| |
| self.final_mlp = Make_MLP(input_size + hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.classify = nn.Linear(hid_size, out_size) |
|
|
| self.model_params.append(self.final_mlp) |
| self.model_params.append(self.classify) |
|
|
| def TL_node_encoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][1]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][1]: |
| x = layer(x) |
| return x |
| |
| def TL_edge_encoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][2]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][2]: |
| x = layer(x) |
| return x |
| |
| def TL_global_encoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][3]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][3]: |
| x = layer(x) |
| return x |
| |
| def TL_node_update(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][4]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][4]: |
| x = layer(x) |
| return x |
|
|
| def TL_edge_update(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][5]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][5]: |
| x = layer(x) |
| return x |
|
|
| def TL_global_update(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][6]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][6]: |
| x = layer(x) |
| return x |
| |
| def TL_global_decoder(self, x, model_idx): |
| try: |
| for layer in self.pretrained_models[model_idx][7]: |
| x = layer(x) |
| return x |
| except (NotImplementedError, IndexError): |
| for layer in self.pretrained_models[model_idx][1][7]: |
| x = layer(x) |
| return x |
|
|
| def Pretrained_Output(self, g, model_idx): |
| h = self.TL_node_encoder(g.ndata['features'], model_idx) |
| e = self.TL_edge_encoder(g.edata['features'], model_idx) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.TL_global_encoder(global_feats, model_idx) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.TL_edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1), model_idx) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.TL_node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1), model_idx) |
| h_global = self.TL_global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1), model_idx) |
| |
| return h_global |
|
|
| def forward(self, g, global_feats): |
| h = self.node_encoder(g.ndata['features']) |
| e = self.edge_encoder(g.edata['features']) |
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = [h_global] |
| for i in range(len(self.pretrained_models)): |
| h_global.append(self.Pretrained_Output(g.clone(), i)) |
| h_global = torch.concatenate(h_global, dim=1) |
| return self.classify(self.final_mlp(h_global)) |
| |
| def to(self, device): |
| for i in range(len(self.pretrained_models)): |
| self.pretrained_models[i].to(device) |
| self.classify.to(device) |
| self.final_mlp.to(device) |
| self.node_encoder.to(device) |
| self.edge_encoder.to(device) |
| self.global_encoder.to(device) |
|
|
| self.node_update.to(device) |
| self.edge_update.to(device) |
| self.global_update.to(device) |
| return self |
| |
| def parameters(self, recurse: bool = True): |
| params = [] |
| for i, model_section in enumerate(self.pretraining_params): |
| if (type(self.learning_rate) == dict and self.learning_rate["pretraining_lr"]): |
| print(f"Pretraining LR = {self.learning_rate['pretraining_lr'][i]}") |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["pretraining_lr"][i]}) |
| else: |
| print(f"Pretraining LR = 0.00001") |
| params.append({'params': model_section.parameters(), 'lr': 0.00001}) |
| for model_section in self.model_params: |
| if (type(self.learning_rate) == dict and self.learning_rate["model_lr"]): |
| print(f"Model LR = {self.learning_rate['model_lr']}") |
| params.append({'params': model_section.parameters(), 'lr': self.learning_rate["model_lr"]}) |
| else: |
| print(f"Model LR = 0.0001") |
| params.append({'params': model_section.parameters(), 'lr': 0.0001}) |
| return params |
|
|
|
|
| class Clustering(nn.Module): |
| def __init__(self, sample_graph, sample_global, hid_size, out_size, n_layers, n_proc_steps, dropout=0, **kwargs): |
| super().__init__() |
| print(f'Unused args while creating GCN: {kwargs}') |
| self.n_layers = n_layers |
| self.n_proc_steps = n_proc_steps |
| self.layers = nn.ModuleList() |
| self.hid_size = hid_size |
| if (len(sample_global) == 0): |
| self.has_global = False |
| else: |
| self.has_global = sample_global.shape[1] != 0 |
| gl_size = sample_global.shape[1] if self.has_global else 1 |
|
|
| |
| self.node_encoder = Make_MLP(sample_graph.ndata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_encoder = Make_MLP(sample_graph.edata['features'].shape[1], hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_encoder = Make_MLP(gl_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.node_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.edge_update = Make_MLP(4*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
| self.global_update = Make_MLP(3*hid_size, hid_size, hid_size, n_layers, dropout=dropout) |
|
|
| |
| self.global_decoder = Make_MLP(hid_size, hid_size, out_size, n_layers, dropout=dropout) |
|
|
| def model_forward(self, g, global_feats, features = 'features'): |
| h = self.node_encoder(g.ndata[features]) |
| e = self.edge_encoder(g.edata[features]) |
|
|
| g.ndata['h'] = h |
| g.edata['e'] = e |
| if not self.has_global: |
| global_feats = g.batch_num_nodes()[:, None].to(torch.float) |
|
|
| batch_num_nodes = None |
| sum_weights = None |
| if "w" in g.ndata: |
| batch_indices = g.batch_num_nodes() |
| |
| non_padded_nodes_mask = torch.any(g.ndata[features] != 0, dim=1) |
| |
| batch_num_nodes = [] |
| start_idx = 0 |
| for num_nodes in batch_indices: |
| end_idx = start_idx + num_nodes |
| non_padded_count = non_padded_nodes_mask[start_idx:end_idx].sum().item() |
| batch_num_nodes.append(non_padded_count) |
| start_idx = end_idx |
| batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device) |
| sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size) |
| global_feats = batch_num_nodes[:, None].to(torch.float) |
|
|
| h_global = self.global_encoder(global_feats) |
| for i in range(self.n_proc_steps): |
| g.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| g.apply_edges(copy_v) |
| g.edata['e'] = self.edge_update(torch.cat((g.edata['e'], g.edata['m_u'], g.edata['m_v'], broadcast_global_to_edges(g, h_global)), dim = 1)) |
| g.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| g.ndata['h'] = self.node_update(torch.cat((g.ndata['h'], g.ndata['h_e'], broadcast_global_to_nodes(g, h_global)), dim = 1)) |
| if "w" in g.ndata: |
| mean_nodes = dgl.sum_nodes(g, 'h', 'w') / sum_weights |
| h_global = self.global_update(torch.cat((h_global, mean_nodes, dgl.mean_edges(g, 'e')), dim = 1)) |
| else: |
| h_global = self.global_update(torch.cat((h_global, dgl.mean_nodes(g, 'h'), dgl.mean_edges(g, 'e')), dim = 1)) |
| h_global = self.global_decoder(h_global) |
| return h_global |
| |
| def forward(self, g, global_feats): |
| h_global = self.model_forward(g, global_feats, 'features') |
| h_global_augmented = self.model_forward(g, global_feats, 'augmented_features') |
| return torch.cat((h_global, h_global_augmented), dim=1) |
|
|
| def representation(self, g, global_feats): |
| h_global = self.model_forward(g, global_feats, 'features') |
| h_global_augmented = self.model_forward(g, global_feats, 'augmented_features') |
| return h_global, h_global_augmented, torch.cat((h_global, h_global_augmented), dim=1) |
| |
| def __str__(self): |
| layer_names = ["node_encoder", "edge_encoder", "global_encoder", |
| "node_update", "edge_update", "global_update", "global_decoder"] |
| |
| layers = [self.node_encoder, self.edge_encoder, self.global_encoder, |
| self.node_update, self.edge_update, self.global_update, self.global_decoder] |
| |
| for i in range(len(layers)): |
| print(layer_names[i]) |
| for layer in layers[i].children(): |
| if isinstance(layer, nn.Linear): |
| print(layer.state_dict()) |
| |
| print("classify") |
| print(self.classify.weight) |
| return "" |