ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from physicsnemo.utils.profiling import profile
from physicsnemo.utils.version_check import check_module_requirements
check_module_requirements("physicsnemo.distributed.shard_tensor")
from torch.distributed.tensor.placement_types import ( # noqa: E402
Shard,
)
from physicsnemo.distributed import ShardTensor # noqa: E402
from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402
MissingShardPatch,
)
def compute_local_padding_and_output_shape(
input_tensor_shape: tuple[int, ...],
pad: tuple[int, ...],
mesh_coords: tuple[int, ...],
mesh_sizes: tuple[int, ...],
tensor_sharding_map: dict[int, int],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""
Compute the local padding and output shape for a given input tensor shape, pad, mode, and value.
Args:
input_tensor_shape: The shape of the input tensor
pad: The padding size(s)
mesh_coords: The coordinates of the tensor in the mesh
mesh_sizes: The sizes of the mesh
tensor_sharding_map: A map from tensor dimension to mesh dimension
Returns:
tuple of the local padding and output shape
"""
tensor_rank = len(input_tensor_shape)
pad_dims = len(pad) // 2
output_padding = []
local_output_shape = list(input_tensor_shape)
# We have to loop over this backwards:
for dim_from_last in range(pad_dims):
tensor_dim = tensor_rank - 1 - dim_from_last
left = pad[2 * dim_from_last]
right = pad[2 * dim_from_last + 1]
# If this axis of the tensor is not sharded, we keep these as they are:
if tensor_dim not in tensor_sharding_map.keys():
output_padding.append(left)
output_padding.append(right)
local_output_shape[tensor_dim] += left + right
else:
# The tensor is sharded on this dim.
# So, determine if this is an edge or not.
# four cases here.
# - is left and not is right
# - not is left and is right
# - is left AND is right:
# - not is left and not is right
mesh_dim = tensor_sharding_map[tensor_dim]
is_left = mesh_coords[mesh_dim] == 0
is_right = mesh_coords[mesh_dim] == mesh_sizes[mesh_dim] - 1
if is_left and not is_right:
output_padding.append(left)
output_padding.append(0)
local_output_shape[tensor_dim] += left
elif not is_left and is_right:
output_padding.append(0)
output_padding.append(right)
local_output_shape[tensor_dim] += right
elif is_left and is_right:
output_padding.append(left)
output_padding.append(right)
local_output_shape[tensor_dim] += left + right
else:
output_padding.append(0)
output_padding.append(0)
return tuple(local_output_shape), tuple(output_padding)
def generic_pad_nd_wrapper(func: callable, types: tuple, args: tuple, kwargs: dict):
"""Wrapper function for N-dimensional padding operations supporting shardtensors.
Args:
func: The padding function to be wrapped
types: tuple of input types (unused)
args: Positional arguments to the padding function
kwargs: Keyword arguments to the padding function
Returns:
The result of the padding operation
"""
# Padding is a no-communication operation unless it's circular padding
# Circular padding is not implemented yet, and probably won't get implemented
# until it's requested.
inputs, pad, mode, value = repackage_pad_args(*args, **kwargs)
if mode == "circular":
raise MissingShardPatch(
"Circular padding is not implemented yet. Please open an issue at https://github.com/NVIDIA/PhysicsNemo/issues if you need this functionality."
)
# Now, get the local tensor:
local_input = inputs.to_local()
# We have to update the padding values based on where this tensor is, and if
# it is on the edge or not.
#
# The only way to do that is to loop over the paddings to determine
# the tensor axes, and then loop over the mesh / spec to see if that axis is
# sharded.
#
# Further, because we don't want to communicate across GPUs unless it's needed,
# We need to compute this for all tensors in the shard spec.
# Sanity checks
if len(pad) % 2 != 0:
raise ValueError("Sharded Padding requires len(pad) to be divisible by 2.")
pad_dims = len(pad) // 2
if pad_dims > len(inputs.shape):
raise ValueError(
f"Sharded Padding specified for {pad_dims} but tensor has only {len(inputs.shape)} dimensions."
)
# By default, all output tensors are unsharded
# This maps tensor dim to mesh dim but ONLY if it's sharded
tensor_sharding_map = {}
mesh_sizes = []
spec = inputs._spec
# Loop over the mesh spec and extract sharding vs tensor dim:
for mesh_dim, placement in enumerate(spec.placements):
if isinstance(placement, Shard):
tensor_sharding_map[placement.dim] = mesh_dim
mesh_sizes.append(spec.mesh.size(mesh_dim))
# If the tensor_shard_map is all False, still (so no sharding)
# We can just use a local computation and be done.
if len(tensor_sharding_map) == 0:
local_output = func(local_input, pad, mode, value)
return ShardTensor.from_local(local_output, spec.mesh, spec.placements)
# at this point, at least one dimension is sharded. Maybe more.
# So, loop over the mesh sharding shapes and compute the local output
# shape and padding for that chunk:
output_shapes = {}
self_mesh_coords = [spec.mesh.get_local_rank(m) for m in range(spec.mesh.ndim)]
self_padding = None
for mesh_dim, sharding_shapes in spec.sharding_shapes().items():
output_shapes[mesh_dim] = []
for i, local_shape in enumerate(sharding_shapes):
# Update the mesh sharding coords:
mesh_coords = list(self_mesh_coords)
mesh_coords[mesh_dim] = i
output_shape, local_padding = compute_local_padding_and_output_shape(
local_input.shape, pad, mesh_coords, mesh_sizes, tensor_sharding_map
)
# Catch and cache the one that applies to this rank:
if mesh_coords == self_mesh_coords:
self_padding = local_padding
output_shapes[mesh_dim].append(output_shape)
# From here, apply the local padding to this tensor:
local_output = func(local_input, self_padding, mode, value)
# Now, convert back to shard tensor.
# We already have all the output shapes
return ShardTensor.from_local(
local_output, spec.mesh, spec.placements, sharding_shapes=output_shapes
)
@profile
def repackage_pad_args(
inputs: ShardTensor,
pad: int | tuple[int, ...] = 0,
mode: str = "constant",
value: float | None = None,
*args,
**kwargs,
) -> tuple[
ShardTensor,
tuple[int, ...],
str,
float,
dict,
]:
"""Repackages pad arguments into standard format.
Takes the full set of arguments that could be passed to a pad operation
and separates them into core tensor inputs (inputs, pad, mode, value) and
configuration parameters packaged as a kwargs dict.
Args:
inputs: Input tensor to convolve
pad: Padding size(s)
mode: Padding mode
value: Padding value
bias: Optional bias tensor
*args: Additional positional args (unused)
**kwargs: Additional keyword args (unused)
Returns:
tuple containing:
- Input tensor
- Padding size(s)
- Padding mode
- Padding value
"""
return inputs, pad, mode, value
ShardTensor.register_function_handler(torch.nn.functional.pad, generic_pad_nd_wrapper)