| import torch |
| import torch.nn as nn |
| import dgl |
|
|
| from models import utils |
|
|
| |
| from physicsnemo.models.meshgraphnet import MeshGraphNet as PhysicsNemoMeshGraphNet |
|
|
| class MeshGraphNet(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| base_gnn_cfg = cfg.base_gnn |
| self.base_gnn = PhysicsNemoMeshGraphNet(**base_gnn_cfg) |
|
|
| self.global_mlp = nn.Sequential( |
| nn.Linear(cfg.global_feat_dim, cfg.global_emb_dim), |
| nn.ReLU(), |
| ) |
|
|
| self.mlp = nn.Linear( |
| base_gnn_cfg['output_dim'] + base_gnn_cfg['input_dim_edges'] + cfg.global_emb_dim, |
| cfg.out_dim |
| ) |
| |
| def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}): |
| """ |
| node_feats: [total_num_nodes, node_feat_dim] |
| edge_feats: [total_num_edges, edge_feat_dim] |
| global_feats: [num_graphs, global_feat_dim] |
| batched_graph: DGLGraph, representing the collection of graphs in a batch |
| metadata: dict, may contain 'batch_num_nodes', 'batch_num_edges', etc. |
| Returns: |
| graph_pred: [num_graphs, out_dim] |
| """ |
| node_pred = self.base_gnn(node_feats, edge_feats, batched_graph) |
| batched_graph.ndata['h'] = node_pred |
| batched_graph.edata['e'] = edge_feats |
|
|
| graph_node_feat = utils.mean_nodes(batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None)) |
| graph_edge_feat = utils.mean_edges(batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None)) |
|
|
| |
| if global_feats.ndim == 3: |
| global_feats = global_feats.view(-1, global_feats.shape[-1]) |
| global_emb = self.global_mlp(global_feats) |
|
|
| combined_feat = torch.cat([graph_node_feat, graph_edge_feat, global_emb], dim=-1) |
| graph_pred = self.mlp(combined_feat) |
| return graph_pred |
| |
|
|