root_gnn_dgl/models/GCN.py CHANGED
@@ -1154,6 +1154,7 @@ class Attention(nn.Module):
1154
  self.n_proc_steps = n_proc_steps
1155
  self.layers = nn.ModuleList()
1156
  self.has_global = sample_global.shape[1] != 0
 
1157
  gl_size = sample_global.shape[1] if self.has_global else 1
1158
 
1159
  #encoder
@@ -1196,7 +1197,7 @@ class Attention(nn.Module):
1196
  batch_num_nodes.append(non_padded_count)
1197
  start_idx = end_idx
1198
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1199
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1200
  global_feats = batch_num_nodes[:, None].to(torch.float)
1201
 
1202
  h_global = self.global_encoder(global_feats)
@@ -1364,6 +1365,7 @@ class Transferred_Learning_Attention(nn.Module):
1364
  self.n_proc_steps = n_proc_steps
1365
  self.layers = nn.ModuleList()
1366
  self.has_global = sample_global.shape[1] != 0
 
1367
  gl_size = sample_global.shape[1] if self.has_global else 1
1368
 
1369
  self.learning_rate = learning_rate
@@ -1440,7 +1442,7 @@ class Transferred_Learning_Attention(nn.Module):
1440
  batch_num_nodes.append(non_padded_count)
1441
  start_idx = end_idx
1442
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1443
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1444
  global_feats = batch_num_nodes[:, None].to(torch.float)
1445
 
1446
  h_global = self.TL_global_encoder(global_feats)
@@ -1856,6 +1858,7 @@ class Clustering(nn.Module):
1856
  self.n_layers = n_layers
1857
  self.n_proc_steps = n_proc_steps
1858
  self.layers = nn.ModuleList()
 
1859
  if (len(sample_global) == 0):
1860
  self.has_global = False
1861
  else:
@@ -1899,7 +1902,7 @@ class Clustering(nn.Module):
1899
  batch_num_nodes.append(non_padded_count)
1900
  start_idx = end_idx
1901
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1902
- sum_weights = batch_num_nodes[:, None].repeat(1, 64)
1903
  global_feats = batch_num_nodes[:, None].to(torch.float)
1904
 
1905
  h_global = self.global_encoder(global_feats)
 
1154
  self.n_proc_steps = n_proc_steps
1155
  self.layers = nn.ModuleList()
1156
  self.has_global = sample_global.shape[1] != 0
1157
+ self.hid_size = hid_size
1158
  gl_size = sample_global.shape[1] if self.has_global else 1
1159
 
1160
  #encoder
 
1197
  batch_num_nodes.append(non_padded_count)
1198
  start_idx = end_idx
1199
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1200
+ sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1201
  global_feats = batch_num_nodes[:, None].to(torch.float)
1202
 
1203
  h_global = self.global_encoder(global_feats)
 
1365
  self.n_proc_steps = n_proc_steps
1366
  self.layers = nn.ModuleList()
1367
  self.has_global = sample_global.shape[1] != 0
1368
+ self.hid_size = hid_size
1369
  gl_size = sample_global.shape[1] if self.has_global else 1
1370
 
1371
  self.learning_rate = learning_rate
 
1442
  batch_num_nodes.append(non_padded_count)
1443
  start_idx = end_idx
1444
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
1445
+ sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1446
  global_feats = batch_num_nodes[:, None].to(torch.float)
1447
 
1448
  h_global = self.TL_global_encoder(global_feats)
 
1858
  self.n_layers = n_layers
1859
  self.n_proc_steps = n_proc_steps
1860
  self.layers = nn.ModuleList()
1861
+ self.hid_size = hid_size
1862
  if (len(sample_global) == 0):
1863
  self.has_global = False
1864
  else:
 
1902
  batch_num_nodes.append(non_padded_count)
1903
  start_idx = end_idx
1904
  batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
1905
+ sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
1906
  global_feats = batch_num_nodes[:, None].to(torch.float)
1907
 
1908
  h_global = self.global_encoder(global_feats)
root_gnn_dgl/root_gnn_base/batched_dataset.py CHANGED
@@ -16,7 +16,7 @@ def GetBatchedLoader(dataset, batch_size, mask_fn = None, drop_last=True, **kwar
16
 
17
  #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
  class PreBatchedDataset(DGLDataset):
19
- def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', **kwargs):
20
  print(f'Unused kwargs: {kwargs}')
21
  self.start_dataset = start_dataset
22
  self.start_dataset.load()
@@ -34,6 +34,7 @@ class PreBatchedDataset(DGLDataset):
34
  self.suffix = suffix
35
  self.current_chunk = None
36
  self.current_chunk_idx = -1
 
37
  super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
38
 
39
  def process(self):
@@ -86,7 +87,7 @@ class PreBatchedDataset(DGLDataset):
86
  for i in range(len(self.graphs)):
87
  unbatched_g = dgl.unbatch(self.graphs[i])
88
  max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
89
- self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes)
90
  self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
91
  self.batch_num_edges.append(self.graphs[i].batch_num_edges())
92
  else:
 
16
 
17
  #Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
18
  class PreBatchedDataset(DGLDataset):
19
+ def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', hidden_size=64, **kwargs):
20
  print(f'Unused kwargs: {kwargs}')
21
  self.start_dataset = start_dataset
22
  self.start_dataset.load()
 
34
  self.suffix = suffix
35
  self.current_chunk = None
36
  self.current_chunk_idx = -1
37
+ self.hid_size = hidden_size
38
  super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
39
 
40
  def process(self):
 
87
  for i in range(len(self.graphs)):
88
  unbatched_g = dgl.unbatch(self.graphs[i])
89
  max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
90
+ self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes, hid_size=self.hid_size)
91
  self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
92
  self.batch_num_edges.append(self.graphs[i].batch_num_edges())
93
  else:
root_gnn_dgl/root_gnn_base/dataset.py CHANGED
@@ -162,6 +162,7 @@ class RootDataset(DGLDataset):
162
  self.global_chunks = []
163
  chunk_id = -1
164
  for chunk in chunks:
 
165
  chunk_id += 1
166
  graphs = []
167
  labels = []
@@ -198,6 +199,12 @@ class RootDataset(DGLDataset):
198
  labels = torch.tensor(labels)
199
  tracking = torch.stack(tracking)
200
  globals = torch.stack(globals)
 
 
 
 
 
 
201
 
202
  if (self.chunks > 1):
203
  self.save_chunk(chunk_id, graphs, labels, tracking, globals)
@@ -223,16 +230,12 @@ class RootDataset(DGLDataset):
223
  return
224
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
225
  if self.chunks == 1:
226
- # print(len(self.graphs))
227
- # print(len(self.labels))
228
- # print(len(self.tracking))
229
- # print(len(self.globals))
230
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
231
  dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
232
  else:
233
- print(len(self.graph_chunks))
234
  for i in range(len(self.process_chunks)):
235
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
 
236
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
237
 
238
  def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
 
162
  self.global_chunks = []
163
  chunk_id = -1
164
  for chunk in chunks:
165
+ print('Processing chunk {}/{}'.format(chunk_id + 1, len(chunks)), flush=True)
166
  chunk_id += 1
167
  graphs = []
168
  labels = []
 
199
  labels = torch.tensor(labels)
200
  tracking = torch.stack(tracking)
201
  globals = torch.stack(globals)
202
+
203
+ self.graph_chunks.append(graphs)
204
+ self.label_chunks.append(labels)
205
+ self.tracking_chunks.append(tracking)
206
+ self.global_chunks.append(globals)
207
+ self.counts.append(len(graphs))
208
 
209
  if (self.chunks > 1):
210
  self.save_chunk(chunk_id, graphs, labels, tracking, globals)
 
230
  return
231
  graph_path = os.path.join(self.save_dir, self.name + '.bin')
232
  if self.chunks == 1:
 
 
 
 
233
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
234
  dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
235
  else:
 
236
  for i in range(len(self.process_chunks)):
237
  print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
238
+
239
  dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
240
 
241
  def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
root_gnn_dgl/root_gnn_base/utils.py CHANGED
@@ -92,7 +92,7 @@ def pad_batch(batch, edges = 104000, nodes = 16000):
92
  return make_padding_graph(batch, pad_nodes, pad_edges)
93
 
94
  def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
95
- print(f"Padding each graph to have {max_num_nodes} nodes")
96
 
97
  unbatched = dgl.unbatch(batch)
98
  for g in unbatched:
 
92
  return make_padding_graph(batch, pad_nodes, pad_edges)
93
 
94
  def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
95
+ print(f"Padding each graph to have {max_num_nodes} nodes. Using hidden size {hid_size}.")
96
 
97
  unbatched = dgl.unbatch(batch)
98
  for g in unbatched:
root_gnn_dgl/scripts/prep_data.py CHANGED
@@ -33,12 +33,12 @@ def main():
33
  fold_conf = dset_config["folding"]
34
  print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
35
  if dset_config["class"] == "LazyMultiLabelDataset":
36
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
37
- LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
38
 
39
  else:
40
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
41
- PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
42
 
43
  if __name__ == "__main__":
44
- main()
 
33
  fold_conf = dset_config["folding"]
34
  print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
35
  if dset_config["class"] == "LazyMultiLabelDataset":
36
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] )
37
+ LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'])
38
 
39
  else:
40
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'])
41
+ PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] )
42
 
43
  if __name__ == "__main__":
44
+ main()
root_gnn_dgl/scripts/training_script.py CHANGED
@@ -707,7 +707,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
707
  mask_fn = utils.fold_selection(fold_conf, "train")
708
  if args.preshuffle:
709
  # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
710
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode)
711
  gsamp, _, _, global_samp = ldr[0]
712
  sampler = None
713
 
@@ -724,7 +724,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
724
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
725
  train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
726
  sampler = None
727
- ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode)
728
  if (args.multigpu):
729
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
730
  # num_batches = len(ldr)
@@ -737,7 +737,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
737
  test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
738
 
739
  if "validation" in fold_conf:
740
- val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
741
  else:
742
  print("No validation set for dataset ", dset_conf)
743
  else:
@@ -753,6 +753,8 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
753
  print("Load time: {:.4f} s".format(load_end - load_start))
754
 
755
  model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
 
 
756
  if not args.nocompile:
757
  model = torch.compile(model)
758
  if args.multigpu:
 
707
  mask_fn = utils.fold_selection(fold_conf, "train")
708
  if args.preshuffle:
709
  # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
710
+ ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size = config["Model"]["args"]["hid_size"])
711
  gsamp, _, _, global_samp = ldr[0]
712
  sampler = None
713
 
 
724
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
725
  train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
726
  sampler = None
727
+ ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size= config['Model']['args']['hid_size'])
728
  if (args.multigpu):
729
  sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
730
  # num_batches = len(ldr)
 
737
  test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
738
 
739
  if "validation" in fold_conf:
740
+ val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, hidden_size=config['Model']['args']['hid_size'], padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
741
  else:
742
  print("No validation set for dataset ", dset_conf)
743
  else:
 
753
  print("Load time: {:.4f} s".format(load_end - load_start))
754
 
755
  model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
756
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
757
+ print(f"Number of trainable parameters = {pytorch_total_params}")
758
  if not args.nocompile:
759
  model = torch.compile(model)
760
  if args.multigpu: