|
|
from typing import Any, List, Optional |
|
|
|
|
|
import dgl |
|
|
import torch |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
|
|
|
try: |
|
|
from typing import Self |
|
|
except ImportError: |
|
|
|
|
|
from typing_extensions import Self |
|
|
|
|
|
|
|
|
try: |
|
|
from pylibcugraphops.pytorch import BipartiteCSC, StaticCSC |
|
|
|
|
|
USE_CUGRAPHOPS = True |
|
|
|
|
|
except ImportError: |
|
|
StaticCSC = None |
|
|
BipartiteCSC = None |
|
|
USE_CUGRAPHOPS = False |
|
|
|
|
|
|
|
|
class CuGraphCSC: |
|
|
def __init__( |
|
|
self, |
|
|
offsets: Tensor, |
|
|
indices: Tensor, |
|
|
num_src_nodes: int, |
|
|
num_dst_nodes: int, |
|
|
ef_indices: Optional[Tensor] = None, |
|
|
reverse_graph_bwd: bool = True, |
|
|
cache_graph: bool = True, |
|
|
partition_size: Optional[int] = -1, |
|
|
partition_group_name: Optional[str] = None, |
|
|
) -> None: |
|
|
self.offsets = offsets |
|
|
self.indices = indices |
|
|
self.num_src_nodes = num_src_nodes |
|
|
self.num_dst_nodes = num_dst_nodes |
|
|
self.ef_indices = ef_indices |
|
|
self.reverse_graph_bwd = reverse_graph_bwd |
|
|
self.cache_graph = cache_graph |
|
|
|
|
|
|
|
|
self.bipartite_csc = None |
|
|
self.static_csc = None |
|
|
|
|
|
self.dgl_graph = None |
|
|
|
|
|
self.is_distributed = False |
|
|
self.dist_csc = None |
|
|
|
|
|
if partition_size <= 1: |
|
|
self.is_distributed = False |
|
|
return |
|
|
|
|
|
if self.ef_indices is not None: |
|
|
raise AssertionError( |
|
|
"DistributedGraph does not support mapping CSC-indices to COO-indices." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.offsets = self.dist_graph.graph_partition.local_offsets |
|
|
self.indices = self.dist_graph.graph_partition.local_indices |
|
|
self.num_src_nodes = self.dist_graph.graph_partition.num_local_src_nodes |
|
|
self.num_dst_nodes = self.dist_graph.graph_partition.num_local_dst_nodes |
|
|
self.is_distributed = True |
|
|
|
|
|
@staticmethod |
|
|
def from_dgl( |
|
|
graph: DGLGraph, |
|
|
partition_size: int = 1, |
|
|
partition_group_name: Optional[str] = None, |
|
|
partition_by_bbox: bool = False, |
|
|
src_coordinates: Optional[torch.Tensor] = None, |
|
|
dst_coordinates: Optional[torch.Tensor] = None, |
|
|
coordinate_separators_min: Optional[List[List[Optional[float]]]] = None, |
|
|
coordinate_separators_max: Optional[List[List[Optional[float]]]] = None, |
|
|
): |
|
|
|
|
|
|
|
|
if hasattr(graph, "adj_tensors"): |
|
|
offsets, indices, edge_perm = graph.adj_tensors("csc") |
|
|
elif hasattr(graph, "adj_sparse"): |
|
|
offsets, indices, edge_perm = graph.adj_sparse("csc") |
|
|
else: |
|
|
raise ValueError("Passed graph object doesn't support conversion to CSC.") |
|
|
|
|
|
n_src_nodes, n_dst_nodes = (graph.num_src_nodes(), graph.num_dst_nodes()) |
|
|
|
|
|
graph_partition = None |
|
|
|
|
|
|
|
|
graph_csc = CuGraphCSC( |
|
|
offsets.to(dtype=torch.int64), |
|
|
indices.to(dtype=torch.int64), |
|
|
n_src_nodes, |
|
|
n_dst_nodes, |
|
|
partition_size=partition_size, |
|
|
partition_group_name=partition_group_name, |
|
|
graph_partition=graph_partition, |
|
|
) |
|
|
|
|
|
return graph_csc, edge_perm |
|
|
|
|
|
|
|
|
def to(self, *args: Any, **kwargs: Any) -> Self: |
|
|
|
|
|
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) |
|
|
if dtype not in ( |
|
|
None, |
|
|
torch.int32, |
|
|
torch.int64, |
|
|
): |
|
|
raise TypeError( |
|
|
f"Invalid dtype, expected torch.int32 or torch.int64, got {dtype}." |
|
|
) |
|
|
self.offsets = self.offsets.to(device=device, dtype=dtype) |
|
|
self.indices = self.indices.to(device=device, dtype=dtype) |
|
|
if self.ef_indices is not None: |
|
|
self.ef_indices = self.ef_indices.to(device=device, dtype=dtype) |
|
|
|
|
|
return self |
|
|
|
|
|
def to_bipartite_csc(self, dtype=None) -> BipartiteCSC: |
|
|
|
|
|
|
|
|
if not (USE_CUGRAPHOPS): |
|
|
raise RuntimeError( |
|
|
"Conversion failed, expected cugraph-ops to be installed." |
|
|
) |
|
|
if not self.offsets.is_cuda: |
|
|
raise RuntimeError("Expected the graph structures to reside on GPU.") |
|
|
|
|
|
if self.bipartite_csc is None or not self.cache_graph: |
|
|
|
|
|
graph_offsets = self.offsets |
|
|
graph_indices = self.indices |
|
|
graph_ef_indices = self.ef_indices |
|
|
|
|
|
if dtype is not None: |
|
|
graph_offsets = self.offsets.to(dtype=dtype) |
|
|
graph_indices = self.indices.to(dtype=dtype) |
|
|
if self.ef_indices is not None: |
|
|
graph_ef_indices = self.ef_indices.to(dtype=dtype) |
|
|
|
|
|
graph = BipartiteCSC( |
|
|
graph_offsets, |
|
|
graph_indices, |
|
|
self.num_src_nodes, |
|
|
graph_ef_indices, |
|
|
reverse_graph_bwd=self.reverse_graph_bwd, |
|
|
) |
|
|
self.bipartite_csc = graph |
|
|
|
|
|
return self.bipartite_csc |
|
|
|
|
|
def to_static_csc(self, dtype=None) -> StaticCSC: |
|
|
if not (USE_CUGRAPHOPS): |
|
|
raise RuntimeError( |
|
|
"Conversion failed, expected cugraph-ops to be installed." |
|
|
) |
|
|
if not self.offsets.is_cuda: |
|
|
raise RuntimeError("Expected the graph structures to reside on GPU.") |
|
|
|
|
|
if self.static_csc is None or not self.cache_graph: |
|
|
graph_offsets = self.offsets |
|
|
graph_indices = self.indices |
|
|
graph_ef_indices = self.ef_indices |
|
|
|
|
|
if dtype is not None: |
|
|
graph_offsets = self.offsets.to(dtype=dtype) |
|
|
graph_indices = self.indices.to(dtype=dtype) |
|
|
if self.ef_indices is not None: |
|
|
graph_ef_indices = self.ef_indices.to(dtype=dtype) |
|
|
|
|
|
graph = StaticCSC( |
|
|
graph_offsets, |
|
|
graph_indices, |
|
|
graph_ef_indices, |
|
|
) |
|
|
self.static_csc = graph |
|
|
|
|
|
return self.static_csc |
|
|
|
|
|
def to_dgl_graph(self) -> DGLGraph: |
|
|
|
|
|
if self.dgl_graph is None or not self.cache_graph: |
|
|
if self.ef_indices is not None: |
|
|
raise AssertionError("ef_indices is not supported.") |
|
|
graph_offsets = self.offsets |
|
|
dst_degree = graph_offsets[1:] - graph_offsets[:-1] |
|
|
src_indices = self.indices |
|
|
dst_indices = torch.arange( |
|
|
0, |
|
|
graph_offsets.size(0) - 1, |
|
|
dtype=graph_offsets.dtype, |
|
|
device=graph_offsets.device, |
|
|
) |
|
|
dst_indices = torch.repeat_interleave(dst_indices, dst_degree, dim=0) |
|
|
|
|
|
|
|
|
self.dgl_graph = dgl.heterograph( |
|
|
{("src", "src2dst", "dst"): ("coo", (src_indices, dst_indices))}, |
|
|
idtype=torch.int32, |
|
|
) |
|
|
|
|
|
return self.dgl_graph |
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Tuple, Union |
|
|
|
|
|
import dgl.function as fn |
|
|
import torch |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
|
|
|
try: |
|
|
from pylibcugraphops.pytorch.operators import ( |
|
|
agg_concat_e2n, |
|
|
update_efeat_bipartite_e2e, |
|
|
update_efeat_static_e2e, |
|
|
) |
|
|
|
|
|
USE_CUGRAPHOPS = True |
|
|
|
|
|
except ImportError: |
|
|
update_efeat_bipartite_e2e = None |
|
|
update_efeat_static_e2e = None |
|
|
agg_concat_e2n = None |
|
|
USE_CUGRAPHOPS = False |
|
|
|
|
|
|
|
|
def checkpoint_identity(layer: Callable, *args: Any, **kwargs: Any) -> Any: |
|
|
|
|
|
return layer(*args) |
|
|
|
|
|
|
|
|
def set_checkpoint_fn(do_checkpointing: bool) -> Callable: |
|
|
|
|
|
if do_checkpointing: |
|
|
return checkpoint |
|
|
else: |
|
|
return checkpoint_identity |
|
|
|
|
|
|
|
|
def concat_message_function(edges: Tensor) -> Dict[str, Tensor]: |
|
|
|
|
|
|
|
|
cat_feat = torch.cat((edges.data["x"], edges.src["x"], edges.dst["x"]), dim=1) |
|
|
return {"cat_feat": cat_feat} |
|
|
|
|
|
|
|
|
@torch.jit.ignore() |
|
|
def concat_efeat_dgl( |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
|
graph: DGLGraph, |
|
|
) -> Tensor: |
|
|
|
|
|
if isinstance(nfeat, Tuple): |
|
|
src_feat, dst_feat = nfeat |
|
|
with graph.local_scope(): |
|
|
graph.srcdata["x"] = src_feat |
|
|
graph.dstdata["x"] = dst_feat |
|
|
graph.edata["x"] = efeat |
|
|
graph.apply_edges(concat_message_function) |
|
|
return graph.edata["cat_feat"] |
|
|
|
|
|
with graph.local_scope(): |
|
|
graph.ndata["x"] = nfeat |
|
|
graph.edata["x"] = efeat |
|
|
graph.apply_edges(concat_message_function) |
|
|
return graph.edata["cat_feat"] |
|
|
|
|
|
|
|
|
def concat_efeat( |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
|
|
|
if isinstance(nfeat, Tensor): |
|
|
if isinstance(graph, CuGraphCSC): |
|
|
if graph.dgl_graph is not None or not USE_CUGRAPHOPS: |
|
|
src_feat, dst_feat = nfeat, nfeat |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(nfeat) |
|
|
efeat = concat_efeat_dgl( |
|
|
efeat, (src_feat, dst_feat), graph.to_dgl_graph() |
|
|
) |
|
|
|
|
|
else: |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(nfeat) |
|
|
|
|
|
bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64) |
|
|
dst_feat = nfeat |
|
|
efeat = update_efeat_bipartite_e2e( |
|
|
efeat, src_feat, dst_feat, bipartite_graph, "concat" |
|
|
) |
|
|
|
|
|
else: |
|
|
static_graph = graph.to_static_csc() |
|
|
efeat = update_efeat_static_e2e( |
|
|
efeat, |
|
|
nfeat, |
|
|
static_graph, |
|
|
mode="concat", |
|
|
use_source_emb=True, |
|
|
use_target_emb=True, |
|
|
) |
|
|
|
|
|
else: |
|
|
efeat = concat_efeat_dgl(efeat, nfeat, graph) |
|
|
|
|
|
else: |
|
|
src_feat, dst_feat = nfeat |
|
|
|
|
|
if isinstance(graph, CuGraphCSC): |
|
|
if graph.dgl_graph is not None or not USE_CUGRAPHOPS: |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(src_feat) |
|
|
efeat = concat_efeat_dgl( |
|
|
efeat, (src_feat, dst_feat), graph.to_dgl_graph() |
|
|
) |
|
|
|
|
|
else: |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(src_feat) |
|
|
|
|
|
bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64) |
|
|
efeat = update_efeat_bipartite_e2e( |
|
|
efeat, src_feat, dst_feat, bipartite_graph, "concat" |
|
|
) |
|
|
else: |
|
|
efeat = concat_efeat_dgl(efeat, (src_feat, dst_feat), graph) |
|
|
|
|
|
return efeat |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def sum_efeat_dgl( |
|
|
efeat: Tensor, src_feat: Tensor, dst_feat: Tensor, src_idx: Tensor, dst_idx: Tensor |
|
|
) -> Tensor: |
|
|
|
|
|
return efeat + src_feat[src_idx] + dst_feat[dst_idx] |
|
|
|
|
|
|
|
|
def sum_efeat( |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
): |
|
|
|
|
|
if isinstance(nfeat, Tensor): |
|
|
if isinstance(graph, CuGraphCSC): |
|
|
if graph.dgl_graph is not None or not USE_CUGRAPHOPS: |
|
|
src_feat, dst_feat = nfeat, nfeat |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(src_feat) |
|
|
|
|
|
src, dst = (item.long() for item in graph.to_dgl_graph().edges()) |
|
|
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst) |
|
|
|
|
|
else: |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(nfeat) |
|
|
dst_feat = nfeat |
|
|
bipartite_graph = graph.to_bipartite_csc() |
|
|
sum_efeat = update_efeat_bipartite_e2e( |
|
|
efeat, src_feat, dst_feat, bipartite_graph, mode="sum" |
|
|
) |
|
|
|
|
|
else: |
|
|
static_graph = graph.to_static_csc() |
|
|
sum_efeat = update_efeat_bipartite_e2e( |
|
|
efeat, nfeat, static_graph, mode="sum" |
|
|
) |
|
|
|
|
|
else: |
|
|
src_feat, dst_feat = nfeat, nfeat |
|
|
src, dst = (item.long() for item in graph.edges()) |
|
|
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst) |
|
|
|
|
|
else: |
|
|
src_feat, dst_feat = nfeat |
|
|
if isinstance(graph, CuGraphCSC): |
|
|
if graph.dgl_graph is not None or not USE_CUGRAPHOPS: |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(src_feat) |
|
|
|
|
|
src, dst = (item.long() for item in graph.to_dgl_graph().edges()) |
|
|
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst) |
|
|
|
|
|
else: |
|
|
if graph.is_distributed: |
|
|
src_feat = graph.get_src_node_features_in_local_graph(src_feat) |
|
|
|
|
|
bipartite_graph = graph.to_bipartite_csc() |
|
|
sum_efeat = update_efeat_bipartite_e2e( |
|
|
efeat, src_feat, dst_feat, bipartite_graph, mode="sum" |
|
|
) |
|
|
else: |
|
|
src, dst = (item.long() for item in graph.edges()) |
|
|
sum_efeat = sum_efeat_dgl(efeat, src_feat, dst_feat, src, dst) |
|
|
|
|
|
return sum_efeat |
|
|
|
|
|
|
|
|
@torch.jit.ignore() |
|
|
def agg_concat_dgl( |
|
|
efeat: Tensor, dst_nfeat: Tensor, graph: DGLGraph, aggregation: str |
|
|
) -> Tensor: |
|
|
|
|
|
with graph.local_scope(): |
|
|
|
|
|
graph.edata["x"] = efeat |
|
|
|
|
|
|
|
|
if aggregation == "sum": |
|
|
graph.update_all(fn.copy_e("x", "m"), fn.sum("m", "h_dest")) |
|
|
elif aggregation == "mean": |
|
|
graph.update_all(fn.copy_e("x", "m"), fn.mean("m", "h_dest")) |
|
|
else: |
|
|
raise RuntimeError("Not a valid aggregation!") |
|
|
|
|
|
|
|
|
cat_feat = torch.cat((graph.dstdata["h_dest"], dst_nfeat), -1) |
|
|
return cat_feat |
|
|
|
|
|
|
|
|
def aggregate_and_concat( |
|
|
efeat: Tensor, |
|
|
nfeat: Tensor, |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
aggregation: str, |
|
|
): |
|
|
|
|
|
|
|
|
if isinstance(graph, CuGraphCSC): |
|
|
|
|
|
if graph.dgl_graph is not None or not USE_CUGRAPHOPS: |
|
|
cat_feat = agg_concat_dgl(efeat, nfeat, graph.to_dgl_graph(), aggregation) |
|
|
|
|
|
else: |
|
|
static_graph = graph.to_static_csc() |
|
|
cat_feat = agg_concat_e2n(nfeat, efeat, static_graph, aggregation) |
|
|
else: |
|
|
cat_feat = agg_concat_dgl(efeat, nfeat, graph, aggregation) |
|
|
|
|
|
return cat_feat |
|
|
|
|
|
|
|
|
import functools |
|
|
import logging |
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
from torch.autograd import Function |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
import nvfuser |
|
|
from nvfuser import DataType, FusionDefinition |
|
|
except ImportError: |
|
|
logger.error( |
|
|
"An error occured. Either nvfuser is not installed or the version is " |
|
|
"incompatible. Please retry after installing correct version of nvfuser. " |
|
|
"The new version of nvfuser should be available in PyTorch container version " |
|
|
">= 23.10. " |
|
|
"https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html. " |
|
|
"If using a source install method, please refer nvFuser repo for installation " |
|
|
"guidelines https://github.com/NVIDIA/Fuser.", |
|
|
) |
|
|
raise |
|
|
|
|
|
_torch_dtype_to_nvfuser = { |
|
|
torch.double: DataType.Double, |
|
|
torch.float: DataType.Float, |
|
|
torch.half: DataType.Half, |
|
|
torch.int: DataType.Int, |
|
|
torch.int32: DataType.Int32, |
|
|
torch.bool: DataType.Bool, |
|
|
torch.bfloat16: DataType.BFloat16, |
|
|
torch.cfloat: DataType.ComplexFloat, |
|
|
torch.cdouble: DataType.ComplexDouble, |
|
|
} |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
|
def silu_backward_for( |
|
|
fd: FusionDefinition, |
|
|
dtype: torch.dtype, |
|
|
dim: int, |
|
|
size: torch.Size, |
|
|
stride: Tuple[int, ...], |
|
|
): |
|
|
|
|
|
try: |
|
|
dtype = _torch_dtype_to_nvfuser[dtype] |
|
|
except KeyError: |
|
|
raise TypeError("Unsupported dtype") |
|
|
|
|
|
x = fd.define_tensor( |
|
|
shape=[-1] * dim, |
|
|
contiguity=nvfuser.compute_contiguity(size, stride), |
|
|
dtype=dtype, |
|
|
) |
|
|
one = fd.define_constant(1.0) |
|
|
|
|
|
|
|
|
y = fd.ops.sigmoid(x) |
|
|
|
|
|
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))) |
|
|
|
|
|
grad_input = fd.ops.cast(grad_input, dtype) |
|
|
|
|
|
fd.add_output(grad_input) |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
|
def silu_double_backward_for( |
|
|
fd: FusionDefinition, |
|
|
dtype: torch.dtype, |
|
|
dim: int, |
|
|
size: torch.Size, |
|
|
stride: Tuple[int, ...], |
|
|
): |
|
|
|
|
|
try: |
|
|
dtype = _torch_dtype_to_nvfuser[dtype] |
|
|
except KeyError: |
|
|
raise TypeError("Unsupported dtype") |
|
|
|
|
|
x = fd.define_tensor( |
|
|
shape=[-1] * dim, |
|
|
contiguity=nvfuser.compute_contiguity(size, stride), |
|
|
dtype=dtype, |
|
|
) |
|
|
one = fd.define_constant(1.0) |
|
|
|
|
|
|
|
|
y = fd.ops.sigmoid(x) |
|
|
|
|
|
dy = fd.ops.mul(y, fd.ops.sub(one, y)) |
|
|
|
|
|
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))) |
|
|
|
|
|
term1 = fd.ops.mul(dy, z) |
|
|
|
|
|
|
|
|
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy))) |
|
|
|
|
|
grad_input = fd.ops.add(term1, term2) |
|
|
|
|
|
grad_input = fd.ops.cast(grad_input, dtype) |
|
|
|
|
|
fd.add_output(grad_input) |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
|
def silu_triple_backward_for( |
|
|
fd: FusionDefinition, |
|
|
dtype: torch.dtype, |
|
|
dim: int, |
|
|
size: torch.Size, |
|
|
stride: Tuple[int, ...], |
|
|
): |
|
|
|
|
|
try: |
|
|
dtype = _torch_dtype_to_nvfuser[dtype] |
|
|
except KeyError: |
|
|
raise TypeError("Unsupported dtype") |
|
|
|
|
|
x = fd.define_tensor( |
|
|
shape=[-1] * dim, |
|
|
contiguity=nvfuser.compute_contiguity(size, stride), |
|
|
dtype=dtype, |
|
|
) |
|
|
one = fd.define_constant(1.0) |
|
|
two = fd.define_constant(2.0) |
|
|
|
|
|
|
|
|
y = fd.ops.sigmoid(x) |
|
|
|
|
|
dy = fd.ops.mul(y, fd.ops.sub(one, y)) |
|
|
|
|
|
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy) |
|
|
|
|
|
term1 = fd.ops.mul( |
|
|
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y))) |
|
|
) |
|
|
|
|
|
|
|
|
term2 = fd.ops.mul( |
|
|
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy)))) |
|
|
) |
|
|
|
|
|
grad_input = fd.ops.add(term1, term2) |
|
|
|
|
|
grad_input = fd.ops.cast(grad_input, dtype) |
|
|
|
|
|
fd.add_output(grad_input) |
|
|
|
|
|
|
|
|
class FusedSiLU(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
|
|
|
ctx.save_for_backward(x) |
|
|
return torch.nn.functional.silu(x) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
(x,) = ctx.saved_tensors |
|
|
return FusedSiLU_deriv_1.apply(x) * grad_output |
|
|
|
|
|
|
|
|
class FusedSiLU_deriv_1(Function): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
ctx.save_for_backward(x) |
|
|
with FusionDefinition() as fd: |
|
|
silu_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride()) |
|
|
out = fd.execute([x])[0] |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
(x,) = ctx.saved_tensors |
|
|
return FusedSiLU_deriv_2.apply(x) * grad_output |
|
|
|
|
|
|
|
|
class FusedSiLU_deriv_2(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
ctx.save_for_backward(x) |
|
|
with FusionDefinition() as fd: |
|
|
silu_double_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride()) |
|
|
out = fd.execute([x])[0] |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
(x,) = ctx.saved_tensors |
|
|
return FusedSiLU_deriv_3.apply(x) * grad_output |
|
|
|
|
|
|
|
|
class FusedSiLU_deriv_3(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
ctx.save_for_backward(x) |
|
|
with FusionDefinition() as fd: |
|
|
silu_triple_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride()) |
|
|
out = fd.execute([x])[0] |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
(x,) = ctx.saved_tensors |
|
|
y = torch.sigmoid(x) |
|
|
dy = y * (1 - y) |
|
|
ddy = (1 - 2 * y) * dy |
|
|
dddy = (1 - 2 * y) * ddy - 2 * dy * dy |
|
|
z = 1 - 2 * (y + x * dy) |
|
|
term1 = dddy * (2 + x - 2 * x * y) |
|
|
term2 = 2 * ddy * z |
|
|
term3 = dy * (-2) * (2 * dy + x * ddy) |
|
|
return (term1 + term2 + term3) * grad_output |
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
from torch.autograd.function import once_differentiable |
|
|
|
|
|
|
|
|
try: |
|
|
from transformer_engine import pytorch as te |
|
|
|
|
|
te_imported = True |
|
|
except ImportError: |
|
|
te_imported = False |
|
|
|
|
|
|
|
|
class CustomSiLuLinearAutogradFunction(torch.autograd.Function): |
|
|
"""Custom SiLU + Linear autograd function""" |
|
|
|
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx, |
|
|
features: torch.Tensor, |
|
|
weight: torch.Tensor, |
|
|
bias: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
out = F.silu(features) |
|
|
out = F.linear(out, weight, bias) |
|
|
ctx.save_for_backward(features, weight) |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
@once_differentiable |
|
|
def backward( |
|
|
ctx, grad_output: torch.Tensor |
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor],]: |
|
|
"""backward pass of the SiLU + Linear function""" |
|
|
|
|
|
from nvfuser import FusionDefinition |
|
|
|
|
|
|
|
|
|
|
|
( |
|
|
need_dgrad, |
|
|
need_wgrad, |
|
|
need_bgrad, |
|
|
) = ctx.needs_input_grad |
|
|
features, weight = ctx.saved_tensors |
|
|
|
|
|
grad_features = None |
|
|
grad_weight = None |
|
|
grad_bias = None |
|
|
|
|
|
if need_bgrad: |
|
|
grad_bias = grad_output.sum(dim=0) |
|
|
|
|
|
if need_wgrad: |
|
|
out = F.silu(features) |
|
|
grad_weight = grad_output.T @ out |
|
|
|
|
|
if need_dgrad: |
|
|
grad_features = grad_output @ weight |
|
|
|
|
|
with FusionDefinition() as fd: |
|
|
silu_backward_for( |
|
|
fd, |
|
|
features.dtype, |
|
|
features.dim(), |
|
|
features.size(), |
|
|
features.stride(), |
|
|
) |
|
|
|
|
|
grad_silu = fd.execute([features])[0] |
|
|
grad_features = grad_features * grad_silu |
|
|
|
|
|
return grad_features, grad_weight, grad_bias |
|
|
|
|
|
|
|
|
class MeshGraphMLP(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: Union[int, None] = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if hidden_layers is not None: |
|
|
layers = [nn.Linear(input_dim, hidden_dim), activation_fn] |
|
|
self.hidden_layers = hidden_layers |
|
|
for _ in range(hidden_layers - 1): |
|
|
layers += [nn.Linear(hidden_dim, hidden_dim), activation_fn] |
|
|
layers.append(nn.Linear(hidden_dim, output_dim)) |
|
|
|
|
|
self.norm_type = norm_type |
|
|
if norm_type is not None: |
|
|
if norm_type not in [ |
|
|
"LayerNorm", |
|
|
"TELayerNorm", |
|
|
]: |
|
|
raise ValueError( |
|
|
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm." |
|
|
) |
|
|
if norm_type == "TELayerNorm" and te_imported: |
|
|
norm_layer = te.LayerNorm |
|
|
elif norm_type == "TELayerNorm" and not te_imported: |
|
|
raise ValueError( |
|
|
"TELayerNorm requires transformer-engine to be installed." |
|
|
) |
|
|
else: |
|
|
norm_layer = getattr(nn, norm_type) |
|
|
layers.append(norm_layer(output_dim)) |
|
|
|
|
|
self.model = nn.Sequential(*layers) |
|
|
else: |
|
|
self.model = nn.Identity() |
|
|
|
|
|
if recompute_activation: |
|
|
if not isinstance(activation_fn, nn.SiLU): |
|
|
raise ValueError(activation_fn) |
|
|
self.recompute_activation = True |
|
|
else: |
|
|
self.recompute_activation = False |
|
|
|
|
|
def default_forward(self, x: Tensor) -> Tensor: |
|
|
"""default forward pass of the MLP""" |
|
|
return self.model(x) |
|
|
|
|
|
@torch.jit.ignore() |
|
|
def custom_silu_linear_forward(self, x: Tensor) -> Tensor: |
|
|
"""forward pass of the MLP where SiLU is recomputed in backward""" |
|
|
lin = self.model[0] |
|
|
hidden = lin(x) |
|
|
for i in range(1, self.hidden_layers + 1): |
|
|
lin = self.model[2 * i] |
|
|
hidden = CustomSiLuLinearAutogradFunction.apply( |
|
|
hidden, lin.weight, lin.bias |
|
|
) |
|
|
|
|
|
if self.norm_type is not None: |
|
|
norm = self.model[2 * self.hidden_layers + 1] |
|
|
hidden = norm(hidden) |
|
|
return hidden |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
if self.recompute_activation: |
|
|
return self.custom_silu_linear_forward(x) |
|
|
return self.default_forward(x) |
|
|
|
|
|
|
|
|
class MeshGraphEdgeMLPConcat(MeshGraphMLP): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
efeat_dim: int = 512, |
|
|
src_dim: int = 512, |
|
|
dst_dim: int = 512, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 2, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
bias: bool = True, |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
cat_dim = efeat_dim + src_dim + dst_dim |
|
|
super(MeshGraphEdgeMLPConcat, self).__init__( |
|
|
cat_dim, |
|
|
output_dim, |
|
|
hidden_dim, |
|
|
hidden_layers, |
|
|
activation_fn, |
|
|
norm_type, |
|
|
recompute_activation, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
efeat = concat_efeat(efeat, nfeat, graph) |
|
|
efeat = self.model(efeat) |
|
|
return efeat |
|
|
|
|
|
|
|
|
class MeshGraphEdgeMLPSum(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
efeat_dim: int, |
|
|
src_dim: int, |
|
|
dst_dim: int, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
bias: bool = True, |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.efeat_dim = efeat_dim |
|
|
self.src_dim = src_dim |
|
|
self.dst_dim = dst_dim |
|
|
|
|
|
tmp_lin = nn.Linear(efeat_dim + src_dim + dst_dim, hidden_dim, bias=bias) |
|
|
|
|
|
orig_weight = tmp_lin.weight |
|
|
w_efeat, w_src, w_dst = torch.split( |
|
|
orig_weight, [efeat_dim, src_dim, dst_dim], dim=1 |
|
|
) |
|
|
self.lin_efeat = nn.Parameter(w_efeat) |
|
|
self.lin_src = nn.Parameter(w_src) |
|
|
self.lin_dst = nn.Parameter(w_dst) |
|
|
|
|
|
if bias: |
|
|
self.bias = tmp_lin.bias |
|
|
else: |
|
|
self.bias = None |
|
|
|
|
|
layers = [activation_fn] |
|
|
self.hidden_layers = hidden_layers |
|
|
for _ in range(hidden_layers - 1): |
|
|
layers += [nn.Linear(hidden_dim, hidden_dim), activation_fn] |
|
|
layers.append(nn.Linear(hidden_dim, output_dim)) |
|
|
|
|
|
self.norm_type = norm_type |
|
|
if norm_type is not None: |
|
|
if norm_type not in [ |
|
|
"LayerNorm", |
|
|
"TELayerNorm", |
|
|
]: |
|
|
raise ValueError( |
|
|
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm." |
|
|
) |
|
|
if norm_type == "TELayerNorm" and te_imported: |
|
|
norm_layer = te.LayerNorm |
|
|
elif norm_type == "TELayerNorm" and not te_imported: |
|
|
raise ValueError( |
|
|
"TELayerNorm requires transformer-engine to be installed." |
|
|
) |
|
|
else: |
|
|
norm_layer = getattr(nn, norm_type) |
|
|
layers.append(norm_layer(output_dim)) |
|
|
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
if recompute_activation: |
|
|
if not isinstance(activation_fn, nn.SiLU): |
|
|
raise ValueError(activation_fn) |
|
|
self.recompute_activation = True |
|
|
else: |
|
|
self.recompute_activation = False |
|
|
|
|
|
def forward_truncated_sum( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
|
|
|
if isinstance(nfeat, Tensor): |
|
|
src_feat, dst_feat = nfeat, nfeat |
|
|
else: |
|
|
src_feat, dst_feat = nfeat |
|
|
mlp_efeat = F.linear(efeat, self.lin_efeat, None) |
|
|
mlp_src = F.linear(src_feat, self.lin_src, None) |
|
|
mlp_dst = F.linear(dst_feat, self.lin_dst, self.bias) |
|
|
mlp_sum = sum_efeat(mlp_efeat, (mlp_src, mlp_dst), graph) |
|
|
return mlp_sum |
|
|
|
|
|
def default_forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
"""Default forward pass of the truncated MLP.""" |
|
|
mlp_sum = self.forward_truncated_sum( |
|
|
efeat, |
|
|
nfeat, |
|
|
graph, |
|
|
) |
|
|
return self.model(mlp_sum) |
|
|
|
|
|
def custom_silu_linear_forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
"""Forward pass of the truncated MLP with custom SiLU function.""" |
|
|
mlp_sum = self.forward_truncated_sum( |
|
|
efeat, |
|
|
nfeat, |
|
|
graph, |
|
|
) |
|
|
lin = self.model[1] |
|
|
hidden = CustomSiLuLinearAutogradFunction.apply(mlp_sum, lin.weight, lin.bias) |
|
|
for i in range(2, self.hidden_layers + 1): |
|
|
lin = self.model[2 * i - 1] |
|
|
hidden = CustomSiLuLinearAutogradFunction.apply( |
|
|
hidden, lin.weight, lin.bias |
|
|
) |
|
|
|
|
|
if self.norm_type is not None: |
|
|
norm = self.model[2 * self.hidden_layers] |
|
|
hidden = norm(hidden) |
|
|
return hidden |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Union[Tensor, Tuple[Tensor]], |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
if self.recompute_activation: |
|
|
return self.custom_silu_linear_forward(efeat, nfeat, graph) |
|
|
return self.default_forward(efeat, nfeat, graph) |
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
class GraphCastEncoderEmbedder(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim_grid_nodes: int = 474, |
|
|
input_dim_mesh_nodes: int = 3, |
|
|
input_dim_edges: int = 4, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.grid_node_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_grid_nodes, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.mesh_node_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_mesh_nodes, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.mesh_edge_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_edges, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.grid2mesh_edge_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_edges, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
grid_nfeat: Tensor, |
|
|
mesh_nfeat: Tensor, |
|
|
g2m_efeat: Tensor, |
|
|
mesh_efeat: Tensor, |
|
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
|
|
|
|
grid_nfeat = self.grid_node_mlp(grid_nfeat) |
|
|
mesh_nfeat = self.mesh_node_mlp(mesh_nfeat) |
|
|
|
|
|
g2m_efeat = self.grid2mesh_edge_mlp(g2m_efeat) |
|
|
mesh_efeat = self.mesh_edge_mlp(mesh_efeat) |
|
|
return grid_nfeat, mesh_nfeat, g2m_efeat, mesh_efeat |
|
|
|
|
|
|
|
|
class GraphCastDecoderEmbedder(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim_edges: int = 4, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.mesh2grid_edge_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_edges, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
m2g_efeat: Tensor, |
|
|
) -> Tensor: |
|
|
m2g_efeat = self.mesh2grid_edge_mlp(m2g_efeat) |
|
|
return m2g_efeat |
|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
class MeshGraphDecoder(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
aggregation: str = "sum", |
|
|
input_dim_src_nodes: int = 512, |
|
|
input_dim_dst_nodes: int = 512, |
|
|
input_dim_edges: int = 512, |
|
|
output_dim_dst_nodes: int = 512, |
|
|
output_dim_edges: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
do_concat_trick: bool = False, |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.aggregation = aggregation |
|
|
|
|
|
MLP = MeshGraphEdgeMLPSum if do_concat_trick else MeshGraphEdgeMLPConcat |
|
|
|
|
|
self.edge_mlp = MLP( |
|
|
efeat_dim=input_dim_edges, |
|
|
src_dim=input_dim_src_nodes, |
|
|
dst_dim=input_dim_dst_nodes, |
|
|
output_dim=output_dim_edges, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.node_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_dst_nodes + output_dim_edges, |
|
|
output_dim=output_dim_dst_nodes, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
@torch.jit.ignore() |
|
|
def forward( |
|
|
self, |
|
|
m2g_efeat: Tensor, |
|
|
grid_nfeat: Tensor, |
|
|
mesh_nfeat: Tensor, |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
|
|
|
efeat = self.edge_mlp(m2g_efeat, (mesh_nfeat, grid_nfeat), graph) |
|
|
|
|
|
cat_feat = aggregate_and_concat(efeat, grid_nfeat, graph, self.aggregation) |
|
|
|
|
|
dst_feat = self.node_mlp(cat_feat) + grid_nfeat |
|
|
return dst_feat |
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeshGraphEncoder(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
aggregation: str = "sum", |
|
|
input_dim_src_nodes: int = 512, |
|
|
input_dim_dst_nodes: int = 512, |
|
|
input_dim_edges: int = 512, |
|
|
output_dim_src_nodes: int = 512, |
|
|
output_dim_dst_nodes: int = 512, |
|
|
output_dim_edges: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: int = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
do_concat_trick: bool = False, |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.aggregation = aggregation |
|
|
|
|
|
MLP = MeshGraphEdgeMLPSum if do_concat_trick else MeshGraphEdgeMLPConcat |
|
|
|
|
|
self.edge_mlp = MLP( |
|
|
efeat_dim=input_dim_edges, |
|
|
src_dim=input_dim_src_nodes, |
|
|
dst_dim=input_dim_dst_nodes, |
|
|
output_dim=output_dim_edges, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.src_node_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_src_nodes, |
|
|
output_dim=output_dim_src_nodes, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.dst_node_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_dst_nodes + output_dim_edges, |
|
|
output_dim=output_dim_dst_nodes, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
@torch.jit.ignore() |
|
|
def forward( |
|
|
self, |
|
|
g2m_efeat: Tensor, |
|
|
grid_nfeat: Tensor, |
|
|
mesh_nfeat: Tensor, |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tuple[Tensor, Tensor]: |
|
|
|
|
|
|
|
|
efeat = self.edge_mlp(g2m_efeat, (grid_nfeat, mesh_nfeat), graph) |
|
|
|
|
|
cat_feat = aggregate_and_concat(efeat, mesh_nfeat, graph, self.aggregation) |
|
|
|
|
|
mesh_nfeat = mesh_nfeat + self.dst_node_mlp(cat_feat) |
|
|
grid_nfeat = grid_nfeat + self.src_node_mlp(grid_nfeat) |
|
|
return grid_nfeat, mesh_nfeat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
|
|
|
class Identity(nn.Module): |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return x |
|
|
|
|
|
|
|
|
class Stan(nn.Module): |
|
|
|
|
|
def __init__(self, out_features: int = 1): |
|
|
super().__init__() |
|
|
self.beta = nn.Parameter(torch.ones(out_features)) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
if x.shape[-1] != self.beta.shape[-1]: |
|
|
raise ValueError( |
|
|
f"The last dimension of the input must be equal to the dimension of Stan parameters. Got inputs: {x.shape}, params: {self.beta.shape}" |
|
|
) |
|
|
return torch.tanh(x) * (1.0 + self.beta * x) |
|
|
|
|
|
|
|
|
class SquarePlus(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.b = 4 |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return 0.5 * (x + torch.sqrt(x * x + self.b)) |
|
|
|
|
|
|
|
|
class CappedLeakyReLU(torch.nn.Module): |
|
|
|
|
|
def __init__(self, cap_value=1.0, **kwargs): |
|
|
|
|
|
super().__init__() |
|
|
self.add_module("leaky_relu", torch.nn.LeakyReLU(**kwargs)) |
|
|
self.register_buffer("cap", torch.tensor(cap_value, dtype=torch.float32)) |
|
|
|
|
|
def forward(self, inputs): |
|
|
x = self.leaky_relu(inputs) |
|
|
x = torch.clamp(x, max=self.cap) |
|
|
return x |
|
|
|
|
|
|
|
|
class CappedGELU(torch.nn.Module): |
|
|
|
|
|
|
|
|
def __init__(self, cap_value=1.0, **kwargs): |
|
|
|
|
|
|
|
|
super().__init__() |
|
|
self.add_module("gelu", torch.nn.GELU(**kwargs)) |
|
|
self.register_buffer("cap", torch.tensor(cap_value, dtype=torch.float32)) |
|
|
|
|
|
def forward(self, inputs): |
|
|
x = self.gelu(inputs) |
|
|
x = torch.clamp(x, max=self.cap) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
ACT2FN = { |
|
|
"relu": nn.ReLU, |
|
|
"leaky_relu": (nn.LeakyReLU, {"negative_slope": 0.1}), |
|
|
"prelu": nn.PReLU, |
|
|
"relu6": nn.ReLU6, |
|
|
"elu": nn.ELU, |
|
|
"selu": nn.SELU, |
|
|
"silu": nn.SiLU, |
|
|
"gelu": nn.GELU, |
|
|
"sigmoid": nn.Sigmoid, |
|
|
"logsigmoid": nn.LogSigmoid, |
|
|
"softplus": nn.Softplus, |
|
|
"softshrink": nn.Softshrink, |
|
|
"softsign": nn.Softsign, |
|
|
"tanh": nn.Tanh, |
|
|
"tanhshrink": nn.Tanhshrink, |
|
|
"threshold": (nn.Threshold, {"threshold": 1.0, "value": 1.0}), |
|
|
"hardtanh": nn.Hardtanh, |
|
|
"identity": Identity, |
|
|
"stan": Stan, |
|
|
"squareplus": SquarePlus, |
|
|
"cappek_leaky_relu": CappedLeakyReLU, |
|
|
"capped_gelu": CappedGELU, |
|
|
} |
|
|
|
|
|
|
|
|
def get_activation(activation: str) -> nn.Module: |
|
|
|
|
|
try: |
|
|
activation = activation.lower() |
|
|
module = ACT2FN[activation] |
|
|
if isinstance(module, tuple): |
|
|
return module[0](**module[1]) |
|
|
else: |
|
|
return module() |
|
|
except KeyError: |
|
|
raise KeyError( |
|
|
f"Activation function {activation} not found. Available options are: {list(ACT2FN.keys())}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelMetaData: |
|
|
"""Data class for storing essential meta data needed for all Modulus Models""" |
|
|
|
|
|
|
|
|
name: str = "ModulusModule" |
|
|
|
|
|
jit: bool = False |
|
|
cuda_graphs: bool = False |
|
|
amp: bool = False |
|
|
amp_cpu: bool = None |
|
|
amp_gpu: bool = None |
|
|
torch_fx: bool = False |
|
|
|
|
|
bf16: bool = False |
|
|
|
|
|
onnx: bool = False |
|
|
onnx_gpu: bool = None |
|
|
onnx_cpu: bool = None |
|
|
onnx_runtime: bool = False |
|
|
trt: bool = False |
|
|
|
|
|
var_dim: int = -1 |
|
|
func_torch: bool = False |
|
|
auto_grad: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu |
|
|
self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu |
|
|
self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu |
|
|
self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu |
|
|
|
|
|
|
|
|
|
|
|
from importlib.metadata import EntryPoint, entry_points |
|
|
from typing import List, Union |
|
|
|
|
|
|
|
|
import importlib_metadata |
|
|
|
|
|
|
|
|
class ModelRegistry: |
|
|
_shared_state = {"_model_registry": None} |
|
|
|
|
|
def __new__(cls, *args, **kwargs): |
|
|
obj = super(ModelRegistry, cls).__new__(cls) |
|
|
obj.__dict__ = cls._shared_state |
|
|
if cls._shared_state["_model_registry"] is None: |
|
|
cls._shared_state["_model_registry"] = cls._construct_registry() |
|
|
return obj |
|
|
|
|
|
@staticmethod |
|
|
def _construct_registry() -> dict: |
|
|
registry = {} |
|
|
entrypoints = entry_points(group="modulus.models") |
|
|
for entry_point in entrypoints: |
|
|
registry[entry_point.name] = entry_point |
|
|
return registry |
|
|
|
|
|
def register(self, model: "modulus.Module", name: Union[str, None] = None) -> None: |
|
|
|
|
|
|
|
|
if not issubclass(model, modulus.Module): |
|
|
raise ValueError( |
|
|
f"Only subclasses of modulus.Module can be registered. " |
|
|
f"Provided model is of type {type(model)}" |
|
|
) |
|
|
|
|
|
|
|
|
if name is None: |
|
|
name = model.__name__ |
|
|
|
|
|
|
|
|
if name in self._model_registry: |
|
|
raise ValueError(f"Name {name} already in use") |
|
|
|
|
|
|
|
|
self._model_registry[name] = model |
|
|
|
|
|
def factory(self, name: str) -> "modulus.Module": |
|
|
|
|
|
model = self._model_registry.get(name) |
|
|
if model is not None: |
|
|
if isinstance(model, (EntryPoint, importlib_metadata.EntryPoint)): |
|
|
model = model.load() |
|
|
return model |
|
|
|
|
|
raise KeyError(f"No model is registered under the name {name}") |
|
|
|
|
|
def list_models(self) -> List[str]: |
|
|
|
|
|
return list(self._model_registry.keys()) |
|
|
|
|
|
def __clear_registry__(self): |
|
|
|
|
|
self._model_registry = {} |
|
|
|
|
|
def __restore_registry__(self): |
|
|
|
|
|
self._model_registry = self._construct_registry() |
|
|
|
|
|
|
|
|
import importlib |
|
|
import inspect |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import tarfile |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class Module(torch.nn.Module): |
|
|
_file_extension = ".mdlus" |
|
|
__model_checkpoint_version__ = ( |
|
|
"0.1.0" |
|
|
) |
|
|
|
|
|
def __new__(cls, *args, **kwargs): |
|
|
out = super().__new__(cls) |
|
|
|
|
|
|
|
|
sig = inspect.signature(cls.__init__) |
|
|
|
|
|
|
|
|
bound_args = sig.bind_partial( |
|
|
*([None] + list(args)), **kwargs |
|
|
) |
|
|
bound_args.apply_defaults() |
|
|
|
|
|
|
|
|
instantiate_args = {} |
|
|
for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()): |
|
|
|
|
|
if k == "self": |
|
|
continue |
|
|
|
|
|
|
|
|
if param.kind == param.VAR_KEYWORD: |
|
|
instantiate_args.update(v) |
|
|
else: |
|
|
instantiate_args[k] = v |
|
|
|
|
|
|
|
|
out._args = { |
|
|
"__name__": cls.__name__, |
|
|
"__module__": cls.__module__, |
|
|
"__args__": instantiate_args, |
|
|
} |
|
|
return out |
|
|
|
|
|
def __init__(self, meta: Union[ModelMetaData, None] = None): |
|
|
super().__init__() |
|
|
self.meta = meta |
|
|
self.register_buffer("device_buffer", torch.empty(0)) |
|
|
self._setup_logger() |
|
|
|
|
|
def _setup_logger(self): |
|
|
self.logger = logging.getLogger("core.module") |
|
|
handler = logging.StreamHandler() |
|
|
formatter = logging.Formatter( |
|
|
"[%(asctime)s - %(levelname)s] %(message)s", datefmt="%H:%M:%S" |
|
|
) |
|
|
handler.setFormatter(formatter) |
|
|
self.logger.addHandler(handler) |
|
|
self.logger.setLevel(logging.WARNING) |
|
|
|
|
|
@staticmethod |
|
|
def _safe_members(tar, local_path): |
|
|
for member in tar.getmembers(): |
|
|
if ( |
|
|
".." in member.name |
|
|
or os.path.isabs(member.name) |
|
|
or os.path.realpath(os.path.join(local_path, member.name)).startswith( |
|
|
os.path.realpath(local_path) |
|
|
) |
|
|
): |
|
|
yield member |
|
|
else: |
|
|
print(f"Skipping potentially malicious file: {member.name}") |
|
|
|
|
|
@classmethod |
|
|
def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module": |
|
|
|
|
|
_cls_name = arg_dict["__name__"] |
|
|
registry = ModelRegistry() |
|
|
if cls.__name__ == arg_dict["__name__"]: |
|
|
_cls = cls |
|
|
elif _cls_name in registry.list_models(): |
|
|
_cls = registry.factory(_cls_name) |
|
|
else: |
|
|
try: |
|
|
|
|
|
_mod = importlib.import_module(arg_dict["__module__"]) |
|
|
_cls = getattr(_mod, arg_dict["__name__"]) |
|
|
except AttributeError: |
|
|
|
|
|
_cls = cls |
|
|
return _cls(**arg_dict["__args__"]) |
|
|
|
|
|
def debug(self): |
|
|
"""Turn on debug logging""" |
|
|
self.logger.handlers.clear() |
|
|
handler = logging.StreamHandler() |
|
|
formatter = logging.Formatter( |
|
|
f"[%(asctime)s - %(levelname)s - {self.meta.name}] %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
) |
|
|
handler.setFormatter(formatter) |
|
|
self.logger.addHandler(handler) |
|
|
self.logger.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
|
|
|
def save(self, file_name: Union[str, None] = None, verbose: bool = False) -> None: |
|
|
|
|
|
if file_name is not None and not file_name.endswith(self._file_extension): |
|
|
raise ValueError( |
|
|
f"File name must end with {self._file_extension} extension" |
|
|
) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
local_path = Path(temp_dir) |
|
|
|
|
|
torch.save(self.state_dict(), local_path / "model.pt") |
|
|
|
|
|
with open(local_path / "args.json", "w") as f: |
|
|
json.dump(self._args, f) |
|
|
|
|
|
|
|
|
metadata_info = { |
|
|
"modulus_version": modulus.__version__, |
|
|
"mdlus_file_version": self.__model_checkpoint_version__, |
|
|
} |
|
|
|
|
|
if verbose: |
|
|
import git |
|
|
|
|
|
try: |
|
|
repo = git.Repo(search_parent_directories=True) |
|
|
metadata_info["git_hash"] = repo.head.object.hexsha |
|
|
except git.InvalidGitRepositoryError: |
|
|
metadata_info["git_hash"] = None |
|
|
|
|
|
with open(local_path / "metadata.json", "w") as f: |
|
|
json.dump(metadata_info, f) |
|
|
|
|
|
|
|
|
with tarfile.open(local_path / "model.tar", "w") as tar: |
|
|
for file in local_path.iterdir(): |
|
|
tar.add(str(file), arcname=file.name) |
|
|
|
|
|
if file_name is None: |
|
|
file_name = self.meta.name + ".mdlus" |
|
|
|
|
|
|
|
|
fs = _get_fs(file_name) |
|
|
fs.put(str(local_path / "model.tar"), file_name) |
|
|
|
|
|
@staticmethod |
|
|
def _check_checkpoint(local_path: str) -> bool: |
|
|
if not local_path.joinpath("args.json").exists(): |
|
|
raise IOError("File 'args.json' not found in checkpoint") |
|
|
|
|
|
if not local_path.joinpath("metadata.json").exists(): |
|
|
raise IOError("File 'metadata.json' not found in checkpoint") |
|
|
|
|
|
if not local_path.joinpath("model.pt").exists(): |
|
|
raise IOError("Model weights 'model.pt' not found in checkpoint") |
|
|
|
|
|
|
|
|
with open(local_path.joinpath("metadata.json"), "r") as f: |
|
|
metadata_info = json.load(f) |
|
|
if ( |
|
|
metadata_info["mdlus_file_version"] |
|
|
!= Module.__model_checkpoint_version__ |
|
|
): |
|
|
raise IOError( |
|
|
f"Model checkpoint version {metadata_info['mdlus_file_version']} is not compatible with current version {Module.__version__}" |
|
|
) |
|
|
|
|
|
def load( |
|
|
self, |
|
|
file_name: str, |
|
|
map_location: Union[None, str, torch.device] = None, |
|
|
strict: bool = True, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
cached_file_name = _download_cached(file_name) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
local_path = Path(temp_dir) |
|
|
|
|
|
|
|
|
with tarfile.open(cached_file_name, "r") as tar: |
|
|
tar.extractall( |
|
|
path=local_path, members=list(Module._safe_members(tar, local_path)) |
|
|
) |
|
|
|
|
|
|
|
|
Module._check_checkpoint(local_path) |
|
|
|
|
|
|
|
|
device = map_location if map_location is not None else self.device |
|
|
model_dict = torch.load( |
|
|
local_path.joinpath("model.pt"), map_location=device |
|
|
) |
|
|
self.load_state_dict(model_dict, strict=strict) |
|
|
|
|
|
@classmethod |
|
|
def from_checkpoint(cls, file_name: str) -> "Module": |
|
|
|
|
|
|
|
|
cached_file_name = _download_cached(file_name) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
local_path = Path(temp_dir) |
|
|
|
|
|
|
|
|
with tarfile.open(cached_file_name, "r") as tar: |
|
|
tar.extractall( |
|
|
path=local_path, members=list(cls._safe_members(tar, local_path)) |
|
|
) |
|
|
|
|
|
|
|
|
Module._check_checkpoint(local_path) |
|
|
|
|
|
|
|
|
with open(local_path.joinpath("args.json"), "r") as f: |
|
|
args = json.load(f) |
|
|
model = cls.instantiate(args) |
|
|
|
|
|
|
|
|
model_dict = torch.load( |
|
|
local_path.joinpath("model.pt"), map_location=model.device |
|
|
) |
|
|
model.load_state_dict(model_dict) |
|
|
|
|
|
return model |
|
|
|
|
|
@staticmethod |
|
|
def from_torch( |
|
|
torch_model_class: torch.nn.Module, meta: ModelMetaData = None |
|
|
) -> "Module": |
|
|
|
|
|
|
|
|
|
|
|
class ModulusModel(Module): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(meta=meta) |
|
|
self.inner_model = torch_model_class(*args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.inner_model(x) |
|
|
|
|
|
|
|
|
init_argspec = inspect.getfullargspec(torch_model_class.__init__) |
|
|
model_argnames = init_argspec.args[1:] |
|
|
model_defaults = init_argspec.defaults or [] |
|
|
defaults_dict = dict( |
|
|
zip(model_argnames[-len(model_defaults) :], model_defaults) |
|
|
) |
|
|
|
|
|
|
|
|
params = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] |
|
|
params += [ |
|
|
inspect.Parameter( |
|
|
argname, |
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD, |
|
|
default=defaults_dict.get(argname, inspect.Parameter.empty), |
|
|
) |
|
|
for argname in model_argnames |
|
|
] |
|
|
init_signature = inspect.Signature(params) |
|
|
|
|
|
|
|
|
ModulusModel.__init__.__signature__ = init_signature |
|
|
|
|
|
|
|
|
new_class_name = f"{torch_model_class.__name__}ModulusModel" |
|
|
ModulusModel.__name__ = new_class_name |
|
|
|
|
|
|
|
|
registry = ModelRegistry() |
|
|
registry.register(ModulusModel, new_class_name) |
|
|
|
|
|
return ModulusModel |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
|
|
|
return self.device_buffer.device |
|
|
|
|
|
def num_parameters(self) -> int: |
|
|
"""Gets the number of learnable parameters""" |
|
|
count = 0 |
|
|
for name, param in self.named_parameters(): |
|
|
count += param.numel() |
|
|
return count |
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple |
|
|
|
|
|
import dgl |
|
|
import numpy as np |
|
|
import torch |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor, testing |
|
|
|
|
|
|
|
|
def create_graph( |
|
|
src: List, |
|
|
dst: List, |
|
|
to_bidirected: bool = True, |
|
|
add_self_loop: bool = False, |
|
|
dtype: torch.dtype = torch.int32, |
|
|
) -> DGLGraph: |
|
|
graph = dgl.graph((src, dst), idtype=dtype) |
|
|
if to_bidirected: |
|
|
graph = dgl.to_bidirected(graph) |
|
|
if add_self_loop: |
|
|
graph = dgl.add_self_loop(graph) |
|
|
return graph |
|
|
|
|
|
|
|
|
def create_heterograph( |
|
|
src: List, |
|
|
dst: List, |
|
|
labels: str, |
|
|
dtype: torch.dtype = torch.int32, |
|
|
num_nodes_dict: dict = None, |
|
|
) -> DGLGraph: |
|
|
|
|
|
graph = dgl.heterograph( |
|
|
{labels: ("coo", (src, dst))}, num_nodes_dict=num_nodes_dict, idtype=dtype |
|
|
) |
|
|
return graph |
|
|
|
|
|
|
|
|
def add_edge_features(graph: DGLGraph, pos: Tensor, normalize: bool = True) -> DGLGraph: |
|
|
|
|
|
if isinstance(pos, tuple): |
|
|
src_pos, dst_pos = pos |
|
|
else: |
|
|
src_pos = dst_pos = pos |
|
|
src, dst = graph.edges() |
|
|
|
|
|
src_pos, dst_pos = src_pos[src.long()], dst_pos[dst.long()] |
|
|
dst_latlon = xyz2latlon(dst_pos, unit="rad") |
|
|
dst_lat, dst_lon = dst_latlon[:, 0], dst_latlon[:, 1] |
|
|
|
|
|
|
|
|
theta_azimuthal = azimuthal_angle(dst_lon) |
|
|
theta_polar = polar_angle(dst_lat) |
|
|
|
|
|
src_pos = geospatial_rotation(src_pos, theta=theta_azimuthal, axis="z", unit="rad") |
|
|
dst_pos = geospatial_rotation(dst_pos, theta=theta_azimuthal, axis="z", unit="rad") |
|
|
|
|
|
try: |
|
|
testing.assert_close(dst_pos[:, 1], torch.zeros_like(dst_pos[:, 1])) |
|
|
except ValueError: |
|
|
raise ValueError("Invalid projection of edge nodes to local ccordinate system") |
|
|
src_pos = geospatial_rotation(src_pos, theta=theta_polar, axis="y", unit="rad") |
|
|
dst_pos = geospatial_rotation(dst_pos, theta=theta_polar, axis="y", unit="rad") |
|
|
|
|
|
try: |
|
|
testing.assert_close(dst_pos[:, 0], torch.ones_like(dst_pos[:, 0])) |
|
|
testing.assert_close(dst_pos[:, 1], torch.zeros_like(dst_pos[:, 1])) |
|
|
testing.assert_close(dst_pos[:, 2], torch.zeros_like(dst_pos[:, 2])) |
|
|
except ValueError: |
|
|
raise ValueError("Invalid projection of edge nodes to local ccordinate system") |
|
|
|
|
|
|
|
|
disp = src_pos - dst_pos |
|
|
disp_norm = torch.linalg.norm(disp, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
max_disp_norm = torch.max(disp_norm) |
|
|
graph.edata["x"] = torch.cat( |
|
|
(disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1 |
|
|
) |
|
|
else: |
|
|
graph.edata["x"] = torch.cat((disp, disp_norm), dim=-1) |
|
|
return graph |
|
|
|
|
|
|
|
|
def add_node_features(graph: DGLGraph, pos: Tensor) -> DGLGraph: |
|
|
|
|
|
latlon = xyz2latlon(pos) |
|
|
lat, lon = latlon[:, 0], latlon[:, 1] |
|
|
graph.ndata["x"] = torch.stack( |
|
|
(torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1 |
|
|
) |
|
|
return graph |
|
|
|
|
|
|
|
|
def latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = "deg") -> Tensor: |
|
|
|
|
|
if unit == "deg": |
|
|
latlon = deg2rad(latlon) |
|
|
elif unit == "rad": |
|
|
pass |
|
|
else: |
|
|
raise ValueError("Not a valid unit") |
|
|
lat, lon = latlon[:, 0], latlon[:, 1] |
|
|
x = radius * torch.cos(lat) * torch.cos(lon) |
|
|
y = radius * torch.cos(lat) * torch.sin(lon) |
|
|
z = radius * torch.sin(lat) |
|
|
return torch.stack((x, y, z), dim=1) |
|
|
|
|
|
|
|
|
def xyz2latlon(xyz: Tensor, radius: float = 1, unit: str = "deg") -> Tensor: |
|
|
|
|
|
lat = torch.arcsin(xyz[:, 2] / radius) |
|
|
lon = torch.arctan2(xyz[:, 1], xyz[:, 0]) |
|
|
if unit == "deg": |
|
|
return torch.stack((rad2deg(lat), rad2deg(lon)), dim=1) |
|
|
elif unit == "rad": |
|
|
return torch.stack((lat, lon), dim=1) |
|
|
else: |
|
|
raise ValueError("Not a valid unit") |
|
|
|
|
|
|
|
|
def geospatial_rotation( |
|
|
invar: Tensor, theta: Tensor, axis: str, unit: str = "rad" |
|
|
) -> Tensor: |
|
|
|
|
|
|
|
|
if unit == "deg": |
|
|
invar = rad2deg(invar) |
|
|
elif unit == "rad": |
|
|
pass |
|
|
else: |
|
|
raise ValueError("Not a valid unit") |
|
|
|
|
|
invar = torch.unsqueeze(invar, -1) |
|
|
rotation = torch.zeros((theta.size(0), 3, 3)) |
|
|
cos = torch.cos(theta) |
|
|
sin = torch.sin(theta) |
|
|
|
|
|
if axis == "x": |
|
|
rotation[:, 0, 0] += 1.0 |
|
|
rotation[:, 1, 1] += cos |
|
|
rotation[:, 1, 2] -= sin |
|
|
rotation[:, 2, 1] += sin |
|
|
rotation[:, 2, 2] += cos |
|
|
elif axis == "y": |
|
|
rotation[:, 0, 0] += cos |
|
|
rotation[:, 0, 2] += sin |
|
|
rotation[:, 1, 1] += 1.0 |
|
|
rotation[:, 2, 0] -= sin |
|
|
rotation[:, 2, 2] += cos |
|
|
elif axis == "z": |
|
|
rotation[:, 0, 0] += cos |
|
|
rotation[:, 0, 1] -= sin |
|
|
rotation[:, 1, 0] += sin |
|
|
rotation[:, 1, 1] += cos |
|
|
rotation[:, 2, 2] += 1.0 |
|
|
else: |
|
|
raise ValueError("Invalid axis") |
|
|
|
|
|
outvar = torch.matmul(rotation, invar) |
|
|
outvar = outvar.squeeze() |
|
|
return outvar |
|
|
|
|
|
|
|
|
def azimuthal_angle(lon: Tensor) -> Tensor: |
|
|
|
|
|
angle = torch.where(lon >= 0.0, 2 * np.pi - lon, -lon) |
|
|
return angle |
|
|
|
|
|
|
|
|
def polar_angle(lat: Tensor) -> Tensor: |
|
|
|
|
|
angle = torch.where(lat >= 0.0, lat, 2 * np.pi + lat) |
|
|
return angle |
|
|
|
|
|
|
|
|
def deg2rad(deg: Tensor) -> Tensor: |
|
|
|
|
|
return deg * np.pi / 180 |
|
|
|
|
|
|
|
|
def rad2deg(rad): |
|
|
|
|
|
return rad * 180 / np.pi |
|
|
|
|
|
|
|
|
def cell_to_adj(cells: List[List[int]]): |
|
|
|
|
|
num_cells = np.shape(cells)[0] |
|
|
src = [cells[i][indx] for i in range(num_cells) for indx in [0, 1, 2]] |
|
|
dst = [cells[i][indx] for i in range(num_cells) for indx in [1, 2, 0]] |
|
|
return src, dst |
|
|
|
|
|
|
|
|
def max_edge_length( |
|
|
vertices: List[List[float]], source_nodes: List[int], destination_nodes: List[int] |
|
|
) -> float: |
|
|
|
|
|
vertices_np = np.array(vertices) |
|
|
source_coords = vertices_np[source_nodes] |
|
|
dest_coords = vertices_np[destination_nodes] |
|
|
|
|
|
|
|
|
squared_differences = np.sum((source_coords - dest_coords) ** 2, axis=1) |
|
|
|
|
|
|
|
|
max_length = np.sqrt(np.max(squared_differences)) |
|
|
|
|
|
return max_length |
|
|
|
|
|
|
|
|
def get_face_centroids( |
|
|
vertices: List[Tuple[float, float, float]], faces: List[List[int]] |
|
|
) -> List[Tuple[float, float, float]]: |
|
|
|
|
|
centroids = [] |
|
|
|
|
|
for face in faces: |
|
|
|
|
|
v0 = vertices[face[0]] |
|
|
v1 = vertices[face[1]] |
|
|
v2 = vertices[face[2]] |
|
|
|
|
|
|
|
|
centroid = ( |
|
|
(v0[0] + v1[0] + v2[0]) / 3, |
|
|
(v0[1] + v1[1] + v2[1]) / 3, |
|
|
(v0[2] + v1[2] + v2[2]) / 3, |
|
|
) |
|
|
|
|
|
centroids.append(centroid) |
|
|
|
|
|
return centroids |
|
|
|
|
|
|
|
|
import itertools |
|
|
from typing import List, NamedTuple, Sequence, Tuple |
|
|
|
|
|
import numpy as np |
|
|
from scipy.spatial import transform |
|
|
|
|
|
|
|
|
class TriangularMesh(NamedTuple): |
|
|
|
|
|
|
|
|
vertices: np.ndarray |
|
|
faces: np.ndarray |
|
|
|
|
|
|
|
|
def merge_meshes(mesh_list: Sequence[TriangularMesh]) -> TriangularMesh: |
|
|
|
|
|
for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list): |
|
|
num_nodes_mesh_i = mesh_i.vertices.shape[0] |
|
|
assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i]) |
|
|
|
|
|
return TriangularMesh( |
|
|
vertices=mesh_list[-1].vertices, |
|
|
faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0), |
|
|
) |
|
|
|
|
|
|
|
|
def get_hierarchy_of_triangular_meshes_for_sphere(splits: int) -> List[TriangularMesh]: |
|
|
|
|
|
current_mesh = get_icosahedron() |
|
|
output_meshes = [current_mesh] |
|
|
for _ in range(splits): |
|
|
current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh) |
|
|
output_meshes.append(current_mesh) |
|
|
return output_meshes |
|
|
|
|
|
|
|
|
def get_icosahedron() -> TriangularMesh: |
|
|
|
|
|
phi = (1 + np.sqrt(5)) / 2 |
|
|
vertices = [] |
|
|
for c1 in [1.0, -1.0]: |
|
|
for c2 in [phi, -phi]: |
|
|
vertices.append((c1, c2, 0.0)) |
|
|
vertices.append((0.0, c1, c2)) |
|
|
vertices.append((c2, 0.0, c1)) |
|
|
|
|
|
vertices = np.array(vertices, dtype=np.float32) |
|
|
vertices /= np.linalg.norm([1.0, phi]) |
|
|
|
|
|
|
|
|
faces = [ |
|
|
(0, 1, 2), |
|
|
(0, 6, 1), |
|
|
(8, 0, 2), |
|
|
(8, 4, 0), |
|
|
(3, 8, 2), |
|
|
(3, 2, 7), |
|
|
(7, 2, 1), |
|
|
(0, 4, 6), |
|
|
(4, 11, 6), |
|
|
(6, 11, 5), |
|
|
(1, 5, 7), |
|
|
(4, 10, 11), |
|
|
(4, 8, 10), |
|
|
(10, 8, 3), |
|
|
(10, 3, 9), |
|
|
(11, 10, 9), |
|
|
(11, 9, 5), |
|
|
(5, 9, 7), |
|
|
(9, 3, 7), |
|
|
(1, 6, 5), |
|
|
] |
|
|
|
|
|
|
|
|
angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3)) |
|
|
rotation_angle = (np.pi - angle_between_faces) / 2 |
|
|
rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle) |
|
|
rotation_matrix = rotation.as_matrix() |
|
|
vertices = np.dot(vertices, rotation_matrix) |
|
|
|
|
|
return TriangularMesh( |
|
|
vertices=vertices.astype(np.float32), faces=np.array(faces, dtype=np.int32) |
|
|
) |
|
|
|
|
|
|
|
|
def _two_split_unit_sphere_triangle_faces( |
|
|
triangular_mesh: TriangularMesh, |
|
|
) -> TriangularMesh: |
|
|
"""Splits each triangular face into 4 triangles keeping the orientation.""" |
|
|
|
|
|
|
|
|
new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices) |
|
|
|
|
|
new_faces = [] |
|
|
for ind1, ind2, ind3 in triangular_mesh.faces: |
|
|
|
|
|
ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2)) |
|
|
ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3)) |
|
|
ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1)) |
|
|
|
|
|
new_faces.extend( |
|
|
[ |
|
|
[ind1, ind12, ind31], |
|
|
[ind12, ind2, ind23], |
|
|
[ind31, ind23, ind3], |
|
|
[ind12, ind23, ind31], |
|
|
] |
|
|
) |
|
|
return TriangularMesh( |
|
|
vertices=new_vertices_builder.get_all_vertices(), |
|
|
faces=np.array(new_faces, dtype=np.int32), |
|
|
) |
|
|
|
|
|
|
|
|
class _ChildVerticesBuilder(object): |
|
|
"""Bookkeeping of new child vertices added to an existing set of vertices.""" |
|
|
|
|
|
def __init__(self, parent_vertices): |
|
|
|
|
|
|
|
|
self._child_vertices_index_mapping = {} |
|
|
self._parent_vertices = parent_vertices |
|
|
|
|
|
self._all_vertices_list = list(parent_vertices) |
|
|
|
|
|
def _get_child_vertex_key(self, parent_vertex_indices): |
|
|
return tuple(sorted(parent_vertex_indices)) |
|
|
|
|
|
def _create_child_vertex(self, parent_vertex_indices): |
|
|
"""Creates a new vertex.""" |
|
|
|
|
|
child_vertex_position = self._parent_vertices[list(parent_vertex_indices)].mean( |
|
|
0 |
|
|
) |
|
|
child_vertex_position /= np.linalg.norm(child_vertex_position) |
|
|
|
|
|
|
|
|
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) |
|
|
self._child_vertices_index_mapping[child_vertex_key] = len( |
|
|
self._all_vertices_list |
|
|
) |
|
|
self._all_vertices_list.append(child_vertex_position) |
|
|
|
|
|
def get_new_child_vertex_index(self, parent_vertex_indices): |
|
|
"""Returns index for a child vertex, creating it if necessary.""" |
|
|
|
|
|
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) |
|
|
if child_vertex_key not in self._child_vertices_index_mapping: |
|
|
self._create_child_vertex(parent_vertex_indices) |
|
|
return self._child_vertices_index_mapping[child_vertex_key] |
|
|
|
|
|
def get_all_vertices(self): |
|
|
"""Returns an array with old vertices.""" |
|
|
return np.array(self._all_vertices_list) |
|
|
|
|
|
|
|
|
def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
|
|
|
|
|
assert faces.ndim == 2 |
|
|
assert faces.shape[-1] == 3 |
|
|
senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) |
|
|
receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) |
|
|
return senders, receivers |
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.neighbors import NearestNeighbors |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class Graph: |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lat_lon_grid: Tensor, |
|
|
mesh_level: int = 6, |
|
|
multimesh: bool = True, |
|
|
khop_neighbors: int = 0, |
|
|
dtype=torch.float, |
|
|
) -> None: |
|
|
self.khop_neighbors = khop_neighbors |
|
|
self.dtype = dtype |
|
|
|
|
|
|
|
|
self.lat_lon_grid_flat = lat_lon_grid.permute(2, 0, 1).view(2, -1).permute(1, 0) |
|
|
|
|
|
|
|
|
_meshes = get_hierarchy_of_triangular_meshes_for_sphere(splits=mesh_level) |
|
|
finest_mesh = _meshes[-1] |
|
|
self.finest_mesh_src, self.finest_mesh_dst = faces_to_edges(finest_mesh.faces) |
|
|
self.finest_mesh_vertices = np.array(finest_mesh.vertices) |
|
|
if multimesh: |
|
|
mesh = merge_meshes(_meshes) |
|
|
self.mesh_src, self.mesh_dst = faces_to_edges(mesh.faces) |
|
|
self.mesh_vertices = np.array(mesh.vertices) |
|
|
else: |
|
|
mesh = finest_mesh |
|
|
self.mesh_src, self.mesh_dst = self.finest_mesh_src, self.finest_mesh_dst |
|
|
self.mesh_vertices = self.finest_mesh_vertices |
|
|
self.mesh_faces = mesh.faces |
|
|
|
|
|
@staticmethod |
|
|
def khop_adj_all_k(g, kmax): |
|
|
if not g.is_homogeneous: |
|
|
raise NotImplementedError("only homogeneous graph is supported") |
|
|
min_degree = g.in_degrees().min() |
|
|
with torch.no_grad(): |
|
|
adj = g.adj_external(transpose=True, scipy_fmt=None) |
|
|
adj_k = adj |
|
|
adj_all = adj.clone() |
|
|
for _ in range(2, kmax + 1): |
|
|
|
|
|
|
|
|
adj_k = (adj @ adj_k) / min_degree |
|
|
adj_all += adj_k |
|
|
return adj_all.to_dense().bool() |
|
|
|
|
|
def create_mesh_graph(self, verbose: bool = True) -> Tensor: |
|
|
|
|
|
mesh_graph = create_graph( |
|
|
self.mesh_src, |
|
|
self.mesh_dst, |
|
|
to_bidirected=True, |
|
|
add_self_loop=False, |
|
|
dtype=torch.int32, |
|
|
) |
|
|
mesh_pos = torch.tensor( |
|
|
self.mesh_vertices, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
mesh_graph = add_edge_features(mesh_graph, mesh_pos) |
|
|
mesh_graph = add_node_features(mesh_graph, mesh_pos) |
|
|
mesh_graph.ndata["lat_lon"] = xyz2latlon(mesh_pos) |
|
|
|
|
|
mesh_graph.ndata["x"] = mesh_graph.ndata["x"].to(dtype=self.dtype) |
|
|
mesh_graph.edata["x"] = mesh_graph.edata["x"].to(dtype=self.dtype) |
|
|
if self.khop_neighbors > 0: |
|
|
|
|
|
khop_adj_bool = self.khop_adj_all_k(g=mesh_graph, kmax=self.khop_neighbors) |
|
|
mask = ~khop_adj_bool |
|
|
else: |
|
|
mask = None |
|
|
if verbose: |
|
|
print("mesh graph:", mesh_graph) |
|
|
return mesh_graph, mask |
|
|
|
|
|
def create_g2m_graph(self, verbose: bool = True) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
max_edge_len = max_edge_length( |
|
|
self.finest_mesh_vertices, self.finest_mesh_src, self.finest_mesh_dst |
|
|
) |
|
|
|
|
|
|
|
|
cartesian_grid = latlon2xyz(self.lat_lon_grid_flat) |
|
|
n_nbrs = 4 |
|
|
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(self.mesh_vertices) |
|
|
distances, indices = neighbors.kneighbors(cartesian_grid) |
|
|
|
|
|
src, dst = [], [] |
|
|
for i in range(len(cartesian_grid)): |
|
|
for j in range(n_nbrs): |
|
|
if distances[i][j] <= 0.6 * max_edge_len: |
|
|
src.append(i) |
|
|
dst.append(indices[i][j]) |
|
|
|
|
|
|
|
|
g2m_graph = create_heterograph( |
|
|
src, dst, ("grid", "g2m", "mesh"), dtype=torch.int32 |
|
|
) |
|
|
g2m_graph.srcdata["pos"] = cartesian_grid.to(torch.float32) |
|
|
g2m_graph.dstdata["pos"] = torch.tensor( |
|
|
self.mesh_vertices, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
g2m_graph.srcdata["lat_lon"] = self.lat_lon_grid_flat |
|
|
g2m_graph.dstdata["lat_lon"] = xyz2latlon(g2m_graph.dstdata["pos"]) |
|
|
|
|
|
g2m_graph = add_edge_features( |
|
|
g2m_graph, (g2m_graph.srcdata["pos"], g2m_graph.dstdata["pos"]) |
|
|
) |
|
|
|
|
|
|
|
|
g2m_graph.srcdata["pos"] = g2m_graph.srcdata["pos"].to(dtype=self.dtype) |
|
|
g2m_graph.dstdata["pos"] = g2m_graph.dstdata["pos"].to(dtype=self.dtype) |
|
|
g2m_graph.ndata["pos"]["grid"] = g2m_graph.ndata["pos"]["grid"].to( |
|
|
dtype=self.dtype |
|
|
) |
|
|
g2m_graph.ndata["pos"]["mesh"] = g2m_graph.ndata["pos"]["mesh"].to( |
|
|
dtype=self.dtype |
|
|
) |
|
|
g2m_graph.edata["x"] = g2m_graph.edata["x"].to(dtype=self.dtype) |
|
|
if verbose: |
|
|
print("g2m graph:", g2m_graph) |
|
|
return g2m_graph |
|
|
|
|
|
def create_m2g_graph(self, verbose: bool = True) -> Tensor: |
|
|
|
|
|
|
|
|
cartesian_grid = latlon2xyz(self.lat_lon_grid_flat) |
|
|
face_centroids = get_face_centroids(self.mesh_vertices, self.mesh_faces) |
|
|
n_nbrs = 1 |
|
|
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(face_centroids) |
|
|
_, indices = neighbors.kneighbors(cartesian_grid) |
|
|
indices = indices.flatten() |
|
|
|
|
|
src = [p for i in indices for p in self.mesh_faces[i]] |
|
|
dst = [i for i in range(len(cartesian_grid)) for _ in range(3)] |
|
|
m2g_graph = create_heterograph( |
|
|
src, dst, ("mesh", "m2g", "grid"), dtype=torch.int32 |
|
|
) |
|
|
|
|
|
m2g_graph.srcdata["pos"] = torch.tensor( |
|
|
self.mesh_vertices, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
m2g_graph.dstdata["pos"] = cartesian_grid.to(dtype=torch.float32) |
|
|
|
|
|
m2g_graph.srcdata["lat_lon"] = xyz2latlon(m2g_graph.srcdata["pos"]) |
|
|
m2g_graph.dstdata["lat_lon"] = self.lat_lon_grid_flat |
|
|
|
|
|
m2g_graph = add_edge_features( |
|
|
m2g_graph, (m2g_graph.srcdata["pos"], m2g_graph.dstdata["pos"]) |
|
|
) |
|
|
|
|
|
m2g_graph.srcdata["pos"] = m2g_graph.srcdata["pos"].to(dtype=self.dtype) |
|
|
m2g_graph.dstdata["pos"] = m2g_graph.dstdata["pos"].to(dtype=self.dtype) |
|
|
m2g_graph.ndata["pos"]["grid"] = m2g_graph.ndata["pos"]["grid"].to( |
|
|
dtype=self.dtype |
|
|
) |
|
|
m2g_graph.ndata["pos"]["mesh"] = m2g_graph.ndata["pos"]["mesh"].to( |
|
|
dtype=self.dtype |
|
|
) |
|
|
m2g_graph.edata["x"] = m2g_graph.edata["x"].to(dtype=self.dtype) |
|
|
|
|
|
if verbose: |
|
|
print("m2g graph:", m2g_graph) |
|
|
return m2g_graph |
|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
class MeshEdgeBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim_nodes: int = 512, |
|
|
input_dim_edges: int = 512, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
do_concat_trick: bool = False, |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
MLP = MeshGraphEdgeMLPSum if do_concat_trick else MeshGraphEdgeMLPConcat |
|
|
|
|
|
self.edge_mlp = MLP( |
|
|
efeat_dim=input_dim_edges, |
|
|
src_dim=input_dim_nodes, |
|
|
dst_dim=input_dim_nodes, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
@torch.jit.ignore() |
|
|
def forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Tensor, |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
efeat_new = self.edge_mlp(efeat, nfeat, graph) |
|
|
efeat_new = efeat_new + efeat |
|
|
return efeat_new, nfeat |
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
class MeshNodeBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
aggregation: str = "sum", |
|
|
input_dim_nodes: int = 512, |
|
|
input_dim_edges: int = 512, |
|
|
output_dim: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.aggregation = aggregation |
|
|
|
|
|
self.node_mlp = MeshGraphMLP( |
|
|
input_dim=input_dim_nodes + input_dim_edges, |
|
|
output_dim=output_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
@torch.jit.ignore() |
|
|
def forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Tensor, |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tuple[Tensor, Tensor]: |
|
|
|
|
|
cat_feat = aggregate_and_concat(efeat, nfeat, graph, self.aggregation) |
|
|
|
|
|
nfeat_new = self.node_mlp(cat_feat) + nfeat |
|
|
return efeat, nfeat_new |
|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from dgl import DGLGraph |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
class GraphCastProcessor(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
aggregation: str = "sum", |
|
|
processor_layers: int = 16, |
|
|
input_dim_nodes: int = 512, |
|
|
input_dim_edges: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
hidden_layers: int = 1, |
|
|
activation_fn: nn.Module = nn.SiLU(), |
|
|
norm_type: str = "LayerNorm", |
|
|
do_concat_trick: bool = False, |
|
|
recompute_activation: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
edge_block_invars = ( |
|
|
input_dim_nodes, |
|
|
input_dim_edges, |
|
|
input_dim_edges, |
|
|
hidden_dim, |
|
|
hidden_layers, |
|
|
activation_fn, |
|
|
norm_type, |
|
|
do_concat_trick, |
|
|
recompute_activation, |
|
|
) |
|
|
node_block_invars = ( |
|
|
aggregation, |
|
|
input_dim_nodes, |
|
|
input_dim_edges, |
|
|
input_dim_nodes, |
|
|
hidden_dim, |
|
|
hidden_layers, |
|
|
activation_fn, |
|
|
norm_type, |
|
|
recompute_activation, |
|
|
) |
|
|
|
|
|
layers = [] |
|
|
for _ in range(processor_layers): |
|
|
layers.append(MeshEdgeBlock(*edge_block_invars)) |
|
|
layers.append(MeshNodeBlock(*node_block_invars)) |
|
|
|
|
|
self.processor_layers = nn.ModuleList(layers) |
|
|
self.num_processor_layers = len(self.processor_layers) |
|
|
|
|
|
|
|
|
self.checkpoint_segments = [(0, self.num_processor_layers)] |
|
|
self.checkpoint_fn = set_checkpoint_fn(False) |
|
|
|
|
|
def set_checkpoint_segments(self, checkpoint_segments: int): |
|
|
|
|
|
if checkpoint_segments > 0: |
|
|
if self.num_processor_layers % checkpoint_segments != 0: |
|
|
raise ValueError( |
|
|
"Processor layers must be a multiple of checkpoint_segments" |
|
|
) |
|
|
segment_size = self.num_processor_layers // checkpoint_segments |
|
|
self.checkpoint_segments = [] |
|
|
for i in range(0, self.num_processor_layers, segment_size): |
|
|
self.checkpoint_segments.append((i, i + segment_size)) |
|
|
|
|
|
self.checkpoint_fn = set_checkpoint_fn(True) |
|
|
else: |
|
|
self.checkpoint_fn = set_checkpoint_fn(False) |
|
|
self.checkpoint_segments = [(0, self.num_processor_layers)] |
|
|
|
|
|
def run_function(self, segment_start: int, segment_end: int): |
|
|
|
|
|
segment = self.processor_layers[segment_start:segment_end] |
|
|
|
|
|
def custom_forward(efeat, nfeat, graph): |
|
|
"""Custom forward function""" |
|
|
for module in segment: |
|
|
efeat, nfeat = module(efeat, nfeat, graph) |
|
|
return efeat, nfeat |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
efeat: Tensor, |
|
|
nfeat: Tensor, |
|
|
graph: Union[DGLGraph, CuGraphCSC], |
|
|
) -> Tensor: |
|
|
for segment_start, segment_end in self.checkpoint_segments: |
|
|
efeat, nfeat = self.checkpoint_fn( |
|
|
self.run_function(segment_start, segment_end), |
|
|
efeat, |
|
|
nfeat, |
|
|
graph, |
|
|
use_reentrant=False, |
|
|
preserve_rng_state=False, |
|
|
) |
|
|
|
|
|
return efeat, nfeat |
|
|
|
|
|
|
|
|
class GraphCastProcessorGraphTransformer(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
attention_mask: torch.Tensor, |
|
|
num_attention_heads: int = 4, |
|
|
processor_layers: int = 16, |
|
|
input_dim_nodes: int = 512, |
|
|
hidden_dim: int = 512, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.hidden_dim = hidden_dim |
|
|
self.attention_mask = torch.tensor(attention_mask, dtype=torch.bool) |
|
|
self.register_buffer("mask", self.attention_mask, persistent=False) |
|
|
|
|
|
layers = [ |
|
|
te.pytorch.TransformerLayer( |
|
|
hidden_size=input_dim_nodes, |
|
|
ffn_hidden_size=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
layer_number=i + 1, |
|
|
fuse_qkv_params=False, |
|
|
) |
|
|
for i in range(processor_layers) |
|
|
] |
|
|
self.processor_layers = nn.ModuleList(layers) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
nfeat: Tensor, |
|
|
) -> Tensor: |
|
|
nfeat = nfeat.unsqueeze(1) |
|
|
|
|
|
for module in self.processor_layers: |
|
|
nfeat = module( |
|
|
nfeat, |
|
|
attention_mask=self.mask, |
|
|
self_attn_mask_type="arbitrary", |
|
|
) |
|
|
|
|
|
return torch.squeeze(nfeat, 1) |
|
|
|
|
|
|
|
|
import logging |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Optional |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
|
|
|
try: |
|
|
from typing import Self |
|
|
except ImportError: |
|
|
|
|
|
from typing_extensions import Self |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_lat_lon_partition_separators(partition_size: int): |
|
|
|
|
|
|
|
|
def _divide(num_lat_chunks: int, num_lon_chunks: int): |
|
|
|
|
|
if (num_lon_chunks * num_lat_chunks) != partition_size: |
|
|
raise ValueError( |
|
|
"Can't divide lat-lon grid into grid {num_lat_chunks} x {num_lon_chunks} chunks for partition_size={partition_size}." |
|
|
) |
|
|
|
|
|
lat_bin_width = 180.0 / num_lat_chunks |
|
|
lon_bin_width = 360.0 / num_lon_chunks |
|
|
|
|
|
lat_ranges = [] |
|
|
lon_ranges = [] |
|
|
|
|
|
for p_lat in range(num_lat_chunks): |
|
|
for p_lon in range(num_lon_chunks): |
|
|
lat_ranges += [ |
|
|
(lat_bin_width * p_lat - 90.0, lat_bin_width * (p_lat + 1) - 90.0) |
|
|
] |
|
|
lon_ranges += [ |
|
|
(lon_bin_width * p_lon - 180.0, lon_bin_width * (p_lon + 1) - 180.0) |
|
|
] |
|
|
|
|
|
lat_ranges[-1] = (lat_ranges[-1][0], None) |
|
|
lon_ranges[-1] = (lon_ranges[-1][0], None) |
|
|
|
|
|
return lat_ranges, lon_ranges |
|
|
|
|
|
|
|
|
lat_chunks, lon_chunks, i = 1, partition_size, 0 |
|
|
while lat_chunks < lon_chunks: |
|
|
i += 1 |
|
|
if partition_size % i == 0: |
|
|
lat_chunks = i |
|
|
lon_chunks = partition_size // lat_chunks |
|
|
|
|
|
lat_ranges, lon_ranges = _divide(lat_chunks, lon_chunks) |
|
|
|
|
|
|
|
|
if (lat_ranges is None) or (lon_ranges is None): |
|
|
raise ValueError("unexpected error, abort") |
|
|
|
|
|
min_seps = [] |
|
|
max_seps = [] |
|
|
|
|
|
for i in range(partition_size): |
|
|
lat = lat_ranges[i] |
|
|
lon = lon_ranges[i] |
|
|
min_seps.append([lat[0], lon[0]]) |
|
|
max_seps.append([lat[1], lon[1]]) |
|
|
|
|
|
return min_seps, max_seps |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MetaData(ModelMetaData): |
|
|
name: str = "GraphCast" |
|
|
|
|
|
jit: bool = False |
|
|
cuda_graphs: bool = False |
|
|
amp_cpu: bool = False |
|
|
amp_gpu: bool = True |
|
|
torch_fx: bool = False |
|
|
|
|
|
bf16: bool = True |
|
|
|
|
|
onnx: bool = False |
|
|
|
|
|
func_torch: bool = False |
|
|
auto_grad: bool = False |
|
|
|
|
|
|
|
|
class GraphCast(Module): |
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
mesh_level: Optional[int] = 5, |
|
|
multimesh_level: Optional[int] = None, |
|
|
multimesh: bool = True, |
|
|
input_res: tuple = (120, 240), |
|
|
input_dim_grid_nodes: int = 69, |
|
|
input_dim_mesh_nodes: int = 3, |
|
|
input_dim_edges: int = 4, |
|
|
output_dim_grid_nodes: int = 69, |
|
|
processor_type: str = "MessagePassing", |
|
|
khop_neighbors: int = 32, |
|
|
num_attention_heads: int = 4, |
|
|
processor_layers: int = 16, |
|
|
hidden_layers: int = 1, |
|
|
hidden_dim: int = 512, |
|
|
aggregation: str = "sum", |
|
|
activation_fn: str = "silu", |
|
|
norm_type: str = "LayerNorm", |
|
|
use_cugraphops_encoder: bool = False, |
|
|
use_cugraphops_processor: bool = False, |
|
|
use_cugraphops_decoder: bool = False, |
|
|
do_concat_trick: bool = False, |
|
|
recompute_activation: bool = True, |
|
|
partition_size: int = 1, |
|
|
partition_group_name: Optional[str] = None, |
|
|
use_lat_lon_partitioning: bool = False, |
|
|
expect_partitioned_input: bool = False, |
|
|
global_features_on_rank_0: bool = False, |
|
|
produce_aggregated_output: bool = True, |
|
|
produce_aggregated_output_on_all_ranks: bool = True, |
|
|
): |
|
|
super().__init__(meta=MetaData()) |
|
|
|
|
|
|
|
|
if multimesh_level is not None: |
|
|
warnings.warn( |
|
|
"'multimesh_level' is deprecated and will be removed in a future version. Use 'mesh_level' instead.", |
|
|
DeprecationWarning, |
|
|
stacklevel=2, |
|
|
) |
|
|
mesh_level = multimesh_level |
|
|
|
|
|
self.processor_type = processor_type |
|
|
if self.processor_type == "MessagePassing": |
|
|
khop_neighbors = 0 |
|
|
self.is_distributed = False |
|
|
if partition_size > 1: |
|
|
self.is_distributed = True |
|
|
self.expect_partitioned_input = expect_partitioned_input |
|
|
self.global_features_on_rank_0 = global_features_on_rank_0 |
|
|
self.produce_aggregated_output = produce_aggregated_output |
|
|
self.produce_aggregated_output_on_all_ranks = ( |
|
|
produce_aggregated_output_on_all_ranks |
|
|
) |
|
|
self.partition_group_name = partition_group_name |
|
|
|
|
|
|
|
|
self.latitudes = torch.linspace(-90, 90, steps=input_res[0]) |
|
|
self.longitudes = torch.linspace(-180, 180, steps=input_res[1] + 1)[1:] |
|
|
|
|
|
|
|
|
self.lat_lon_grid = torch.stack( |
|
|
torch.meshgrid(self.latitudes, self.longitudes, indexing="ij"), dim=-1 |
|
|
) |
|
|
|
|
|
|
|
|
activation_fn = get_activation(activation_fn) |
|
|
|
|
|
|
|
|
self.graph = Graph(self.lat_lon_grid, mesh_level, multimesh, khop_neighbors) |
|
|
|
|
|
self.mesh_graph, self.attn_mask = self.graph.create_mesh_graph(verbose=False) |
|
|
self.g2m_graph = self.graph.create_g2m_graph(verbose=False) |
|
|
self.m2g_graph = self.graph.create_m2g_graph(verbose=False) |
|
|
|
|
|
self.g2m_edata = self.g2m_graph.edata["x"] |
|
|
self.m2g_edata = self.m2g_graph.edata["x"] |
|
|
self.mesh_ndata = self.mesh_graph.ndata["x"] |
|
|
if self.processor_type == "MessagePassing": |
|
|
self.mesh_edata = self.mesh_graph.edata["x"] |
|
|
elif self.processor_type == "GraphTransformer": |
|
|
|
|
|
self.mesh_edata = torch.zeros((1, input_dim_edges)) |
|
|
else: |
|
|
raise ValueError(f"Invalid processor type {processor_type}") |
|
|
|
|
|
if use_cugraphops_encoder or self.is_distributed: |
|
|
kwargs = {} |
|
|
if use_lat_lon_partitioning: |
|
|
min_seps, max_seps = get_lat_lon_partition_separators(partition_size) |
|
|
kwargs = { |
|
|
"src_coordinates": self.g2m_graph.srcdata["lat_lon"], |
|
|
"dst_coordinates": self.g2m_graph.dstdata["lat_lon"], |
|
|
"coordinate_separators_min": min_seps, |
|
|
"coordinate_separators_max": max_seps, |
|
|
} |
|
|
self.g2m_graph, edge_perm = CuGraphCSC.from_dgl( |
|
|
graph=self.g2m_graph, |
|
|
partition_size=partition_size, |
|
|
partition_group_name=partition_group_name, |
|
|
partition_by_bbox=use_lat_lon_partitioning, |
|
|
**kwargs, |
|
|
) |
|
|
self.g2m_edata = self.g2m_edata[edge_perm] |
|
|
|
|
|
if self.is_distributed: |
|
|
self.g2m_edata = self.g2m_graph.get_edge_features_in_partition( |
|
|
self.g2m_edata |
|
|
) |
|
|
|
|
|
if use_cugraphops_decoder or self.is_distributed: |
|
|
kwargs = {} |
|
|
if use_lat_lon_partitioning: |
|
|
min_seps, max_seps = get_lat_lon_partition_separators(partition_size) |
|
|
kwargs = { |
|
|
"src_coordinates": self.m2g_graph.srcdata["lat_lon"], |
|
|
"dst_coordinates": self.m2g_graph.dstdata["lat_lon"], |
|
|
"coordinate_separators_min": min_seps, |
|
|
"coordinate_separators_max": max_seps, |
|
|
} |
|
|
|
|
|
self.m2g_graph, edge_perm = CuGraphCSC.from_dgl( |
|
|
graph=self.m2g_graph, |
|
|
partition_size=partition_size, |
|
|
partition_group_name=partition_group_name, |
|
|
partition_by_bbox=use_lat_lon_partitioning, |
|
|
**kwargs, |
|
|
) |
|
|
self.m2g_edata = self.m2g_edata[edge_perm] |
|
|
|
|
|
if self.is_distributed: |
|
|
self.m2g_edata = self.m2g_graph.get_edge_features_in_partition( |
|
|
self.m2g_edata |
|
|
) |
|
|
|
|
|
if use_cugraphops_processor or self.is_distributed: |
|
|
kwargs = {} |
|
|
if use_lat_lon_partitioning: |
|
|
min_seps, max_seps = get_lat_lon_partition_separators(partition_size) |
|
|
kwargs = { |
|
|
"src_coordinates": self.mesh_graph.ndata["lat_lon"], |
|
|
"dst_coordinates": self.mesh_graph.ndata["lat_lon"], |
|
|
"coordinate_separators_min": min_seps, |
|
|
"coordinate_separators_max": max_seps, |
|
|
} |
|
|
|
|
|
self.mesh_graph, edge_perm = CuGraphCSC.from_dgl( |
|
|
graph=self.mesh_graph, |
|
|
partition_size=partition_size, |
|
|
partition_group_name=partition_group_name, |
|
|
partition_by_bbox=use_lat_lon_partitioning, |
|
|
**kwargs, |
|
|
) |
|
|
self.mesh_edata = self.mesh_edata[edge_perm] |
|
|
if self.is_distributed: |
|
|
self.mesh_edata = self.mesh_graph.get_edge_features_in_partition( |
|
|
self.mesh_edata |
|
|
) |
|
|
self.mesh_ndata = self.mesh_graph.get_dst_node_features_in_partition( |
|
|
self.mesh_ndata |
|
|
) |
|
|
|
|
|
self.input_dim_grid_nodes = input_dim_grid_nodes |
|
|
self.output_dim_grid_nodes = output_dim_grid_nodes |
|
|
self.input_res = input_res |
|
|
|
|
|
|
|
|
self.model_checkpoint_fn = set_checkpoint_fn(False) |
|
|
self.encoder_checkpoint_fn = set_checkpoint_fn(False) |
|
|
self.decoder_checkpoint_fn = set_checkpoint_fn(False) |
|
|
|
|
|
|
|
|
self.encoder_embedder = GraphCastEncoderEmbedder( |
|
|
input_dim_grid_nodes=input_dim_grid_nodes, |
|
|
input_dim_mesh_nodes=input_dim_mesh_nodes, |
|
|
input_dim_edges=input_dim_edges, |
|
|
output_dim=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
self.decoder_embedder = GraphCastDecoderEmbedder( |
|
|
input_dim_edges=input_dim_edges, |
|
|
output_dim=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.encoder = MeshGraphEncoder( |
|
|
aggregation=aggregation, |
|
|
input_dim_src_nodes=hidden_dim, |
|
|
input_dim_dst_nodes=hidden_dim, |
|
|
input_dim_edges=hidden_dim, |
|
|
output_dim_src_nodes=hidden_dim, |
|
|
output_dim_dst_nodes=hidden_dim, |
|
|
output_dim_edges=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
do_concat_trick=do_concat_trick, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
if processor_layers <= 2: |
|
|
raise ValueError("Expected at least 3 processor layers") |
|
|
if processor_type == "MessagePassing": |
|
|
self.processor_encoder = GraphCastProcessor( |
|
|
aggregation=aggregation, |
|
|
processor_layers=1, |
|
|
input_dim_nodes=hidden_dim, |
|
|
input_dim_edges=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
do_concat_trick=do_concat_trick, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
self.processor = GraphCastProcessor( |
|
|
aggregation=aggregation, |
|
|
processor_layers=processor_layers - 2, |
|
|
input_dim_nodes=hidden_dim, |
|
|
input_dim_edges=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
do_concat_trick=do_concat_trick, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
self.processor_decoder = GraphCastProcessor( |
|
|
aggregation=aggregation, |
|
|
processor_layers=1, |
|
|
input_dim_nodes=hidden_dim, |
|
|
input_dim_edges=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
do_concat_trick=do_concat_trick, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
else: |
|
|
self.processor_encoder = torch.nn.Identity() |
|
|
self.processor = GraphCastProcessorGraphTransformer( |
|
|
attention_mask=self.attn_mask, |
|
|
num_attention_heads=num_attention_heads, |
|
|
processor_layers=processor_layers, |
|
|
input_dim_nodes=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
) |
|
|
self.processor_decoder = torch.nn.Identity() |
|
|
|
|
|
|
|
|
self.decoder = MeshGraphDecoder( |
|
|
aggregation=aggregation, |
|
|
input_dim_src_nodes=hidden_dim, |
|
|
input_dim_dst_nodes=hidden_dim, |
|
|
input_dim_edges=hidden_dim, |
|
|
output_dim_dst_nodes=hidden_dim, |
|
|
output_dim_edges=hidden_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=norm_type, |
|
|
do_concat_trick=do_concat_trick, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
|
|
|
self.finale = MeshGraphMLP( |
|
|
input_dim=hidden_dim, |
|
|
output_dim=output_dim_grid_nodes, |
|
|
hidden_dim=hidden_dim, |
|
|
hidden_layers=hidden_layers, |
|
|
activation_fn=activation_fn, |
|
|
norm_type=None, |
|
|
recompute_activation=recompute_activation, |
|
|
) |
|
|
|
|
|
def set_checkpoint_model(self, checkpoint_flag: bool): |
|
|
|
|
|
|
|
|
self.model_checkpoint_fn = set_checkpoint_fn(checkpoint_flag) |
|
|
if checkpoint_flag: |
|
|
self.processor.set_checkpoint_segments(-1) |
|
|
self.encoder_checkpoint_fn = set_checkpoint_fn(False) |
|
|
self.decoder_checkpoint_fn = set_checkpoint_fn(False) |
|
|
|
|
|
def set_checkpoint_processor(self, checkpoint_segments: int): |
|
|
|
|
|
self.processor.set_checkpoint_segments(checkpoint_segments) |
|
|
|
|
|
def set_checkpoint_encoder(self, checkpoint_flag: bool): |
|
|
|
|
|
self.encoder_checkpoint_fn = set_checkpoint_fn(checkpoint_flag) |
|
|
|
|
|
def set_checkpoint_decoder(self, checkpoint_flag: bool): |
|
|
|
|
|
self.decoder_checkpoint_fn = set_checkpoint_fn(checkpoint_flag) |
|
|
|
|
|
def encoder_forward( |
|
|
self, |
|
|
grid_nfeat: Tensor, |
|
|
) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
( |
|
|
grid_nfeat_embedded, |
|
|
mesh_nfeat_embedded, |
|
|
g2m_efeat_embedded, |
|
|
mesh_efeat_embedded, |
|
|
) = self.encoder_embedder( |
|
|
grid_nfeat, |
|
|
self.mesh_ndata, |
|
|
self.g2m_edata, |
|
|
self.mesh_edata, |
|
|
) |
|
|
|
|
|
|
|
|
grid_nfeat_encoded, mesh_nfeat_encoded = self.encoder( |
|
|
g2m_efeat_embedded, |
|
|
grid_nfeat_embedded, |
|
|
mesh_nfeat_embedded, |
|
|
self.g2m_graph, |
|
|
) |
|
|
|
|
|
|
|
|
if self.processor_type == "MessagePassing": |
|
|
mesh_efeat_processed, mesh_nfeat_processed = self.processor_encoder( |
|
|
mesh_efeat_embedded, |
|
|
mesh_nfeat_encoded, |
|
|
self.mesh_graph, |
|
|
) |
|
|
else: |
|
|
mesh_nfeat_processed = self.processor_encoder( |
|
|
mesh_nfeat_encoded, |
|
|
) |
|
|
mesh_efeat_processed = None |
|
|
return mesh_efeat_processed, mesh_nfeat_processed, grid_nfeat_encoded |
|
|
|
|
|
def decoder_forward( |
|
|
self, |
|
|
mesh_efeat_processed: Tensor, |
|
|
mesh_nfeat_processed: Tensor, |
|
|
grid_nfeat_encoded: Tensor, |
|
|
) -> Tensor: |
|
|
|
|
|
|
|
|
if self.processor_type == "MessagePassing": |
|
|
_, mesh_nfeat_processed = self.processor_decoder( |
|
|
mesh_efeat_processed, |
|
|
mesh_nfeat_processed, |
|
|
self.mesh_graph, |
|
|
) |
|
|
else: |
|
|
mesh_nfeat_processed = self.processor_decoder( |
|
|
mesh_nfeat_processed, |
|
|
) |
|
|
|
|
|
m2g_efeat_embedded = self.decoder_embedder(self.m2g_edata) |
|
|
|
|
|
|
|
|
grid_nfeat_decoded = self.decoder( |
|
|
m2g_efeat_embedded, grid_nfeat_encoded, mesh_nfeat_processed, self.m2g_graph |
|
|
) |
|
|
|
|
|
|
|
|
grid_nfeat_finale = self.finale( |
|
|
grid_nfeat_decoded, |
|
|
) |
|
|
|
|
|
return grid_nfeat_finale |
|
|
|
|
|
def custom_forward(self, grid_nfeat: Tensor) -> Tensor: |
|
|
|
|
|
( |
|
|
mesh_efeat_processed, |
|
|
mesh_nfeat_processed, |
|
|
grid_nfeat_encoded, |
|
|
) = self.encoder_checkpoint_fn( |
|
|
self.encoder_forward, |
|
|
grid_nfeat, |
|
|
use_reentrant=False, |
|
|
preserve_rng_state=False, |
|
|
) |
|
|
|
|
|
|
|
|
if self.processor_type == "MessagePassing": |
|
|
mesh_efeat_processed, mesh_nfeat_processed = self.processor( |
|
|
mesh_efeat_processed, |
|
|
mesh_nfeat_processed, |
|
|
self.mesh_graph, |
|
|
) |
|
|
else: |
|
|
mesh_nfeat_processed = self.processor( |
|
|
mesh_nfeat_processed, |
|
|
) |
|
|
mesh_efeat_processed = None |
|
|
|
|
|
grid_nfeat_finale = self.decoder_checkpoint_fn( |
|
|
self.decoder_forward, |
|
|
mesh_efeat_processed, |
|
|
mesh_nfeat_processed, |
|
|
grid_nfeat_encoded, |
|
|
use_reentrant=False, |
|
|
preserve_rng_state=False, |
|
|
) |
|
|
|
|
|
return grid_nfeat_finale |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
grid_nfeat: Tensor, |
|
|
) -> Tensor: |
|
|
invar = self.prepare_input( |
|
|
grid_nfeat, self.expect_partitioned_input, self.global_features_on_rank_0 |
|
|
) |
|
|
outvar = self.model_checkpoint_fn( |
|
|
self.custom_forward, |
|
|
invar, |
|
|
use_reentrant=False, |
|
|
preserve_rng_state=False, |
|
|
) |
|
|
outvar = self.prepare_output( |
|
|
outvar, |
|
|
self.produce_aggregated_output, |
|
|
self.produce_aggregated_output_on_all_ranks, |
|
|
) |
|
|
return outvar |
|
|
|
|
|
def prepare_input( |
|
|
self, |
|
|
invar: Tensor, |
|
|
expect_partitioned_input: bool, |
|
|
global_features_on_rank_0: bool, |
|
|
) -> Tensor: |
|
|
|
|
|
if global_features_on_rank_0 and expect_partitioned_input: |
|
|
raise ValueError( |
|
|
"global_features_on_rank_0 and expect_partitioned_input cannot be set at the same time." |
|
|
) |
|
|
|
|
|
if not self.is_distributed: |
|
|
if invar.size(0) != 1: |
|
|
raise ValueError("GraphCast does not support batch size > 1") |
|
|
invar = invar[0].view(self.input_dim_grid_nodes, -1).permute(1, 0) |
|
|
|
|
|
else: |
|
|
|
|
|
if not expect_partitioned_input: |
|
|
|
|
|
if invar.size(0) != 1: |
|
|
raise ValueError("GraphCast does not support batch size > 1") |
|
|
|
|
|
invar = invar[0].view(self.input_dim_grid_nodes, -1).permute(1, 0) |
|
|
|
|
|
|
|
|
invar = self.g2m_graph.get_src_node_features_in_partition( |
|
|
invar, |
|
|
scatter_features=global_features_on_rank_0, |
|
|
) |
|
|
|
|
|
return invar |
|
|
|
|
|
def prepare_output( |
|
|
self, |
|
|
outvar: Tensor, |
|
|
produce_aggregated_output: bool, |
|
|
produce_aggregated_output_on_all_ranks: bool = True, |
|
|
) -> Tensor: |
|
|
|
|
|
if produce_aggregated_output or not self.is_distributed: |
|
|
|
|
|
if self.is_distributed: |
|
|
outvar = self.m2g_graph.get_global_dst_node_features( |
|
|
outvar, |
|
|
get_on_all_ranks=produce_aggregated_output_on_all_ranks, |
|
|
) |
|
|
|
|
|
outvar = outvar.permute(1, 0) |
|
|
outvar = outvar.view(self.output_dim_grid_nodes, *self.input_res) |
|
|
outvar = torch.unsqueeze(outvar, dim=0) |
|
|
|
|
|
return outvar |
|
|
|
|
|
def to(self, *args: Any, **kwargs: Any) -> Self: |
|
|
|
|
|
self = super(GraphCast, self).to(*args, **kwargs) |
|
|
|
|
|
self.g2m_edata = self.g2m_edata.to(*args, **kwargs) |
|
|
self.m2g_edata = self.m2g_edata.to(*args, **kwargs) |
|
|
self.mesh_ndata = self.mesh_ndata.to(*args, **kwargs) |
|
|
self.mesh_edata = self.mesh_edata.to(*args, **kwargs) |
|
|
|
|
|
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs) |
|
|
self.g2m_graph = self.g2m_graph.to(device) |
|
|
self.mesh_graph = self.mesh_graph.to(device) |
|
|
self.m2g_graph = self.m2g_graph.to(device) |
|
|
|
|
|
return self |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
net = GraphCast().to(device) |
|
|
|
|
|
input = torch.randn(1, 69, 120, 240).to(device) |
|
|
output = net(input) |
|
|
|
|
|
print(output.shape) |
|
|
|