Commit ·
d646e7f
1
Parent(s): 5ceead6
adding edge network
Browse files
physicsnemo/configs/tHjb_CP_0_vs_90.yaml
CHANGED
|
@@ -31,13 +31,21 @@ performance:
|
|
| 31 |
jit: False
|
| 32 |
|
| 33 |
architecture:
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
paths:
|
| 43 |
data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
|
|
|
|
| 31 |
jit: False
|
| 32 |
|
| 33 |
architecture:
|
| 34 |
+
module: models.MeshGraphNet
|
| 35 |
+
class: MeshGraphNet
|
| 36 |
+
args:
|
| 37 |
+
base_gnn:
|
| 38 |
+
input_dim_nodes: 7
|
| 39 |
+
input_dim_edges: 3
|
| 40 |
+
output_dim: 128
|
| 41 |
+
processor_size: 8
|
| 42 |
+
hidden_dim_node_encoder: 128
|
| 43 |
+
hidden_dim_edge_encoder: 128
|
| 44 |
+
hidden_dim_processor: 128
|
| 45 |
+
hidden_dim_node_decoder: 128
|
| 46 |
+
global_emb_dim: 128
|
| 47 |
+
global_feat_dim: 1
|
| 48 |
+
out_dim: 1
|
| 49 |
|
| 50 |
paths:
|
| 51 |
data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
|
physicsnemo/configs/tHjb_CP_0_vs_90_edge_network.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ignore_header_test
|
| 2 |
+
# Copyright 2023 Stanford University
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
random_seed: 2
|
| 17 |
+
|
| 18 |
+
scheduler:
|
| 19 |
+
lr: 1.E-3
|
| 20 |
+
lr_decay: 1.E-3
|
| 21 |
+
|
| 22 |
+
training:
|
| 23 |
+
epochs: 100
|
| 24 |
+
|
| 25 |
+
checkpoints:
|
| 26 |
+
ckpt_path: "checkpoints"
|
| 27 |
+
ckpt_name: "tHjb_CP_0_vs_90_edge_network"
|
| 28 |
+
|
| 29 |
+
performance:
|
| 30 |
+
amp: False
|
| 31 |
+
jit: False
|
| 32 |
+
|
| 33 |
+
architecture:
|
| 34 |
+
module: models.Edge_Network
|
| 35 |
+
class: Edge_Network
|
| 36 |
+
args:
|
| 37 |
+
input_dim_nodes: 7
|
| 38 |
+
input_dim_edges: 3
|
| 39 |
+
input_dim_globals: 1
|
| 40 |
+
hid_size: 64
|
| 41 |
+
n_layers: 4
|
| 42 |
+
n_proc_steps: 4
|
| 43 |
+
out_dim: 1
|
| 44 |
+
|
| 45 |
+
paths:
|
| 46 |
+
data_dir: /global/cfs/projectdirs/atlas/joshua/ttHCP/ntuples/v02/preselection/merged_fixed/train/
|
| 47 |
+
save_dir: /pscratch/sd/j/joshuaho/physicsnemo/ttHCP/graphs/tHjb_CP_0_vs_90/
|
| 48 |
+
training_dir: ./tHjb_CP_0_vs_90_edge_network/
|
| 49 |
+
|
| 50 |
+
datasets:
|
| 51 |
+
- name: tHjb_cp_0_had
|
| 52 |
+
load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_had_scaled.root
|
| 53 |
+
label: 0
|
| 54 |
+
- name: tHjb_cp_0_lep
|
| 55 |
+
load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_0_AF3_lep_scaled.root
|
| 56 |
+
label: 0
|
| 57 |
+
- name: tHjb_cp_90_had
|
| 58 |
+
load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_had_scaled.root
|
| 59 |
+
label: 1
|
| 60 |
+
- name: tHjb_cp_90_lep
|
| 61 |
+
load_path: ${paths.data_dir}/merged_aMCPy8_tHjb125_CP_90_AF3_lep_scaled.root
|
| 62 |
+
label: 1
|
| 63 |
+
|
| 64 |
+
root_dataset:
|
| 65 |
+
ttree: output
|
| 66 |
+
dtype: torch.bfloat16
|
| 67 |
+
features:
|
| 68 |
+
# pt, eta, phi, energy, btag, charge, node_type
|
| 69 |
+
jet: [m_jet_pt, m_jet_eta, m_jet_phi, CALC_E, m_jet_PCbtag, 0, 0]
|
| 70 |
+
electron: [m_el_pt, m_el_eta, m_el_phi, CALC_E, 0, m_el_charge, 1]
|
| 71 |
+
muon: [m_mu_pt, m_mu_eta, m_mu_phi, CALC_E, 0, m_mu_charge, 2]
|
| 72 |
+
photon: [ph_pt_myy, ph_eta, ph_phi, CALC_E, 0, 0, 3]
|
| 73 |
+
met: [m_met, 0, m_met_phi, CALC_E, 0, 0, 4]
|
| 74 |
+
globals: [NUM_NODES]
|
| 75 |
+
weights: 1
|
| 76 |
+
tracking: []
|
| 77 |
+
step_size: 16384
|
| 78 |
+
batch_size: 16384
|
| 79 |
+
train_val_test_split: [0.5, 0.25, 0.25]
|
| 80 |
+
prebatch:
|
| 81 |
+
enabled: True
|
| 82 |
+
chunk_size: 512
|
physicsnemo/models/Edge_Network.py
CHANGED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import dgl
|
| 4 |
+
|
| 5 |
+
from models import utils
|
| 6 |
+
|
| 7 |
+
class Edge_Network(nn.Module):
|
| 8 |
+
def __init__(self, cfg):
|
| 9 |
+
super().__init__()
|
| 10 |
+
hid_size = cfg.hid_size
|
| 11 |
+
n_layers = cfg.n_layers
|
| 12 |
+
self.n_proc_steps = cfg.n_proc_steps
|
| 13 |
+
|
| 14 |
+
#encoder
|
| 15 |
+
self.node_encoder = utils.Make_MLP(cfg.input_dim_nodes, hid_size, hid_size, n_layers)
|
| 16 |
+
self.edge_encoder = utils.Make_MLP(cfg.input_dim_edges, hid_size, hid_size, n_layers)
|
| 17 |
+
self.global_encoder = utils.Make_MLP(cfg.input_dim_globals, hid_size, hid_size, n_layers)
|
| 18 |
+
|
| 19 |
+
#GNN
|
| 20 |
+
self.node_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
|
| 21 |
+
self.edge_update = utils.Make_MLP(4*hid_size, hid_size, hid_size, n_layers)
|
| 22 |
+
self.global_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
|
| 23 |
+
|
| 24 |
+
#decoder
|
| 25 |
+
self.global_decoder = utils.Make_MLP(hid_size, hid_size, hid_size, n_layers)
|
| 26 |
+
self.classify = nn.Linear(hid_size, cfg.out_dim)
|
| 27 |
+
|
| 28 |
+
def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}):
|
| 29 |
+
# encoders
|
| 30 |
+
batched_graph.ndata['h'] = self.node_encoder(node_feats)
|
| 31 |
+
batched_graph.edata['e'] = self.edge_encoder(edge_feats)
|
| 32 |
+
|
| 33 |
+
if global_feats.ndim == 3:
|
| 34 |
+
global_feats = global_feats.view(-1, global_feats.shape[-1])
|
| 35 |
+
h_global = self.global_encoder(global_feats)
|
| 36 |
+
|
| 37 |
+
# message passing
|
| 38 |
+
for _ in range(self.n_proc_steps):
|
| 39 |
+
batched_graph.apply_edges(dgl.function.copy_u('h', 'm_u'))
|
| 40 |
+
batched_graph.apply_edges(utils.copy_v)
|
| 41 |
+
|
| 42 |
+
# edge update
|
| 43 |
+
edge_inputs = torch.cat([
|
| 44 |
+
batched_graph.edata['e'],
|
| 45 |
+
batched_graph.edata['m_u'],
|
| 46 |
+
batched_graph.edata['m_v'],
|
| 47 |
+
utils.broadcast_global_to_edges(h_global, edge_split=metadata.get("batch_num_edges", None))
|
| 48 |
+
], dim=1)
|
| 49 |
+
batched_graph.edata['e'] = self.edge_update(edge_inputs)
|
| 50 |
+
|
| 51 |
+
# node update
|
| 52 |
+
batched_graph.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
|
| 53 |
+
node_inputs = torch.cat([
|
| 54 |
+
batched_graph.ndata['h'],
|
| 55 |
+
batched_graph.ndata['h_e'],
|
| 56 |
+
utils.broadcast_global_to_nodes(h_global, node_split=metadata.get("batch_num_nodes", None))
|
| 57 |
+
], dim=1)
|
| 58 |
+
batched_graph.ndata['h'] = self.node_update(node_inputs)
|
| 59 |
+
|
| 60 |
+
# global update
|
| 61 |
+
graph_node_feat = utils.mean_nodes(
|
| 62 |
+
batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None)
|
| 63 |
+
)
|
| 64 |
+
graph_edge_feat = utils.mean_edges(
|
| 65 |
+
batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None)
|
| 66 |
+
)
|
| 67 |
+
h_global = self.global_update(torch.cat([h_global, graph_node_feat, graph_edge_feat], dim=1))
|
| 68 |
+
|
| 69 |
+
h_global = self.global_decoder(h_global)
|
| 70 |
+
out = self.classify(h_global)
|
| 71 |
+
return out
|
| 72 |
+
|
physicsnemo/models/MeshGraphNet.py
CHANGED
|
@@ -2,6 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import dgl
|
| 4 |
|
|
|
|
|
|
|
| 5 |
# Import the PhysicsNemo MeshGraphNet model
|
| 6 |
from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
|
| 7 |
|
|
@@ -35,8 +37,8 @@ class MeshGraphNet(nn.Module):
|
|
| 35 |
batched_graph.ndata['h'] = node_pred
|
| 36 |
batched_graph.edata['e'] = edge_feats
|
| 37 |
|
| 38 |
-
graph_node_feat = mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None))
|
| 39 |
-
graph_edge_feat = mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None))
|
| 40 |
|
| 41 |
# Flatten global_feats if needed
|
| 42 |
if global_feats.ndim == 3:
|
|
@@ -47,82 +49,3 @@ class MeshGraphNet(nn.Module):
|
|
| 47 |
graph_pred = self.mlp(combined_feat)
|
| 48 |
return graph_pred
|
| 49 |
|
| 50 |
-
def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None):
|
| 51 |
-
"""
|
| 52 |
-
Aggregates node features per disjoint graph in a batched DGLGraph.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
batched_graph: DGLGraph
|
| 56 |
-
feat_key: str, node feature key
|
| 57 |
-
op: 'mean', 'sum', or 'max'
|
| 58 |
-
node_split: 1D tensor or list of ints (num nodes per graph)
|
| 59 |
-
|
| 60 |
-
Returns:
|
| 61 |
-
Tensor of shape [num_graphs, node_feat_dim]
|
| 62 |
-
"""
|
| 63 |
-
h = batched_graph.ndata[feat_key]
|
| 64 |
-
if node_split is None or len(node_split) == 0:
|
| 65 |
-
if op == 'mean':
|
| 66 |
-
return dgl.mean_nodes(batched_graph, feat_key)
|
| 67 |
-
elif op == 'sum':
|
| 68 |
-
return dgl.sum_nodes(batched_graph, feat_key)
|
| 69 |
-
elif op == 'max':
|
| 70 |
-
return dgl.max_nodes(batched_graph, feat_key)
|
| 71 |
-
else:
|
| 72 |
-
raise ValueError(f"Unknown op: {op}")
|
| 73 |
-
else:
|
| 74 |
-
# Ensure node_split is a flat list of ints
|
| 75 |
-
if isinstance(node_split, torch.Tensor):
|
| 76 |
-
splits = node_split.view(-1).tolist()
|
| 77 |
-
else:
|
| 78 |
-
splits = [int(x) for x in node_split]
|
| 79 |
-
chunks = torch.split(h, splits, dim=0)
|
| 80 |
-
if op == 'mean':
|
| 81 |
-
out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
|
| 82 |
-
elif op == 'sum':
|
| 83 |
-
out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
|
| 84 |
-
elif op == 'max':
|
| 85 |
-
out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
|
| 86 |
-
else:
|
| 87 |
-
raise ValueError(f"Unknown op: {op}")
|
| 88 |
-
return out
|
| 89 |
-
|
| 90 |
-
def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None):
|
| 91 |
-
"""
|
| 92 |
-
Aggregates edge features per disjoint graph in a batched DGLGraph.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
batched_graph: DGLGraph
|
| 96 |
-
feat_key: str, edge feature key
|
| 97 |
-
op: 'mean', 'sum', or 'max'
|
| 98 |
-
edge_split: 1D tensor or list of ints (num edges per graph)
|
| 99 |
-
|
| 100 |
-
Returns:
|
| 101 |
-
Tensor of shape [num_graphs, edge_feat_dim]
|
| 102 |
-
"""
|
| 103 |
-
e = batched_graph.edata[feat_key]
|
| 104 |
-
if edge_split is None or len(edge_split) == 0:
|
| 105 |
-
if op == 'mean':
|
| 106 |
-
return dgl.mean_edges(batched_graph, feat_key)
|
| 107 |
-
elif op == 'sum':
|
| 108 |
-
return dgl.sum_edges(batched_graph, feat_key)
|
| 109 |
-
elif op == 'max':
|
| 110 |
-
return dgl.max_edges(batched_graph, feat_key)
|
| 111 |
-
else:
|
| 112 |
-
raise ValueError(f"Unknown op: {op}")
|
| 113 |
-
else:
|
| 114 |
-
# Ensure edge_split is a flat list of ints
|
| 115 |
-
if isinstance(edge_split, torch.Tensor):
|
| 116 |
-
splits = edge_split.view(-1).tolist()
|
| 117 |
-
else:
|
| 118 |
-
splits = [int(x) for x in edge_split]
|
| 119 |
-
chunks = torch.split(e, splits, dim=0)
|
| 120 |
-
if op == 'mean':
|
| 121 |
-
out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
|
| 122 |
-
elif op == 'sum':
|
| 123 |
-
out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
|
| 124 |
-
elif op == 'max':
|
| 125 |
-
out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
|
| 126 |
-
else:
|
| 127 |
-
raise ValueError(f"Unknown op: {op}")
|
| 128 |
-
return out
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import dgl
|
| 4 |
|
| 5 |
+
from models import utils
|
| 6 |
+
|
| 7 |
# Import the PhysicsNemo MeshGraphNet model
|
| 8 |
from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
|
| 9 |
|
|
|
|
| 37 |
batched_graph.ndata['h'] = node_pred
|
| 38 |
batched_graph.edata['e'] = edge_feats
|
| 39 |
|
| 40 |
+
graph_node_feat = utils.mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None))
|
| 41 |
+
graph_edge_feat = utils.mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None))
|
| 42 |
|
| 43 |
# Flatten global_feats if needed
|
| 44 |
if global_feats.ndim == 3:
|
|
|
|
| 49 |
graph_pred = self.mlp(combined_feat)
|
| 50 |
return graph_pred
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
physicsnemo/models/utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import dgl
|
| 4 |
+
|
| 5 |
+
def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None):
|
| 6 |
+
"""
|
| 7 |
+
Aggregates node features per disjoint graph in a batched DGLGraph.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
batched_graph: DGLGraph
|
| 11 |
+
feat_key: str, node feature key
|
| 12 |
+
op: 'mean', 'sum', or 'max'
|
| 13 |
+
node_split: 1D tensor or list of ints (num nodes per graph)
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Tensor of shape [num_graphs, node_feat_dim]
|
| 17 |
+
"""
|
| 18 |
+
h = batched_graph.ndata[feat_key]
|
| 19 |
+
if node_split is None or len(node_split) == 0:
|
| 20 |
+
if op == 'mean':
|
| 21 |
+
return dgl.mean_nodes(batched_graph, feat_key)
|
| 22 |
+
elif op == 'sum':
|
| 23 |
+
return dgl.sum_nodes(batched_graph, feat_key)
|
| 24 |
+
elif op == 'max':
|
| 25 |
+
return dgl.max_nodes(batched_graph, feat_key)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unknown op: {op}")
|
| 28 |
+
else:
|
| 29 |
+
# Ensure node_split is a flat list of ints
|
| 30 |
+
if isinstance(node_split, torch.Tensor):
|
| 31 |
+
splits = node_split.view(-1).tolist()
|
| 32 |
+
else:
|
| 33 |
+
splits = [int(x) for x in node_split]
|
| 34 |
+
chunks = torch.split(h, splits, dim=0)
|
| 35 |
+
if op == 'mean':
|
| 36 |
+
out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
|
| 37 |
+
elif op == 'sum':
|
| 38 |
+
out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
|
| 39 |
+
elif op == 'max':
|
| 40 |
+
out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unknown op: {op}")
|
| 43 |
+
return out
|
| 44 |
+
|
| 45 |
+
def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None):
|
| 46 |
+
"""
|
| 47 |
+
Aggregates edge features per disjoint graph in a batched DGLGraph.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
batched_graph: DGLGraph
|
| 51 |
+
feat_key: str, edge feature key
|
| 52 |
+
op: 'mean', 'sum', or 'max'
|
| 53 |
+
edge_split: 1D tensor or list of ints (num edges per graph)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Tensor of shape [num_graphs, edge_feat_dim]
|
| 57 |
+
"""
|
| 58 |
+
e = batched_graph.edata[feat_key]
|
| 59 |
+
if edge_split is None or len(edge_split) == 0:
|
| 60 |
+
if op == 'mean':
|
| 61 |
+
return dgl.mean_edges(batched_graph, feat_key)
|
| 62 |
+
elif op == 'sum':
|
| 63 |
+
return dgl.sum_edges(batched_graph, feat_key)
|
| 64 |
+
elif op == 'max':
|
| 65 |
+
return dgl.max_edges(batched_graph, feat_key)
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"Unknown op: {op}")
|
| 68 |
+
else:
|
| 69 |
+
# Ensure edge_split is a flat list of ints
|
| 70 |
+
if isinstance(edge_split, torch.Tensor):
|
| 71 |
+
splits = edge_split.view(-1).tolist()
|
| 72 |
+
else:
|
| 73 |
+
splits = [int(x) for x in edge_split]
|
| 74 |
+
chunks = torch.split(e, splits, dim=0)
|
| 75 |
+
if op == 'mean':
|
| 76 |
+
out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
|
| 77 |
+
elif op == 'sum':
|
| 78 |
+
out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
|
| 79 |
+
elif op == 'max':
|
| 80 |
+
out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(f"Unknown op: {op}")
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
|
| 86 |
+
layers = []
|
| 87 |
+
layers.append(nn.Linear(in_size, out_size))
|
| 88 |
+
layers.append(activation())
|
| 89 |
+
layers.append(nn.Dropout(dropout))
|
| 90 |
+
return layers
|
| 91 |
+
|
| 92 |
+
def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
|
| 93 |
+
layers = []
|
| 94 |
+
if n_layers > 1:
|
| 95 |
+
layers += Make_SLP(in_size, hid_size, activation, dropout)
|
| 96 |
+
for i in range(n_layers-2):
|
| 97 |
+
layers += Make_SLP(hid_size, hid_size, activation, dropout)
|
| 98 |
+
layers += Make_SLP(hid_size, out_size, activation, dropout)
|
| 99 |
+
else:
|
| 100 |
+
layers += Make_SLP(in_size, out_size, activation, dropout)
|
| 101 |
+
layers.append(torch.nn.LayerNorm(out_size))
|
| 102 |
+
return nn.Sequential(*layers)
|
| 103 |
+
|
| 104 |
+
def broadcast_global_to_nodes(globals, node_split):
|
| 105 |
+
"""
|
| 106 |
+
globals: [num_graphs, global_dim]
|
| 107 |
+
node_split: list/1D tensor of length num_graphs, number of nodes per graph
|
| 108 |
+
Returns: [total_num_nodes, global_dim]
|
| 109 |
+
"""
|
| 110 |
+
if node_split is None:
|
| 111 |
+
raise ValueError("node_split must be provided")
|
| 112 |
+
if not torch.is_tensor(node_split):
|
| 113 |
+
node_split = torch.tensor(node_split, dtype=torch.long, device=globals.device)
|
| 114 |
+
else:
|
| 115 |
+
node_split = node_split.to(device=globals.device, dtype=torch.long)
|
| 116 |
+
node_split = node_split.flatten()
|
| 117 |
+
return torch.repeat_interleave(globals, node_split, dim=0)
|
| 118 |
+
|
| 119 |
+
def broadcast_global_to_edges(globals, edge_split):
|
| 120 |
+
"""
|
| 121 |
+
globals: [num_graphs, global_dim] (on CUDA or CPU)
|
| 122 |
+
edge_split: list/1D tensor of length num_graphs, number of edges per graph (CPU or CUDA)
|
| 123 |
+
Returns: [total_num_edges, global_dim]
|
| 124 |
+
"""
|
| 125 |
+
if edge_split is None:
|
| 126 |
+
raise ValueError("edge_split must be provided")
|
| 127 |
+
if not torch.is_tensor(edge_split):
|
| 128 |
+
edge_split = torch.tensor(edge_split, dtype=torch.long, device=globals.device)
|
| 129 |
+
else:
|
| 130 |
+
edge_split = edge_split.to(device=globals.device, dtype=torch.long)
|
| 131 |
+
edge_split = edge_split.flatten()
|
| 132 |
+
return torch.repeat_interleave(globals, edge_split, dim=0)
|
| 133 |
+
|
| 134 |
+
def copy_v(edges):
|
| 135 |
+
return {'m_v': edges.dst['h']}
|
physicsnemo/train.py
CHANGED
|
@@ -23,6 +23,7 @@ import models.MeshGraphNet as MeshGraphNet
|
|
| 23 |
from dataset.Dataset import get_dataset
|
| 24 |
import metrics
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
class MGNTrainer:
|
| 28 |
def __init__(self, logger, cfg, dist):
|
|
@@ -30,8 +31,6 @@ class MGNTrainer:
|
|
| 30 |
self.device = dist.device
|
| 31 |
logger.info(f"Using {self.device} device")
|
| 32 |
|
| 33 |
-
params = {}
|
| 34 |
-
|
| 35 |
start = time.time()
|
| 36 |
self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
|
| 37 |
print(f"total time loading dataset: {time.time() - start:.2f} seconds")
|
|
@@ -42,20 +41,10 @@ class MGNTrainer:
|
|
| 42 |
else:
|
| 43 |
self.dtype = torch.float32
|
| 44 |
|
| 45 |
-
|
| 46 |
-
edge_features = ["dR", "deta", "dphi"]
|
| 47 |
-
global_features = ["num_nodes"]
|
| 48 |
-
|
| 49 |
-
params["infeat_nodes"] = len(node_features)
|
| 50 |
-
params["infeat_edges"] = len(edge_features)
|
| 51 |
-
params["infeat_globals"] = len(global_features)
|
| 52 |
-
params["out_dim"] = cfg.architecture.out_dim
|
| 53 |
-
params["node_features"] = list(node_features)
|
| 54 |
-
params["edge_features"] = edge_features
|
| 55 |
-
params["global_features"] = global_features
|
| 56 |
-
|
| 57 |
-
self.model = MeshGraphNet.MeshGraphNet(cfg.architecture)
|
| 58 |
self.model = self.model.to(dtype=self.dtype, device=self.device)
|
|
|
|
|
|
|
| 59 |
|
| 60 |
if cfg.performance.jit:
|
| 61 |
self.model = torch.jit.script(self.model).to(self.device)
|
|
@@ -81,7 +70,6 @@ class MGNTrainer:
|
|
| 81 |
device=self.device,
|
| 82 |
)
|
| 83 |
|
| 84 |
-
self.params = params
|
| 85 |
self.cfg = cfg
|
| 86 |
|
| 87 |
def backward(self, loss):
|
|
@@ -244,9 +232,6 @@ def do_training(cfg: DictConfig):
|
|
| 244 |
)
|
| 245 |
start = time.time()
|
| 246 |
trainer.scheduler.step()
|
| 247 |
-
|
| 248 |
-
with open(cfg.checkpoints.ckpt_path + "/parameters.json", "w") as outf:
|
| 249 |
-
json.dump(trainer.params, outf, indent=4)
|
| 250 |
logger.info("Training completed!")
|
| 251 |
|
| 252 |
|
|
|
|
| 23 |
from dataset.Dataset import get_dataset
|
| 24 |
import metrics
|
| 25 |
|
| 26 |
+
import utils
|
| 27 |
|
| 28 |
class MGNTrainer:
|
| 29 |
def __init__(self, logger, cfg, dist):
|
|
|
|
| 31 |
self.device = dist.device
|
| 32 |
logger.info(f"Using {self.device} device")
|
| 33 |
|
|
|
|
|
|
|
| 34 |
start = time.time()
|
| 35 |
self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
|
| 36 |
print(f"total time loading dataset: {time.time() - start:.2f} seconds")
|
|
|
|
| 41 |
else:
|
| 42 |
self.dtype = torch.float32
|
| 43 |
|
| 44 |
+
self.model = utils.build_from_module(cfg.architecture)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
self.model = self.model.to(dtype=self.dtype, device=self.device)
|
| 46 |
+
# num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 47 |
+
# print(f"Number of trainable parameters: {num_params}")
|
| 48 |
|
| 49 |
if cfg.performance.jit:
|
| 50 |
self.model = torch.jit.script(self.model).to(self.device)
|
|
|
|
| 70 |
device=self.device,
|
| 71 |
)
|
| 72 |
|
|
|
|
| 73 |
self.cfg = cfg
|
| 74 |
|
| 75 |
def backward(self, loss):
|
|
|
|
| 232 |
)
|
| 233 |
start = time.time()
|
| 234 |
trainer.scheduler.step()
|
|
|
|
|
|
|
|
|
|
| 235 |
logger.info("Training completed!")
|
| 236 |
|
| 237 |
|
physicsnemo/utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from types import SimpleNamespace
|
| 3 |
+
|
| 4 |
+
def build_from_module(cfg):
|
| 5 |
+
modname = cfg['module']
|
| 6 |
+
classname = cfg['class']
|
| 7 |
+
args = cfg['args']
|
| 8 |
+
module = importlib.import_module(modname)
|
| 9 |
+
model_cls = getattr(module, classname)
|
| 10 |
+
cfg_obj = SimpleNamespace(**args)
|
| 11 |
+
return model_cls(cfg_obj)
|