File size: 1,979 Bytes
a10ecc5 d646e7f a10ecc5 5ceead6 a10ecc5 5ceead6 a10ecc5 5ceead6 a10ecc5 5ceead6 a10ecc5 5ceead6 d646e7f 5ceead6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import torch.nn as nn
import dgl
from models import utils
# Import the PhysicsNemo MeshGraphNet model
from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet
class MeshGraphNet(nn.Module):
def __init__(self, cfg):
super().__init__()
base_gnn_cfg = cfg.base_gnn
self.base_gnn = PhysicsNemoMeshGraphNet(**base_gnn_cfg)
self.global_mlp = nn.Sequential(
nn.Linear(cfg.global_feat_dim, cfg.global_emb_dim),
nn.ReLU(),
)
self.mlp = nn.Linear(
base_gnn_cfg['output_dim'] + base_gnn_cfg['input_dim_edges'] + cfg.global_emb_dim,
cfg.out_dim
)
def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}):
"""
node_feats: [total_num_nodes, node_feat_dim]
edge_feats: [total_num_edges, edge_feat_dim]
global_feats: [num_graphs, global_feat_dim]
batched_graph: DGLGraph, representing the collection of graphs in a batch
metadata: dict, may contain 'batch_num_nodes', 'batch_num_edges', etc.
Returns:
graph_pred: [num_graphs, out_dim]
"""
node_pred = self.base_gnn(node_feats, edge_feats, batched_graph)
batched_graph.ndata['h'] = node_pred
batched_graph.edata['e'] = edge_feats
graph_node_feat = utils.mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None))
graph_edge_feat = utils.mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None))
# Flatten global_feats if needed
if global_feats.ndim == 3:
global_feats = global_feats.view(-1, global_feats.shape[-1])
global_emb = self.global_mlp(global_feats) # [num_graphs, global_emb_dim]
combined_feat = torch.cat([graph_node_feat, graph_edge_feat, global_emb], dim=-1)
graph_pred = self.mlp(combined_feat)
return graph_pred
|