ho22joshua commited on
Commit
a10ecc5
·
1 Parent(s): f95cb56

working physicsnemo training script

Browse files
physicsnemo/Dataset.py CHANGED
@@ -42,7 +42,7 @@ def make_graph(node_features: np.array, dtype=torch.float32):
42
  dphi = phi[src] - phi[dst]
43
  dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi
44
  dR = torch.sqrt(deta ** 2 + dphi ** 2)
45
- edge_features = torch.stack([deta, dphi, dR], dim=1)
46
  g.edata['features'] = edge_features
47
 
48
  g.globals = torch.tensor([num_nodes], dtype=dtype)
@@ -102,7 +102,7 @@ def process_chunk(args):
102
  node_features = np.empty((0, len(features)))
103
  graphs.append(make_graph(node_features, dtype=dtype))
104
 
105
- labels = torch.full((len(graphs),), label, dtype=torch.long)
106
  dgl.save_graphs(f"{save_path}/{name}_{chunk_id:02d}.bin", graphs, {'label': labels})
107
  print(f"Saved {name} chunk {chunk_id:02d} to {save_path}/{name}_{chunk_id:03d}.bin")
108
  return
@@ -151,8 +151,7 @@ class Root_Graph:
151
  num_entries = tree.num_entries
152
 
153
  print(f"Getting branches: {branches}")
154
- graphs = []
155
-
156
  step_size = math.ceil(num_entries / self.chunks)
157
 
158
  # Prepare chunk arguments for each chunk
@@ -241,6 +240,8 @@ class Root_Graph:
241
  val_labels = val_label_dict['label']
242
  test_labels = test_label_dict['label']
243
 
 
 
244
  return train_graphs, train_labels, val_graphs, val_labels, test_graphs, test_labels
245
 
246
  class GraphDataset(Dataset):
@@ -284,8 +285,9 @@ def get_dataset(cfg: DictConfig):
284
 
285
  batch_size = cfg.root_graph.batch_size
286
 
287
- train_loader = GraphDataLoader(train_dataset, {'batch_size' : batch_size, 'shuffle' : True})
288
- val_loader = GraphDataLoader(val_dataset, {'batch_size' : batch_size, 'shuffle' : False})
289
- test_loader = GraphDataLoader(test_dataset, {'batch_size' : batch_size, 'shuffle' : False})
290
 
 
291
  return train_loader, val_loader, test_loader
 
42
  dphi = phi[src] - phi[dst]
43
  dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi
44
  dR = torch.sqrt(deta ** 2 + dphi ** 2)
45
+ edge_features = torch.stack([dR, deta, dphi], dim=1)
46
  g.edata['features'] = edge_features
47
 
48
  g.globals = torch.tensor([num_nodes], dtype=dtype)
 
102
  node_features = np.empty((0, len(features)))
103
  graphs.append(make_graph(node_features, dtype=dtype))
104
 
105
+ labels = torch.full((len(graphs),), label, dtype=dtype)
106
  dgl.save_graphs(f"{save_path}/{name}_{chunk_id:02d}.bin", graphs, {'label': labels})
107
  print(f"Saved {name} chunk {chunk_id:02d} to {save_path}/{name}_{chunk_id:03d}.bin")
108
  return
 
151
  num_entries = tree.num_entries
152
 
153
  print(f"Getting branches: {branches}")
154
+
 
155
  step_size = math.ceil(num_entries / self.chunks)
156
 
157
  # Prepare chunk arguments for each chunk
 
240
  val_labels = val_label_dict['label']
241
  test_labels = test_label_dict['label']
242
 
243
+ print(f"successfully loaded {self.name}")
244
+
245
  return train_graphs, train_labels, val_graphs, val_labels, test_graphs, test_labels
246
 
247
  class GraphDataset(Dataset):
 
285
 
286
  batch_size = cfg.root_graph.batch_size
287
 
288
+ train_loader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
289
+ val_loader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
290
+ test_loader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False)
291
 
292
+ print("all data loaded successfully")
293
  return train_loader, val_loader, test_loader
physicsnemo/MeshGraphNet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import dgl
4
+
5
+ # Import the PhysicsNemo MeshGraphNet model
6
+ from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
7
+
8
+ class MeshGraphNet(nn.Module):
9
+ def __init__(self, *args, out_dim=1, **kwargs):
10
+ super().__init__()
11
+ # Initialize the PhysicsNemo MeshGraphNet
12
+ self.base_gnn = PhysicsNemoMeshGraphNet(*args, **kwargs)
13
+ # Assume node_output_dim is known or infer from args/kwargs
14
+ node_output_dim = kwargs.get('hidden_dim_node_decoder', 64)
15
+ self.mlp = nn.Linear(node_output_dim, out_dim)
16
+
17
+ def forward(self, node_feats, edge_feats, batched_graph):
18
+ """
19
+ Args:
20
+ node_feats: [total_num_nodes, node_feat_dim]
21
+ edge_feats: [total_num_edges, edge_feat_dim]
22
+ batched_graph: DGLGraph, batched graphs
23
+ Returns:
24
+ graph_pred: [num_graphs, out_dim]
25
+ """
26
+ node_pred = self.base_gnn(node_feats, edge_feats, batched_graph)
27
+ batched_graph.ndata['h'] = node_pred
28
+ graph_feat = dgl.readout_nodes(batched_graph, 'h', op='mean') # [num_graphs, node_output_dim]
29
+ graph_pred = self.mlp(graph_feat) # [num_graphs, out_dim]
30
+ return graph_pred
physicsnemo/config.yaml CHANGED
@@ -20,16 +20,8 @@ scheduler:
20
  lr_decay: 1.E-3
21
 
22
  training:
23
- batch_size: 100
24
  epochs: 100
25
- geometries: "healthy"
26
- stride: 5
27
- rate_noise: 100
28
- train_test_split: 0.9
29
- loss_weight_1st_timestep: 1
30
- loss_weight_other_timesteps: 0.5
31
- loss_weight_boundary_nodes: 100
32
-
33
  checkpoints:
34
  ckpt_path: "checkpoints"
35
  ckpt_name: "model.pt"
@@ -47,6 +39,7 @@ architecture:
47
  hidden_dim_edge_encoder: 64
48
  hidden_dim_processor: 64
49
  hidden_dim_node_decoder: 64
 
50
 
51
  paths:
52
  data_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K
 
20
  lr_decay: 1.E-3
21
 
22
  training:
 
23
  epochs: 100
24
+
 
 
 
 
 
 
 
25
  checkpoints:
26
  ckpt_path: "checkpoints"
27
  ckpt_name: "model.pt"
 
39
  hidden_dim_edge_encoder: 64
40
  hidden_dim_processor: 64
41
  hidden_dim_node_decoder: 64
42
+ out_dim: 1
43
 
44
  paths:
45
  data_dir: /global/cfs/projectdirs/trn007/lbl_atlas/data/stats_100K
physicsnemo/setup/Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06
2
+
3
+ WORKDIR /global/cfs/projectdirs/atlas/joshua/GNN4Colliders
4
+
5
+ LABEL maintainer.name="Joshua Ho"
6
+ LABEL maintainer.email="ho22joshua@berkeley.edu"
7
+
8
+ ENV LANG=C.UTF-8
9
+
10
+ # Install system dependencies: vim, OpenMPI, and build tools
11
+ RUN apt-get update -qq \
12
+ && apt-get install -y --no-install-recommends \
13
+ wget lsb-release gnupg software-properties-common \
14
+ vim \
15
+ g++-11 gcc-11 libstdc++-11-dev \
16
+ openmpi-bin openmpi-common libopenmpi-dev \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Install Python packages: mpi4py and jupyter
20
+ RUN pip install --no-cache-dir mpi4py jupyter uproot
21
+
22
+ # (Optional) Expose Jupyter port
23
+ EXPOSE 8888
24
+
25
+
physicsnemo/setup/build_image.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tag=$1
2
+ echo $tag
3
+ podman-hpc build -t joshuaho/nemo:$tag --platform linux/amd64 .
4
+ podman-hpc migrate joshuaho/nemo:$tag
physicsnemo/train.py CHANGED
@@ -16,19 +16,212 @@ from physicsnemo.distributed.manager import DistributedManager
16
  from Dataset import get_dataset
17
  import json
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class MGNTrainer:
20
  def __init__(self, logger, cfg, dist):
21
  # set device
22
  self.device = dist.device
23
  logger.info(f"Using {self.device} device")
24
 
 
 
25
  norm_type = {"features": "normal", "labels": "normal"}
26
 
27
- self.train_loader, self.val_loader, self.test_loader = get_dataset(cfg)
28
- print(f"train: {self.train_loader}")
29
- print(f"val: {self.val_loader}")
30
- print(f"test: {self.test_loader}")
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  @hydra.main(version_base=None, config_path=".", config_name="config")
34
  def do_training(cfg: DictConfig):
@@ -55,11 +248,19 @@ def do_training(cfg: DictConfig):
55
  start = time.time()
56
  logger.info("Training started...")
57
  for epoch in range(trainer.epoch_init, cfg.training.epochs):
58
- for graph in trainer.dataloader:
59
- loss = trainer.train(graph)
 
 
 
 
 
 
 
 
60
 
61
  logger.info(
62
- f"epoch: {epoch}, loss: {loss:10.3e}, time per epoch: {(time.time()-start):10.3e}"
63
  )
64
 
65
  # save checkpoint
 
16
  from Dataset import get_dataset
17
  import json
18
 
19
+ from sklearn.metrics import roc_auc_score
20
+
21
+ import MeshGraphNet
22
+
23
+ import torch.nn.functional as F
24
+
25
+ def weighted_bce(input, target, device=None, weights=None):
26
+ """
27
+ Compute a weighted and label-normalized binary cross entropy (BCE) loss.
28
+
29
+ For each unique label in the target tensor, the BCE loss is computed and weighted,
30
+ then normalized by the sum of weights for that label. The final loss is the mean
31
+ of these per-label normalized losses.
32
+
33
+ Args:
34
+ input (Tensor): Predicted logits of shape (N, ...).
35
+ target (Tensor): Ground truth labels of shape (N, ...), with discrete label values.
36
+ device (torch.device or None): Device to move tensors to (optional).
37
+ weights (Tensor or None): Optional tensor of per-sample weights, same shape as input/target.
38
+
39
+ Returns:
40
+ Tensor: Scalar tensor representing the normalized weighted BCE loss.
41
+ """
42
+
43
+ if input.shape != target.shape:
44
+ if input.shape[-1] == 1 and input.shape[:-1] == target.shape:
45
+ input = input.squeeze(-1)
46
+ elif target.shape[-1] == 1 and target.shape[:-1] == input.shape:
47
+ target = target.squeeze(-1)
48
+
49
+ if weights is None:
50
+ weights = torch.ones_like(target)
51
+
52
+ if device is not None:
53
+ input = input.to(device)
54
+ target = target.to(device)
55
+ weights = weights.to(device)
56
+
57
+ # Compute per-element BCE loss (no reduction)
58
+ loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
59
+
60
+ # Vectorized label normalization
61
+ unique_labels = torch.unique(target)
62
+ normalized_losses = []
63
+ for label in unique_labels:
64
+ label_mask = (target == label)
65
+ label_weights = weights[label_mask]
66
+ label_losses = loss[label_mask]
67
+ weight_sum = label_weights.sum()
68
+ if weight_sum > 0:
69
+ label_loss = (label_weights * label_losses).sum() / weight_sum
70
+ normalized_losses.append(label_loss)
71
+
72
+ if normalized_losses:
73
+ return torch.stack(normalized_losses).mean()
74
+ else:
75
+ return torch.tensor(0.0, device=input.device)
76
+
77
  class MGNTrainer:
78
  def __init__(self, logger, cfg, dist):
79
  # set device
80
  self.device = dist.device
81
  logger.info(f"Using {self.device} device")
82
 
83
+ params = {}
84
+
85
  norm_type = {"features": "normal", "labels": "normal"}
86
 
87
+ self.dataloader, self.valloader, self.testloader = get_dataset(cfg)
88
+
89
+ dtype_str = getattr(cfg.root_graph, "type", "torch.float32")
90
+ if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
91
+ self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
92
+ else:
93
+ self.dtype = torch.float32
94
+
95
+ nodes_features = cfg.root_graph.features
96
+ edges_features = ["dR", "deta", "dphi"]
97
+ global_features = ["num_nodes"]
98
 
99
+ params["infeat_nodes"] = len(nodes_features)
100
+ params["infeat_edges"] = len(edges_features)
101
+ params["infeat_globals"] = len(global_features)
102
+ params["out_dim"] = cfg.architecture.hidden_dim_node_encoder
103
+ params["node_features"] = list(nodes_features)
104
+ params["edges_features"] = edges_features
105
+ params["global_features"] = global_features
106
+
107
+ self.model = MeshGraphNet.MeshGraphNet(
108
+ params["infeat_nodes"],
109
+ params["infeat_edges"],
110
+ params['out_dim'],
111
+ processor_size=cfg.architecture.processor_size,
112
+ hidden_dim_node_encoder=cfg.architecture.hidden_dim_node_encoder,
113
+ hidden_dim_edge_encoder=cfg.architecture.hidden_dim_edge_encoder,
114
+ hidden_dim_processor=cfg.architecture.hidden_dim_processor,
115
+ hidden_dim_node_decoder=cfg.architecture.hidden_dim_node_decoder,
116
+ )
117
+ self.model = self.model.to(dtype=self.dtype, device=self.device)
118
+
119
+ if cfg.performance.jit:
120
+ self.model = torch.jit.script(self.model).to(self.device)
121
+ else:
122
+ self.model = self.model.to(self.device)
123
+
124
+ # instantiate loss, optimizer, and scheduler
125
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.scheduler.lr)
126
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
127
+ self.optimizer,
128
+ T_max=cfg.training.epochs,
129
+ eta_min=cfg.scheduler.lr * cfg.scheduler.lr_decay,
130
+ )
131
+ self.scaler = GradScaler(self.device)
132
+
133
+ # load checkpoint
134
+ self.epoch_init = load_checkpoint(
135
+ os.path.join(cfg.checkpoints.ckpt_path, cfg.checkpoints.ckpt_name),
136
+ models=self.model,
137
+ optimizer=self.optimizer,
138
+ scheduler=self.scheduler,
139
+ scaler=self.scaler,
140
+ device=self.device,
141
+ )
142
+
143
+ self.params = params
144
+ self.cfg = cfg
145
+
146
+ def backward(self, loss):
147
+ """
148
+ Perform backward pass.
149
+
150
+ Arguments:
151
+ loss: loss value.
152
+
153
+ """
154
+ # backward pass
155
+ if self.cfg.performance.amp:
156
+ self.scaler.scale(loss).backward()
157
+ self.scaler.step(self.optimizer)
158
+ self.scaler.update()
159
+ else:
160
+ loss.backward()
161
+ self.optimizer.step()
162
+
163
+ def train(self, graph, label):
164
+ """
165
+ Perform one training iteration over one graph. The training is performed
166
+ over multiple timesteps, where the number of timesteps is specified in
167
+ the 'stride' parameter.
168
+
169
+ Arguments:
170
+ graph: the desired graph.
171
+
172
+ Returns:
173
+ loss: loss value.
174
+
175
+ """
176
+ graph = graph.to(self.device)
177
+ self.optimizer.zero_grad()
178
+ pred = self.model(graph.ndata["features"], graph.edata["features"], graph)
179
+ loss = weighted_bce(pred, label, device=self.device)
180
+ self.backward(loss)
181
+ return loss
182
+
183
+ @torch.no_grad()
184
+ def eval(self):
185
+ """
186
+ Evaluate the model on one batch.
187
+
188
+ Args:
189
+ graph (DGLGraph): The input graph.
190
+ label (Tensor): The target labels.
191
+
192
+ Returns:
193
+ loss (Tensor): The computed loss value (scalar).
194
+ """
195
+
196
+ predictions = []
197
+ labels = []
198
+
199
+ for graph, label in self.valloader:
200
+
201
+ graph = graph.to(self.device)
202
+ pred = self.model(graph.ndata["features"], graph.edata["features"], graph)
203
+ predictions.append(pred)
204
+ labels.append(label)
205
+
206
+ predictions = torch.cat(predictions, dim=0)
207
+ labels = torch.cat(labels, dim=0)
208
+
209
+ loss = weighted_bce(predictions, labels, device=self.device)
210
+
211
+ # Convert logits to probabilities
212
+ prob = torch.sigmoid(predictions)
213
+
214
+ # Flatten to 1D arrays
215
+ prob_flat = prob.detach().to(torch.float32).cpu().numpy().flatten()
216
+ labels_flat = labels.detach().to(torch.float32).cpu().numpy().flatten()
217
+
218
+ # Calculate AUC
219
+ try:
220
+ auc = roc_auc_score(labels_flat, prob_flat)
221
+ except ValueError:
222
+ auc = float('nan') # Not enough classes present for AUC
223
+
224
+ return loss, auc
225
 
226
  @hydra.main(version_base=None, config_path=".", config_name="config")
227
  def do_training(cfg: DictConfig):
 
248
  start = time.time()
249
  logger.info("Training started...")
250
  for epoch in range(trainer.epoch_init, cfg.training.epochs):
251
+
252
+ # Training
253
+ train_loss = []
254
+ for graph, label in trainer.dataloader:
255
+ trainer.model.train()
256
+ train_loss.append(trainer.train(graph, label))
257
+
258
+ val_loss, val_auc = trainer.eval()
259
+
260
+ train_loss = torch.mean(torch.stack(train_loss)).item()
261
 
262
  logger.info(
263
+ f"epoch: {epoch}, loss: {train_loss:10.3e}, val_loss: {val_loss:10.3e}, val_auc = {val_auc:10.3e}, time per epoch: {(time.time()-start):10.3e}"
264
  )
265
 
266
  # save checkpoint