ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from types import NoneType
from typing import Any, Callable, Dict, Tuple, TypeAlias, Union
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import Data as PyGData
from torch_geometric.data import HeteroData as PyGHeteroData
try:
import dgl # noqa: F401 for docs
import dgl.function as fn
from dgl import DGLGraph
except ImportError:
warnings.warn(
"Note: This only applies if you're using DGL.\n"
"MeshGraphNet (DGL version) requires the DGL library.\n"
"Install it with your preferred CUDA version from:\n"
"https://www.dgl.ai/pages/start.html\n"
)
DGLGraph: TypeAlias = NoneType
try:
import torch_scatter
except ImportError:
warnings.warn(
"MeshGraphNet will soon require PyTorch Geometric and torch_scatter.\n"
"Install it from here:\n"
"https://github.com/rusty1s/pytorch_scatter\n"
)
from physicsnemo.models.gnn_layers import CuGraphCSC
GraphType: TypeAlias = PyGData | PyGHeteroData | DGLGraph | CuGraphCSC
try:
from pylibcugraphops.pytorch.operators import (
agg_concat_e2n,
update_efeat_bipartite_e2e,
update_efeat_static_e2e,
)
USE_CUGRAPHOPS = True
except ImportError:
update_efeat_bipartite_e2e = None
update_efeat_static_e2e = None
agg_concat_e2n = None
USE_CUGRAPHOPS = False
def checkpoint_identity(layer: Callable, *args: Any, **kwargs: Any) -> Any:
"""Applies the identity function for checkpointing.
This function serves as an identity function for use with model layers
when checkpointing is not enabled. It simply forwards the input arguments
to the specified layer and returns its output.
Parameters
----------
layer : Callable
The model layer or function to apply to the input arguments.
*args
Positional arguments to be passed to the layer.
**kwargs
Keyword arguments to be passed to the layer.
Returns
-------
Any
The output of the specified layer after processing the input arguments.
"""
return layer(*args)
def set_checkpoint_fn(do_checkpointing: bool) -> Callable:
"""Sets checkpoint function.
This function returns the appropriate checkpoint function based on the
provided `do_checkpointing` flag. If `do_checkpointing` is True, the
function returns the checkpoint function from PyTorch's
`torch.utils.checkpoint`. Otherwise, it returns an identity function
that simply passes the inputs through the given layer.
Parameters
----------
do_checkpointing : bool
Whether to use checkpointing for gradient computation. Checkpointing
can reduce memory usage during backpropagation at the cost of
increased computation time.
Returns
-------
Callable
The selected checkpoint function to use for gradient computation.
"""
if do_checkpointing:
return checkpoint
else:
return checkpoint_identity
def concat_message_function(edges: Tensor) -> Dict[str, Tensor]:
"""Concatenates source node, destination node, and edge features.
Parameters
----------
edges : Tensor
Edges.
Returns
-------
Dict[Tensor]
Concatenated source node, destination node, and edge features.
"""
# concats src node , dst node, and edge features
cat_feat = torch.cat((edges.data["x"], edges.src["x"], edges.dst["x"]), dim=1)
return {"cat_feat": cat_feat}
@torch.jit.ignore()
def concat_efeat_dgl(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[torch.Tensor, torch.Tensor]],
graph: DGLGraph,
) -> Tensor:
"""Concatenates edge features with source and destination node features.
Use for homogeneous graphs.
Parameters
----------
efeat : Tensor
Edge features.
nfeat : Tensor | Tuple[Tensor, Tensor]
Node features.
graph : DGLGraph
Graph.
Returns
-------
Tensor
Concatenated edge features with source and destination node features.
"""
if isinstance(nfeat, Tuple):
src_feat, dst_feat = nfeat
with graph.local_scope():
graph.srcdata["x"] = src_feat
graph.dstdata["x"] = dst_feat
graph.edata["x"] = efeat
graph.apply_edges(concat_message_function)
return graph.edata["cat_feat"]
with graph.local_scope():
graph.ndata["x"] = nfeat
graph.edata["x"] = efeat
graph.apply_edges(concat_message_function)
return graph.edata["cat_feat"]
@torch.jit.ignore()
def concat_efeat_hetero_dgl(
mesh_efeat: Tensor,
world_efeat: Tensor,
nfeat: Union[Tensor, Tuple[torch.Tensor, torch.Tensor]],
graph: DGLGraph,
) -> Tensor:
"""Concatenates edge features with source and destination node features.
Use for heterogeneous graphs.
Parameters
----------
mesh_efeat : Tensor
Mesh edge features.
world_efeat : Tensor
World edge features.
nfeat : Tensor | Tuple[Tensor, Tensor]
Node features.
graph : DGLGraph
Graph.
Returns
-------
Tensor
Concatenated edge features with source and destination node features.
"""
if isinstance(nfeat, Tuple):
src_feat, dst_feat = nfeat
with graph.local_scope():
graph.srcdata["x"] = src_feat
graph.dstdata["x"] = dst_feat
graph.edata["x"] = torch.cat([mesh_efeat, world_efeat], dim=0)
graph.apply_edges(concat_message_function)
return graph.edata["cat_feat"]
with graph.local_scope():
graph.ndata["x"] = nfeat
graph.edata["x"] = torch.cat([mesh_efeat, world_efeat], dim=0)
graph.apply_edges(concat_message_function)
return graph.edata["cat_feat"]
def concat_efeat_pyg(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor, Tensor]],
graph: PyGData | PyGHeteroData,
) -> Tensor:
"""Concatenates edge features with source and destination node features.
Use for PyG graphs.
Parameters
----------
efeat : Tensor
Edge features.
nfeat : Tensor | Tuple[Tensor]
Node features.
graph : PyGData
Graph.
Returns
-------
Tensor
Concatenated edge features with source and destination node features.
"""
src_feat, dst_feat = nfeat if isinstance(nfeat, Tuple) else (nfeat, nfeat)
if isinstance(graph, PyGHeteroData):
src_idx, dst_idx = graph[graph.edge_types[0]].edge_index.long()
else:
src_idx, dst_idx = graph.edge_index.long()
cat_feat = torch.cat((efeat, src_feat[src_idx], dst_feat[dst_idx]), dim=1)
return cat_feat
def concat_efeat(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: GraphType,
) -> Tensor:
"""Concatenates edge features with source and destination node features.
Use for homogeneous graphs.
Parameters
----------
efeat : Tensor
Edge features.
nfeat : Tensor | Tuple[Tensor]
Node features.
graph : GraphType
Graph.
Returns
-------
Tensor
Concatenated edge features with source and destination node features.
"""
if isinstance(nfeat, Tensor):
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
src_feat, dst_feat = nfeat, nfeat
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(nfeat)
efeat = concat_efeat_dgl(
efeat, (src_feat, dst_feat), graph.to_dgl_graph()
)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(nfeat)
# torch.int64 to avoid indexing overflows due tu current behavior of cugraph-ops
bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64)
dst_feat = nfeat
efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, "concat"
)
else:
static_graph = graph.to_static_csc()
efeat = update_efeat_static_e2e(
efeat,
nfeat,
static_graph,
mode="concat",
use_source_emb=True,
use_target_emb=True,
)
elif isinstance(graph, DGLGraph):
efeat = concat_efeat_dgl(efeat, nfeat, graph)
elif isinstance(graph, (PyGData, PyGHeteroData)):
efeat = concat_efeat_pyg(efeat, nfeat, graph)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
elif isinstance(nfeat, Tuple):
src_feat, dst_feat = nfeat
# update edge features through concatenating edge and node features
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
efeat = concat_efeat_dgl(
efeat, (src_feat, dst_feat), graph.to_dgl_graph()
)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
# torch.int64 to avoid indexing overflows due tu current behavior of cugraph-ops
bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64)
efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, "concat"
)
elif isinstance(graph, DGLGraph):
efeat = concat_efeat_dgl(efeat, (src_feat, dst_feat), graph)
elif isinstance(graph, (PyGData, PyGHeteroData)):
efeat = concat_efeat_pyg(efeat, (src_feat, dst_feat), graph)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
else:
raise ValueError(f"Unsupported node feature type: {type(nfeat)}")
return efeat
def concat_efeat_hetero(
mesh_efeat: Tensor,
world_efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor, Tensor]],
graph: GraphType,
) -> Tensor:
"""Concatenates edge features with source and destination node features.
Use for heterogeneous graphs.
"""
if isinstance(graph, CuGraphCSC):
raise NotImplementedError(
"concat_efeat_hetero is not supported for CuGraphCSC graphs yet."
)
elif isinstance(graph, DGLGraph):
efeat = concat_efeat_hetero_dgl(mesh_efeat, world_efeat, nfeat, graph)
elif isinstance(graph, PyGData):
efeat = concat_efeat_pyg(
torch.cat((mesh_efeat, world_efeat), dim=0), nfeat, graph
)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
return efeat
@torch.jit.script
def sum_edge_node_feat(
efeat: Tensor, src_feat: Tensor, dst_feat: Tensor, src_idx: Tensor, dst_idx: Tensor
) -> Tensor:
"""Sums edge features with source and destination node features.
Parameters
----------
efeat : Tensor
Edge features.
src_feat : Tensor
Source node features.
dst_feat : Tensor
Destination node features.
src_idx : Tensor
Source node indices.
dst_idx : Tensor
Destination node indices.
Returns
-------
Tensor
Sum of edge features with source and destination node features.
"""
return efeat + src_feat[src_idx] + dst_feat[dst_idx]
def sum_efeat(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: GraphType,
):
"""Sums edge features with source and destination node features.
Parameters
----------
efeat : Tensor
Edge features.
nfeat : Tensor | Tuple[Tensor]
Node features (static setting) or tuple of node features of
source and destination nodes (bipartite setting).
graph : GraphType
The underlying graph.
Returns
-------
Tensor
Sum of edge features with source and destination node features.
"""
if isinstance(nfeat, Tensor):
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
src_feat, dst_feat = nfeat, nfeat
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
src, dst = (item.long() for item in graph.to_dgl_graph().edges())
sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(nfeat)
dst_feat = nfeat
bipartite_graph = graph.to_bipartite_csc()
sum_efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, mode="sum"
)
else:
static_graph = graph.to_static_csc()
sum_efeat = update_efeat_bipartite_e2e(
efeat, nfeat, static_graph, mode="sum"
)
elif isinstance(graph, DGLGraph):
src_feat, dst_feat = nfeat, nfeat
src, dst = (item.long() for item in graph.edges())
sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst)
elif isinstance(graph, PyGData):
src_feat, dst_feat = nfeat, nfeat
src, dst = graph.edge_index.long()
sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
else:
src_feat, dst_feat = nfeat
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
src, dst = (item.long() for item in graph.to_dgl_graph().edges())
sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
bipartite_graph = graph.to_bipartite_csc()
sum_efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, mode="sum"
)
elif isinstance(graph, DGLGraph):
src, dst = (item.long() for item in graph.edges())
sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst)
elif isinstance(graph, (PyGData, PyGHeteroData)):
if isinstance(graph, PyGHeteroData):
src, dst = graph[graph.edge_types[0]].edge_index.long()
else:
src, dst = graph.edge_index.long()
sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
return sum_efeat
@torch.jit.ignore()
def agg_concat_dgl(
efeat: Tensor, dst_nfeat: Tensor, graph: DGLGraph, aggregation: str
) -> Tensor:
"""Aggregates edge features and concatenates result with destination node features.
Parameters
----------
efeat : Tensor
Edge features.
nfeat : Tensor
Node features (destination nodes).
graph : DGLGraph
Graph.
aggregation : str
Aggregation method (sum or mean).
Returns
-------
Tensor
Aggregated edge features concatenated with destination node features.
Raises
------
RuntimeError
If aggregation method is not sum or mean.
"""
with graph.local_scope():
# populate features on graph edges
graph.edata["x"] = efeat
# aggregate edge features
if aggregation == "sum":
graph.update_all(fn.copy_e("x", "m"), fn.sum("m", "h_dest"))
elif aggregation == "mean":
graph.update_all(fn.copy_e("x", "m"), fn.mean("m", "h_dest"))
else:
raise RuntimeError("Not a valid aggregation!")
# concat dst-node & edge features
cat_feat = torch.cat((graph.dstdata["h_dest"], dst_nfeat), -1)
return cat_feat
@torch.jit.ignore()
def agg_concat_hetero_dgl(
mesh_efeat: Tensor,
world_efeat: Tensor,
dst_nfeat: Tensor,
graph: DGLGraph,
aggregation: str,
) -> Tensor:
"""Aggregates edge features and concatenates result with destination node features.
Use for heterogeneous graphs.
Parameters
----------
mesh_efeat : Tensor
Mesh edge features.
world_efeat : Tensor
World edge features.
dst_nfeat : Tensor
Node features (destination nodes).
graph : DGLGraph
Graph.
aggregation : str
Aggregation method (sum or mean).
Returns
-------
Tensor
Aggregated edge features concatenated with destination node features.
Raises
------
RuntimeError
If aggregation method is not sum or mean.
"""
with graph.local_scope():
# populate features on graph edges
graph.edata["x"] = torch.cat([mesh_efeat, world_efeat], dim=0)
# aggregate edge features
if aggregation == "sum":
graph.update_all(fn.copy_e("x", "m"), fn.sum("m", "h_dest"))
elif aggregation == "mean":
graph.update_all(fn.copy_e("x", "m"), fn.mean("m", "h_dest"))
else:
raise RuntimeError("Not a valid aggregation!")
# concat dst-node & edge features
cat_feat = torch.cat((graph.dstdata["h_dest"], dst_nfeat), -1)
return cat_feat
def agg_concat_pyg(
efeat: Tensor,
nfeat: Tensor,
graph: PyGData | PyGHeteroData,
aggregation: str,
) -> Tensor:
if isinstance(graph, PyGHeteroData):
_, dst = graph[graph.edge_types[0]].edge_index.long()
else:
_, dst = graph.edge_index.long()
h_dest = torch_scatter.scatter(
efeat, dst, dim=0, dim_size=nfeat.shape[0], reduce=aggregation
)
cat_feat = torch.cat((h_dest, nfeat), -1)
return cat_feat
def aggregate_and_concat(
efeat: Tensor,
nfeat: Tensor,
graph: GraphType,
aggregation: str,
):
"""
Aggregates edge features and concatenates result with destination node features.
Parameters
----------
efeat : Tensor
Edge features.
nfeat : Tensor
Node features (destination nodes).
graph : GraphType
Graph.
aggregation : str
Aggregation method (sum or mean).
Returns
-------
Tensor
Aggregated edge features concatenated with destination node features.
Raises
------
RuntimeError
If aggregation method is not sum or mean.
"""
if isinstance(graph, CuGraphCSC):
# in this case, we don't have to distinguish a distributed setting
# or the defalt setting as both efeat and nfeat are already
# gurantueed to be on the same rank on both cases due to our
# partitioning scheme
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
cat_feat = agg_concat_dgl(efeat, nfeat, graph.to_dgl_graph(), aggregation)
else:
static_graph = graph.to_static_csc()
cat_feat = agg_concat_e2n(nfeat, efeat, static_graph, aggregation)
elif isinstance(graph, DGLGraph):
cat_feat = agg_concat_dgl(efeat, nfeat, graph, aggregation)
elif isinstance(graph, (PyGData, PyGHeteroData)):
cat_feat = agg_concat_pyg(efeat, nfeat, graph, aggregation)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
return cat_feat
def aggregate_and_concat_hetero(
mesh_efeat: Tensor,
world_efeat: Tensor,
nfeat: Tensor,
graph: GraphType,
aggregation: str,
):
"""
Aggregates edge features and concatenates result with destination node features.
Use for heterogeneous graphs.
Parameters
----------
mesh_efeat : Tensor
Mesh edge features.
world_efeat : Tensor
World edge features.
nfeat : Tensor
Node features (destination nodes).
graph : GraphType
Graph.
aggregation : str
Aggregation method (sum or mean).
Returns
-------
Tensor
Aggregated edge features concatenated with destination node features.
Raises
------
RuntimeError
If aggregation method is not sum or mean.
"""
if isinstance(graph, CuGraphCSC):
raise NotImplementedError(
"aggregate_and_concat_hetero is not supported for CuGraphCSC graphs yet."
)
elif isinstance(graph, DGLGraph):
cat_feat = agg_concat_hetero_dgl(
mesh_efeat, world_efeat, nfeat, graph, aggregation
)
elif isinstance(graph, PyGData):
cat_feat = agg_concat_pyg(
torch.cat((mesh_efeat, world_efeat), dim=0), nfeat, graph, aggregation
)
else:
raise ValueError(f"Unsupported graph type: {type(graph)}")
return cat_feat