ho22joshua commited on
Commit
d646e7f
·
1 Parent(s): 5ceead6

adding edge network

Browse files
physicsnemo/configs/tHjb_CP_0_vs_90.yaml CHANGED
@@ -31,13 +31,21 @@ performance:
31
  jit: False
32
 
33
  architecture:
34
- processor_size: 8
35
- hidden_dim_node_encoder: 128
36
- hidden_dim_edge_encoder: 128
37
- hidden_dim_processor: 128
38
- hidden_dim_node_decoder: 128
39
- global_emb_dim: 128
40
- out_dim: 1
 
 
 
 
 
 
 
 
41
 
42
  paths:
43
  data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
 
31
  jit: False
32
 
33
  architecture:
34
+ module: models.MeshGraphNet
35
+ class: MeshGraphNet
36
+ args:
37
+ base_gnn:
38
+ input_dim_nodes: 7
39
+ input_dim_edges: 3
40
+ output_dim: 128
41
+ processor_size: 8
42
+ hidden_dim_node_encoder: 128
43
+ hidden_dim_edge_encoder: 128
44
+ hidden_dim_processor: 128
45
+ hidden_dim_node_decoder: 128
46
+ global_emb_dim: 128
47
+ global_feat_dim: 1
48
+ out_dim: 1
49
 
50
  paths:
51
  data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
physicsnemo/configs/tHjb_CP_0_vs_90_edge_network.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore_header_test
2
+ # Copyright 2023 Stanford University
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ random_seed: 2
17
+
18
+ scheduler:
19
+ lr: 1.E-3
20
+ lr_decay: 1.E-3
21
+
22
+ training:
23
+ epochs: 100
24
+
25
+ checkpoints:
26
+ ckpt_path: "checkpoints"
27
+ ckpt_name: "tHjb_CP_0_vs_90_edge_network"
28
+
29
+ performance:
30
+ amp: False
31
+ jit: False
32
+
33
+ architecture:
34
+ module: models.Edge_Network
35
+ class: Edge_Network
36
+ args:
37
+ input_dim_nodes: 7
38
+ input_dim_edges: 3
39
+ input_dim_globals: 1
40
+ hid_size: 64
41
+ n_layers: 4
42
+ n_proc_steps: 4
43
+ out_dim: 1
44
+
45
+ paths:
46
+ data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
47
+ save_dir: /pscratch/sd/j/joshuaho/physicsnemo/ttHCP/graphs/tHjb_CP_0_vs_90/
48
+ training_dir: ./tHjb_CP_0_vs_90_edge_network/
49
+
50
+ datasets:
51
+ - name: tHjb_cp_0_had
52
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_had_scaled.root
53
+ label: 0
54
+ - name: tHjb_cp_0_lep
55
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_lep_scaled.root
56
+ label: 0
57
+ - name: tHjb_cp_90_had
58
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_had_scaled.root
59
+ label: 1
60
+ - name: tHjb_cp_90_lep
61
+ load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_lep_scaled.root
62
+ label: 1
63
+
64
+ root_dataset:
65
+ ttree: output
66
+ dtype: torch.bfloat16
67
+ features:
68
+ # pt, eta, phi, energy, btag, charge, node_type
69
+ jet: [m_jet_pt, m_jet_eta, m_jet_phi, CALC_E, m_jet_PCbtag, 0, 0]
70
+ electron: [m_el_pt, m_el_eta, m_el_phi, CALC_E, 0, m_el_charge, 1]
71
+ muon: [m_mu_pt, m_mu_eta, m_mu_phi, CALC_E, 0, m_mu_charge, 2]
72
+ photon: [ph_pt_myy, ph_eta, ph_phi, CALC_E, 0, 0, 3]
73
+ met: [m_met, 0, m_met_phi, CALC_E, 0, 0, 4]
74
+ globals: [NUM_NODES]
75
+ weights: 1
76
+ tracking: []
77
+ step_size: 16384
78
+ batch_size: 16384
79
+ train_val_test_split: [0.5, 0.25, 0.25]
80
+ prebatch:
81
+ enabled: True
82
+ chunk_size: 512
physicsnemo/models/Edge_Network.py CHANGED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import dgl
4
+
5
+ from models import utils
6
+
7
+ class Edge_Network(nn.Module):
8
+ def __init__(self, cfg):
9
+ super().__init__()
10
+ hid_size = cfg.hid_size
11
+ n_layers = cfg.n_layers
12
+ self.n_proc_steps = cfg.n_proc_steps
13
+
14
+ #encoder
15
+ self.node_encoder = utils.Make_MLP(cfg.input_dim_nodes, hid_size, hid_size, n_layers)
16
+ self.edge_encoder = utils.Make_MLP(cfg.input_dim_edges, hid_size, hid_size, n_layers)
17
+ self.global_encoder = utils.Make_MLP(cfg.input_dim_globals, hid_size, hid_size, n_layers)
18
+
19
+ #GNN
20
+ self.node_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
21
+ self.edge_update = utils.Make_MLP(4*hid_size, hid_size, hid_size, n_layers)
22
+ self.global_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
23
+
24
+ #decoder
25
+ self.global_decoder = utils.Make_MLP(hid_size, hid_size, hid_size, n_layers)
26
+ self.classify = nn.Linear(hid_size, cfg.out_dim)
27
+
28
+ def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}):
29
+ # encoders
30
+ batched_graph.ndata['h'] = self.node_encoder(node_feats)
31
+ batched_graph.edata['e'] = self.edge_encoder(edge_feats)
32
+
33
+ if global_feats.ndim == 3:
34
+ global_feats = global_feats.view(-1, global_feats.shape[-1])
35
+ h_global = self.global_encoder(global_feats)
36
+
37
+ # message passing
38
+ for _ in range(self.n_proc_steps):
39
+ batched_graph.apply_edges(dgl.function.copy_u('h', 'm_u'))
40
+ batched_graph.apply_edges(utils.copy_v)
41
+
42
+ # edge update
43
+ edge_inputs = torch.cat([
44
+ batched_graph.edata['e'],
45
+ batched_graph.edata['m_u'],
46
+ batched_graph.edata['m_v'],
47
+ utils.broadcast_global_to_edges(h_global, edge_split=metadata.get("batch_num_edges", None))
48
+ ], dim=1)
49
+ batched_graph.edata['e'] = self.edge_update(edge_inputs)
50
+
51
+ # node update
52
+ batched_graph.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
53
+ node_inputs = torch.cat([
54
+ batched_graph.ndata['h'],
55
+ batched_graph.ndata['h_e'],
56
+ utils.broadcast_global_to_nodes(h_global, node_split=metadata.get("batch_num_nodes", None))
57
+ ], dim=1)
58
+ batched_graph.ndata['h'] = self.node_update(node_inputs)
59
+
60
+ # global update
61
+ graph_node_feat = utils.mean_nodes(
62
+ batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None)
63
+ )
64
+ graph_edge_feat = utils.mean_edges(
65
+ batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None)
66
+ )
67
+ h_global = self.global_update(torch.cat([h_global, graph_node_feat, graph_edge_feat], dim=1))
68
+
69
+ h_global = self.global_decoder(h_global)
70
+ out = self.classify(h_global)
71
+ return out
72
+
physicsnemo/models/MeshGraphNet.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  import torch.nn as nn
3
  import dgl
4
 
 
 
5
  # Import the PhysicsNemo MeshGraphNet model
6
  from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
7
 
@@ -35,8 +37,8 @@ class MeshGraphNet(nn.Module):
35
  batched_graph.ndata['h'] = node_pred
36
  batched_graph.edata['e'] = edge_feats
37
 
38
- graph_node_feat = mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None))
39
- graph_edge_feat = mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None))
40
 
41
  # Flatten global_feats if needed
42
  if global_feats.ndim == 3:
@@ -47,82 +49,3 @@ class MeshGraphNet(nn.Module):
47
  graph_pred = self.mlp(combined_feat)
48
  return graph_pred
49
 
50
- def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None):
51
- """
52
- Aggregates node features per disjoint graph in a batched DGLGraph.
53
-
54
- Args:
55
- batched_graph: DGLGraph
56
- feat_key: str, node feature key
57
- op: 'mean', 'sum', or 'max'
58
- node_split: 1D tensor or list of ints (num nodes per graph)
59
-
60
- Returns:
61
- Tensor of shape [num_graphs, node_feat_dim]
62
- """
63
- h = batched_graph.ndata[feat_key]
64
- if node_split is None or len(node_split) == 0:
65
- if op == 'mean':
66
- return dgl.mean_nodes(batched_graph, feat_key)
67
- elif op == 'sum':
68
- return dgl.sum_nodes(batched_graph, feat_key)
69
- elif op == 'max':
70
- return dgl.max_nodes(batched_graph, feat_key)
71
- else:
72
- raise ValueError(f"Unknown op: {op}")
73
- else:
74
- # Ensure node_split is a flat list of ints
75
- if isinstance(node_split, torch.Tensor):
76
- splits = node_split.view(-1).tolist()
77
- else:
78
- splits = [int(x) for x in node_split]
79
- chunks = torch.split(h, splits, dim=0)
80
- if op == 'mean':
81
- out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
82
- elif op == 'sum':
83
- out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
84
- elif op == 'max':
85
- out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
86
- else:
87
- raise ValueError(f"Unknown op: {op}")
88
- return out
89
-
90
- def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None):
91
- """
92
- Aggregates edge features per disjoint graph in a batched DGLGraph.
93
-
94
- Args:
95
- batched_graph: DGLGraph
96
- feat_key: str, edge feature key
97
- op: 'mean', 'sum', or 'max'
98
- edge_split: 1D tensor or list of ints (num edges per graph)
99
-
100
- Returns:
101
- Tensor of shape [num_graphs, edge_feat_dim]
102
- """
103
- e = batched_graph.edata[feat_key]
104
- if edge_split is None or len(edge_split) == 0:
105
- if op == 'mean':
106
- return dgl.mean_edges(batched_graph, feat_key)
107
- elif op == 'sum':
108
- return dgl.sum_edges(batched_graph, feat_key)
109
- elif op == 'max':
110
- return dgl.max_edges(batched_graph, feat_key)
111
- else:
112
- raise ValueError(f"Unknown op: {op}")
113
- else:
114
- # Ensure edge_split is a flat list of ints
115
- if isinstance(edge_split, torch.Tensor):
116
- splits = edge_split.view(-1).tolist()
117
- else:
118
- splits = [int(x) for x in edge_split]
119
- chunks = torch.split(e, splits, dim=0)
120
- if op == 'mean':
121
- out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
122
- elif op == 'sum':
123
- out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
124
- elif op == 'max':
125
- out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
126
- else:
127
- raise ValueError(f"Unknown op: {op}")
128
- return out
 
2
  import torch.nn as nn
3
  import dgl
4
 
5
+ from models import utils
6
+
7
  # Import the PhysicsNemo MeshGraphNet model
8
  from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
9
 
 
37
  batched_graph.ndata['h'] = node_pred
38
  batched_graph.edata['e'] = edge_feats
39
 
40
+ graph_node_feat = utils.mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None))
41
+ graph_edge_feat = utils.mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None))
42
 
43
  # Flatten global_feats if needed
44
  if global_feats.ndim == 3:
 
49
  graph_pred = self.mlp(combined_feat)
50
  return graph_pred
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
physicsnemo/models/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import dgl
4
+
5
+ def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None):
6
+ """
7
+ Aggregates node features per disjoint graph in a batched DGLGraph.
8
+
9
+ Args:
10
+ batched_graph: DGLGraph
11
+ feat_key: str, node feature key
12
+ op: 'mean', 'sum', or 'max'
13
+ node_split: 1D tensor or list of ints (num nodes per graph)
14
+
15
+ Returns:
16
+ Tensor of shape [num_graphs, node_feat_dim]
17
+ """
18
+ h = batched_graph.ndata[feat_key]
19
+ if node_split is None or len(node_split) == 0:
20
+ if op == 'mean':
21
+ return dgl.mean_nodes(batched_graph, feat_key)
22
+ elif op == 'sum':
23
+ return dgl.sum_nodes(batched_graph, feat_key)
24
+ elif op == 'max':
25
+ return dgl.max_nodes(batched_graph, feat_key)
26
+ else:
27
+ raise ValueError(f"Unknown op: {op}")
28
+ else:
29
+ # Ensure node_split is a flat list of ints
30
+ if isinstance(node_split, torch.Tensor):
31
+ splits = node_split.view(-1).tolist()
32
+ else:
33
+ splits = [int(x) for x in node_split]
34
+ chunks = torch.split(h, splits, dim=0)
35
+ if op == 'mean':
36
+ out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
37
+ elif op == 'sum':
38
+ out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
39
+ elif op == 'max':
40
+ out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
41
+ else:
42
+ raise ValueError(f"Unknown op: {op}")
43
+ return out
44
+
45
+ def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None):
46
+ """
47
+ Aggregates edge features per disjoint graph in a batched DGLGraph.
48
+
49
+ Args:
50
+ batched_graph: DGLGraph
51
+ feat_key: str, edge feature key
52
+ op: 'mean', 'sum', or 'max'
53
+ edge_split: 1D tensor or list of ints (num edges per graph)
54
+
55
+ Returns:
56
+ Tensor of shape [num_graphs, edge_feat_dim]
57
+ """
58
+ e = batched_graph.edata[feat_key]
59
+ if edge_split is None or len(edge_split) == 0:
60
+ if op == 'mean':
61
+ return dgl.mean_edges(batched_graph, feat_key)
62
+ elif op == 'sum':
63
+ return dgl.sum_edges(batched_graph, feat_key)
64
+ elif op == 'max':
65
+ return dgl.max_edges(batched_graph, feat_key)
66
+ else:
67
+ raise ValueError(f"Unknown op: {op}")
68
+ else:
69
+ # Ensure edge_split is a flat list of ints
70
+ if isinstance(edge_split, torch.Tensor):
71
+ splits = edge_split.view(-1).tolist()
72
+ else:
73
+ splits = [int(x) for x in edge_split]
74
+ chunks = torch.split(e, splits, dim=0)
75
+ if op == 'mean':
76
+ out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
77
+ elif op == 'sum':
78
+ out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
79
+ elif op == 'max':
80
+ out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
81
+ else:
82
+ raise ValueError(f"Unknown op: {op}")
83
+ return out
84
+
85
+ def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
86
+ layers = []
87
+ layers.append(nn.Linear(in_size, out_size))
88
+ layers.append(activation())
89
+ layers.append(nn.Dropout(dropout))
90
+ return layers
91
+
92
+ def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
93
+ layers = []
94
+ if n_layers > 1:
95
+ layers += Make_SLP(in_size, hid_size, activation, dropout)
96
+ for i in range(n_layers-2):
97
+ layers += Make_SLP(hid_size, hid_size, activation, dropout)
98
+ layers += Make_SLP(hid_size, out_size, activation, dropout)
99
+ else:
100
+ layers += Make_SLP(in_size, out_size, activation, dropout)
101
+ layers.append(torch.nn.LayerNorm(out_size))
102
+ return nn.Sequential(*layers)
103
+
104
+ def broadcast_global_to_nodes(globals, node_split):
105
+ """
106
+ globals: [num_graphs, global_dim]
107
+ node_split: list/1D tensor of length num_graphs, number of nodes per graph
108
+ Returns: [total_num_nodes, global_dim]
109
+ """
110
+ if node_split is None:
111
+ raise ValueError("node_split must be provided")
112
+ if not torch.is_tensor(node_split):
113
+ node_split = torch.tensor(node_split, dtype=torch.long, device=globals.device)
114
+ else:
115
+ node_split = node_split.to(device=globals.device, dtype=torch.long)
116
+ node_split = node_split.flatten()
117
+ return torch.repeat_interleave(globals, node_split, dim=0)
118
+
119
+ def broadcast_global_to_edges(globals, edge_split):
120
+ """
121
+ globals: [num_graphs, global_dim] (on CUDA or CPU)
122
+ edge_split: list/1D tensor of length num_graphs, number of edges per graph (CPU or CUDA)
123
+ Returns: [total_num_edges, global_dim]
124
+ """
125
+ if edge_split is None:
126
+ raise ValueError("edge_split must be provided")
127
+ if not torch.is_tensor(edge_split):
128
+ edge_split = torch.tensor(edge_split, dtype=torch.long, device=globals.device)
129
+ else:
130
+ edge_split = edge_split.to(device=globals.device, dtype=torch.long)
131
+ edge_split = edge_split.flatten()
132
+ return torch.repeat_interleave(globals, edge_split, dim=0)
133
+
134
+ def copy_v(edges):
135
+ return {'m_v': edges.dst['h']}
physicsnemo/train.py CHANGED
@@ -23,6 +23,7 @@ import models.MeshGraphNet as MeshGraphNet
23
  from dataset.Dataset import get_dataset
24
  import metrics
25
 
 
26
 
27
  class MGNTrainer:
28
  def __init__(self, logger, cfg, dist):
@@ -30,8 +31,6 @@ class MGNTrainer:
30
  self.device = dist.device
31
  logger.info(f"Using {self.device} device")
32
 
33
- params = {}
34
-
35
  start = time.time()
36
  self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
37
  print(f"total time loading dataset: {time.time() - start:.2f} seconds")
@@ -42,20 +41,10 @@ class MGNTrainer:
42
  else:
43
  self.dtype = torch.float32
44
 
45
- node_features = list(cfg.root_dataset.features.values())[0]
46
- edge_features = ["dR", "deta", "dphi"]
47
- global_features = ["num_nodes"]
48
-
49
- params["infeat_nodes"] = len(node_features)
50
- params["infeat_edges"] = len(edge_features)
51
- params["infeat_globals"] = len(global_features)
52
- params["out_dim"] = cfg.architecture.out_dim
53
- params["node_features"] = list(node_features)
54
- params["edge_features"] = edge_features
55
- params["global_features"] = global_features
56
-
57
- self.model = MeshGraphNet.MeshGraphNet(cfg.architecture)
58
  self.model = self.model.to(dtype=self.dtype, device=self.device)
 
 
59
 
60
  if cfg.performance.jit:
61
  self.model = torch.jit.script(self.model).to(self.device)
@@ -81,7 +70,6 @@ class MGNTrainer:
81
  device=self.device,
82
  )
83
 
84
- self.params = params
85
  self.cfg = cfg
86
 
87
  def backward(self, loss):
@@ -244,9 +232,6 @@ def do_training(cfg: DictConfig):
244
  )
245
  start = time.time()
246
  trainer.scheduler.step()
247
-
248
- with open(cfg.checkpoints.ckpt_path + "/parameters.json", "w") as outf:
249
- json.dump(trainer.params, outf, indent=4)
250
  logger.info("Training completed!")
251
 
252
 
 
23
  from dataset.Dataset import get_dataset
24
  import metrics
25
 
26
+ import utils
27
 
28
  class MGNTrainer:
29
  def __init__(self, logger, cfg, dist):
 
31
  self.device = dist.device
32
  logger.info(f"Using {self.device} device")
33
 
 
 
34
  start = time.time()
35
  self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
36
  print(f"total time loading dataset: {time.time() - start:.2f} seconds")
 
41
  else:
42
  self.dtype = torch.float32
43
 
44
+ self.model = utils.build_from_module(cfg.architecture)
 
 
 
 
 
 
 
 
 
 
 
 
45
  self.model = self.model.to(dtype=self.dtype, device=self.device)
46
+ # num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
47
+ # print(f"Number of trainable parameters: {num_params}")
48
 
49
  if cfg.performance.jit:
50
  self.model = torch.jit.script(self.model).to(self.device)
 
70
  device=self.device,
71
  )
72
 
 
73
  self.cfg = cfg
74
 
75
  def backward(self, loss):
 
232
  )
233
  start = time.time()
234
  trainer.scheduler.step()
 
 
 
235
  logger.info("Training completed!")
236
 
237
 
physicsnemo/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from types import SimpleNamespace
3
+
4
+ def build_from_module(cfg):
5
+ modname = cfg['module']
6
+ classname = cfg['class']
7
+ args = cfg['args']
8
+ module = importlib.import_module(modname)
9
+ model_cls = getattr(module, classname)
10
+ cfg_obj = SimpleNamespace(**args)
11
+ return model_cls(cfg_obj)