charlie
#3
by
chultquist0
- opened
root_gnn_dgl/models/GCN.py
CHANGED
|
@@ -1154,6 +1154,7 @@ class Attention(nn.Module):
|
|
| 1154 |
self.n_proc_steps = n_proc_steps
|
| 1155 |
self.layers = nn.ModuleList()
|
| 1156 |
self.has_global = sample_global.shape[1] != 0
|
|
|
|
| 1157 |
gl_size = sample_global.shape[1] if self.has_global else 1
|
| 1158 |
|
| 1159 |
#encoder
|
|
@@ -1196,7 +1197,7 @@ class Attention(nn.Module):
|
|
| 1196 |
batch_num_nodes.append(non_padded_count)
|
| 1197 |
start_idx = end_idx
|
| 1198 |
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
|
| 1199 |
-
sum_weights = batch_num_nodes[:, None].repeat(1,
|
| 1200 |
global_feats = batch_num_nodes[:, None].to(torch.float)
|
| 1201 |
|
| 1202 |
h_global = self.global_encoder(global_feats)
|
|
@@ -1364,6 +1365,7 @@ class Transferred_Learning_Attention(nn.Module):
|
|
| 1364 |
self.n_proc_steps = n_proc_steps
|
| 1365 |
self.layers = nn.ModuleList()
|
| 1366 |
self.has_global = sample_global.shape[1] != 0
|
|
|
|
| 1367 |
gl_size = sample_global.shape[1] if self.has_global else 1
|
| 1368 |
|
| 1369 |
self.learning_rate = learning_rate
|
|
@@ -1440,7 +1442,7 @@ class Transferred_Learning_Attention(nn.Module):
|
|
| 1440 |
batch_num_nodes.append(non_padded_count)
|
| 1441 |
start_idx = end_idx
|
| 1442 |
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
|
| 1443 |
-
sum_weights = batch_num_nodes[:, None].repeat(1,
|
| 1444 |
global_feats = batch_num_nodes[:, None].to(torch.float)
|
| 1445 |
|
| 1446 |
h_global = self.TL_global_encoder(global_feats)
|
|
@@ -1856,6 +1858,7 @@ class Clustering(nn.Module):
|
|
| 1856 |
self.n_layers = n_layers
|
| 1857 |
self.n_proc_steps = n_proc_steps
|
| 1858 |
self.layers = nn.ModuleList()
|
|
|
|
| 1859 |
if (len(sample_global) == 0):
|
| 1860 |
self.has_global = False
|
| 1861 |
else:
|
|
@@ -1899,7 +1902,7 @@ class Clustering(nn.Module):
|
|
| 1899 |
batch_num_nodes.append(non_padded_count)
|
| 1900 |
start_idx = end_idx
|
| 1901 |
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
|
| 1902 |
-
sum_weights = batch_num_nodes[:, None].repeat(1,
|
| 1903 |
global_feats = batch_num_nodes[:, None].to(torch.float)
|
| 1904 |
|
| 1905 |
h_global = self.global_encoder(global_feats)
|
|
|
|
| 1154 |
self.n_proc_steps = n_proc_steps
|
| 1155 |
self.layers = nn.ModuleList()
|
| 1156 |
self.has_global = sample_global.shape[1] != 0
|
| 1157 |
+
self.hid_size = hid_size
|
| 1158 |
gl_size = sample_global.shape[1] if self.has_global else 1
|
| 1159 |
|
| 1160 |
#encoder
|
|
|
|
| 1197 |
batch_num_nodes.append(non_padded_count)
|
| 1198 |
start_idx = end_idx
|
| 1199 |
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
|
| 1200 |
+
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
|
| 1201 |
global_feats = batch_num_nodes[:, None].to(torch.float)
|
| 1202 |
|
| 1203 |
h_global = self.global_encoder(global_feats)
|
|
|
|
| 1365 |
self.n_proc_steps = n_proc_steps
|
| 1366 |
self.layers = nn.ModuleList()
|
| 1367 |
self.has_global = sample_global.shape[1] != 0
|
| 1368 |
+
self.hid_size = hid_size
|
| 1369 |
gl_size = sample_global.shape[1] if self.has_global else 1
|
| 1370 |
|
| 1371 |
self.learning_rate = learning_rate
|
|
|
|
| 1442 |
batch_num_nodes.append(non_padded_count)
|
| 1443 |
start_idx = end_idx
|
| 1444 |
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata['features'].device)
|
| 1445 |
+
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
|
| 1446 |
global_feats = batch_num_nodes[:, None].to(torch.float)
|
| 1447 |
|
| 1448 |
h_global = self.TL_global_encoder(global_feats)
|
|
|
|
| 1858 |
self.n_layers = n_layers
|
| 1859 |
self.n_proc_steps = n_proc_steps
|
| 1860 |
self.layers = nn.ModuleList()
|
| 1861 |
+
self.hid_size = hid_size
|
| 1862 |
if (len(sample_global) == 0):
|
| 1863 |
self.has_global = False
|
| 1864 |
else:
|
|
|
|
| 1902 |
batch_num_nodes.append(non_padded_count)
|
| 1903 |
start_idx = end_idx
|
| 1904 |
batch_num_nodes = torch.tensor(batch_num_nodes, device = g.ndata[features].device)
|
| 1905 |
+
sum_weights = batch_num_nodes[:, None].repeat(1, self.hid_size)
|
| 1906 |
global_feats = batch_num_nodes[:, None].to(torch.float)
|
| 1907 |
|
| 1908 |
h_global = self.global_encoder(global_feats)
|
root_gnn_dgl/root_gnn_base/batched_dataset.py
CHANGED
|
@@ -16,7 +16,7 @@ def GetBatchedLoader(dataset, batch_size, mask_fn = None, drop_last=True, **kwar
|
|
| 16 |
|
| 17 |
#Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
|
| 18 |
class PreBatchedDataset(DGLDataset):
|
| 19 |
-
def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', **kwargs):
|
| 20 |
print(f'Unused kwargs: {kwargs}')
|
| 21 |
self.start_dataset = start_dataset
|
| 22 |
self.start_dataset.load()
|
|
@@ -34,6 +34,7 @@ class PreBatchedDataset(DGLDataset):
|
|
| 34 |
self.suffix = suffix
|
| 35 |
self.current_chunk = None
|
| 36 |
self.current_chunk_idx = -1
|
|
|
|
| 37 |
super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
|
| 38 |
|
| 39 |
def process(self):
|
|
@@ -86,7 +87,7 @@ class PreBatchedDataset(DGLDataset):
|
|
| 86 |
for i in range(len(self.graphs)):
|
| 87 |
unbatched_g = dgl.unbatch(self.graphs[i])
|
| 88 |
max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
|
| 89 |
-
self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes)
|
| 90 |
self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
|
| 91 |
self.batch_num_edges.append(self.graphs[i].batch_num_edges())
|
| 92 |
else:
|
|
|
|
| 16 |
|
| 17 |
#Dataset which contains prebatched shuffled graphs. Cannot be saved to disk, else batching info is lost.
|
| 18 |
class PreBatchedDataset(DGLDataset):
|
| 19 |
+
def __init__(self, start_dataset, batch_size, mask_fn = None, drop_last=True, save_to_disk = True, suffix = '', chunks = 1, chunkno = -1, shuffle = True, padding_mode = 'NONE', hidden_size=64, **kwargs):
|
| 20 |
print(f'Unused kwargs: {kwargs}')
|
| 21 |
self.start_dataset = start_dataset
|
| 22 |
self.start_dataset.load()
|
|
|
|
| 34 |
self.suffix = suffix
|
| 35 |
self.current_chunk = None
|
| 36 |
self.current_chunk_idx = -1
|
| 37 |
+
self.hid_size = hidden_size
|
| 38 |
super().__init__(name = start_dataset.name + '_prebatched_padded', save_dir=start_dataset.save_dir)
|
| 39 |
|
| 40 |
def process(self):
|
|
|
|
| 87 |
for i in range(len(self.graphs)):
|
| 88 |
unbatched_g = dgl.unbatch(self.graphs[i])
|
| 89 |
max_num_nodes = max(g.number_of_nodes() for g in unbatched_g)
|
| 90 |
+
self.graphs[i] = utils.pad_batch_num_nodes(self.graphs[i], max_num_nodes, hid_size=self.hid_size)
|
| 91 |
self.batch_num_nodes.append(self.graphs[i].batch_num_nodes())
|
| 92 |
self.batch_num_edges.append(self.graphs[i].batch_num_edges())
|
| 93 |
else:
|
root_gnn_dgl/root_gnn_base/dataset.py
CHANGED
|
@@ -162,6 +162,7 @@ class RootDataset(DGLDataset):
|
|
| 162 |
self.global_chunks = []
|
| 163 |
chunk_id = -1
|
| 164 |
for chunk in chunks:
|
|
|
|
| 165 |
chunk_id += 1
|
| 166 |
graphs = []
|
| 167 |
labels = []
|
|
@@ -198,6 +199,12 @@ class RootDataset(DGLDataset):
|
|
| 198 |
labels = torch.tensor(labels)
|
| 199 |
tracking = torch.stack(tracking)
|
| 200 |
globals = torch.stack(globals)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
if (self.chunks > 1):
|
| 203 |
self.save_chunk(chunk_id, graphs, labels, tracking, globals)
|
|
@@ -223,16 +230,12 @@ class RootDataset(DGLDataset):
|
|
| 223 |
return
|
| 224 |
graph_path = os.path.join(self.save_dir, self.name + '.bin')
|
| 225 |
if self.chunks == 1:
|
| 226 |
-
# print(len(self.graphs))
|
| 227 |
-
# print(len(self.labels))
|
| 228 |
-
# print(len(self.tracking))
|
| 229 |
-
# print(len(self.globals))
|
| 230 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
|
| 231 |
dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
|
| 232 |
else:
|
| 233 |
-
print(len(self.graph_chunks))
|
| 234 |
for i in range(len(self.process_chunks)):
|
| 235 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
|
|
|
|
| 236 |
dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
|
| 237 |
|
| 238 |
def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
|
|
|
|
| 162 |
self.global_chunks = []
|
| 163 |
chunk_id = -1
|
| 164 |
for chunk in chunks:
|
| 165 |
+
print('Processing chunk {}/{}'.format(chunk_id + 1, len(chunks)), flush=True)
|
| 166 |
chunk_id += 1
|
| 167 |
graphs = []
|
| 168 |
labels = []
|
|
|
|
| 199 |
labels = torch.tensor(labels)
|
| 200 |
tracking = torch.stack(tracking)
|
| 201 |
globals = torch.stack(globals)
|
| 202 |
+
|
| 203 |
+
self.graph_chunks.append(graphs)
|
| 204 |
+
self.label_chunks.append(labels)
|
| 205 |
+
self.tracking_chunks.append(tracking)
|
| 206 |
+
self.global_chunks.append(globals)
|
| 207 |
+
self.counts.append(len(graphs))
|
| 208 |
|
| 209 |
if (self.chunks > 1):
|
| 210 |
self.save_chunk(chunk_id, graphs, labels, tracking, globals)
|
|
|
|
| 230 |
return
|
| 231 |
graph_path = os.path.join(self.save_dir, self.name + '.bin')
|
| 232 |
if self.chunks == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + ".bin")}')
|
| 234 |
dgl.save_graphs(str(graph_path), self.graphs, {'labels': torch.tensor(self.labels), 'tracking': torch.tensor(self.tracking), 'global': torch.tensor(self.global_features)})
|
| 235 |
else:
|
|
|
|
| 236 |
for i in range(len(self.process_chunks)):
|
| 237 |
print(f'Saving dataset to {os.path.join(self.save_dir, self.name + f"_{self.process_chunks[i]}.bin")}')
|
| 238 |
+
|
| 239 |
dgl.save_graphs(str(graph_path).replace('.bin', f'_{self.process_chunks[i]}.bin'), self.graph_chunks[i], {'labels': self.label_chunks[i], 'tracking': self.tracking_chunks[i], 'global': self.global_chunks[i]})
|
| 240 |
|
| 241 |
def save_chunk(self, chunk_id, graphs, labels, tracking, globals):
|
root_gnn_dgl/root_gnn_base/utils.py
CHANGED
|
@@ -92,7 +92,7 @@ def pad_batch(batch, edges = 104000, nodes = 16000):
|
|
| 92 |
return make_padding_graph(batch, pad_nodes, pad_edges)
|
| 93 |
|
| 94 |
def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
|
| 95 |
-
print(f"Padding each graph to have {max_num_nodes} nodes")
|
| 96 |
|
| 97 |
unbatched = dgl.unbatch(batch)
|
| 98 |
for g in unbatched:
|
|
|
|
| 92 |
return make_padding_graph(batch, pad_nodes, pad_edges)
|
| 93 |
|
| 94 |
def pad_batch_num_nodes(batch, max_num_nodes, hid_size = 64):
|
| 95 |
+
print(f"Padding each graph to have {max_num_nodes} nodes. Using hidden size {hid_size}.")
|
| 96 |
|
| 97 |
unbatched = dgl.unbatch(batch)
|
| 98 |
for g in unbatched:
|
root_gnn_dgl/scripts/prep_data.py
CHANGED
|
@@ -33,12 +33,12 @@ def main():
|
|
| 33 |
fold_conf = dset_config["folding"]
|
| 34 |
print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
|
| 35 |
if dset_config["class"] == "LazyMultiLabelDataset":
|
| 36 |
-
LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
|
| 37 |
-
LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
|
| 38 |
|
| 39 |
else:
|
| 40 |
-
PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
|
| 41 |
-
PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last)
|
| 42 |
|
| 43 |
if __name__ == "__main__":
|
| 44 |
-
main()
|
|
|
|
| 33 |
fold_conf = dset_config["folding"]
|
| 34 |
print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
|
| 35 |
if dset_config["class"] == "LazyMultiLabelDataset":
|
| 36 |
+
LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] )
|
| 37 |
+
LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'])
|
| 38 |
|
| 39 |
else:
|
| 40 |
+
PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'])
|
| 41 |
+
PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] )
|
| 42 |
|
| 43 |
if __name__ == "__main__":
|
| 44 |
+
main()
|
root_gnn_dgl/scripts/training_script.py
CHANGED
|
@@ -707,7 +707,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
|
|
| 707 |
mask_fn = utils.fold_selection(fold_conf, "train")
|
| 708 |
if args.preshuffle:
|
| 709 |
# ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
|
| 710 |
-
ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode)
|
| 711 |
gsamp, _, _, global_samp = ldr[0]
|
| 712 |
sampler = None
|
| 713 |
|
|
@@ -724,7 +724,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
|
|
| 724 |
sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
|
| 725 |
train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
|
| 726 |
sampler = None
|
| 727 |
-
ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode)
|
| 728 |
if (args.multigpu):
|
| 729 |
sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
|
| 730 |
# num_batches = len(ldr)
|
|
@@ -737,7 +737,7 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
|
|
| 737 |
test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
|
| 738 |
|
| 739 |
if "validation" in fold_conf:
|
| 740 |
-
val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
|
| 741 |
else:
|
| 742 |
print("No validation set for dataset ", dset_conf)
|
| 743 |
else:
|
|
@@ -753,6 +753,8 @@ def main(rank=0, args=None, world_size=1, port=24500, seed=12345):
|
|
| 753 |
print("Load time: {:.4f} s".format(load_end - load_start))
|
| 754 |
|
| 755 |
model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
|
|
|
|
|
|
|
| 756 |
if not args.nocompile:
|
| 757 |
model = torch.compile(model)
|
| 758 |
if args.multigpu:
|
|
|
|
| 707 |
mask_fn = utils.fold_selection(fold_conf, "train")
|
| 708 |
if args.preshuffle:
|
| 709 |
# ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size)
|
| 710 |
+
ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size = config["Model"]["args"]["hid_size"])
|
| 711 |
gsamp, _, _, global_samp = ldr[0]
|
| 712 |
sampler = None
|
| 713 |
|
|
|
|
| 724 |
sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True)
|
| 725 |
train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler))
|
| 726 |
sampler = None
|
| 727 |
+
ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size= config['Model']['args']['hid_size'])
|
| 728 |
if (args.multigpu):
|
| 729 |
sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
|
| 730 |
# num_batches = len(ldr)
|
|
|
|
| 737 |
test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler))
|
| 738 |
|
| 739 |
if "validation" in fold_conf:
|
| 740 |
+
val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, hidden_size=config['Model']['args']['hid_size'], padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler))
|
| 741 |
else:
|
| 742 |
print("No validation set for dataset ", dset_conf)
|
| 743 |
else:
|
|
|
|
| 753 |
print("Load time: {:.4f} s".format(load_end - load_start))
|
| 754 |
|
| 755 |
model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device)
|
| 756 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 757 |
+
print(f"Number of trainable parameters = {pytorch_total_params}")
|
| 758 |
if not args.nocompile:
|
| 759 |
model = torch.compile(model)
|
| 760 |
if args.multigpu:
|