OneForecast / models /GraphCast.py
YuanGao-YG's picture
Upload 97 files
912fe5a verified
from typing import Any, List, Optional
import dgl
import torch
from dgl import DGLGraph
from torch import Tensor
try:
from typing import Self
except ImportError:
# for Python versions < 3.11
from typing_extensions import Self
try:
from pylibcugraphops.pytorch import BipartiteCSC, StaticCSC
USE_CUGRAPHOPS = True
except ImportError:
StaticCSC = None
BipartiteCSC = None
USE_CUGRAPHOPS = False
class CuGraphCSC:
def __init__(
self,
offsets: Tensor,
indices: Tensor,
num_src_nodes: int,
num_dst_nodes: int,
ef_indices: Optional[Tensor] = None,
reverse_graph_bwd: bool = True,
cache_graph: bool = True,
partition_size: Optional[int] = -1,
partition_group_name: Optional[str] = None,
) -> None:
self.offsets = offsets
self.indices = indices
self.num_src_nodes = num_src_nodes
self.num_dst_nodes = num_dst_nodes
self.ef_indices = ef_indices
self.reverse_graph_bwd = reverse_graph_bwd
self.cache_graph = cache_graph
# cugraph-ops structures
self.bipartite_csc = None
self.static_csc = None
# dgl graph
self.dgl_graph = None
self.is_distributed = False
self.dist_csc = None
if partition_size <= 1:
self.is_distributed = False
return
if self.ef_indices is not None:
raise AssertionError(
"DistributedGraph does not support mapping CSC-indices to COO-indices."
)
# overwrite graph information with local graph after distribution
self.offsets = self.dist_graph.graph_partition.local_offsets
self.indices = self.dist_graph.graph_partition.local_indices
self.num_src_nodes = self.dist_graph.graph_partition.num_local_src_nodes
self.num_dst_nodes = self.dist_graph.graph_partition.num_local_dst_nodes
self.is_distributed = True
@staticmethod
def from_dgl(
graph: DGLGraph,
partition_size: int = 1,
partition_group_name: Optional[str] = None,
partition_by_bbox: bool = False,
src_coordinates: Optional[torch.Tensor] = None,
dst_coordinates: Optional[torch.Tensor] = None,
coordinate_separators_min: Optional[List[List[Optional[float]]]] = None,
coordinate_separators_max: Optional[List[List[Optional[float]]]] = None,
): # pragma: no cover
# DGL changed their APIs w.r.t. how sparse formats can be accessed
# this here is done to support both versions
if hasattr(graph, "adj_tensors"):
offsets, indices, edge_perm = graph.adj_tensors("csc")
elif hasattr(graph, "adj_sparse"):
offsets, indices, edge_perm = graph.adj_sparse("csc")
else:
raise ValueError("Passed graph object doesn't support conversion to CSC.")
n_src_nodes, n_dst_nodes = (graph.num_src_nodes(), graph.num_dst_nodes())
graph_partition = None
graph_csc = CuGraphCSC(
offsets.to(dtype=torch.int64),
indices.to(dtype=torch.int64),
n_src_nodes,
n_dst_nodes,
partition_size=partition_size,
partition_group_name=partition_group_name,
graph_partition=graph_partition,
)
return graph_csc, edge_perm
def to(self, *args: Any, **kwargs: Any) -> Self:
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
if dtype not in (
None,
torch.int32,
torch.int64,
):
raise TypeError(
f"Invalid dtype, expected torch.int32 or torch.int64, got {dtype}."
)
self.offsets = self.offsets.to(device=device, dtype=dtype)
self.indices = self.indices.to(device=device, dtype=dtype)
if self.ef_indices is not None:
self.ef_indices = self.ef_indices.to(device=device, dtype=dtype)
return self
def to_bipartite_csc(self, dtype=None) -> BipartiteCSC:
if not (USE_CUGRAPHOPS):
raise RuntimeError(
"Conversion failed, expected cugraph-ops to be installed."
)
if not self.offsets.is_cuda:
raise RuntimeError("Expected the graph structures to reside on GPU.")
if self.bipartite_csc is None or not self.cache_graph:
graph_offsets = self.offsets
graph_indices = self.indices
graph_ef_indices = self.ef_indices
if dtype is not None:
graph_offsets = self.offsets.to(dtype=dtype)
graph_indices = self.indices.to(dtype=dtype)
if self.ef_indices is not None:
graph_ef_indices = self.ef_indices.to(dtype=dtype)
graph = BipartiteCSC(
graph_offsets,
graph_indices,
self.num_src_nodes,
graph_ef_indices,
reverse_graph_bwd=self.reverse_graph_bwd,
)
self.bipartite_csc = graph
return self.bipartite_csc
def to_static_csc(self, dtype=None) -> StaticCSC:
if not (USE_CUGRAPHOPS):
raise RuntimeError(
"Conversion failed, expected cugraph-ops to be installed."
)
if not self.offsets.is_cuda:
raise RuntimeError("Expected the graph structures to reside on GPU.")
if self.static_csc is None or not self.cache_graph:
graph_offsets = self.offsets
graph_indices = self.indices
graph_ef_indices = self.ef_indices
if dtype is not None:
graph_offsets = self.offsets.to(dtype=dtype)
graph_indices = self.indices.to(dtype=dtype)
if self.ef_indices is not None:
graph_ef_indices = self.ef_indices.to(dtype=dtype)
graph = StaticCSC(
graph_offsets,
graph_indices,
graph_ef_indices,
)
self.static_csc = graph
return self.static_csc
def to_dgl_graph(self) -> DGLGraph: # pragma: no cover
if self.dgl_graph is None or not self.cache_graph:
if self.ef_indices is not None:
raise AssertionError("ef_indices is not supported.")
graph_offsets = self.offsets
dst_degree = graph_offsets[1:] - graph_offsets[:-1]
src_indices = self.indices
dst_indices = torch.arange(
0,
graph_offsets.size(0) - 1,
dtype=graph_offsets.dtype,
device=graph_offsets.device,
)
dst_indices = torch.repeat_interleave(dst_indices, dst_degree, dim=0)
# labels not important here
self.dgl_graph = dgl.heterograph(
{("src", "src2dst", "dst"): ("coo", (src_indices, dst_indices))},
idtype=torch.int32,
)
return self.dgl_graph
from typing import Any, Callable, Dict, Tuple, Union
import dgl.function as fn
import torch
from dgl import DGLGraph
from torch import Tensor
from torch.utils.checkpoint import checkpoint
try:
from pylibcugraphops.pytorch.operators import (
agg_concat_e2n,
update_efeat_bipartite_e2e,
update_efeat_static_e2e,
)
USE_CUGRAPHOPS = True
except ImportError:
update_efeat_bipartite_e2e = None
update_efeat_static_e2e = None
agg_concat_e2n = None
USE_CUGRAPHOPS = False
def checkpoint_identity(layer: Callable, *args: Any, **kwargs: Any) -> Any:
return layer(*args)
def set_checkpoint_fn(do_checkpointing: bool) -> Callable:
if do_checkpointing:
return checkpoint
else:
return checkpoint_identity
def concat_message_function(edges: Tensor) -> Dict[str, Tensor]:
# concats src node , dst node, and edge features
cat_feat = torch.cat((edges.data["x"], edges.src["x"], edges.dst["x"]), dim=1)
return {"cat_feat": cat_feat}
@torch.jit.ignore()
def concat_efeat_dgl(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[torch.Tensor, torch.Tensor]],
graph: DGLGraph,
) -> Tensor:
if isinstance(nfeat, Tuple):
src_feat, dst_feat = nfeat
with graph.local_scope():
graph.srcdata["x"] = src_feat
graph.dstdata["x"] = dst_feat
graph.edata["x"] = efeat
graph.apply_edges(concat_message_function)
return graph.edata["cat_feat"]
with graph.local_scope():
graph.ndata["x"] = nfeat
graph.edata["x"] = efeat
graph.apply_edges(concat_message_function)
return graph.edata["cat_feat"]
def concat_efeat(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
if isinstance(nfeat, Tensor):
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
src_feat, dst_feat = nfeat, nfeat
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(nfeat)
efeat = concat_efeat_dgl(
efeat, (src_feat, dst_feat), graph.to_dgl_graph()
)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(nfeat)
# torch.int64 to avoid indexing overflows due tu current behavior of cugraph-ops
bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64)
dst_feat = nfeat
efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, "concat"
)
else:
static_graph = graph.to_static_csc()
efeat = update_efeat_static_e2e(
efeat,
nfeat,
static_graph,
mode="concat",
use_source_emb=True,
use_target_emb=True,
)
else:
efeat = concat_efeat_dgl(efeat, nfeat, graph)
else:
src_feat, dst_feat = nfeat
# update edge features through concatenating edge and node features
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
efeat = concat_efeat_dgl(
efeat, (src_feat, dst_feat), graph.to_dgl_graph()
)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
# torch.int64 to avoid indexing overflows due tu current behavior of cugraph-ops
bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64)
efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, "concat"
)
else:
efeat = concat_efeat_dgl(efeat, (src_feat, dst_feat), graph)
return efeat
@torch.jit.script
def sum_efeat_dgl(
efeat: Tensor, src_feat: Tensor, dst_feat: Tensor, src_idx: Tensor, dst_idx: Tensor
) -> Tensor:
return efeat + src_feat[src_idx] + dst_feat[dst_idx]
def sum_efeat(
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
):
if isinstance(nfeat, Tensor):
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
src_feat, dst_feat = nfeat, nfeat
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
src, dst = (item.long() for item in graph.to_dgl_graph().edges())
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(nfeat)
dst_feat = nfeat
bipartite_graph = graph.to_bipartite_csc()
sum_efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, mode="sum"
)
else:
static_graph = graph.to_static_csc()
sum_efeat = update_efeat_bipartite_e2e(
efeat, nfeat, static_graph, mode="sum"
)
else:
src_feat, dst_feat = nfeat, nfeat
src, dst = (item.long() for item in graph.edges())
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst)
else:
src_feat, dst_feat = nfeat
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
src, dst = (item.long() for item in graph.to_dgl_graph().edges())
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst)
else:
if graph.is_distributed:
src_feat = graph.get_src_node_features_in_local_graph(src_feat)
bipartite_graph = graph.to_bipartite_csc()
sum_efeat = update_efeat_bipartite_e2e(
efeat, src_feat, dst_feat, bipartite_graph, mode="sum"
)
else:
src, dst = (item.long() for item in graph.edges())
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst)
return sum_efeat
@torch.jit.ignore()
def agg_concat_dgl(
efeat: Tensor, dst_nfeat: Tensor, graph: DGLGraph, aggregation: str
) -> Tensor:
with graph.local_scope():
# populate features on graph edges
graph.edata["x"] = efeat
# aggregate edge features
if aggregation == "sum":
graph.update_all(fn.copy_e("x", "m"), fn.sum("m", "h_dest"))
elif aggregation == "mean":
graph.update_all(fn.copy_e("x", "m"), fn.mean("m", "h_dest"))
else:
raise RuntimeError("Not a valid aggregation!")
# concat dst-node & edge features
cat_feat = torch.cat((graph.dstdata["h_dest"], dst_nfeat), -1)
return cat_feat
def aggregate_and_concat(
efeat: Tensor,
nfeat: Tensor,
graph: Union[DGLGraph, CuGraphCSC],
aggregation: str,
):
if isinstance(graph, CuGraphCSC):
if graph.dgl_graph is not None or not USE_CUGRAPHOPS:
cat_feat = agg_concat_dgl(efeat, nfeat, graph.to_dgl_graph(), aggregation)
else:
static_graph = graph.to_static_csc()
cat_feat = agg_concat_e2n(nfeat, efeat, static_graph, aggregation)
else:
cat_feat = agg_concat_dgl(efeat, nfeat, graph, aggregation)
return cat_feat
import functools
import logging
from typing import Tuple
import torch
from torch.autograd import Function
logger = logging.getLogger(__name__)
try:
import nvfuser
from nvfuser import DataType, FusionDefinition
except ImportError:
logger.error(
"An error occured. Either nvfuser is not installed or the version is "
"incompatible. Please retry after installing correct version of nvfuser. "
"The new version of nvfuser should be available in PyTorch container version "
">= 23.10. "
"https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html. "
"If using a source install method, please refer nvFuser repo for installation "
"guidelines https://github.com/NVIDIA/Fuser.",
)
raise
_torch_dtype_to_nvfuser = {
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.int: DataType.Int,
torch.int32: DataType.Int32,
torch.bool: DataType.Bool,
torch.bfloat16: DataType.BFloat16,
torch.cfloat: DataType.ComplexFloat,
torch.cdouble: DataType.ComplexDouble,
}
@functools.lru_cache(maxsize=None)
def silu_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# z = sigmoid(x)
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))))
grad_input = fd.ops.cast(grad_input, dtype)
fd.add_output(grad_input)
@functools.lru_cache(maxsize=None)
def silu_double_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# z = 1 + x * (1 - y)
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))
# term1 = dy * z
term1 = fd.ops.mul(dy, z)
# term2 = y * ((1 - y) - x * dy)
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy)))
grad_input = fd.ops.add(term1, term2)
grad_input = fd.ops.cast(grad_input, dtype)
fd.add_output(grad_input)
@functools.lru_cache(maxsize=None)
def silu_triple_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
two = fd.define_constant(2.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# ddy = (1 - 2y) * dy
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy)
# term1 = ddy * (2 + x - 2xy)
term1 = fd.ops.mul(
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y)))
)
# term2 = dy * (1 - 2 (y + x * dy))
term2 = fd.ops.mul(
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy))))
)
grad_input = fd.ops.add(term1, term2)
grad_input = fd.ops.cast(grad_input, dtype)
fd.add_output(grad_input)
class FusedSiLU(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.nn.functional.silu(x)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
return FusedSiLU_deriv_1.apply(x) * grad_output
class FusedSiLU_deriv_1(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with FusionDefinition() as fd:
silu_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
return FusedSiLU_deriv_2.apply(x) * grad_output
class FusedSiLU_deriv_2(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with FusionDefinition() as fd:
silu_double_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
return FusedSiLU_deriv_3.apply(x) * grad_output
class FusedSiLU_deriv_3(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with FusionDefinition() as fd:
silu_triple_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
y = torch.sigmoid(x)
dy = y * (1 - y)
ddy = (1 - 2 * y) * dy
dddy = (1 - 2 * y) * ddy - 2 * dy * dy
z = 1 - 2 * (y + x * dy)
term1 = dddy * (2 + x - 2 * x * y)
term2 = 2 * ddy * z
term3 = dy * (-2) * (2 * dy + x * ddy)
return (term1 + term2 + term3) * grad_output
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from torch import Tensor
from torch.autograd.function import once_differentiable
try:
from transformer_engine import pytorch as te
te_imported = True
except ImportError:
te_imported = False
class CustomSiLuLinearAutogradFunction(torch.autograd.Function):
"""Custom SiLU + Linear autograd function"""
@staticmethod
def forward(
ctx,
features: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
out = F.silu(features)
out = F.linear(out, weight, bias)
ctx.save_for_backward(features, weight)
return out
@staticmethod
@once_differentiable
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor],]:
"""backward pass of the SiLU + Linear function"""
from nvfuser import FusionDefinition
(
need_dgrad,
need_wgrad,
need_bgrad,
) = ctx.needs_input_grad
features, weight = ctx.saved_tensors
grad_features = None
grad_weight = None
grad_bias = None
if need_bgrad:
grad_bias = grad_output.sum(dim=0)
if need_wgrad:
out = F.silu(features)
grad_weight = grad_output.T @ out
if need_dgrad:
grad_features = grad_output @ weight
with FusionDefinition() as fd:
silu_backward_for(
fd,
features.dtype,
features.dim(),
features.size(),
features.stride(),
)
grad_silu = fd.execute([features])[0]
grad_features = grad_features * grad_silu
return grad_features, grad_weight, grad_bias
class MeshGraphMLP(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: Union[int, None] = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
recompute_activation: bool = False,
):
super().__init__()
if hidden_layers is not None:
layers = [nn.Linear(input_dim, hidden_dim), activation_fn]
self.hidden_layers = hidden_layers
for _ in range(hidden_layers - 1):
layers += [nn.Linear(hidden_dim, hidden_dim), activation_fn]
layers.append(nn.Linear(hidden_dim, output_dim))
self.norm_type = norm_type
if norm_type is not None:
if norm_type not in [
"LayerNorm",
"TELayerNorm",
]:
raise ValueError(
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm."
)
if norm_type == "TELayerNorm" and te_imported:
norm_layer = te.LayerNorm
elif norm_type == "TELayerNorm" and not te_imported:
raise ValueError(
"TELayerNorm requires transformer-engine to be installed."
)
else:
norm_layer = getattr(nn, norm_type)
layers.append(norm_layer(output_dim))
self.model = nn.Sequential(*layers)
else:
self.model = nn.Identity()
if recompute_activation:
if not isinstance(activation_fn, nn.SiLU):
raise ValueError(activation_fn)
self.recompute_activation = True
else:
self.recompute_activation = False
def default_forward(self, x: Tensor) -> Tensor:
"""default forward pass of the MLP"""
return self.model(x)
@torch.jit.ignore()
def custom_silu_linear_forward(self, x: Tensor) -> Tensor:
"""forward pass of the MLP where SiLU is recomputed in backward"""
lin = self.model[0]
hidden = lin(x)
for i in range(1, self.hidden_layers + 1):
lin = self.model[2 * i]
hidden = CustomSiLuLinearAutogradFunction.apply(
hidden, lin.weight, lin.bias
)
if self.norm_type is not None:
norm = self.model[2 * self.hidden_layers + 1]
hidden = norm(hidden)
return hidden
def forward(self, x: Tensor) -> Tensor:
if self.recompute_activation:
return self.custom_silu_linear_forward(x)
return self.default_forward(x)
class MeshGraphEdgeMLPConcat(MeshGraphMLP):
def __init__(
self,
efeat_dim: int = 512,
src_dim: int = 512,
dst_dim: int = 512,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 2,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
bias: bool = True,
recompute_activation: bool = False,
):
cat_dim = efeat_dim + src_dim + dst_dim
super(MeshGraphEdgeMLPConcat, self).__init__(
cat_dim,
output_dim,
hidden_dim,
hidden_layers,
activation_fn,
norm_type,
recompute_activation,
)
def forward(
self,
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
efeat = concat_efeat(efeat, nfeat, graph)
efeat = self.model(efeat)
return efeat
class MeshGraphEdgeMLPSum(nn.Module):
def __init__(
self,
efeat_dim: int,
src_dim: int,
dst_dim: int,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
bias: bool = True,
recompute_activation: bool = False,
):
super().__init__()
self.efeat_dim = efeat_dim
self.src_dim = src_dim
self.dst_dim = dst_dim
tmp_lin = nn.Linear(efeat_dim + src_dim + dst_dim, hidden_dim, bias=bias)
# orig_weight has shape (hidden_dim, efeat_dim + src_dim + dst_dim)
orig_weight = tmp_lin.weight
w_efeat, w_src, w_dst = torch.split(
orig_weight, [efeat_dim, src_dim, dst_dim], dim=1
)
self.lin_efeat = nn.Parameter(w_efeat)
self.lin_src = nn.Parameter(w_src)
self.lin_dst = nn.Parameter(w_dst)
if bias:
self.bias = tmp_lin.bias
else:
self.bias = None
layers = [activation_fn]
self.hidden_layers = hidden_layers
for _ in range(hidden_layers - 1):
layers += [nn.Linear(hidden_dim, hidden_dim), activation_fn]
layers.append(nn.Linear(hidden_dim, output_dim))
self.norm_type = norm_type
if norm_type is not None:
if norm_type not in [
"LayerNorm",
"TELayerNorm",
]:
raise ValueError(
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm."
)
if norm_type == "TELayerNorm" and te_imported:
norm_layer = te.LayerNorm
elif norm_type == "TELayerNorm" and not te_imported:
raise ValueError(
"TELayerNorm requires transformer-engine to be installed."
)
else:
norm_layer = getattr(nn, norm_type)
layers.append(norm_layer(output_dim))
self.model = nn.Sequential(*layers)
if recompute_activation:
if not isinstance(activation_fn, nn.SiLU):
raise ValueError(activation_fn)
self.recompute_activation = True
else:
self.recompute_activation = False
def forward_truncated_sum(
self,
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
if isinstance(nfeat, Tensor):
src_feat, dst_feat = nfeat, nfeat
else:
src_feat, dst_feat = nfeat
mlp_efeat = F.linear(efeat, self.lin_efeat, None)
mlp_src = F.linear(src_feat, self.lin_src, None)
mlp_dst = F.linear(dst_feat, self.lin_dst, self.bias)
mlp_sum = sum_efeat(mlp_efeat, (mlp_src, mlp_dst), graph)
return mlp_sum
def default_forward(
self,
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
"""Default forward pass of the truncated MLP."""
mlp_sum = self.forward_truncated_sum(
efeat,
nfeat,
graph,
)
return self.model(mlp_sum)
def custom_silu_linear_forward(
self,
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
"""Forward pass of the truncated MLP with custom SiLU function."""
mlp_sum = self.forward_truncated_sum(
efeat,
nfeat,
graph,
)
lin = self.model[1]
hidden = CustomSiLuLinearAutogradFunction.apply(mlp_sum, lin.weight, lin.bias)
for i in range(2, self.hidden_layers + 1):
lin = self.model[2 * i - 1]
hidden = CustomSiLuLinearAutogradFunction.apply(
hidden, lin.weight, lin.bias
)
if self.norm_type is not None:
norm = self.model[2 * self.hidden_layers]
hidden = norm(hidden)
return hidden
def forward(
self,
efeat: Tensor,
nfeat: Union[Tensor, Tuple[Tensor]],
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
if self.recompute_activation:
return self.custom_silu_linear_forward(efeat, nfeat, graph)
return self.default_forward(efeat, nfeat, graph)
from typing import Tuple
import torch.nn as nn
from torch import Tensor
class GraphCastEncoderEmbedder(nn.Module):
def __init__(
self,
input_dim_grid_nodes: int = 474,
input_dim_mesh_nodes: int = 3,
input_dim_edges: int = 4,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
recompute_activation: bool = False,
):
super().__init__()
# MLP for grid node embedding
self.grid_node_mlp = MeshGraphMLP(
input_dim=input_dim_grid_nodes,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# MLP for mesh node embedding
self.mesh_node_mlp = MeshGraphMLP(
input_dim=input_dim_mesh_nodes,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# MLP for mesh edge embedding
self.mesh_edge_mlp = MeshGraphMLP(
input_dim=input_dim_edges,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# MLP for grid2mesh edge embedding
self.grid2mesh_edge_mlp = MeshGraphMLP(
input_dim=input_dim_edges,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
def forward(
self,
grid_nfeat: Tensor,
mesh_nfeat: Tensor,
g2m_efeat: Tensor,
mesh_efeat: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# Input node feature embedding
grid_nfeat = self.grid_node_mlp(grid_nfeat)
mesh_nfeat = self.mesh_node_mlp(mesh_nfeat)
# Input edge feature embedding
g2m_efeat = self.grid2mesh_edge_mlp(g2m_efeat)
mesh_efeat = self.mesh_edge_mlp(mesh_efeat)
return grid_nfeat, mesh_nfeat, g2m_efeat, mesh_efeat
class GraphCastDecoderEmbedder(nn.Module):
def __init__(
self,
input_dim_edges: int = 4,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
recompute_activation: bool = False,
):
super().__init__()
# MLP for mesh2grid edge embedding
self.mesh2grid_edge_mlp = MeshGraphMLP(
input_dim=input_dim_edges,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
def forward(
self,
m2g_efeat: Tensor,
) -> Tensor:
m2g_efeat = self.mesh2grid_edge_mlp(m2g_efeat)
return m2g_efeat
from typing import Union
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
class MeshGraphDecoder(nn.Module):
def __init__(
self,
aggregation: str = "sum",
input_dim_src_nodes: int = 512,
input_dim_dst_nodes: int = 512,
input_dim_edges: int = 512,
output_dim_dst_nodes: int = 512,
output_dim_edges: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
do_concat_trick: bool = False,
recompute_activation: bool = False,
):
super().__init__()
self.aggregation = aggregation
MLP = MeshGraphEdgeMLPSum if do_concat_trick else MeshGraphEdgeMLPConcat
# edge MLP
self.edge_mlp = MLP(
efeat_dim=input_dim_edges,
src_dim=input_dim_src_nodes,
dst_dim=input_dim_dst_nodes,
output_dim=output_dim_edges,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# dst node MLP
self.node_mlp = MeshGraphMLP(
input_dim=input_dim_dst_nodes + output_dim_edges,
output_dim=output_dim_dst_nodes,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
@torch.jit.ignore()
def forward(
self,
m2g_efeat: Tensor,
grid_nfeat: Tensor,
mesh_nfeat: Tensor,
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
# update edge features
efeat = self.edge_mlp(m2g_efeat, (mesh_nfeat, grid_nfeat), graph)
# aggregate messages (edge features) to obtain updated node features
cat_feat = aggregate_and_concat(efeat, grid_nfeat, graph, self.aggregation)
# transformation and residual connection
dst_feat = self.node_mlp(cat_feat) + grid_nfeat
return dst_feat
from typing import Tuple, Union
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
class MeshGraphEncoder(nn.Module):
def __init__(
self,
aggregation: str = "sum",
input_dim_src_nodes: int = 512,
input_dim_dst_nodes: int = 512,
input_dim_edges: int = 512,
output_dim_src_nodes: int = 512,
output_dim_dst_nodes: int = 512,
output_dim_edges: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: int = nn.SiLU(),
norm_type: str = "LayerNorm",
do_concat_trick: bool = False,
recompute_activation: bool = False,
):
super().__init__()
self.aggregation = aggregation
MLP = MeshGraphEdgeMLPSum if do_concat_trick else MeshGraphEdgeMLPConcat
# edge MLP
self.edge_mlp = MLP(
efeat_dim=input_dim_edges,
src_dim=input_dim_src_nodes,
dst_dim=input_dim_dst_nodes,
output_dim=output_dim_edges,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# src node MLP
self.src_node_mlp = MeshGraphMLP(
input_dim=input_dim_src_nodes,
output_dim=output_dim_src_nodes,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# dst node MLP
self.dst_node_mlp = MeshGraphMLP(
input_dim=input_dim_dst_nodes + output_dim_edges,
output_dim=output_dim_dst_nodes,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
@torch.jit.ignore()
def forward(
self,
g2m_efeat: Tensor,
grid_nfeat: Tensor,
mesh_nfeat: Tensor,
graph: Union[DGLGraph, CuGraphCSC],
) -> Tuple[Tensor, Tensor]:
# update edge features by concatenating node features (both mesh and grid) and existing edge featues
# (or applying the concat trick instead)
efeat = self.edge_mlp(g2m_efeat, (grid_nfeat, mesh_nfeat), graph)
# aggregate messages (edge features) to obtain updated node features
cat_feat = aggregate_and_concat(efeat, mesh_nfeat, graph, self.aggregation)
# update src, dst node features + residual connections
mesh_nfeat = mesh_nfeat + self.dst_node_mlp(cat_feat)
grid_nfeat = grid_nfeat + self.src_node_mlp(grid_nfeat)
return grid_nfeat, mesh_nfeat
import torch
import torch.nn as nn
Tensor = torch.Tensor
class Identity(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return x
class Stan(nn.Module):
def __init__(self, out_features: int = 1):
super().__init__()
self.beta = nn.Parameter(torch.ones(out_features))
def forward(self, x: Tensor) -> Tensor:
if x.shape[-1] != self.beta.shape[-1]:
raise ValueError(
f"The last dimension of the input must be equal to the dimension of Stan parameters. Got inputs: {x.shape}, params: {self.beta.shape}"
)
return torch.tanh(x) * (1.0 + self.beta * x)
class SquarePlus(nn.Module):
def __init__(self):
super().__init__()
self.b = 4
def forward(self, x: Tensor) -> Tensor:
return 0.5 * (x + torch.sqrt(x * x + self.b))
class CappedLeakyReLU(torch.nn.Module):
def __init__(self, cap_value=1.0, **kwargs):
super().__init__()
self.add_module("leaky_relu", torch.nn.LeakyReLU(**kwargs))
self.register_buffer("cap", torch.tensor(cap_value, dtype=torch.float32))
def forward(self, inputs):
x = self.leaky_relu(inputs)
x = torch.clamp(x, max=self.cap)
return x
class CappedGELU(torch.nn.Module):
def __init__(self, cap_value=1.0, **kwargs):
super().__init__()
self.add_module("gelu", torch.nn.GELU(**kwargs))
self.register_buffer("cap", torch.tensor(cap_value, dtype=torch.float32))
def forward(self, inputs):
x = self.gelu(inputs)
x = torch.clamp(x, max=self.cap)
return x
# Dictionary of activation functions
ACT2FN = {
"relu": nn.ReLU,
"leaky_relu": (nn.LeakyReLU, {"negative_slope": 0.1}),
"prelu": nn.PReLU,
"relu6": nn.ReLU6,
"elu": nn.ELU,
"selu": nn.SELU,
"silu": nn.SiLU,
"gelu": nn.GELU,
"sigmoid": nn.Sigmoid,
"logsigmoid": nn.LogSigmoid,
"softplus": nn.Softplus,
"softshrink": nn.Softshrink,
"softsign": nn.Softsign,
"tanh": nn.Tanh,
"tanhshrink": nn.Tanhshrink,
"threshold": (nn.Threshold, {"threshold": 1.0, "value": 1.0}),
"hardtanh": nn.Hardtanh,
"identity": Identity,
"stan": Stan,
"squareplus": SquarePlus,
"cappek_leaky_relu": CappedLeakyReLU,
"capped_gelu": CappedGELU,
}
def get_activation(activation: str) -> nn.Module:
try:
activation = activation.lower()
module = ACT2FN[activation]
if isinstance(module, tuple):
return module[0](**module[1])
else:
return module()
except KeyError:
raise KeyError(
f"Activation function {activation} not found. Available options are: {list(ACT2FN.keys())}"
)
from dataclasses import dataclass
@dataclass
class ModelMetaData:
"""Data class for storing essential meta data needed for all Modulus Models"""
# Model info
name: str = "ModulusModule"
# Optimization
jit: bool = False
cuda_graphs: bool = False
amp: bool = False
amp_cpu: bool = None
amp_gpu: bool = None
torch_fx: bool = False
# Data type
bf16: bool = False
# Inference
onnx: bool = False
onnx_gpu: bool = None
onnx_cpu: bool = None
onnx_runtime: bool = False
trt: bool = False
# Physics informed
var_dim: int = -1
func_torch: bool = False
auto_grad: bool = False
def __post_init__(self):
self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu
self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu
self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu
self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu
from importlib.metadata import EntryPoint, entry_points
from typing import List, Union
# This import is required for compatibility with doctests.
import importlib_metadata
class ModelRegistry:
_shared_state = {"_model_registry": None}
def __new__(cls, *args, **kwargs):
obj = super(ModelRegistry, cls).__new__(cls)
obj.__dict__ = cls._shared_state
if cls._shared_state["_model_registry"] is None:
cls._shared_state["_model_registry"] = cls._construct_registry()
return obj
@staticmethod
def _construct_registry() -> dict:
registry = {}
entrypoints = entry_points(group="modulus.models")
for entry_point in entrypoints:
registry[entry_point.name] = entry_point
return registry
def register(self, model: "modulus.Module", name: Union[str, None] = None) -> None:
# Check if model is a modulus model
if not issubclass(model, modulus.Module):
raise ValueError(
f"Only subclasses of modulus.Module can be registered. "
f"Provided model is of type {type(model)}"
)
# If no name provided, use the model's name
if name is None:
name = model.__name__
# Check if name already in use
if name in self._model_registry:
raise ValueError(f"Name {name} already in use")
# Add this class to the dict of model registry
self._model_registry[name] = model
def factory(self, name: str) -> "modulus.Module":
model = self._model_registry.get(name)
if model is not None:
if isinstance(model, (EntryPoint, importlib_metadata.EntryPoint)):
model = model.load()
return model
raise KeyError(f"No model is registered under the name {name}")
def list_models(self) -> List[str]:
return list(self._model_registry.keys())
def __clear_registry__(self):
# NOTE: This is only used for testing purposes
self._model_registry = {}
def __restore_registry__(self):
# NOTE: This is only used for testing purposes
self._model_registry = self._construct_registry()
import importlib
import inspect
import json
import logging
import os
import tarfile
import tempfile
from pathlib import Path
from typing import Any, Dict, Union
import torch
class Module(torch.nn.Module):
_file_extension = ".mdlus" # Set file extension for saving and loading
__model_checkpoint_version__ = (
"0.1.0" # Used for file versioning and is not the same as modulus version
)
def __new__(cls, *args, **kwargs):
out = super().__new__(cls)
# Get signature of __init__ function
sig = inspect.signature(cls.__init__)
# Bind args and kwargs to signature
bound_args = sig.bind_partial(
*([None] + list(args)), **kwargs
) # Add None to account for self
bound_args.apply_defaults()
# Get args and kwargs (excluding self and unroll kwargs)
instantiate_args = {}
for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()):
# Skip self
if k == "self":
continue
# Add args and kwargs to instantiate_args
if param.kind == param.VAR_KEYWORD:
instantiate_args.update(v)
else:
instantiate_args[k] = v
# Store args needed for instantiation
out._args = {
"__name__": cls.__name__,
"__module__": cls.__module__,
"__args__": instantiate_args,
}
return out
def __init__(self, meta: Union[ModelMetaData, None] = None):
super().__init__()
self.meta = meta
self.register_buffer("device_buffer", torch.empty(0))
self._setup_logger()
def _setup_logger(self):
self.logger = logging.getLogger("core.module")
handler = logging.StreamHandler()
formatter = logging.Formatter(
"[%(asctime)s - %(levelname)s] %(message)s", datefmt="%H:%M:%S"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.WARNING)
@staticmethod
def _safe_members(tar, local_path):
for member in tar.getmembers():
if (
".." in member.name
or os.path.isabs(member.name)
or os.path.realpath(os.path.join(local_path, member.name)).startswith(
os.path.realpath(local_path)
)
):
yield member
else:
print(f"Skipping potentially malicious file: {member.name}")
@classmethod
def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module":
_cls_name = arg_dict["__name__"]
registry = ModelRegistry()
if cls.__name__ == arg_dict["__name__"]: # If cls is the class
_cls = cls
elif _cls_name in registry.list_models(): # Built in registry
_cls = registry.factory(_cls_name)
else:
try:
# Otherwise, try to import the class
_mod = importlib.import_module(arg_dict["__module__"])
_cls = getattr(_mod, arg_dict["__name__"])
except AttributeError:
# Cross fingers and hope for the best (maybe the class name changed)
_cls = cls
return _cls(**arg_dict["__args__"])
def debug(self):
"""Turn on debug logging"""
self.logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter(
f"[%(asctime)s - %(levelname)s - {self.meta.name}] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.DEBUG)
# TODO: set up debug log
# fh = logging.FileHandler(f'modulus-core-{self.meta.name}.log')
def save(self, file_name: Union[str, None] = None, verbose: bool = False) -> None:
if file_name is not None and not file_name.endswith(self._file_extension):
raise ValueError(
f"File name must end with {self._file_extension} extension"
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir)
torch.save(self.state_dict(), local_path / "model.pt")
with open(local_path / "args.json", "w") as f:
json.dump(self._args, f)
# Save the modulus version and git hash (if available)
metadata_info = {
"modulus_version": modulus.__version__,
"mdlus_file_version": self.__model_checkpoint_version__,
}
if verbose:
import git
try:
repo = git.Repo(search_parent_directories=True)
metadata_info["git_hash"] = repo.head.object.hexsha
except git.InvalidGitRepositoryError:
metadata_info["git_hash"] = None
with open(local_path / "metadata.json", "w") as f:
json.dump(metadata_info, f)
# Once all files are saved, package them into a tar file
with tarfile.open(local_path / "model.tar", "w") as tar:
for file in local_path.iterdir():
tar.add(str(file), arcname=file.name)
if file_name is None:
file_name = self.meta.name + ".mdlus"
# Save files to remote destination
fs = _get_fs(file_name)
fs.put(str(local_path / "model.tar"), file_name)
@staticmethod
def _check_checkpoint(local_path: str) -> bool:
if not local_path.joinpath("args.json").exists():
raise IOError("File 'args.json' not found in checkpoint")
if not local_path.joinpath("metadata.json").exists():
raise IOError("File 'metadata.json' not found in checkpoint")
if not local_path.joinpath("model.pt").exists():
raise IOError("Model weights 'model.pt' not found in checkpoint")
# Check if the checkpoint version is compatible with the current version
with open(local_path.joinpath("metadata.json"), "r") as f:
metadata_info = json.load(f)
if (
metadata_info["mdlus_file_version"]
!= Module.__model_checkpoint_version__
):
raise IOError(
f"Model checkpoint version {metadata_info['mdlus_file_version']} is not compatible with current version {Module.__version__}"
)
def load(
self,
file_name: str,
map_location: Union[None, str, torch.device] = None,
strict: bool = True,
) -> None:
# Download and cache the checkpoint file if needed
cached_file_name = _download_cached(file_name)
# Use a temporary directory to extract the tar file
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir)
# Open the tar file and extract its contents to the temporary directory
with tarfile.open(cached_file_name, "r") as tar:
tar.extractall(
path=local_path, members=list(Module._safe_members(tar, local_path))
)
# Check if the checkpoint is valid
Module._check_checkpoint(local_path)
# Load the model weights
device = map_location if map_location is not None else self.device
model_dict = torch.load(
local_path.joinpath("model.pt"), map_location=device
)
self.load_state_dict(model_dict, strict=strict)
@classmethod
def from_checkpoint(cls, file_name: str) -> "Module":
# Download and cache the checkpoint file if needed
cached_file_name = _download_cached(file_name)
# Use a temporary directory to extract the tar file
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir)
# Open the tar file and extract its contents to the temporary directory
with tarfile.open(cached_file_name, "r") as tar:
tar.extractall(
path=local_path, members=list(cls._safe_members(tar, local_path))
)
# Check if the checkpoint is valid
Module._check_checkpoint(local_path)
# Load model arguments and instantiate the model
with open(local_path.joinpath("args.json"), "r") as f:
args = json.load(f)
model = cls.instantiate(args)
# Load the model weights
model_dict = torch.load(
local_path.joinpath("model.pt"), map_location=model.device
)
model.load_state_dict(model_dict)
return model
@staticmethod
def from_torch(
torch_model_class: torch.nn.Module, meta: ModelMetaData = None
) -> "Module":
# Define an internal class as before
class ModulusModel(Module):
def __init__(self, *args, **kwargs):
super().__init__(meta=meta)
self.inner_model = torch_model_class(*args, **kwargs)
def forward(self, x):
return self.inner_model(x)
# Get the argument names and default values of the PyTorch model's init method
init_argspec = inspect.getfullargspec(torch_model_class.__init__)
model_argnames = init_argspec.args[1:] # Exclude 'self'
model_defaults = init_argspec.defaults or []
defaults_dict = dict(
zip(model_argnames[-len(model_defaults) :], model_defaults)
)
# Define the signature of new init
params = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
params += [
inspect.Parameter(
argname,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=defaults_dict.get(argname, inspect.Parameter.empty),
)
for argname in model_argnames
]
init_signature = inspect.Signature(params)
# Replace ModulusModel.__init__ signature with new init signature
ModulusModel.__init__.__signature__ = init_signature
# Generate a unique name for the created class
new_class_name = f"{torch_model_class.__name__}ModulusModel"
ModulusModel.__name__ = new_class_name
# Add this class to the dict of models classes
registry = ModelRegistry()
registry.register(ModulusModel, new_class_name)
return ModulusModel
@property
def device(self) -> torch.device:
return self.device_buffer.device
def num_parameters(self) -> int:
"""Gets the number of learnable parameters"""
count = 0
for name, param in self.named_parameters():
count += param.numel()
return count
from typing import List, Tuple
import dgl
import numpy as np
import torch
from dgl import DGLGraph
from torch import Tensor, testing
def create_graph(
src: List,
dst: List,
to_bidirected: bool = True,
add_self_loop: bool = False,
dtype: torch.dtype = torch.int32,
) -> DGLGraph:
graph = dgl.graph((src, dst), idtype=dtype)
if to_bidirected:
graph = dgl.to_bidirected(graph)
if add_self_loop:
graph = dgl.add_self_loop(graph)
return graph
def create_heterograph(
src: List,
dst: List,
labels: str,
dtype: torch.dtype = torch.int32,
num_nodes_dict: dict = None,
) -> DGLGraph:
graph = dgl.heterograph(
{labels: ("coo", (src, dst))}, num_nodes_dict=num_nodes_dict, idtype=dtype
)
return graph
def add_edge_features(graph: DGLGraph, pos: Tensor, normalize: bool = True) -> DGLGraph:
if isinstance(pos, tuple):
src_pos, dst_pos = pos
else:
src_pos = dst_pos = pos
src, dst = graph.edges()
src_pos, dst_pos = src_pos[src.long()], dst_pos[dst.long()]
dst_latlon = xyz2latlon(dst_pos, unit="rad")
dst_lat, dst_lon = dst_latlon[:, 0], dst_latlon[:, 1]
# azimuthal & polar rotation
theta_azimuthal = azimuthal_angle(dst_lon)
theta_polar = polar_angle(dst_lat)
src_pos = geospatial_rotation(src_pos, theta=theta_azimuthal, axis="z", unit="rad")
dst_pos = geospatial_rotation(dst_pos, theta=theta_azimuthal, axis="z", unit="rad")
# y values should be zero
try:
testing.assert_close(dst_pos[:, 1], torch.zeros_like(dst_pos[:, 1]))
except ValueError:
raise ValueError("Invalid projection of edge nodes to local ccordinate system")
src_pos = geospatial_rotation(src_pos, theta=theta_polar, axis="y", unit="rad")
dst_pos = geospatial_rotation(dst_pos, theta=theta_polar, axis="y", unit="rad")
# x values should be one, y & z values should be zero
try:
testing.assert_close(dst_pos[:, 0], torch.ones_like(dst_pos[:, 0]))
testing.assert_close(dst_pos[:, 1], torch.zeros_like(dst_pos[:, 1]))
testing.assert_close(dst_pos[:, 2], torch.zeros_like(dst_pos[:, 2]))
except ValueError:
raise ValueError("Invalid projection of edge nodes to local ccordinate system")
# prepare edge features
disp = src_pos - dst_pos
disp_norm = torch.linalg.norm(disp, dim=-1, keepdim=True)
# normalize using the longest edge
if normalize:
max_disp_norm = torch.max(disp_norm)
graph.edata["x"] = torch.cat(
(disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1
)
else:
graph.edata["x"] = torch.cat((disp, disp_norm), dim=-1)
return graph
def add_node_features(graph: DGLGraph, pos: Tensor) -> DGLGraph:
latlon = xyz2latlon(pos)
lat, lon = latlon[:, 0], latlon[:, 1]
graph.ndata["x"] = torch.stack(
(torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1
)
return graph
def latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = "deg") -> Tensor:
if unit == "deg":
latlon = deg2rad(latlon)
elif unit == "rad":
pass
else:
raise ValueError("Not a valid unit")
lat, lon = latlon[:, 0], latlon[:, 1]
x = radius * torch.cos(lat) * torch.cos(lon)
y = radius * torch.cos(lat) * torch.sin(lon)
z = radius * torch.sin(lat)
return torch.stack((x, y, z), dim=1)
def xyz2latlon(xyz: Tensor, radius: float = 1, unit: str = "deg") -> Tensor:
lat = torch.arcsin(xyz[:, 2] / radius)
lon = torch.arctan2(xyz[:, 1], xyz[:, 0])
if unit == "deg":
return torch.stack((rad2deg(lat), rad2deg(lon)), dim=1)
elif unit == "rad":
return torch.stack((lat, lon), dim=1)
else:
raise ValueError("Not a valid unit")
def geospatial_rotation(
invar: Tensor, theta: Tensor, axis: str, unit: str = "rad"
) -> Tensor:
# get the right unit
if unit == "deg":
invar = rad2deg(invar)
elif unit == "rad":
pass
else:
raise ValueError("Not a valid unit")
invar = torch.unsqueeze(invar, -1)
rotation = torch.zeros((theta.size(0), 3, 3))
cos = torch.cos(theta)
sin = torch.sin(theta)
if axis == "x":
rotation[:, 0, 0] += 1.0
rotation[:, 1, 1] += cos
rotation[:, 1, 2] -= sin
rotation[:, 2, 1] += sin
rotation[:, 2, 2] += cos
elif axis == "y":
rotation[:, 0, 0] += cos
rotation[:, 0, 2] += sin
rotation[:, 1, 1] += 1.0
rotation[:, 2, 0] -= sin
rotation[:, 2, 2] += cos
elif axis == "z":
rotation[:, 0, 0] += cos
rotation[:, 0, 1] -= sin
rotation[:, 1, 0] += sin
rotation[:, 1, 1] += cos
rotation[:, 2, 2] += 1.0
else:
raise ValueError("Invalid axis")
outvar = torch.matmul(rotation, invar)
outvar = outvar.squeeze()
return outvar
def azimuthal_angle(lon: Tensor) -> Tensor:
angle = torch.where(lon >= 0.0, 2 * np.pi - lon, -lon)
return angle
def polar_angle(lat: Tensor) -> Tensor:
angle = torch.where(lat >= 0.0, lat, 2 * np.pi + lat)
return angle
def deg2rad(deg: Tensor) -> Tensor:
return deg * np.pi / 180
def rad2deg(rad):
return rad * 180 / np.pi
def cell_to_adj(cells: List[List[int]]):
num_cells = np.shape(cells)[0]
src = [cells[i][indx] for i in range(num_cells) for indx in [0, 1, 2]]
dst = [cells[i][indx] for i in range(num_cells) for indx in [1, 2, 0]]
return src, dst
def max_edge_length(
vertices: List[List[float]], source_nodes: List[int], destination_nodes: List[int]
) -> float:
vertices_np = np.array(vertices)
source_coords = vertices_np[source_nodes]
dest_coords = vertices_np[destination_nodes]
# Compute the squared distances for all edges
squared_differences = np.sum((source_coords - dest_coords) ** 2, axis=1)
# Compute the maximum edge length
max_length = np.sqrt(np.max(squared_differences))
return max_length
def get_face_centroids(
vertices: List[Tuple[float, float, float]], faces: List[List[int]]
) -> List[Tuple[float, float, float]]:
centroids = []
for face in faces:
# Extract the coordinates of the vertices for the current face
v0 = vertices[face[0]]
v1 = vertices[face[1]]
v2 = vertices[face[2]]
# Compute the centroid of the triangle
centroid = (
(v0[0] + v1[0] + v2[0]) / 3,
(v0[1] + v1[1] + v2[1]) / 3,
(v0[2] + v1[2] + v2[2]) / 3,
)
centroids.append(centroid)
return centroids
import itertools
from typing import List, NamedTuple, Sequence, Tuple
import numpy as np
from scipy.spatial import transform
class TriangularMesh(NamedTuple):
vertices: np.ndarray
faces: np.ndarray
def merge_meshes(mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
num_nodes_mesh_i = mesh_i.vertices.shape[0]
assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
return TriangularMesh(
vertices=mesh_list[-1].vertices,
faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0),
)
def get_hierarchy_of_triangular_meshes_for_sphere(splits: int) -> List[TriangularMesh]:
current_mesh = get_icosahedron()
output_meshes = [current_mesh]
for _ in range(splits):
current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
output_meshes.append(current_mesh)
return output_meshes
def get_icosahedron() -> TriangularMesh:
phi = (1 + np.sqrt(5)) / 2
vertices = []
for c1 in [1.0, -1.0]:
for c2 in [phi, -phi]:
vertices.append((c1, c2, 0.0))
vertices.append((0.0, c1, c2))
vertices.append((c2, 0.0, c1))
vertices = np.array(vertices, dtype=np.float32)
vertices /= np.linalg.norm([1.0, phi])
# I did this manually, checking the orientation one by one.
faces = [
(0, 1, 2),
(0, 6, 1),
(8, 0, 2),
(8, 4, 0),
(3, 8, 2),
(3, 2, 7),
(7, 2, 1),
(0, 4, 6),
(4, 11, 6),
(6, 11, 5),
(1, 5, 7),
(4, 10, 11),
(4, 8, 10),
(10, 8, 3),
(10, 3, 9),
(11, 10, 9),
(11, 9, 5),
(5, 9, 7),
(9, 3, 7),
(1, 6, 5),
]
angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
rotation_angle = (np.pi - angle_between_faces) / 2
rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
rotation_matrix = rotation.as_matrix()
vertices = np.dot(vertices, rotation_matrix)
return TriangularMesh(
vertices=vertices.astype(np.float32), faces=np.array(faces, dtype=np.int32)
)
def _two_split_unit_sphere_triangle_faces(
triangular_mesh: TriangularMesh,
) -> TriangularMesh:
"""Splits each triangular face into 4 triangles keeping the orientation."""
new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
new_faces = []
for ind1, ind2, ind3 in triangular_mesh.faces:
ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
new_faces.extend(
[
[ind1, ind12, ind31], # 1
[ind12, ind2, ind23], # 2
[ind31, ind23, ind3], # 3
[ind12, ind23, ind31], # 4
]
)
return TriangularMesh(
vertices=new_vertices_builder.get_all_vertices(),
faces=np.array(new_faces, dtype=np.int32),
)
class _ChildVerticesBuilder(object):
"""Bookkeeping of new child vertices added to an existing set of vertices."""
def __init__(self, parent_vertices):
self._child_vertices_index_mapping = {}
self._parent_vertices = parent_vertices
# We start with all previous vertices.
self._all_vertices_list = list(parent_vertices)
def _get_child_vertex_key(self, parent_vertex_indices):
return tuple(sorted(parent_vertex_indices))
def _create_child_vertex(self, parent_vertex_indices):
"""Creates a new vertex."""
child_vertex_position = self._parent_vertices[list(parent_vertex_indices)].mean(
0
)
child_vertex_position /= np.linalg.norm(child_vertex_position)
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
self._child_vertices_index_mapping[child_vertex_key] = len(
self._all_vertices_list
)
self._all_vertices_list.append(child_vertex_position)
def get_new_child_vertex_index(self, parent_vertex_indices):
"""Returns index for a child vertex, creating it if necessary."""
# Get the key to see if we already have a new vertex in the middle.
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
if child_vertex_key not in self._child_vertices_index_mapping:
self._create_child_vertex(parent_vertex_indices)
return self._child_vertices_index_mapping[child_vertex_key]
def get_all_vertices(self):
"""Returns an array with old vertices."""
return np.array(self._all_vertices_list)
def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
assert faces.ndim == 2
assert faces.shape[-1] == 3
senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
return senders, receivers
import logging
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch import Tensor
logger = logging.getLogger(__name__)
class Graph:
def __init__(
self,
lat_lon_grid: Tensor,
mesh_level: int = 6,
multimesh: bool = True,
khop_neighbors: int = 0,
dtype=torch.float,
) -> None:
self.khop_neighbors = khop_neighbors
self.dtype = dtype
# flatten lat/lon gird
self.lat_lon_grid_flat = lat_lon_grid.permute(2, 0, 1).view(2, -1).permute(1, 0)
# create the multi-mesh
_meshes = get_hierarchy_of_triangular_meshes_for_sphere(splits=mesh_level)
finest_mesh = _meshes[-1] # get the last one in the list of meshes
self.finest_mesh_src, self.finest_mesh_dst = faces_to_edges(finest_mesh.faces)
self.finest_mesh_vertices = np.array(finest_mesh.vertices)
if multimesh:
mesh = merge_meshes(_meshes)
self.mesh_src, self.mesh_dst = faces_to_edges(mesh.faces)
self.mesh_vertices = np.array(mesh.vertices)
else:
mesh = finest_mesh
self.mesh_src, self.mesh_dst = self.finest_mesh_src, self.finest_mesh_dst
self.mesh_vertices = self.finest_mesh_vertices
self.mesh_faces = mesh.faces
@staticmethod
def khop_adj_all_k(g, kmax):
if not g.is_homogeneous:
raise NotImplementedError("only homogeneous graph is supported")
min_degree = g.in_degrees().min()
with torch.no_grad():
adj = g.adj_external(transpose=True, scipy_fmt=None)
adj_k = adj
adj_all = adj.clone()
for _ in range(2, kmax + 1):
# scale with min-degree to avoid too large values
# but >= 1.0
adj_k = (adj @ adj_k) / min_degree
adj_all += adj_k
return adj_all.to_dense().bool()
def create_mesh_graph(self, verbose: bool = True) -> Tensor:
mesh_graph = create_graph(
self.mesh_src,
self.mesh_dst,
to_bidirected=True,
add_self_loop=False,
dtype=torch.int32,
)
mesh_pos = torch.tensor(
self.mesh_vertices,
dtype=torch.float32,
)
mesh_graph = add_edge_features(mesh_graph, mesh_pos)
mesh_graph = add_node_features(mesh_graph, mesh_pos)
mesh_graph.ndata["lat_lon"] = xyz2latlon(mesh_pos)
# ensure fields set to dtype to avoid later conversions
mesh_graph.ndata["x"] = mesh_graph.ndata["x"].to(dtype=self.dtype)
mesh_graph.edata["x"] = mesh_graph.edata["x"].to(dtype=self.dtype)
if self.khop_neighbors > 0:
# Make a graph whose edges connect the k-hop neighbors of the original graph.
khop_adj_bool = self.khop_adj_all_k(g=mesh_graph, kmax=self.khop_neighbors)
mask = ~khop_adj_bool
else:
mask = None
if verbose:
print("mesh graph:", mesh_graph)
return mesh_graph, mask
def create_g2m_graph(self, verbose: bool = True) -> Tensor:
# get the max edge length of icosphere with max order
max_edge_len = max_edge_length(
self.finest_mesh_vertices, self.finest_mesh_src, self.finest_mesh_dst
)
# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(self.lat_lon_grid_flat)
n_nbrs = 4
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(self.mesh_vertices)
distances, indices = neighbors.kneighbors(cartesian_grid)
src, dst = [], []
for i in range(len(cartesian_grid)):
for j in range(n_nbrs):
if distances[i][j] <= 0.6 * max_edge_len:
src.append(i)
dst.append(indices[i][j])
# NOTE this gives 1,618,820 edges, in the paper it is 1,618,746
g2m_graph = create_heterograph(
src, dst, ("grid", "g2m", "mesh"), dtype=torch.int32
)
g2m_graph.srcdata["pos"] = cartesian_grid.to(torch.float32)
g2m_graph.dstdata["pos"] = torch.tensor(
self.mesh_vertices,
dtype=torch.float32,
)
g2m_graph.srcdata["lat_lon"] = self.lat_lon_grid_flat
g2m_graph.dstdata["lat_lon"] = xyz2latlon(g2m_graph.dstdata["pos"])
g2m_graph = add_edge_features(
g2m_graph, (g2m_graph.srcdata["pos"], g2m_graph.dstdata["pos"])
)
# avoid potential conversions at later points
g2m_graph.srcdata["pos"] = g2m_graph.srcdata["pos"].to(dtype=self.dtype)
g2m_graph.dstdata["pos"] = g2m_graph.dstdata["pos"].to(dtype=self.dtype)
g2m_graph.ndata["pos"]["grid"] = g2m_graph.ndata["pos"]["grid"].to(
dtype=self.dtype
)
g2m_graph.ndata["pos"]["mesh"] = g2m_graph.ndata["pos"]["mesh"].to(
dtype=self.dtype
)
g2m_graph.edata["x"] = g2m_graph.edata["x"].to(dtype=self.dtype)
if verbose:
print("g2m graph:", g2m_graph)
return g2m_graph
def create_m2g_graph(self, verbose: bool = True) -> Tensor:
# create the mesh2grid bipartite graph
cartesian_grid = latlon2xyz(self.lat_lon_grid_flat)
face_centroids = get_face_centroids(self.mesh_vertices, self.mesh_faces)
n_nbrs = 1
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(face_centroids)
_, indices = neighbors.kneighbors(cartesian_grid)
indices = indices.flatten()
src = [p for i in indices for p in self.mesh_faces[i]]
dst = [i for i in range(len(cartesian_grid)) for _ in range(3)]
m2g_graph = create_heterograph(
src, dst, ("mesh", "m2g", "grid"), dtype=torch.int32
) # number of edges is 3,114,720, exactly matches with the paper
m2g_graph.srcdata["pos"] = torch.tensor(
self.mesh_vertices,
dtype=torch.float32,
)
m2g_graph.dstdata["pos"] = cartesian_grid.to(dtype=torch.float32)
m2g_graph.srcdata["lat_lon"] = xyz2latlon(m2g_graph.srcdata["pos"])
m2g_graph.dstdata["lat_lon"] = self.lat_lon_grid_flat
m2g_graph = add_edge_features(
m2g_graph, (m2g_graph.srcdata["pos"], m2g_graph.dstdata["pos"])
)
# avoid potential conversions at later points
m2g_graph.srcdata["pos"] = m2g_graph.srcdata["pos"].to(dtype=self.dtype)
m2g_graph.dstdata["pos"] = m2g_graph.dstdata["pos"].to(dtype=self.dtype)
m2g_graph.ndata["pos"]["grid"] = m2g_graph.ndata["pos"]["grid"].to(
dtype=self.dtype
)
m2g_graph.ndata["pos"]["mesh"] = m2g_graph.ndata["pos"]["mesh"].to(
dtype=self.dtype
)
m2g_graph.edata["x"] = m2g_graph.edata["x"].to(dtype=self.dtype)
if verbose:
print("m2g graph:", m2g_graph)
return m2g_graph
from typing import Union
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
class MeshEdgeBlock(nn.Module):
def __init__(
self,
input_dim_nodes: int = 512,
input_dim_edges: int = 512,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
do_concat_trick: bool = False,
recompute_activation: bool = False,
):
super().__init__()
MLP = MeshGraphEdgeMLPSum if do_concat_trick else MeshGraphEdgeMLPConcat
self.edge_mlp = MLP(
efeat_dim=input_dim_edges,
src_dim=input_dim_nodes,
dst_dim=input_dim_nodes,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
@torch.jit.ignore()
def forward(
self,
efeat: Tensor,
nfeat: Tensor,
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
efeat_new = self.edge_mlp(efeat, nfeat, graph)
efeat_new = efeat_new + efeat
return efeat_new, nfeat
from typing import Tuple, Union
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
class MeshNodeBlock(nn.Module):
def __init__(
self,
aggregation: str = "sum",
input_dim_nodes: int = 512,
input_dim_edges: int = 512,
output_dim: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
recompute_activation: bool = False,
):
super().__init__()
self.aggregation = aggregation
self.node_mlp = MeshGraphMLP(
input_dim=input_dim_nodes + input_dim_edges,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
@torch.jit.ignore()
def forward(
self,
efeat: Tensor,
nfeat: Tensor,
graph: Union[DGLGraph, CuGraphCSC],
) -> Tuple[Tensor, Tensor]:
# update edge features
cat_feat = aggregate_and_concat(efeat, nfeat, graph, self.aggregation)
# update node features + residual connection
nfeat_new = self.node_mlp(cat_feat) + nfeat
return efeat, nfeat_new
from typing import Union
import torch
import torch.nn as nn
# import transformer_engine as te
from dgl import DGLGraph
from torch import Tensor
class GraphCastProcessor(nn.Module):
def __init__(
self,
aggregation: str = "sum",
processor_layers: int = 16,
input_dim_nodes: int = 512,
input_dim_edges: int = 512,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
do_concat_trick: bool = False,
recompute_activation: bool = False,
):
super().__init__()
edge_block_invars = (
input_dim_nodes,
input_dim_edges,
input_dim_edges,
hidden_dim,
hidden_layers,
activation_fn,
norm_type,
do_concat_trick,
recompute_activation,
)
node_block_invars = (
aggregation,
input_dim_nodes,
input_dim_edges,
input_dim_nodes,
hidden_dim,
hidden_layers,
activation_fn,
norm_type,
recompute_activation,
)
layers = []
for _ in range(processor_layers):
layers.append(MeshEdgeBlock(*edge_block_invars))
layers.append(MeshNodeBlock(*node_block_invars))
self.processor_layers = nn.ModuleList(layers)
self.num_processor_layers = len(self.processor_layers)
# per default, no checkpointing
# one segment for compatability
self.checkpoint_segments = [(0, self.num_processor_layers)]
self.checkpoint_fn = set_checkpoint_fn(False)
def set_checkpoint_segments(self, checkpoint_segments: int):
if checkpoint_segments > 0:
if self.num_processor_layers % checkpoint_segments != 0:
raise ValueError(
"Processor layers must be a multiple of checkpoint_segments"
)
segment_size = self.num_processor_layers // checkpoint_segments
self.checkpoint_segments = []
for i in range(0, self.num_processor_layers, segment_size):
self.checkpoint_segments.append((i, i + segment_size))
self.checkpoint_fn = set_checkpoint_fn(True)
else:
self.checkpoint_fn = set_checkpoint_fn(False)
self.checkpoint_segments = [(0, self.num_processor_layers)]
def run_function(self, segment_start: int, segment_end: int):
segment = self.processor_layers[segment_start:segment_end]
def custom_forward(efeat, nfeat, graph):
"""Custom forward function"""
for module in segment:
efeat, nfeat = module(efeat, nfeat, graph)
return efeat, nfeat
return custom_forward
def forward(
self,
efeat: Tensor,
nfeat: Tensor,
graph: Union[DGLGraph, CuGraphCSC],
) -> Tensor:
for segment_start, segment_end in self.checkpoint_segments:
efeat, nfeat = self.checkpoint_fn(
self.run_function(segment_start, segment_end),
efeat,
nfeat,
graph,
use_reentrant=False,
preserve_rng_state=False,
)
return efeat, nfeat
class GraphCastProcessorGraphTransformer(nn.Module):
def __init__(
self,
attention_mask: torch.Tensor,
num_attention_heads: int = 4,
processor_layers: int = 16,
input_dim_nodes: int = 512,
hidden_dim: int = 512,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.hidden_dim = hidden_dim
self.attention_mask = torch.tensor(attention_mask, dtype=torch.bool)
self.register_buffer("mask", self.attention_mask, persistent=False)
layers = [
te.pytorch.TransformerLayer(
hidden_size=input_dim_nodes,
ffn_hidden_size=hidden_dim,
num_attention_heads=num_attention_heads,
layer_number=i + 1,
fuse_qkv_params=False,
)
for i in range(processor_layers)
]
self.processor_layers = nn.ModuleList(layers)
def forward(
self,
nfeat: Tensor,
) -> Tensor:
nfeat = nfeat.unsqueeze(1)
# TODO make sure reshaping the last dim to (h, d) is done automatically in the transformer layer
for module in self.processor_layers:
nfeat = module(
nfeat,
attention_mask=self.mask,
self_attn_mask_type="arbitrary",
)
return torch.squeeze(nfeat, 1)
import logging
import warnings
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch import Tensor
try:
from typing import Self
except ImportError:
# for Python versions < 3.11
from typing_extensions import Self
logger = logging.getLogger(__name__)
def get_lat_lon_partition_separators(partition_size: int):
def _divide(num_lat_chunks: int, num_lon_chunks: int):
# divide lat-lon grid into equally-sizes chunks along both latitude and longitude
if (num_lon_chunks * num_lat_chunks) != partition_size:
raise ValueError(
"Can't divide lat-lon grid into grid {num_lat_chunks} x {num_lon_chunks} chunks for partition_size={partition_size}."
)
lat_bin_width = 180.0 / num_lat_chunks
lon_bin_width = 360.0 / num_lon_chunks
lat_ranges = []
lon_ranges = []
for p_lat in range(num_lat_chunks):
for p_lon in range(num_lon_chunks):
lat_ranges += [
(lat_bin_width * p_lat - 90.0, lat_bin_width * (p_lat + 1) - 90.0)
]
lon_ranges += [
(lon_bin_width * p_lon - 180.0, lon_bin_width * (p_lon + 1) - 180.0)
]
lat_ranges[-1] = (lat_ranges[-1][0], None)
lon_ranges[-1] = (lon_ranges[-1][0], None)
return lat_ranges, lon_ranges
# use two closest factors of partition_size
lat_chunks, lon_chunks, i = 1, partition_size, 0
while lat_chunks < lon_chunks:
i += 1
if partition_size % i == 0:
lat_chunks = i
lon_chunks = partition_size // lat_chunks
lat_ranges, lon_ranges = _divide(lat_chunks, lon_chunks)
# mainly for debugging
if (lat_ranges is None) or (lon_ranges is None):
raise ValueError("unexpected error, abort")
min_seps = []
max_seps = []
for i in range(partition_size):
lat = lat_ranges[i]
lon = lon_ranges[i]
min_seps.append([lat[0], lon[0]])
max_seps.append([lat[1], lon[1]])
return min_seps, max_seps
@dataclass
class MetaData(ModelMetaData):
name: str = "GraphCast"
# Optimization
jit: bool = False
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
# Data type
bf16: bool = True
# Inference
onnx: bool = False
# Physics informed
func_torch: bool = False
auto_grad: bool = False
class GraphCast(Module):
def __init__(
self,
# params,
mesh_level: Optional[int] = 5,
multimesh_level: Optional[int] = None,
multimesh: bool = True,
input_res: tuple = (120, 240),
input_dim_grid_nodes: int = 69,
input_dim_mesh_nodes: int = 3,
input_dim_edges: int = 4,
output_dim_grid_nodes: int = 69,
processor_type: str = "MessagePassing",
khop_neighbors: int = 32,
num_attention_heads: int = 4,
processor_layers: int = 16,
hidden_layers: int = 1,
hidden_dim: int = 512,
aggregation: str = "sum",
activation_fn: str = "silu",
norm_type: str = "LayerNorm",
use_cugraphops_encoder: bool = False,
use_cugraphops_processor: bool = False,
use_cugraphops_decoder: bool = False,
do_concat_trick: bool = False,
recompute_activation: bool = True,
partition_size: int = 1,
partition_group_name: Optional[str] = None,
use_lat_lon_partitioning: bool = False,
expect_partitioned_input: bool = False,
global_features_on_rank_0: bool = False,
produce_aggregated_output: bool = True,
produce_aggregated_output_on_all_ranks: bool = True,
):
super().__init__(meta=MetaData())
# 'multimesh_level' deprecation handling
if multimesh_level is not None:
warnings.warn(
"'multimesh_level' is deprecated and will be removed in a future version. Use 'mesh_level' instead.",
DeprecationWarning,
stacklevel=2,
)
mesh_level = multimesh_level
self.processor_type = processor_type
if self.processor_type == "MessagePassing":
khop_neighbors = 0
self.is_distributed = False
if partition_size > 1:
self.is_distributed = True
self.expect_partitioned_input = expect_partitioned_input
self.global_features_on_rank_0 = global_features_on_rank_0
self.produce_aggregated_output = produce_aggregated_output
self.produce_aggregated_output_on_all_ranks = (
produce_aggregated_output_on_all_ranks
)
self.partition_group_name = partition_group_name
# create the lat_lon_grid
self.latitudes = torch.linspace(-90, 90, steps=input_res[0])
self.longitudes = torch.linspace(-180, 180, steps=input_res[1] + 1)[1:]
# self.latitudes = torch.linspace(-90, 90, steps=input_res[0] + 1)[:-1]
# self.longitudes = torch.linspace(-180, 180, steps=input_res[1] + 1)[1:]
self.lat_lon_grid = torch.stack(
torch.meshgrid(self.latitudes, self.longitudes, indexing="ij"), dim=-1
)
# Set activation function
activation_fn = get_activation(activation_fn)
# construct the graph
self.graph = Graph(self.lat_lon_grid, mesh_level, multimesh, khop_neighbors)
self.mesh_graph, self.attn_mask = self.graph.create_mesh_graph(verbose=False)
self.g2m_graph = self.graph.create_g2m_graph(verbose=False)
self.m2g_graph = self.graph.create_m2g_graph(verbose=False)
self.g2m_edata = self.g2m_graph.edata["x"]
self.m2g_edata = self.m2g_graph.edata["x"]
self.mesh_ndata = self.mesh_graph.ndata["x"]
if self.processor_type == "MessagePassing":
self.mesh_edata = self.mesh_graph.edata["x"]
elif self.processor_type == "GraphTransformer":
# Dummy tensor to avoid breaking the API
self.mesh_edata = torch.zeros((1, input_dim_edges))
else:
raise ValueError(f"Invalid processor type {processor_type}")
if use_cugraphops_encoder or self.is_distributed:
kwargs = {}
if use_lat_lon_partitioning:
min_seps, max_seps = get_lat_lon_partition_separators(partition_size)
kwargs = {
"src_coordinates": self.g2m_graph.srcdata["lat_lon"],
"dst_coordinates": self.g2m_graph.dstdata["lat_lon"],
"coordinate_separators_min": min_seps,
"coordinate_separators_max": max_seps,
}
self.g2m_graph, edge_perm = CuGraphCSC.from_dgl(
graph=self.g2m_graph,
partition_size=partition_size,
partition_group_name=partition_group_name,
partition_by_bbox=use_lat_lon_partitioning,
**kwargs,
)
self.g2m_edata = self.g2m_edata[edge_perm]
if self.is_distributed:
self.g2m_edata = self.g2m_graph.get_edge_features_in_partition(
self.g2m_edata
)
if use_cugraphops_decoder or self.is_distributed:
kwargs = {}
if use_lat_lon_partitioning:
min_seps, max_seps = get_lat_lon_partition_separators(partition_size)
kwargs = {
"src_coordinates": self.m2g_graph.srcdata["lat_lon"],
"dst_coordinates": self.m2g_graph.dstdata["lat_lon"],
"coordinate_separators_min": min_seps,
"coordinate_separators_max": max_seps,
}
self.m2g_graph, edge_perm = CuGraphCSC.from_dgl(
graph=self.m2g_graph,
partition_size=partition_size,
partition_group_name=partition_group_name,
partition_by_bbox=use_lat_lon_partitioning,
**kwargs,
)
self.m2g_edata = self.m2g_edata[edge_perm]
if self.is_distributed:
self.m2g_edata = self.m2g_graph.get_edge_features_in_partition(
self.m2g_edata
)
if use_cugraphops_processor or self.is_distributed:
kwargs = {}
if use_lat_lon_partitioning:
min_seps, max_seps = get_lat_lon_partition_separators(partition_size)
kwargs = {
"src_coordinates": self.mesh_graph.ndata["lat_lon"],
"dst_coordinates": self.mesh_graph.ndata["lat_lon"],
"coordinate_separators_min": min_seps,
"coordinate_separators_max": max_seps,
}
self.mesh_graph, edge_perm = CuGraphCSC.from_dgl(
graph=self.mesh_graph,
partition_size=partition_size,
partition_group_name=partition_group_name,
partition_by_bbox=use_lat_lon_partitioning,
**kwargs,
)
self.mesh_edata = self.mesh_edata[edge_perm]
if self.is_distributed:
self.mesh_edata = self.mesh_graph.get_edge_features_in_partition(
self.mesh_edata
)
self.mesh_ndata = self.mesh_graph.get_dst_node_features_in_partition(
self.mesh_ndata
)
self.input_dim_grid_nodes = input_dim_grid_nodes
self.output_dim_grid_nodes = output_dim_grid_nodes
self.input_res = input_res
# by default: don't checkpoint at all
self.model_checkpoint_fn = set_checkpoint_fn(False)
self.encoder_checkpoint_fn = set_checkpoint_fn(False)
self.decoder_checkpoint_fn = set_checkpoint_fn(False)
# initial feature embedder
self.encoder_embedder = GraphCastEncoderEmbedder(
input_dim_grid_nodes=input_dim_grid_nodes,
input_dim_mesh_nodes=input_dim_mesh_nodes,
input_dim_edges=input_dim_edges,
output_dim=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
self.decoder_embedder = GraphCastDecoderEmbedder(
input_dim_edges=input_dim_edges,
output_dim=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
recompute_activation=recompute_activation,
)
# grid2mesh encoder
self.encoder = MeshGraphEncoder(
aggregation=aggregation,
input_dim_src_nodes=hidden_dim,
input_dim_dst_nodes=hidden_dim,
input_dim_edges=hidden_dim,
output_dim_src_nodes=hidden_dim,
output_dim_dst_nodes=hidden_dim,
output_dim_edges=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
do_concat_trick=do_concat_trick,
recompute_activation=recompute_activation,
)
# icosahedron processor
if processor_layers <= 2:
raise ValueError("Expected at least 3 processor layers")
if processor_type == "MessagePassing":
self.processor_encoder = GraphCastProcessor(
aggregation=aggregation,
processor_layers=1,
input_dim_nodes=hidden_dim,
input_dim_edges=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
do_concat_trick=do_concat_trick,
recompute_activation=recompute_activation,
)
self.processor = GraphCastProcessor(
aggregation=aggregation,
processor_layers=processor_layers - 2,
input_dim_nodes=hidden_dim,
input_dim_edges=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
do_concat_trick=do_concat_trick,
recompute_activation=recompute_activation,
)
self.processor_decoder = GraphCastProcessor(
aggregation=aggregation,
processor_layers=1,
input_dim_nodes=hidden_dim,
input_dim_edges=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
do_concat_trick=do_concat_trick,
recompute_activation=recompute_activation,
)
else:
self.processor_encoder = torch.nn.Identity()
self.processor = GraphCastProcessorGraphTransformer(
attention_mask=self.attn_mask,
num_attention_heads=num_attention_heads,
processor_layers=processor_layers,
input_dim_nodes=hidden_dim,
hidden_dim=hidden_dim,
)
self.processor_decoder = torch.nn.Identity()
# mesh2grid decoder
self.decoder = MeshGraphDecoder(
aggregation=aggregation,
input_dim_src_nodes=hidden_dim,
input_dim_dst_nodes=hidden_dim,
input_dim_edges=hidden_dim,
output_dim_dst_nodes=hidden_dim,
output_dim_edges=hidden_dim,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=norm_type,
do_concat_trick=do_concat_trick,
recompute_activation=recompute_activation,
)
# final MLP
self.finale = MeshGraphMLP(
input_dim=hidden_dim,
output_dim=output_dim_grid_nodes,
hidden_dim=hidden_dim,
hidden_layers=hidden_layers,
activation_fn=activation_fn,
norm_type=None,
recompute_activation=recompute_activation,
)
def set_checkpoint_model(self, checkpoint_flag: bool):
# force a single checkpoint for the whole model
self.model_checkpoint_fn = set_checkpoint_fn(checkpoint_flag)
if checkpoint_flag:
self.processor.set_checkpoint_segments(-1)
self.encoder_checkpoint_fn = set_checkpoint_fn(False)
self.decoder_checkpoint_fn = set_checkpoint_fn(False)
def set_checkpoint_processor(self, checkpoint_segments: int):
self.processor.set_checkpoint_segments(checkpoint_segments)
def set_checkpoint_encoder(self, checkpoint_flag: bool):
self.encoder_checkpoint_fn = set_checkpoint_fn(checkpoint_flag)
def set_checkpoint_decoder(self, checkpoint_flag: bool):
self.decoder_checkpoint_fn = set_checkpoint_fn(checkpoint_flag)
def encoder_forward(
self,
grid_nfeat: Tensor,
) -> Tensor:
# embedd graph features
(
grid_nfeat_embedded,
mesh_nfeat_embedded,
g2m_efeat_embedded,
mesh_efeat_embedded,
) = self.encoder_embedder(
grid_nfeat,
self.mesh_ndata,
self.g2m_edata,
self.mesh_edata,
)
# encode lat/lon to multimesh
grid_nfeat_encoded, mesh_nfeat_encoded = self.encoder(
g2m_efeat_embedded,
grid_nfeat_embedded,
mesh_nfeat_embedded,
self.g2m_graph,
)
# process multimesh graph
if self.processor_type == "MessagePassing":
mesh_efeat_processed, mesh_nfeat_processed = self.processor_encoder(
mesh_efeat_embedded,
mesh_nfeat_encoded,
self.mesh_graph,
)
else:
mesh_nfeat_processed = self.processor_encoder(
mesh_nfeat_encoded,
)
mesh_efeat_processed = None
return mesh_efeat_processed, mesh_nfeat_processed, grid_nfeat_encoded
def decoder_forward(
self,
mesh_efeat_processed: Tensor,
mesh_nfeat_processed: Tensor,
grid_nfeat_encoded: Tensor,
) -> Tensor:
# process multimesh graph
if self.processor_type == "MessagePassing":
_, mesh_nfeat_processed = self.processor_decoder(
mesh_efeat_processed,
mesh_nfeat_processed,
self.mesh_graph,
)
else:
mesh_nfeat_processed = self.processor_decoder(
mesh_nfeat_processed,
)
m2g_efeat_embedded = self.decoder_embedder(self.m2g_edata)
# decode multimesh to lat/lon
grid_nfeat_decoded = self.decoder(
m2g_efeat_embedded, grid_nfeat_encoded, mesh_nfeat_processed, self.m2g_graph
)
# map to the target output dimension
grid_nfeat_finale = self.finale(
grid_nfeat_decoded,
)
return grid_nfeat_finale
def custom_forward(self, grid_nfeat: Tensor) -> Tensor:
(
mesh_efeat_processed,
mesh_nfeat_processed,
grid_nfeat_encoded,
) = self.encoder_checkpoint_fn(
self.encoder_forward,
grid_nfeat,
use_reentrant=False,
preserve_rng_state=False,
)
# checkpoint of processor done in processor itself
if self.processor_type == "MessagePassing":
mesh_efeat_processed, mesh_nfeat_processed = self.processor(
mesh_efeat_processed,
mesh_nfeat_processed,
self.mesh_graph,
)
else:
mesh_nfeat_processed = self.processor(
mesh_nfeat_processed,
)
mesh_efeat_processed = None
grid_nfeat_finale = self.decoder_checkpoint_fn(
self.decoder_forward,
mesh_efeat_processed,
mesh_nfeat_processed,
grid_nfeat_encoded,
use_reentrant=False,
preserve_rng_state=False,
)
return grid_nfeat_finale
def forward(
self,
grid_nfeat: Tensor,
) -> Tensor:
invar = self.prepare_input(
grid_nfeat, self.expect_partitioned_input, self.global_features_on_rank_0
)
outvar = self.model_checkpoint_fn(
self.custom_forward,
invar,
use_reentrant=False,
preserve_rng_state=False,
)
outvar = self.prepare_output(
outvar,
self.produce_aggregated_output,
self.produce_aggregated_output_on_all_ranks,
)
return outvar
def prepare_input(
self,
invar: Tensor,
expect_partitioned_input: bool,
global_features_on_rank_0: bool,
) -> Tensor:
if global_features_on_rank_0 and expect_partitioned_input:
raise ValueError(
"global_features_on_rank_0 and expect_partitioned_input cannot be set at the same time."
)
if not self.is_distributed:
if invar.size(0) != 1:
raise ValueError("GraphCast does not support batch size > 1")
invar = invar[0].view(self.input_dim_grid_nodes, -1).permute(1, 0)
else:
# is_distributed
if not expect_partitioned_input:
# global_features_on_rank_0
if invar.size(0) != 1:
raise ValueError("GraphCast does not support batch size > 1")
invar = invar[0].view(self.input_dim_grid_nodes, -1).permute(1, 0)
# scatter global features
invar = self.g2m_graph.get_src_node_features_in_partition(
invar,
scatter_features=global_features_on_rank_0,
)
return invar
def prepare_output(
self,
outvar: Tensor,
produce_aggregated_output: bool,
produce_aggregated_output_on_all_ranks: bool = True,
) -> Tensor:
if produce_aggregated_output or not self.is_distributed:
# default case: output of shape [N, C, H, W]
if self.is_distributed:
outvar = self.m2g_graph.get_global_dst_node_features(
outvar,
get_on_all_ranks=produce_aggregated_output_on_all_ranks,
)
outvar = outvar.permute(1, 0)
outvar = outvar.view(self.output_dim_grid_nodes, *self.input_res)
outvar = torch.unsqueeze(outvar, dim=0)
return outvar
def to(self, *args: Any, **kwargs: Any) -> Self:
self = super(GraphCast, self).to(*args, **kwargs)
self.g2m_edata = self.g2m_edata.to(*args, **kwargs)
self.m2g_edata = self.m2g_edata.to(*args, **kwargs)
self.mesh_ndata = self.mesh_ndata.to(*args, **kwargs)
self.mesh_edata = self.mesh_edata.to(*args, **kwargs)
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
self.g2m_graph = self.g2m_graph.to(device)
self.mesh_graph = self.mesh_graph.to(device)
self.m2g_graph = self.m2g_graph.to(device)
return self
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
net = GraphCast().to(device)
input = torch.randn(1, 69, 120, 240).to(device)
output = net(input)
print(output.shape)